In [1]:
import os
import torch
from torch import optim, nn, utils, Tensor
from torchvision.transforms import ToTensor
import pytorch_lightning as pl
import torchvision.transforms.functional as TF
import torch.nn.functional as F
import monai
from unet import UNet

In [2]:
from torch.utils.data import Dataset
import pandas as pd
import SimpleITK as sitk
from torchvision.transforms.functional import resize, center_crop
from torchvision.transforms import InterpolationMode
from vicreg import VICReg, train_vicreg

class CamusDataset(Dataset):
    def __init__(self, data_path, image_size=(512, 512)):
        super().__init__()
        self.root = data_path

        self.data_list = []
        self.image_size = image_size

        self.num_imgs = len(os.listdir(self.root))

    def __len__(self):
        return self.num_imgs

    def __getitem__(self, idx):

        chambers = '2CH' if torch.rand(1) > 0.5 else '4CH'

        path = os.path.join(self.root, f'patient{idx+1:04d}', f'patient{idx+1:04d}_{chambers}_sequence')

        image_sitk = sitk.ReadImage(f'{path}.mhd', sitk.sitkFloat32)

        # get pixel spacing to correct aspect ratio
        spacing = image_sitk.GetSpacing()
        aspect_ratio = spacing[1]/spacing[0]

        # convert to numpy

        first_frame_index = torch.randint(0, image_sitk.GetSize()[2] - 2, (1,))

        image_1 = sitk.GetArrayFromImage(image_sitk)[first_frame_index] / 255
        image_2 = sitk.GetArrayFromImage(image_sitk)[first_frame_index + 1] / 255



        # compute aspect ratio of pixel(mm) and image(pixels)
        pixel_aspect = spacing[1] / spacing[0]
        image_aspect = image_sitk.GetHeight() / image_sitk.GetWidth()

        # preprocess image and mask

        image_1, image_2 = torch.Tensor(image_1).unsqueeze(0), torch.Tensor(image_2).unsqueeze(0)
        size =  (self.image_size[0], int(image_1.shape[2]*image_aspect*pixel_aspect))

        image_1  = resize(image_1, size, interpolation=InterpolationMode.BICUBIC)
        image_2 = resize(image_2, size, interpolation=InterpolationMode.NEAREST)

        image_1, image_2 = center_crop(image_1, self.image_size), center_crop(image_2, self.image_size)

        image = torch.cat((image_1, image_2), dim=0)

        return image

In [3]:
dataset = CamusDataset('../data/training')
model = UNet(n_channels=1, n_classes=4, bilinear=False, scaling=4)
vicreg = VICReg(model)
loader = utils.data.DataLoader(dataset, batch_size=5, shuffle=True)
optimizer = optim.Adam(vicreg.parameters(), lr=5e-5)

In [4]:
vicreg, losses = train_vicreg(vicreg, loader, optimizer, 10)

Starting epoch 0
Batch 0 loss: 94.20459747314453
Batch 5 loss: 73.82655334472656
Batch 10 loss: 62.79039001464844
Batch 15 loss: 55.23118591308594
Batch 20 loss: 44.032386779785156
Batch 25 loss: 39.59077453613281
Batch 30 loss: 33.03168869018555
Batch 35 loss: 31.54857635498047
Batch 40 loss: 31.135936737060547
Batch 45 loss: 26.35281753540039
Batch 50 loss: 26.801044464111328
Batch 55 loss: 23.837892532348633
Batch 60 loss: 24.75250244140625
Batch 65 loss: 27.436601638793945
Batch 70 loss: 24.447341918945312
Batch 75 loss: 24.478166580200195
Batch 80 loss: 24.483457565307617
Batch 85 loss: 23.345046997070312
Starting epoch 1
Batch 0 loss: 23.533626556396484
Batch 5 loss: 23.666913986206055
Batch 10 loss: 23.509546279907227
Batch 15 loss: 23.921009063720703
Batch 20 loss: 23.253942489624023
Batch 25 loss: 23.20338249206543
Batch 30 loss: 23.162216186523438
Batch 35 loss: 23.9178524017334
Batch 40 loss: 23.756580352783203
Batch 45 loss: 23.528400421142578
Batch 50 loss: 23.062162399291

In [5]:
torch.save(vicreg.encoder.state_dict(), 'models/vicreg_encoder.pth')