In [1]:
from Models.modl import modl
from torch.utils.data import DataLoader
from Transforms import (pad, trim_coils, combine_coil, toTensor, permute, 
                        view_as_real, remove_slice_dim, fft_2d, normalize, addChannels)
from Dataset.undersampled_dataset import UndersampledKSpaceDataset
from torchvision.transforms import Compose
import numpy as np
import torch

In [2]:
%load_ext autoreload
%autoreload 2 

In [3]:
def collate_fn(data):
    undersampled = [d['undersampled'] for d in data]
    sampled = [d['k_space'] for d in data]
    ismrmrd_header = [d['ismrmrd_header'] for d in data]
    mask = [d['mask'] for d in data]
    recon_rss = [d['reconstruction_rss'] for d in data]

    undersampled = torch.concat(undersampled, dim=0)
    sampled = torch.concat(sampled, dim=0)
    mask = torch.concat(mask, dim=0)

    data = {
        'undersampled': undersampled, 
        'sampled': sampled,
        'ismrmrd_header': ismrmrd_header,
        'mask': mask, 
        'recon': recon_rss,
    }
    return data

In [4]:
transforms = Compose(
    (
        trim_coils(12),
        pad((640, 320)), 
        fft_2d(axes=[2,3]),
        combine_coil(),
        normalize(),
        toTensor(),
    )
)
dataset = UndersampledKSpaceDataset('D:/multicoil_train', transforms=transforms)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn)
    

In [5]:
data = (next(iter(dataloader)))

In [6]:
import cProfile
cProfile.run('next(iter(dataloader))', 'dataloader.profile')

In [7]:
model = modl(1, 3)

In [8]:
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), momentum=0.99, lr=0.0001)

In [15]:
def train(model, loss_function, optimizer, dataloader):
    cur_loss = 0
    current_index = 0
    for data in dataloader:
        
        sampled = data['sampled']
        mask = data['mask']
        undersampled = data['undersampled']
        for i in range(sampled.shape[0]):
            optimizer.zero_grad()
            sampled_slice = sampled[[i],...]
            mask_slice = mask[[i],...]
            undersampled_slice = undersampled[[i],...]

            predicted_sampled = model(undersampled_slice, mask_slice, 1)
            loss = loss_function(torch.view_as_real(predicted_sampled), torch.view_as_real(sampled_slice))

            loss.backward()
            optimizer.step()
            cur_loss += loss.item()
            if current_index % 10 == 9:
                print(f"Iteration: {current_index + 1:>d}, Loss: {cur_loss:>7f}")
                cur_loss = 0
            current_index += 1


In [16]:
train(model, loss_fn, optimizer, dataloader)

Iteration: 10, Loss: 2.502711
Iteration: 4, Loss: 5.851493
Iteration: 14, Loss: 9.766493
Iteration: 8, Loss: 8.772163
Iteration: 2, Loss: 8.291339
Iteration: 12, Loss: 6.630379
Iteration: 6, Loss: 5.387787


KeyboardInterrupt: 