In [1]:
import os
import torch
from torch import optim, nn, utils, Tensor
from torchvision.transforms import ToTensor
import pytorch_lightning as pl
import torchvision.transforms.functional as TF
import torch.nn.functional as F
import monai
from unet import UNet

In [2]:
from torch.utils.data import Dataset
import pandas as pd
import SimpleITK as sitk
from torchvision.transforms.functional import resize, center_crop
from torchvision.transforms import InterpolationMode
from vicreg import VICReg, train_vicreg

class CamusDataset(Dataset):
    def __init__(self, data_path, image_size=(512, 512)):
        super().__init__()
        self.root = data_path

        self.data_list = []
        self.image_size = image_size

        self.num_imgs = len(os.listdir(self.root))

    def __len__(self):
        return self.num_imgs

    def __getitem__(self, idx):

        chambers = '2CH' if torch.rand(1) > 0.5 else '4CH'

        path = os.path.join(self.root, f'patient{idx+1:04d}', f'patient{idx+1:04d}_{chambers}_sequence')

        image_sitk = sitk.ReadImage(f'{path}.mhd', sitk.sitkFloat32)

        # get pixel spacing to correct aspect ratio
        spacing = image_sitk.GetSpacing()
        aspect_ratio = spacing[1]/spacing[0]

        # convert to numpy

        first_frame_index = torch.randint(0, image_sitk.GetSize()[2] - 2, (1,))

        image_1 = sitk.GetArrayFromImage(image_sitk)[first_frame_index] / 255
        image_2 = sitk.GetArrayFromImage(image_sitk)[first_frame_index + 1] / 255



        # compute aspect ratio of pixel(mm) and image(pixels)
        pixel_aspect = spacing[1] / spacing[0]
        image_aspect = image_sitk.GetHeight() / image_sitk.GetWidth()

        # preprocess image and mask

        image_1, image_2 = torch.Tensor(image_1).unsqueeze(0), torch.Tensor(image_2).unsqueeze(0)
        size =  (self.image_size[0], int(image_1.shape[2]*image_aspect*pixel_aspect))

        image_1  = resize(image_1, size, interpolation=InterpolationMode.BICUBIC)
        image_2 = resize(image_2, size, interpolation=InterpolationMode.NEAREST)

        image_1, image_2 = center_crop(image_1, self.image_size), center_crop(image_2, self.image_size)

        image = torch.cat((image_1, image_2), dim=0)

        return image

In [3]:
dataset = CamusDataset('../data/training')
model = UNet(n_channels=1, n_classes=4, bilinear=False, scaling=2)
vicreg = VICReg(model)
loader = utils.data.DataLoader(dataset, batch_size=3, shuffle=True)
optimizer = optim.Adam(vicreg.parameters(), lr=5e-5)

In [4]:
vicreg, losses = train_vicreg(vicreg, loader, optimizer, 10)

Starting epoch 0
Batch 0 loss: 198.98565673828125
Batch 5 loss: 152.1251220703125
Batch 10 loss: 134.63430786132812
Batch 15 loss: 103.71627807617188
Batch 20 loss: 96.26701354980469
Batch 25 loss: 69.46277618408203
Batch 30 loss: 78.04834747314453
Batch 35 loss: 48.51445007324219
Batch 40 loss: 53.67090606689453
Batch 45 loss: 91.96554565429688
Batch 50 loss: 37.54883575439453
Batch 55 loss: 35.15755081176758
Batch 60 loss: 28.968894958496094
Batch 65 loss: 42.093666076660156
Batch 70 loss: 40.1080322265625
Batch 75 loss: 31.390186309814453
Batch 80 loss: 26.47576332092285
Batch 85 loss: 28.208036422729492
Batch 90 loss: 26.250574111938477
Batch 95 loss: 29.65609359741211
Batch 100 loss: 27.12096405029297
Batch 105 loss: 25.896150588989258
Batch 110 loss: 25.378459930419922
Batch 115 loss: 25.56060791015625
Batch 120 loss: 35.1664924621582
Batch 125 loss: 27.330699920654297
Batch 130 loss: 27.975772857666016
Batch 135 loss: 24.657821655273438
Batch 140 loss: 23.884981155395508
Batch 1

In [5]:
torch.save(vicreg.encoder.state_dict(), 'models/vicreg_encoder_big.pth')