In [None]:
import torch;

from train import normalize

print(torch.__version__)
import torchvision; print(torchvision.__version__)

In [None]:
# pytorch
import torch
from torchvision.utils import make_grid
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as tvtransformss

# fastmri
import fastmri
from fastmri.data import subsample
from fastmri.data import transforms, mri_data
from fastmri.evaluate import ssim, psnr, nmse
from fastmri.losses import SSIMLoss
from fastmri.models import Unet

# other
import random
import PIL.Image as Image
from glob import glob
from myutils import SSIM, PSNR
from models.rec_models.vit_model import VisionTransformer
from models.rec_models.recon_net import ReconNet

# Device
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
device = 'cuda'

In [None]:
print("hello")

In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt
import os
from torchvision.transforms import Resize, ToPILImage
%matplotlib inline
def show_image(source):
    try:
        source = source.clone()
        source.reshape(320, 320)
        image = source
        image -= image.min()
        max_val = image.max()
        if max_val > 0:
            image /= max_val
        source = image
        grid = torchvision.utils.make_grid(source, nrow=4, pad_value=1)
        numpy_image = grid.permute(1, 2, 0).cpu().detach().numpy()
    
        # Save or display the image
        plt.imshow(numpy_image)
        plt.axis('off')
        plt.show()
    except:
        return

In [None]:
class fastMRIDataset(Dataset):
    def __init__(self, challenge, path, isval, sample_rate = 0.1):
        """
        Dataloader for 4x acceleration and random sampling
        challenge: 'multicoil' or 'singlecoil'
        path: path to dataset
        isval: whether dataset is fastMRI's validation set or training set
        """
        self.challenge = challenge 
        self.data_path = path
        self.isval = isval

        self.data = mri_data.SliceDataset(
            root=self.data_path,
            transform=self.data_transform,
            challenge=self.challenge,
            use_dataset_cache=True,
            sample_rate = sample_rate
            )

        self.mask_func = subsample.EquispacedMaskFunc( # RandomMaskFunc for knee, EquispacedMaskFunc for brain
            center_fractions=[0.08],
            accelerations=[4],
            )
            
    def data_transform(self, kspace, mask, target, data_attributes, filename, slice_num):
        if self.isval:
            seed = tuple(map(ord, filename))
        else:
            seed = None     
        kspace = transforms.to_tensor(kspace)
        masked_kspace, _ = transforms.apply_mask(kspace, self.mask_func, seed)        
        
        target = transforms.to_tensor(target)
        zero_fill = fastmri.ifft2c(masked_kspace)
        zero_fill = transforms.complex_center_crop(zero_fill, target.shape)   
        x_zero_fill = fastmri.complex_abs(zero_fill)
        
        if self.challenge == 'multicoil':
            x_zero_fill = fastmri.rss(x_zero_fill)
            zero_fill = fastmri.rss(zero_fill)

        x_zero_fill = x_zero_fill.unsqueeze(0)
        zero_fill = zero_fill.unsqueeze(0)
        target = target.unsqueeze(0)
        
        return (x_zero_fill, target, data_attributes['max'], zero_fill)    

    def __len__(self,):
        return len(self.data)
    
    def __getitem__(self, idx):
        data = self.data[idx]

        return data

In [None]:
challenge = 'multicoil' # 'multicoil' or 'singlecoil'
train_path = '/mnt/walkure_public/users/mohammedw/fastmri_downloads/multicoil_train/' # path to fastmri's training data
val_path = '/mnt/walkure_public/users/mohammedw/fastmri_downloads/multicoil_val/' # path to fastmri's validation data
dataset = fastMRIDataset(challenge=challenge, path=train_path, isval=False)
val_dataset = fastMRIDataset(challenge=challenge, path=val_path, isval=True)

ntrain = len(dataset) # number of training data
train_dataset, _ = torch.utils.data.random_split(dataset, [ntrain, len(dataset)-ntrain], generator=torch.Generator().manual_seed(42))
print(len(train_dataset))

In [None]:
batch_size = 1
trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, generator=torch.Generator().manual_seed(42))
valloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=0)  

In [None]:
print(len(valloader))

In [None]:
trainloader

In [None]:
# Validate model
def validate(model):
    valloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)   
    model.eval()    
    ssim_ = SSIM().to(device)
    psnr_ = PSNR().to(device)
    psnrs = []
    ssims = []
    
    with torch.no_grad():
        for data in valloader:
            inputs, targets, maxval, _ = data
            inputs = normalize(center_crop(inputs, 320, 320))
            targets = normalize(center_crop(targets, 320, 320))
            outputs = model(inputs.to(device))
            show_image(outputs)
            ssims.append(ssim_(outputs, targets.to(device), maxval.to(device)))
            psnrs.append(psnr_(outputs, targets.to(device), maxval.to(device)))
    
    ssimval = torch.cat(ssims).mean()
    
    print(' Recon. PSNR: {:0.3f} pm {:0.2f}'.format(torch.cat(psnrs).mean(), 2*torch.cat(psnrs).std()))
    print(' Recon. SSIM: {:0.4f} pm {:0.3f}'.format(torch.cat(ssims).mean(), 2*torch.cat(ssims).std()))
                
    return (1-ssimval).item()

# Save model
def save_model(path, model, train_hist, val_hist, optimizer, scheduler=None):
    net = model.net
    if scheduler:
        checkpoint = {
            'model' :  ReconNet(net),
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(), 
        }
    else:
        checkpoint = {
            'model' :  ReconNet(net),
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }
        
    torch.save(train_hist, path + 'train_hist.pt')
    torch.save(val_hist, path + 'val_hist.pt')    
    torch.save(checkpoint,  path + 'checkpoint.pth')

In [None]:
# Step 1: Re-initialize everything
avrg_img_size = 340
patch_size = 10
depth = 10
num_heads = 16
embed_dim = 44

net = VisionTransformer(
    avrg_img_size=avrg_img_size,
    patch_size=patch_size,
    in_chans=1, embed_dim=embed_dim,
    depth=depth, num_heads=num_heads,
    )

model = ReconNet(net)  # replace `net` as needed
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)  # match optimizer
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)  # match scheduler

path = "models/vit-l_equidist_acc4/checkpoint.pth"
# Step 2: Load the checkpoint
checkpoint = torch.load(path)

# Step 3: Restore state dicts
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler'])

# Step 4: Set model to eval or train
model.train()  # or model.train() depending on usage

In [None]:
"""Optimizer"""
criterion = torch.nn.L1Loss() #CompositeMRILoss(15, 0.0025, 0.1).to(device) #torch.nn.MSELoss() #torch.nn.L1Loss()  #SSIMLoss().to(device) #torch.nn.L1Loss() 
optimizer = optim.Adam(model.parameters(), lr=2 * 0.0001)
train_hist = []
val_hist = []
best_val = float("inf")
path = './' # Path for saving model checkpoint and loss history
num_epochs = 100000
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.0003,
                                          total_steps=num_epochs, pct_start=0.1,
                                          anneal_strategy='linear',
                                          cycle_momentum=False,
                                          base_momentum=0., max_momentum=0., div_factor=0.1*num_epochs, final_div_factor=9)

In [None]:
optimizer = torch.optim.Adam(
    model.parameters(),
    lr = 0.00001,              # Learning rate
    betas=(0.9, 0.999),   # Default Adam settings
    eps=1e-8,             # For numerical stability
    weight_decay=0        # Set to 0 for overfitting (no regularization)
)

In [None]:
def center_crop(tensor, *target_sizes):
    """
    Center crop the last N dimensions to match target_sizes.
    """
    spatial_dims = tensor.shape[-len(target_sizes):]
    starts = [(dim - target) // 2 for dim, target in zip(spatial_dims, target_sizes)]
    slices = tuple(slice(start, start + size) for start, size in zip(starts, target_sizes))
    return tensor[(...,) + slices]

def top_left_crop(tensor, *target_sizes):
    """
    Top-left crop the last N dimensions to match target_sizes.
    """
    slices = tuple(slice(0, size) for size in target_sizes)
    return tensor[(...,) + slices]

def crop_at(tensor, starts, sizes):
    """
    Crop starting at 'starts' with 'sizes' over the last N dimensions.
    starts: tuple of starting indices (length = num spatial dims)
    sizes: tuple of sizes (length = num spatial dims)
    """
    slices = tuple(slice(start, start + size) for start, size in zip(starts, sizes))
    return tensor[(...,) + slices]

In [None]:
validate(model.to("cuda"))

In [None]:
"""Train Model"""
from common.evaluate import psnr, ssim
from train import normalize
model = model.to("cuda")
for epoch in range(0, 10):
    model.train()
    train_loss = 0.0

    for iter, data in enumerate(valloader):
        inputs, targets, maxval, inputs_c = data
        #inputs_c = inputs_c.squeeze(0).permute(0, 3, 1, 2)
        inputs = (center_crop(inputs, 320, 320))
        targets = (center_crop(targets, 320, 320))
        #inputs_c = center_crop(inputs_c, 2, 320, 320)
        #print(inputs_c, targets)
        optimizer.zero_grad()
        print("1")
        #outputs = fastmri.complex_abs(model(inputs_c.to(device)).permute(0, 2, 3, 1))
        #print(output)
        outputs = model(inputs.to(device))
        #print(outputs.shape)
        show_image(inputs[0])
        show_image(outputs[0])
        loss = criterion(outputs, targets.to(device))#, maxval.to(device))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1, norm_type=1.)
        optimizer.step()

        train_loss += loss.item()
            
    if scheduler:
        scheduler.step()
        
    train_hist.append(train_loss/len(trainloader))
    print('Epoch {}, Train loss.: {:0.10f}'.format(epoch+1, train_hist[-1]))
    
    if (epoch+1)%5==0:
        print('Validation:')
        val_hist.append(validate(model))        
        if val_hist[-1] < best_val:
            save_model(path, model, train_hist, val_hist, optimizer, scheduler=scheduler)
            best_val = val_hist[-1]

In [None]:
print("hello")