In [1]:
from h5_utils import read_sim_data, basic_grouping, catalog_exams
import glob
import torch
import h5py
import matplotlib.pyplot as plt
import os
import numpy as np
from scipy import ndimage
from tqdm import tqdm
import pandas as pd
import pytorch_lightning as pl
import monai
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

In [4]:
# define the LightningModule
class UnetModel(pl.LightningModule):
    def __init__(self, input_size, spatial_dims, num_channels):
        super().__init__()
        self.input_size = input_size
        self.num_channels = num_channels
        self.model = monai.networks.nets.BasicUNet(spatial_dims=spatial_dims, in_channels=num_channels, out_channels=1, features=(32, 32, 64, 128, 256, 32)) #.to(device)

    def forward(self, data):
        #ouput of this is [batch_size, (64, 64, 1)]
        #print('forward step')
        return self.model(data.float())

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        imgs = batch["data"] # shape [N 60 64 64]
        label = batch["kPL"] # shape [N 64 64]
        prediction = self.forward(imgs).squeeze()

        # L1 loss
        l1loss = torch.nn.L1Loss(reduction='none')
        output = l1loss(prediction, label)
        loss_l1_batch = torch.sum(output, (1,2))
        loss_l1 = torch.mean(loss_l1_batch)

        # L2 loss
        l2loss = torch.nn.L2Loss(reduction='none')
        output = l2loss(prediction, label)
        loss_l2_batch = torch.sum(output, (1,2))
        loss_l2 = torch.mean(loss_l2_batch)

        # Logging to TensorBoard (if installed) by default
        values = {"train_loss": loss_l1, "train_l2": loss_l2}
        self.log(values)
        return loss_l1
    
    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        imgs = batch["data"] # shape [N 60 64 64]
        label = batch["kPL"] # shape [N 64 64]
        prediction = self.forward(imgs).squeeze()

        # L1 loss
        l1loss = torch.nn.L1Loss(reduction='none')
        output = l1loss(prediction, label)
        loss_l1_batch = torch.sum(output, (1,2))
        loss_l1 = torch.mean(loss_l1_batch)

        # L2 loss
        l2loss = torch.nn.L2Loss(reduction='none')
        output = l2loss(prediction, label)
        loss_l2_batch = torch.sum(output, (1,2))
        loss_l2 = torch.mean(loss_l2_batch)

        # Logging to TensorBoard (if installed) by default
        values = {"val_loss": loss_l1, "val_l2": loss_l2}
        self.log(values)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3) 
        return optimizer


# init the classification model
input_size = (64, 64)
spatial_dims = 2 
num_channels = 60 #metabolies*tpts
unet_model = UnetModel(input_size=input_size, spatial_dims=spatial_dims, num_channels=num_channels)
logger = TensorBoardLogger(save_dir=os.getcwd(), version=2, name="lightning_logs")
checkpoint_callback = ModelCheckpoint(dirpath="./checkpoints", save_top_k=3, monitor="val_loss")

BasicUNet features: (32, 32, 64, 128, 256, 32).


In [3]:
# dataloader
train_dataset = BrainWebDataset("groups/train_TEST", read_brainweb_sim_data)
val_dataset = BrainWebDataset("groups/val_TEST", read_brainweb_sim_data)

dl_train = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=8)
dl_val = DataLoader(val_dataset, batch_size=2, shuffle=True, num_workers=8)

training data


100%|██████████| 800/800 [00:02<00:00, 329.18it/s]

	 all (800, 4)
--> # steps in epoch: 400





val data


100%|██████████| 100/100 [00:00<00:00, 314.90it/s]

	 all (100, 4)
--> # steps in epoch: 50





In [6]:
trainer = pl.Trainer(max_epochs=500, devices=1, accelerator="gpu", logger=logger, callbacks=[checkpoint_callback]) #devices here is the numbers of the devices to use
trainer.fit(model=unet_model, train_dataloaders=dl_train, val_dataloaders=dl_val) 

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type      | Params
------------------------------------
0 | model | BasicUNet | 2.0 M 
------------------------------------
2.0 M     Trainable params
0         Non-trainable params
2.0 M     Total params
7.982     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  "Trying to infer the `batch_size` from an ambiguous collection. The batch size we"


Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [7]:
for i_batch, sample_batched in enumerate(dl_train):
    print('Batch index: ', i_batch)
    print('Data size: ', sample_batched["data"].size())
    print('kPL map size: ', sample_batched["kPL"].size())
    predicted_map = unet_model(sample_batched["data"]) 
    print(predicted_map.shape)
    loss = unet_model.training_step(sample_batched, i_batch)
    print(loss)
    #if i_batch >= count_train-1:
    if i_batch >= 1:
        break

Batch index:  0
Data size:  torch.Size([2, 60, 64, 64])
kPL map size:  torch.Size([2, 64, 64])
forward step
torch.Size([2, 1, 64, 64])
forward step
tensor(1997.7825, grad_fn=<MeanBackward0>)
Batch index:  1
Data size:  torch.Size([2, 60, 64, 64])
kPL map size:  torch.Size([2, 64, 64])
forward step


  "You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet."


torch.Size([2, 1, 64, 64])
forward step
tensor(1962.2917, grad_fn=<MeanBackward0>)


In [6]:
loss = torch.nn.L1Loss(reduction='none')
input = torch.randn(2, 3, 5, requires_grad=True)
target = torch.randn(2, 3, 5)
output = loss(input, target)
print(output)
print(output.shape)

tensor([[[2.1736, 2.1025, 0.4572, 2.3601, 0.2952],
         [0.5978, 1.0092, 1.6555, 1.3653, 0.2777],
         [0.7864, 0.4777, 0.3644, 3.0574, 1.4337]],

        [[1.9992, 0.5427, 0.4297, 1.5851, 0.7495],
         [1.3727, 1.3939, 0.1706, 2.7125, 0.1682],
         [0.6284, 0.4092, 2.5690, 1.2286, 1.2483]]], grad_fn=<L1LossBackward0>)
torch.Size([2, 3, 5])
