Welcome to my contrastive learning demo using MONAI's modules.

In [1]:
import monai
import monai.transforms as M
import torch
from torchvision import transforms as T
from PIL import Image
from numpy import asarray
from monai.data import Dataset, DataLoader, CacheDataset, PersistentDataset, SmartCacheDataset
from monai_model import CompleteNet, CNNBackbone
from monai_train import predict, train_one_step, train_one_epoch, train
from monai_dataloader import custom_collate

VERY IMPORTANT: Monai readers by default swap axis 0 and 1 after loading the array with ``reverse_indexing`` set to ``True`` ("because the spatial axes definition for non-medical specific file formats is different from other common medical packages"). Also, I set image_only=True so that I don't get metadata about the image file. Also, we have to remember to add a channel since MONAI functions expect that channel before image size(Using EnsureChannelFirst()), and Monai's LoadImage does not add this channel dimension for you.

In [2]:
# These two functions help load in the image and ensure it has the extra channel dimension.
load_image_monai = M.LoadImage(reader='pilreader', image_only=True, reverse_indexing=False) 
ensureChannel = M.EnsureChannelFirst()

# These are just some of the many MONAI transformations we can play with
resizeCrop = M.ResizeWithPadOrCrop((50, 50))
randWeightCrop = M.RandWeightedCrop((50,50))
randSpatialCrop = M.RandSpatialCrop(roi_size=(200,200), random_size=False)
randRotate = M.RandRotate(range_x = 0.5, prob = 1.0)
resize = M.Resize((200, 200))
normalize = M.NormalizeIntensity()

# Here are the transformations I decided on using in my code
composedTransform = M.Compose([randSpatialCrop, randRotate])
identityTransform = M.Compose([resize])

In [3]:
y = load_image_monai('/Users/hairanliang/Downloads/NORMAL-6477461-4.jpeg') 

In [4]:
y.shape # Here, we don't have the needed channel dimension. We will have to use Monai's EnsureChannelFirst later.

torch.Size([496, 512])

In [5]:
y.type

<function MetaTensor.type>

Positive to Monai's LoadImage: it gives me a tensor immediately! In my previous code, I would load in the image, but then have to convert to tensor using ToTensor(). This saves me some time, but I have to remember to add in the extra channel dimension with ensureChannel, since ToTensor() gave me the extra channel dimension in my original code. Below, you can see the difference: ToTensor(), which is used in my original code, provides that extra dimension.

In [6]:
def augment_data(data, augmentation=True):
    if augmentation:
        augmented_data = composedTransform(data) 
    else:
        augmented_data = identityTransform(data)
    return augmented_data # This should be a tensor
    
def load_image(image_link):
    image = Image.open(image_link)
    image.show()
    
def load_data(image_link):
    image = Image.open(image_link)
    return image

In [7]:
y_orig = load_data('/Users/hairanliang/Downloads/NORMAL-6477461-4.jpeg')

In [8]:
toTens = T.ToTensor()
y_tens = toTens(y_orig)

In [9]:
y_tens.shape # Here, we see the channel dimension that we need, since it comes from ToTensor()

torch.Size([1, 496, 512])

Now, I implement my new get_item, which will take care of getting one item from my image_list, augmenting it twice, and returning two tensors stacked on top of each other. The key thing is we have to add the ensureChannel to make sure we have the channel dimension. The code for augment_data stays the same.

In [10]:
def get_item(link, augmentation=True):
    x = load_image_monai(link)
    x = ensureChannel(x)
    aug_x = augment_data(x, augmentation)
    return aug_x

What makes MONAI different: They have the "_transform", which is a neat way to specify a transform, in my case my transform is kinda unique since I need two transforms, so I just did it manually and without doing it within "_transform" but I basically moved all my code from getitem to transform so that it could work (
data_i = self.data[index] is important within the "_transform").

There is a special thing with MONAI where getitem should not be indexing, and instead leaving it for transform.
When I tried indexing within getitem, it would instead of getting the first link, it would get the first character of the first link,
and this is likely due to transform being the one who is responsible for first retrieving the first data (link)
whereas in pytorch, they assume getitem_ gets the first index of the data (monai uses transform to do this)

In [11]:
class OCTDataset(Dataset):
    def __init__(self, image_list, augmentation_mode=False):
        self.data = image_list # data = image_list
        self.transform = None
        self.augmentation_mode = augmentation_mode
        
    def __len__(self):
        return len(self.data)
    
    def _transform(self, index):
        data_i = self.data[index]
        if self.augmentation_mode == True:
            aug_x1 = get_item(data_i, self.augmentation_mode)
            aug_x2 = get_item(data_i, self.augmentation_mode)
            aug_stack = torch.stack((aug_x1, aug_x2), dim=0)
            return aug_stack
        else:
            aug_x = get_item(data_i, self.augmentation_mode)
            return aug_x
    
    def __getitem__(self, index):
        return self._transform(index)
    


Now I begin to define my links, datasets, and dataloader.

In [12]:
image_links_1 = ['/Users/hairanliang/Downloads/NORMAL-6477461-4.jpeg', 
               '/Users/hairanliang/Downloads/NORMAL-3767173-12.jpeg']
image_links_2 = ['/Users/hairanliang/Downloads/NORMAL-9453329-20.jpeg',
              '/Users/hairanliang/Downloads/NORMAL-7021113-21.jpeg']

batch_test = ['/Users/hairanliang/Downloads/NORMAL-6477461-4.jpeg', 
               '/Users/hairanliang/Downloads/NORMAL-3767173-12.jpeg', 
             '/Users/hairanliang/Downloads/NORMAL-9453329-20.jpeg',
              '/Users/hairanliang/Downloads/NORMAL-7021113-21.jpeg',]

In [13]:
dataset = OCTDataset(batch_test, True) # Remember, true means augmentation_mode is on.

In [14]:
len(dataset)

4

In [15]:
dataset[0]

metatensor([[[[69.3208, 67.4907, 66.0000,  ..., 51.6992, 51.1892, 50.6792],
          [68.8108, 68.5106, 66.0000,  ..., 50.3521, 51.1170, 51.8820],
          [68.3008, 69.5306, 66.0000,  ..., 53.1686, 53.6786, 54.1885],
          ...,
          [14.0000, 14.0000, 14.0000,  ...,  7.3257,  2.7041,  2.0000],
          [11.3913,  9.6064,  7.8215,  ...,  5.2858,  4.2341,  2.0000],
          [ 7.6792,  8.1892,  8.6992,  ...,  3.2458,  5.7640,  2.0000]]],


        [[[38.1337, 57.6485, 44.3408,  ..., 70.0793, 65.9265, 63.2725],
          [39.2272, 47.9264, 51.2621,  ..., 63.2012, 64.1240, 65.5155],
          [45.2256, 37.3139, 57.7191,  ..., 64.9756, 74.8823, 85.0335],
          ...,
          [27.2123, 17.9841,  9.0489,  ...,  9.9389,  5.5029, 27.8111],
          [ 9.4684, 10.7520, 12.5977,  ...,  9.1016,  7.8101, 14.4302],
          [12.4550, 13.9755, 15.3598,  ...,  8.1788,  9.9532,  5.6812]]]])

In [16]:
# This is to test that my getitem/_transform is working within Dataset. This surprisingly was not trivial, until I
# realized that _transform should be the one indexing, not getitem like I was used to from PyTorch's Dataset Class.
for item in dataset:
    print(item)

metatensor([[[[ 71.7970,  79.1828,  86.0196,  ...,  39.7430,  40.7276,  41.7122],
          [ 70.4186,  77.6075,  84.8381,  ...,  43.7420,  43.3481,  42.9543],
          [ 69.0401,  76.0320,  83.6565,  ...,  42.2188,  42.6127,  43.0065],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.5471,  20.9041,  17.1887],
          [  2.3856,   3.3704,   4.3550,  ...,   1.5318,  16.1777,  18.7640],
          [ 13.6937,  17.4354,  21.1768,  ...,   2.5163,  11.4514,  20.3395]]],


        [[[ 66.3475,  70.3470,  73.7599,  ..., 160.5477, 166.0163, 165.3516],
          [ 66.2842,  70.2837,  73.7124,  ..., 138.1484, 132.5853, 132.1580],
          [ 65.2970,  68.1041,  69.8062,  ..., 137.9533, 127.2787, 127.4527],
          ...,
          [  2.0000,   2.0000,   7.7365,  ...,  13.8628,  19.0148,  12.9724],
          [  1.5867,   1.5709,   5.4858,  ...,  11.4380,  21.2872,  16.5657],
          [  0.5869,   0.5710,   3.8145,  ...,  11.2009,  21.2398,  16.8033]]]])
metatensor([[[[ 19.4002

I first test that my MONAI dataset is compatible with PyTorch's DataLoader, since it should be.

In [17]:
data_train = torch.utils.data.DataLoader(dataset, 2, shuffle=True, collate_fn=custom_collate)

In [18]:
batch = next(iter(data_train))

In [19]:
print(batch.shape) # The shape is what I expect. Batch size of 2, so 2 original images leads to 4 augmented images. 
# And, the 1 is there for the channel dimension, and the image is of dimension 200x200 after our transformations.

torch.Size([4, 1, 200, 200])


In [20]:
for batch_idx, samples in enumerate(data_train):
    print(batch_idx, samples)

0 metatensor([[[[7.1167e+01, 7.5904e+01, 8.4355e+01,  ..., 5.0818e+01,
           4.9859e+01, 4.8933e+01],
          [7.2146e+01, 7.4306e+01, 8.1478e+01,  ..., 4.8317e+01,
           4.7678e+01, 4.7038e+01],
          [7.3424e+01, 7.2707e+01, 7.8600e+01,  ..., 4.6711e+01,
           4.6391e+01, 4.6072e+01],
          ...,
          [2.4495e+01, 1.7781e+01, 1.1067e+01,  ..., 1.6445e+01,
           1.5659e+01, 1.2969e+01],
          [5.1343e+00, 7.3722e+00, 9.6103e+00,  ..., 1.8043e+01,
           1.5339e+01, 1.4568e+01],
          [1.1767e+01, 1.4005e+01, 1.6243e+01,  ..., 1.9642e+01,
           1.5019e+01, 1.5967e+01]]],


        [[[1.0700e+02, 1.0700e+02, 1.0700e+02,  ..., 8.4158e-01,
           1.0328e+01, 5.0186e-01],
          [1.1263e+02, 1.1158e+02, 1.1053e+02,  ..., 2.7671e+00,
           8.4027e+00, 2.4272e+00],
          [1.1023e+02, 1.1075e+02, 1.1128e+02,  ..., 4.6927e+00,
           6.4770e+00, 4.3531e+00],
          ...,
          [1.1519e+01, 9.8883e+00, 1.8014e+01,  ...

In [21]:
from monai.networks.blocks import Convolution, MaxAvgPool

In [22]:
import torch.nn as nn

conv1 = Convolution(
    spatial_dims=2,
    in_channels=1,
    out_channels=64,
    kernel_size = (5,5),
    adn_ordering="NDA",
    act=("prelu", {"init": 0.2}),
    dropout=0.1
)
print(conv1)

conv2 = Convolution(
    spatial_dims=2,
    in_channels=64,
    out_channels=64,
    kernel_size = (5,5),
    adn_ordering="NDA",
    act=("prelu", {"init": 0.2}),
    dropout=0.1
)
print(conv2)

pool = nn.MaxPool2d(kernel_size=2, stride=2)

global_pool = nn.AdaptiveAvgPool2d(1)

Convolution(
  (conv): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (adn): ADN(
    (N): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (D): Dropout(p=0.1, inplace=False)
    (A): PReLU(num_parameters=1)
  )
)
Convolution(
  (conv): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (adn): ADN(
    (N): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (D): Dropout(p=0.1, inplace=False)
    (A): PReLU(num_parameters=1)
  )
)


In [23]:
import torch.nn.functional as F

class CompleteNet(nn.Module):
    def __init__(self, backbone):
        super(CompleteNet, self).__init__()
        self.backbone = backbone # This is the CNN
        self.fc1 = nn.Linear(64, 80) # Converting into linear layer
        self.fc2 = nn.Linear(80, 40)
    
    def forward(self, x):
        x = self.backbone(x)
        x = x.view(-1, 64)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
class CNNBackbone(nn.Module):
    def __init__(self):
        super(CNNBackbone, self).__init__()
        self.conv1 = Convolution(
                spatial_dims=2,
                in_channels=1,
                out_channels=64,
                kernel_size = (5,5),
                adn_ordering="NDA",
                act=("prelu", {"init": 0.2}),
                dropout=0.1
                )
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.conv2 = Convolution(
                spatial_dims=2,
                in_channels=64,
                out_channels=64,
                kernel_size = (5,5),
                adn_ordering="NDA",
                act=("prelu", {"init": 0.2}),
                dropout=0.1
                )
    def forward(self, x):
        x = self.pool(self.conv1(x))
        x = self.pool(self.conv2(x))
        x = self.global_pool(x)
        return x

In [35]:
# Initializing the test model 

backboneTest = CNNBackbone()
modelTest = CompleteNet(backboneTest)
learning_rate = 0.5
optimizer = torch.optim.SGD(modelTest.parameters(), lr=learning_rate)

In [36]:
samples.shape

torch.Size([4, 1, 200, 200])

In [55]:
train_one_step(samples, modelTest, optimizer)

metatensor(1.0977, grad_fn=<AliasBackward0>)

Works for replacing with Monai Convolution Layer. Now I can try to do it with RESNET model, but those rely on changing other parameters as well. 

Seems like MONAI provides lots of nice customization within the Convolution models for instance. You have the option of. adding normalization, activation, dropout, etc.

Current bugs of original framework: Training functions don't work if augmentation mode is False (which is fine, because we are never going to train if we have augmentation on False anyway). Something to do with shape within __getitem__ and custom_collate