In [2]:
# Import pytorch modules 

import torch
import torch.nn as nn
import torch.nn.functional as F

# Import pytorch lightning modules 

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

# Import other modules

# import numpy as np
# import pandas as pd
# import matplotlib.pyplot as plt
# import seaborn as sns
# import os
# import glob
# from PIL import Image
from torchvision.models import resnet18
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
class Encoder(torch.nn.Module): 
    def __init__(self): 
        super().__init__()
        self.resnet = resnet18(pretrained=True)
        self.resnet.avgpool = nn.Identity()
        self.resnet.fc = nn.Identity()
        self.resnet.eval()

    def forward(self, x):
        return self.resnet(x).reshape(-1, 512, 8, 8)
    
class Decoder(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(512, 256, 3, padding=1)
        self.conv2 = nn.Conv2d(256, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 64, 3, padding=1)
        self.conv4 = nn.Conv2d(64, 32, 3, padding=1)
        self.conv5 = nn.Conv2d(32, 16, 3, padding=1)
        self.conv6 = nn.Conv2d(16, 3, 3, padding=1)
        self.conv7 = nn.Conv2d(3, 3, 3, padding=1)

    def forward(self, x):
        '''
            Output shape: (batch_size, 3, 256, 256)
        '''
        x = F.interpolate(x, scale_factor=2)
        x = F.relu(self.conv1(x))
        x = F.interpolate(x, scale_factor=2)
        x = F.relu(self.conv2(x))
        x = F.interpolate(x, scale_factor=2)
        x = F.relu(self.conv3(x))
        x = F.interpolate(x, scale_factor=2)
        x = F.relu(self.conv4(x))
        # x = F.interpolate(x, scale_factor=2)
        x = F.relu(self.conv5(x))
        # x = F.interpolate(x, scale_factor=2)
        x = F.relu(self.conv6(x))
        x = F.interpolate(x, scale_factor=2)
        x = F.relu(self.conv7(x))
        return x
    

In [8]:
class SSLModule(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.criterion = nn.MSELoss()

    def forward(self, x):
        x = self.corrupt(x)
        x = self.encoder(x)
        x = self.decoder(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def corrupt(self, x): 
        '''Corrupt the image by choosing random patches and interchanging their values.

        Args:
            x: (3, 256, 256)

        Returns:
            x: (3, 256, 256)S
        '''

        for i in range(10):
            x = self.swap_patches(x)

        return x

    def swap_patches(self, x):
        '''Swap two random patches of size 32x32.

        Args:
            x: (3, 256, 256)

        Returns:
            x: (3, 256, 256)
        '''

        x1, y1 = torch.randint(0, 256-32, (2,))
        x2, y2 = torch.randint(0, 256-32, (2,))
        x[:, x1:x1+32, y1:y1+32], x[:, x2:x2+32, y2:y2+32] = x[:, x2:x2+32, y2:y2+32], x[:, x1:x1+32, y1:y1+32]

        return x

In [6]:
image = torch.rand(1, 3, 256, 256)

encoder = Encoder()
decoder = Decoder()

encoding = encoder(image).reshape(1, 512, 8, 8)
print('Image: ', image.shape)
print('Encoding: ', encoding.shape)
decoding = decoder(encoding)
print('Decoding: ', decoding.shape)

Image:  torch.Size([1, 3, 256, 256])
Encoding:  torch.Size([1, 512, 8, 8])
Decoding:  torch.Size([1, 3, 256, 256])


In [9]:
ssl_module = SSLModule()
ssl_module(image)


tensor([[[[0.0393, 0.0357, 0.0344,  ..., 0.0357, 0.0359, 0.0409],
          [0.0454, 0.0408, 0.0384,  ..., 0.0408, 0.0409, 0.0429],
          [0.0457, 0.0412, 0.0381,  ..., 0.0410, 0.0413, 0.0429],
          ...,
          [0.0464, 0.0400, 0.0370,  ..., 0.0398, 0.0397, 0.0428],
          [0.0466, 0.0404, 0.0381,  ..., 0.0404, 0.0401, 0.0432],
          [0.0498, 0.0477, 0.0459,  ..., 0.0475, 0.0472, 0.0429]],

         [[0.1313, 0.1251, 0.1255,  ..., 0.1268, 0.1264, 0.1258],
          [0.1339, 0.1285, 0.1290,  ..., 0.1309, 0.1304, 0.1297],
          [0.1340, 0.1286, 0.1290,  ..., 0.1312, 0.1307, 0.1300],
          ...,
          [0.1336, 0.1281, 0.1273,  ..., 0.1280, 0.1283, 0.1279],
          [0.1337, 0.1286, 0.1280,  ..., 0.1285, 0.1287, 0.1281],
          [0.1406, 0.1428, 0.1424,  ..., 0.1427, 0.1426, 0.1413]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0

In [None]:
#create a dataloader to load the data
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
dataset = ImageFolder(root='data', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

# train the model
logger = TensorBoardLogger('logs', name='ssl')

In [13]:
# import cv2
# import matplotlib
# matplotlib.use("TkAgg")
# # from matplotlib import pyplot as plt

# # import wx
# # Import matplotlib
# # matplotlib.use('wxAgg')
# from matplotlib import pyplot as plt
# # your scripts
# plt.close('all')

In [10]:
!pip install kaggle

Collecting kaggle
  Downloading kaggle-1.5.13.tar.gz (63 kB)
Collecting python-slugify
  Downloading python_slugify-8.0.1-py2.py3-none-any.whl (9.7 kB)
Collecting text-unidecode>=1.3
  Downloading text_unidecode-1.3-py2.py3-none-any.whl (78 kB)
Building wheels for collected packages: kaggle
  Building wheel for kaggle (setup.py): started
  Building wheel for kaggle (setup.py): finished with status 'done'
  Created wheel for kaggle: filename=kaggle-1.5.13-py3-none-any.whl size=77733 sha256=3651e09721a5f4eb084ead1f9a05262ad1f65f7a950e2d68ff1596ea24f33d31
  Stored in directory: c:\users\megde\appdata\local\pip\cache\wheels\9c\45\15\6d6d116cd2539fb8f450d64b0aee4a480e5366bb11b42ac763
Successfully built kaggle
Installing collected packages: text-unidecode, python-slugify, kaggle
Successfully installed kaggle-1.5.13 python-slugify-8.0.1 text-unidecode-1.3


Error processing line 1 of C:\Users\megde\miniconda3\lib\site-packages\vision-1.0.0-py3.9-nspkg.pth:

  Traceback (most recent call last):
    File "C:\Users\megde\miniconda3\lib\site.py", line 169, in addpackage
      exec(line)
    File "<string>", line 1, in <module>
    File "<frozen importlib._bootstrap>", line 562, in module_from_spec
  AttributeError: 'NoneType' object has no attribute 'loader'

Remainder of file ignored


In [12]:
!pip install opendatasets

Collecting opendatasets
  Downloading opendatasets-0.1.22-py3-none-any.whl (15 kB)
Installing collected packages: opendatasets
Successfully installed opendatasets-0.1.22


Error processing line 1 of C:\Users\megde\miniconda3\lib\site-packages\vision-1.0.0-py3.9-nspkg.pth:

  Traceback (most recent call last):
    File "C:\Users\megde\miniconda3\lib\site.py", line 169, in addpackage
      exec(line)
    File "<string>", line 1, in <module>
    File "<frozen importlib._bootstrap>", line 562, in module_from_spec
  AttributeError: 'NoneType' object has no attribute 'loader'

Remainder of file ignored


In [13]:
import opendatasets
opendatasets.download('https://www.kaggle.com/input/brats20-dataset-training-validation')

Please provide your Kaggle credentials to download this dataset. Learn more: http://bit.ly/kaggle-creds
Your Kaggle username:Your Kaggle Key:

ApiException: (401)
Reason: Unauthorized
HTTP response headers: HTTPHeaderDict({'Content-Length': '0', 'Date': 'Tue, 21 Mar 2023 14:32:26 GMT', 'Access-Control-Allow-Credentials': 'true', 'Set-Cookie': 'ka_sessionid=d151bb6b127738b512a1ea6a0ab7d603; max-age=2626560; path=/, GCLB=CMWvmfjPqZOYPQ; path=/; HttpOnly', 'Turbolinks-Location': 'https://www.kaggle.com/api/v1/datasets/download/input/brats20-dataset-training-validation', 'Strict-Transport-Security': 'max-age=63072000; includeSubDomains; preload', 'Content-Security-Policy': "object-src 'none'; script-src 'nonce-mS684sV0gU1gvLVLfQBWDg==' 'report-sample' 'unsafe-inline' 'unsafe-eval' 'strict-dynamic' https: http:; frame-src 'self' https://www.kaggleusercontent.com https://www.youtube.com/embed/ https://polygraph-cool.github.io https://www.google.com/recaptcha/ https://form.jotform.com https://submit.jotform.us https://submit.jotformpro.com https://submit.jotform.com https://www.docdroid.com https://www.docdroid.net https://kaggle-static.storage.googleapis.com https://kaggle-static-staging.storage.googleapis.com https://kkb-dev.jupyter-proxy.kaggle.net https://kkb-staging.jupyter-proxy.kaggle.net https://kkb-production.jupyter-proxy.kaggle.net https://kkb-dev.firebaseapp.com https://kkb-staging.firebaseapp.com https://kkb-production.firebaseapp.com https://kaggle-metastore-test.firebaseapp.com https://kaggle-metastore.firebaseapp.com https://apis.google.com https://content-sheets.googleapis.com/ https://accounts.google.com/ https://storage.googleapis.com https://docs.google.com https://drive.google.com https://calendar.google.com/; base-uri 'none'; report-uri https://csp.withgoogle.com/csp/kaggle/20201130;", 'X-Content-Type-Options': 'nosniff', 'Referrer-Policy': 'strict-origin-when-cross-origin', 'Via': '1.1 google', 'Alt-Svc': 'h3=":443"; ma=2592000,h3-29=":443"; ma=2592000'})


In [None]:
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import os
from PIL import Image

root = '/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData'

class BraTS2020_Dataset(Dataset): 
    def __init__(self): 
        self.modality_dict = {0: 'flair', 1: 't1', 2: 't1ce', 3: 't2'}
        pass 

    def __len__(self):
        return len(os.listdir(root))*4

    def __getitem__(self, idx):
        patient_id, modality = divmod(idx, 4)
        patient_id = str(patient_id).zfill(3)
        modality = self.modality_dict[modality]
        path = os.path.join(root, f'BraTS20_Training_{patient_id}', f'BraTS20_Training_{patient_id}_{modality}.nii')
        
        img = Image.open(path)
        return img


train_dataset = BraTS2020_Dataset()
print('train_dataset',train_dataset)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)