Attcking notebook

In [None]:
import logging
import pathlib
import random
import shutil
import time
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "6"
import sys
# sys.path.insert(0, '/home/tomerweiss/multiPILOT2')
from tqdm import tqdm  
import numpy as np
# np.seterr('raise')
import torch
import torchvision
from tensorboardX import SummaryWriter
from torch.nn import functional as F
from torch.utils.data import DataLoader
from common.args import Args
from data import transforms
from data.mri_data import SliceData
import matplotlib

matplotlib.use('Agg')
import matplotlib.pyplot as plt
from models.subsampling_model import Subsampling_Model
from scipy.spatial import distance_matrix
from tsp_solver.greedy import solve_tsp
import scipy.io as sio
from common.utils import get_vel_acc
from common.evaluate import psnr, ssim
from fastmri.losses import SSIMLoss
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

import matplotlib.pyplot as plt

In [None]:
import torch
from torch import nn
from torch.nn import functional as F

import fastmri.models
from models.rec_models.unet_model import UnetModel
from models.rec_models.complex_unet import ComplexUnetModel
import data.transforms as transforms
from pytorch_nufft.nufft import nufft, nufft_adjoint
import numpy as np
from WaveformProjection.run_projection import proj_handler
import matplotlib.pylab as P
from models.rec_models.vision_transformer import VisionTransformer
from models.rec_models.recon_net import ReconNet
from models.rec_models.humus_net import HUMUSNet, HUMUSBlock
from  models.VarBlock import VarNet
from typing import Tuple
from fastmri.data.subsample import create_mask_for_mask_type
from fastmri.data.transforms import apply_mask

import torch
import numpy as np


In [None]:
import torch
if torch.cuda.is_available():
    print(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")
else:
    print("CUDA device not available")

In [None]:
torch.cuda.empty_cache()

In [None]:
class DataTransform:
    def __init__(self, resolution):
        self.resolution = resolution

    def __call__(self, kspace, target, attrs, fname, slice):
        kspace = transforms.to_tensor(kspace)
        #print("kspace: ",kspace.shape)
        image = transforms.ifft2_regular(kspace)
        image = transforms.complex_center_crop(image, (self.resolution, self.resolution))
        image, mean, std = transforms.normalize_instance(image, eps=1e-11)
        target = transforms.to_tensor(target)
        target, mean, std = transforms.normalize_instance(target, eps=1e-11)
        # # target = transforms.normalize(target, mean, std)
        # target = target.clamp(-6, 6)
        mean = std = 0

        if target.shape[1] != self.resolution:
            target = transforms.center_crop(target, (self.resolution, self.resolution))
        return image.mean(0) , target, mean, std, attrs['norm'].astype(np.float32)


def create_test_dataset(args):
    test_path = args.data_path / 'multicoil_val' 
    test_data = SliceData(
        root=test_path,
        transform=DataTransform(args.resolution),
        sample_rate=args.sample_rate)
    return test_data


def create_test_loader(args):
    test_data = create_test_dataset(args)
    test_loader = DataLoader(
        dataset=test_data,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=20,
        pin_memory=True,
    )

    return test_loader

In [None]:
def evaluate(args, model, data_loader):
    model.eval()
    losses = []
    psnr_l = []
    ssim_l = []
    progress_bar = tqdm(enumerate(data_loader), total=len(data_loader), desc="Processing Batches")
    for iter, data in progress_bar:
        input, target, mean, std, norm = data

        # Move input and target to the specified device
        input = input.to(args.device)
        target = target.to(args.device)

        # Model prediction
        output = model(input.unsqueeze(1))

        # Reconstruction
        recons = output.to('cpu').squeeze(1).view(target.shape)
        recons = recons.squeeze()

        # Compute loss
        loss = F.l1_loss(output.squeeze(), target)
        losses.append(loss.item())

        # Reshape reconstructions
        recons = recons.view(target.shape)

        # Compute PSNR and SSIM
        psnr_value = psnr(target.detach().to('cpu').numpy(), recons.detach().numpy())
        ssim_value = ssim(target.detach().to('cpu').numpy(), recons.detach().numpy())
        psnr_l.append(psnr_value)
        ssim_l.append(ssim_value)

        # Update the tqdm progress bar with metrics
        progress_bar.set_postfix({
            "Loss": f"{loss.item():.4f}",
            "PSNR": f"{psnr_value:.2f}",
            "SSIM": f"{ssim_value:.4f}"
        })
    print(f'PSNR: {np.mean(psnr_l):.2f} +- {np.std(psnr_l):.2f}, SSIM: {np.mean(ssim_l):.4f} +- {np.std(ssim_l):.4f}')
    return loss, np.mean(psnr_l), np.mean(ssim_l)

In [None]:
def build_model(args):
    print(f"reconstructing : {args.model}")
    model = Subsampling_Model(
        in_chans=1,
        out_chans=1,
        chans=args.num_chans,
        num_pool_layers=args.num_pools,
        drop_prob=args.drop_prob,
        decimation_rate=args.decimation_rate,
        res=args.resolution,
        trajectory_learning=args.trajectory_learning,
        initialization=args.initialization,
        SNR=args.SNR,
        n_shots=args.n_shots,
        interp_gap=args.interp_gap,
        type=args.model
    ).to(args.device)
    return model

def load_model(checkpoint_file):
    print(checkpoint_file)
    checkpoint = torch.load(checkpoint_file)
    args = checkpoint['args']
    model = build_model(args)
    if args.data_parallel:
        model = torch.nn.DataParallel(model)
    model.load_state_dict(checkpoint['model'])
    optimizer = build_optim(args, model)
    optimizer.load_state_dict(checkpoint['optimizer'])
    return checkpoint, model, optimizer, args

def build_optim(args, model):
    optimizer = torch.optim.Adam([{'params': model.module.subsampling.parameters(), 'lr': args.sub_lr},
                                  {'params': model.module.reconstruction_model.parameters()}], args.lr)
    return optimizer

In [None]:
unet_t_norm = (Unet_with_trajectory_learning.module.subsampling.x - Unet_without_trajectory_learning.module.subsampling.x).norm()

In [None]:
humus_t_norm = (Humus_with_trajectory_learning.module.subsampling.x - Humus_without_trajectory_learning.module.subsampling.x).norm()

In [None]:
checkpoint_files = {
    "Unet_with_trajectory_learning": "summary/16/spiral_high_1e-05_0.01_0.01/best_model.pt",
    "Unet_without_trajectory_learning": "summary/16/spiral_high_0_0.01_0.01/best_model.pt",
    "Humus_with_trajectory_learning": "summary/16/spiral_high_1e-05_0.01_0.01_humus/best_model.pt",
    "Humus_without_trajectory_learning": "summary/16/spiral_high_0_0.01_0.01_humus/best_model.pt",
}
_, Unet_with_trajectory_learning, _, args = load_model(checkpoint_files["Unet_with_trajectory_learning"])
_, Unet_without_trajectory_learning, _, _ = load_model(checkpoint_files["Unet_without_trajectory_learning"])
_, Humus_with_trajectory_learning, _, _ = load_model(checkpoint_files["Humus_with_trajectory_learning"])
_, Humus_without_trajectory_learning, _, _ = load_model(checkpoint_files["Humus_without_trajectory_learning"])

In [None]:
from models.subsampling_model import Subsampling_Layer
subsampling = Subsampling_Layer(10, 1, 0, "radial", 16, 20, False).to("cpu")
humusBlock = HUMUSBlock(use_checkpoint=False, num_cascades=4, img_size=[320, 320], window_size=4,
                                mask_center=False, num_adj_slices=1, in_chans=2).to("cpu")

In [None]:
loader = create_test_loader(args)
for data in loader:
    input, target, mean, std, norm = data
    s = subsampling(input.unsqueeze(1))
    break

In [None]:
Unet_with_trajectory_learning.to("cpu")
del Unet_with_trajectory_learning

Unet_without_trajectory_learning.to("cpu")
del Unet_without_trajectory_learning

Humus_with_trajectory_learning.to("cpu")
del Humus_with_trajectory_learning

Humus_without_trajectory_learning.to("cpu")
del Humus_without_trajectory_learning

torch.cuda.empty_cache()

In [None]:
print(f"the diff in trajectory for the unet: {unet_t_norm}")
print(f"the diff in trajectory for the humus: {humus_t_norm}")

In [None]:
import numpy as np
import torch
import json

def apply_noise(noise_type = "random"):
    noise_levels = np.logspace(0, 15, 16)
    for model_name, file in checkpoint_files.items():
        _, model, _, args = load_model(file)
        args.sample_rate = 0.2
        test_loader = create_test_loader(args)
        initial_trajectory = model.module.subsampling.x.clone()
        results = []

        for noise_level in noise_levels:
            l1_noise = torch.randn_like(initial_trajectory) if noise_type == "random" else torch.ones_like(initial_trajectory)
            l1_norm = torch.norm(l1_noise, p=1)
            l1_noise_scaled = noise_level * l1_noise / l1_norm 

            l2_noise = torch.randn_like(initial_trajectory) if noise_type == "random" else torch.ones_like(initial_trajectory)
            l2_norm = torch.norm(l2_noise)
            l2_noise_scaled = noise_level * l2_noise / l2_norm 

            linf_noise = torch.randn_like(initial_trajectory) if noise_type == "random" else torch.ones_like(initial_trajectory)
            linf_norm = torch.norm(linf_noise, p=float('inf')) 
            linf_noise_scaled = noise_level * linf_noise / linf_norm  

            noisy_trajectory_l1 = initial_trajectory + l1_noise_scaled
            noisy_trajectory_l1 = torch.clamp(noisy_trajectory_l1, min=-160, max=160)

            noisy_trajectory_l2 = initial_trajectory + l2_noise_scaled
            noisy_trajectory_l2 = torch.clamp(noisy_trajectory_l2, min=-160, max=160)

            noisy_trajectory_linf = initial_trajectory + linf_noise_scaled
            noisy_trajectory_linf = torch.clamp(noisy_trajectory_linf, min=-160, max=160)

            with torch.no_grad():
                model.module.subsampling.x.data = noisy_trajectory_l1
            dev_loss_l1, noisy_psnr_l1, noisy_ssim_l1 = evaluate(args, model, test_loader)

            with torch.no_grad():
                model.module.subsampling.x.data = noisy_trajectory_l2
            dev_loss_l2, noisy_psnr_l2, noisy_ssim_l2 = evaluate(args, model, test_loader)

            with torch.no_grad():
                model.module.subsampling.x.data = noisy_trajectory_linf
            dev_loss_linf, noisy_psnr_linf, noisy_ssim_linf = evaluate(args, model, test_loader)

            results.append({
                'noise_level': noise_level,
                'psnr_l1': noisy_psnr_l1,
                'ssim_l1': noisy_ssim_l1,
                'psnr_l2': noisy_psnr_l2,
                'ssim_l2': noisy_ssim_l2,
                'psnr_linf': noisy_psnr_linf,
                'ssim_linf': noisy_ssim_linf
            })


        with open(f'results_{model_name}_{noise_type}.json', 'w') as f:
            json.dump(results, f, indent=4)

In [None]:
%matplotlib inline
import json
import matplotlib.pyplot as plt

filenames = {
    "Humus + Trajectory": "results_Humus_with_trajectory_learning.json",
    "Humus - Trajectory": "results_Humus_without_trajectory_learning.json",
    "UNet + Trajectory":  "results_Unet_with_trajectory_learning.json",
    "UNet - Trajectory":  "results_Unet_without_trajectory_learning.json",
}

results = {}

for label, fname in filenames.items():
    with open(fname, 'r') as f:
        data = json.load(f)  
        # If 'data' is a list of dicts, do something like:
        #  [ { "noise_level": ..., "psnr_l1": ..., ... },
        #    { "noise_level": ..., "psnr_l1": ..., ... }, ...]

    # Sort by noise_level (if that’s what you want):
    data = sorted(data, key=lambda d: d["noise_level"])

    noise_levels = []
    psnr_l1 = []
    ssim_l1 = []
    psnr_l2 = []
    ssim_l2 = []
    psnr_linf = []
    ssim_linf = []
    
    for item in data:
        noise_levels.append(item["noise_level"])
        psnr_l1.append(item["psnr_l1"])
        ssim_l1.append(item["ssim_l1"])
        psnr_l2.append(item["psnr_l2"])
        ssim_l2.append(item["ssim_l2"])
        psnr_linf.append(item["psnr_linf"])
        ssim_linf.append(item["ssim_linf"])
    
    results[label] = {
        "noise_levels": noise_levels,
        "psnr_l1": psnr_l1,
        "psnr_l2": psnr_l2,
        "psnr_linf": psnr_linf,
        "ssim_l1": ssim_l1,
        "ssim_l2": ssim_l2,
        "ssim_linf": ssim_linf
    }

# ----------------------------
# Now plot the aggregated data
# ----------------------------
fig, axs = plt.subplots(2, 3, figsize=(12, 6), sharex=True)
axs = axs.ravel()

metric_labels = ["psnr_l1", "psnr_l2", "psnr_linf", "ssim_l1", "ssim_l2", "ssim_linf"]
titles       = ["PSNR (L1)", "PSNR (L2)", "PSNR (L∞)", "SSIM (L1)", "SSIM (L2)", "SSIM (L∞)"]

for ax, metric_label, title in zip(axs, metric_labels, titles):
    for label in results:
        noise_levels = results[label]["noise_levels"]
        metric_vals  = results[label][metric_label]
        ax.plot(noise_levels, metric_vals, marker='o', label=label)
        
    ax.set_title(title)
    ax.set_xlabel("Noise level")
    ax.set_ylabel(metric_label.upper())
    ax.set_xscale("log")
    ax.grid(True)
    ax.legend(fontsize=8)
    
plt.tight_layout()
plt.show()