In [None]:
import os, glob, torch
import numpy as np
import scipy.io as spio
from torch.utils.data import Dataset

class LungDataset(Dataset):
    def __init__(self):
        samples = glob.glob("/media/agjvc_rad3/_TESTKOLLEKTIV/Daten/Daten/*")
        series = [glob.glob(os.path.join(sample, "Series*/dicoms.mat")) for sample in samples]
        slices = [s[int(len(s) / 2)] for s in series if len(s) > 0]
        self.slices = slices

    def __len__(self):
        return len(self.slices)
    
    def __getitem__(self, idx):
        path = self.slices[idx]
        dcm = spio.loadmat(path)["dcm"].astype(np.int32)[:192] # must be int32 as uint16 is not supported and 32bit required for safe upcast
        dcm = (dcm / 255).astype(np.float32)
        tensor = torch.from_numpy(dcm)
        x = torch.permute(tensor, (1,2,0))[None, :, :, :] # C, W, H, t
        return x

In [None]:
ds = LungDataset()

In [None]:
%%script echo "SKIP"
import matplotlib.pyplot as plt

dataset = LungDataset()
sample = dataset[10]
print(sample.shape)
print(sample.dtype)
fig, ax = plt.subplots(3, 10, figsize=(30, 10))
for i in range(3):
    for j in range(10):
        f = sample[:,:,:, i * j + j]
        ax[i][j].imshow(f.view(256,256), cmap="gray")

In [None]:
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split

class LungDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=1, num_workers=1, pin_memory=True):
        super().__init__()

        self.batch_size = batch_size
        self.pin_memory = pin_memory
        self.num_workers = num_workers
        
        self.dataset = None
        self.train_data = None
        self.val_data = None
        self.test_data = None
    
    def prepare_date(self):
        # not needed as the data is not downloaded
        pass

    def setup(self, stage=None):
        self.dataset = LungDataset()
        generator = torch.Generator().manual_seed(42)

        total_len = len(self.dataset)
        train_len = int(total_len * 0.8)
        val_len = int(total_len * 0.1)
        test_len = total_len - (train_len + val_len)
        
        self.train_data, self.val_data, self.test_data = random_split(
            dataset = self.dataset, 
            lengths = [train_len, val_len, test_len], 
            generator = generator
        )

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size, pin_memory=self.pin_memory, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=self.batch_size, pin_memory=self.pin_memory, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size=self.batch_size, pin_memory=self.pin_memory, num_workers=self.num_workers)

In [None]:
import torch
import math
import numpy as np

class NCC_vxm(torch.nn.Module):
    """
    Local (over window) normalized cross correlation loss.<
    """

    def __init__(self, win=None):
        super(NCC_vxm, self).__init__()
        self.win = win

    def forward(self, y_true, y_pred):

        Ii = y_true
        Ji = y_pred

        # get dimension of volume
        # assumes Ii, Ji are sized [batch_size, *vol_shape, nb_feats]
        ndims = len(list(Ii.size())) - 2
        assert ndims in [1, 2, 3], "volumes should be 1 to 3 dimensions. found: %d" % ndims

        # set window size
        win = [9] * ndims if self.win is None else self.win

        # compute filters
        sum_filt = torch.ones([1, 1, *win]).to("cuda")

        pad_no = math.floor(win[0] / 2)

        if ndims == 1:
            stride = (1)
            padding = (pad_no)
        elif ndims == 2:
            stride = (1, 1)
            padding = (pad_no, pad_no)
        else:
            stride = (1, 1, 1)
            padding = (pad_no, pad_no, pad_no)

        # get convolution function
        conv_fn = getattr(torch.nn.functional, 'conv%dd' % ndims)

        # compute CC squares
        I2 = Ii * Ii
        J2 = Ji * Ji
        IJ = Ii * Ji

        I_sum = conv_fn(Ii, sum_filt, stride=stride, padding=padding)
        J_sum = conv_fn(Ji, sum_filt, stride=stride, padding=padding)
        I2_sum = conv_fn(I2, sum_filt, stride=stride, padding=padding)
        J2_sum = conv_fn(J2, sum_filt, stride=stride, padding=padding)
        IJ_sum = conv_fn(IJ, sum_filt, stride=stride, padding=padding)

        win_size = np.prod(win)
        u_I = I_sum / win_size
        u_J = J_sum / win_size

        cross = IJ_sum - u_J * I_sum - u_I * J_sum + u_I * u_J * win_size
        I_var = I2_sum - 2 * u_I * I_sum + u_I * u_I * win_size
        J_var = J2_sum - 2 * u_J * J_sum + u_J * u_J * win_size

        cc = cross * cross / (I_var * J_var + 1e-5)

        return -torch.mean(cc)
    

In [None]:
import torch

class SpatialTransformer(torch.nn.Module):
    """
    N-D Spatial Transformer
    """

    def __init__(self, size, mode='bilinear'):
        super().__init__()

        self.mode = mode

        # create sampling grid
        vectors = [torch.arange(0, s) for s in size]
        grids = torch.meshgrid(vectors)
        grid = torch.stack(grids)
        grid = torch.unsqueeze(grid, 0)
        grid = grid.type(torch.FloatTensor).cuda()

        # registering the grid as a buffer cleanly moves it to the GPU, but it also
        # adds it to the state dict. this is annoying since everything in the state dict
        # is included when saving weights to disk, so the model files are way bigger
        # than they need to be. so far, there does not appear to be an elegant solution.
        # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict
        self.register_buffer('grid', grid)

    def forward(self, src, flow):
        # new locations
        new_locs = self.grid + flow
        shape = flow.shape[2:]

        # need to normalize grid values to [-1, 1] for resampler
        for i in range(len(shape)):
            new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)

        # move channels dim to last position
        # also not sure why, but the channels need to be reversed
        if len(shape) == 2:
            new_locs = new_locs.permute(0, 2, 3, 1)
            new_locs = new_locs[..., [1, 0]]
        elif len(shape) == 3:
            new_locs = new_locs.permute(0, 2, 3, 4, 1)
            new_locs = new_locs[..., [2, 1, 0]]

        return torch.nn.functional.grid_sample(src, new_locs, align_corners=True, mode=self.mode)

class register_model(torch.nn.Module):
    def __init__(self, img_size=(64, 256, 256), mode='bilinear'):
        super(register_model, self).__init__()
        self.spatial_trans = SpatialTransformer(img_size, mode)

    def forward(self, x):
        img = x[0].cuda()
        flow = x[1].cuda()
        out = self.spatial_trans(img, flow)
        return out

In [None]:
import pytorch_lightning as pl
import torch

max_epoch = 10
lr = 1e-4
weights = [1]
criterions = [NCC_vxm()]

pl.seed_everything(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

class TransMorphModel(pl.LightningModule):
    def __init__(self, model, stn):
        super().__init__()
        self.model = model
        self.stn = stn

    def training_step(self, batch, batch_idx):
        target = batch[:,:,:,-1] # always use last image in a seq.
        output = self.model(batch)
        loss = 0
        loss_vals = []
        for n, loss_fn in enumerate(criterions):
            curr_loss = loss_fn(output[n], target) * weights[n]
            loss_vals.append(curr_loss)
            loss += curr_loss
        return loss

    def validation_step(self, batch, batch_idx):
        # TODO: track val loss
        pass

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        return optimizer

In [None]:
from transmorph.models import TransMorph
from transmorph.models.TransMorph import CONFIGS as CONFIGS_TM

config = CONFIGS_TM["TransMorph-Tiny"]
model = TransMorph.TransMorph(config)
reg_model = register_model(config.img_size, 'nearest')

plmodel = TransMorphModel(model, reg_model)
trainer = pl.Trainer(max_epochs=max_epoch)

In [None]:
datamodule = LungDataModule(batch_size=1, num_workers=4, pin_memory=True)
torch.set_float32_matmul_precision('medium')
trainer.fit(plmodel, datamodule=datamodule)

In [None]:
trainer.save_checkpoint("first_fit.ckpt")
# new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt")