In [1]:
from ml_recon.Models.modl import modl
from torch.utils.data import DataLoader
from ml_recon.Transforms import (pad, trim_coils, combine_coil, toTensor, permute, 
                        view_as_real, remove_slice_dim, fft_2d, normalize, addChannels)
from ml_recon.Dataset.undersampled_dataset import UndersampledKSpaceDataset
from torchvision.transforms import Compose
import numpy as np

import torch
from ml_recon.Utils import image_slices, save_model
from ml_recon.Utils.collate_function import collate_fn
from ml_recon.Models.varnet import VarNet

In [2]:
torch.manual_seed(0)
np.random.seed(0)

In [3]:
transforms = Compose(
    (
        pad((640, 320)), 
        toTensor(),
        normalize(),
    )
)
dataset = UndersampledKSpaceDataset('/home/kadotab/projects/def-mchiew/kadotab/Datasets/fastMRI/multicoil_train', transforms=transforms, R=4)
dataloader = DataLoader(dataset, batch_size=1, collate_fn=collate_fn)
    

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [5]:
model = VarNet(2, 2, num_cascades=5, use_norm=True)
model.to(device)

VarNet(
  (cascade): ModuleList(
    (0): VarnetBlock(
      (unet): Unet(
        (down_sample_layers): ModuleList(
          (0): double_conv(
            (conv1): Conv2d(2, 18, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (conv2): Conv2d(18, 18, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (activation): LeakyReLU(negative_slope=0.2, inplace=True)
            (instance_norm1): InstanceNorm2d(18, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (instance_norm2): InstanceNorm2d(18, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
            (drop_out1): Dropout2d(p=0, inplace=False)
            (drop_out2): Dropout2d(p=0, inplace=False)
          )
          (1): Unet_down(
            (down): down(
              (max_pool): AvgPool2d(kernel_size=2, stride=(2, 2), padding=0)
            )
            (conv): double_conv(
              (conv1): Conv2d(18, 36, kernel_size=(3, 3), st

In [6]:
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [7]:
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
writer = SummaryWriter('/home/kadotab/scratch/runs' +  datetime.now().strftime("%Y%m%d-%H%M%S"))

In [12]:
path = '/home/kadotab/python/ml/ml_recon/Model_Weights/'
def train(model, loss_function, optimizer, dataloader, epoch=7):
    cur_loss = 0
    current_index = 0
    try:
        for e in range(epoch):
            for data in dataloader:
                sampled = data['k_space']
                mask = data['mask']
                delta_mask = data['delta_mask']
                undersampled = data['undersampled']
                for i in range(sampled.shape[0]):
                    optimizer.zero_grad()
                    sampled_slice = sampled[[i],...]
                    mask_slice = mask[[i],...]
                    undersampled_slice = undersampled[[i],...]
                    mask_slice = mask_slice.to(device)
                    mask_slice = mask_slice.bool()
                    undersampled_slice = undersampled_slice.to(device)
                    sampled_slice = sampled_slice.to(device)

                    predicted_sampled = model(undersampled_slice, mask_slice)
                    loss = loss_function(torch.view_as_real(predicted_sampled), torch.view_as_real(sampled_slice))

                    loss.backward()
                    optimizer.step()
                    cur_loss += loss.item()
                    current_index += 1
                    if current_index % 1000 == 999:
                        writer.add_histogram('sens/weights1', next(model.sens_model.model.conv1d.parameters()), current_index)
                        writer.add_histogram('castcade0/weights1', next(model.cascade[0].unet.conv1d.parameters()), current_index)
                        writer.add_histogram('castcade0/weights2', next(model.cascade[1].unet.conv1d.parameters()), current_index)
                        writer.add_histogram('castcade0/weights11', next(model.cascade[-2].unet.conv1d.parameters()), current_index)
                        writer.add_histogram('castcade0/weights12', next(model.cascade[-1].unet.conv1d.parameters()), current_index)
                        writer.add_histogram('varnet/regularizer', model.lambda_reg.data, current_index)
                        writer.add_scalar('Loss/train', cur_loss, current_index)
                        print(f"Iteration: {current_index + 1:>d}, Loss: {cur_loss:>7f}")
                        cur_loss = 0
                        save_model(path, model, optimizer, current_index) 
    except KeyboardInterrupt:
        pass

    save_model(path, model, optimizer, -1)

In [9]:
writer.add_histogram('test', model.cascade[-1].unet.conv1d.weight.flatten(), 1)

In [13]:
train(model, loss_fn, optimizer, dataloader)

Iteration: 10, Loss: 0.072625
Iteration: 20, Loss: 0.084318
Iteration: 30, Loss: 0.064502
Iteration: 40, Loss: 0.258372
Iteration: 50, Loss: 0.156658
Iteration: 60, Loss: 0.107559
Iteration: 70, Loss: 0.156748
Iteration: 80, Loss: 0.115576
Iteration: 90, Loss: 0.073983
Iteration: 100, Loss: 0.086453
Iteration: 110, Loss: 0.050611
Iteration: 120, Loss: 0.045982
Iteration: 130, Loss: 0.042751
Iteration: 140, Loss: 0.043112
Iteration: 150, Loss: 0.028080
Iteration: 160, Loss: 0.014854
Iteration: 170, Loss: 0.032200
Iteration: 180, Loss: 0.027171
Iteration: 190, Loss: 0.027857
Iteration: 200, Loss: 0.029766
Iteration: 210, Loss: 0.038177
Iteration: 220, Loss: 0.021150
