Importing packages

In [3]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Subset, Dataset
from torchvision import datasets
from torchvision import transforms
import os
from os import listdir
from os.path import isfile, join
import cv2
from PIL import Image
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms.functional as fn

Loading the data

In [1]:
!unzip 'ScientificPython.zip'

wget: missing URL
Usage: wget [OPTION]... [URL]...

Try `wget --help' for more options.


Preprocess

In [None]:
preprocess = transforms.Compose(
    [transforms.ToTensor(), 
     transforms.resize((128, 128)))] #resizing every image to the same size
)

In [None]:
train_datasource = os.listdir(path = 'ScientificPython/')
train_datasource.sort()
train_datasource = train_datasource[:2] #we only need to first two folders as they contain the training data

idx_neuron = 0 

train_img = []
train_cat = [] #category = 0 or 1 (non-neuron or neuron image) from the name of the forlder
for filename in train_datasource:
  current_filename = os.listdir('ScientificPython/' + filename) #we iterate through the folders 0 and 1
  current_filename.sort()
  for img_name in current_filename:
    img = cv2.imread(f'ScientificPython/{filename}/{img_name}') #we read the images from the folder
    img = preprocess(img) #applying the preprocess function
    train_img.append(img)
    train_cat.append(torch.tensor(idx_neuron)) #filename will be the category/label
  idx_neuron += 1

train_img = torch.stack(train_img), dim=0)
train_cat = torch.stack(train_cat, dim = 0)
train_img.shape, train_cat.shape 

# Network

Convolution

In [None]:
class Conv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(Conv, self).__init__()

    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU(inplace=True)
    )

    def forward(self, x):
      return self.conv(x)

Downscaling

In [None]:
class Down(nn.Module): 
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__() 
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            Conv(in_channels, out_channels)
        )
    def forward(self, x): 
        return self.maxpool_conv(x)

Upscaling

In [None]:
class Up(nn.Module):
    def __init__(self, in_channels, out_channels): 
        super().__init__() 

        self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)

        self.conv = Conv(in_channels, out_channels)
    def forward(self, x1, x2):
        x1 = self.up(x1)


        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

In [None]:
#constructing the UNet with the above defined functions
class Unet(nn.Module):
  def __init__(self, i_ch_n, o_ch_n):
    super(Unet, self).__init__()

    # parameters:
    self.input_channel_num = i_ch_n
    self.output_channel_num = o_ch_n
    #self.bilinear = bilinear

    # downscaleing layers
    self.inc = Conv(self.input_channel_num, 16)
    self.down1 = Conv(16, 32)
    self.down2 = Down(32, 64)
    self.down3 = Down(64, 128)

    # upscaling layers
    self.up1 = Up(128, 64)
    self.up2 = Up(64, 32)
    self.up3 = Up(32, 16)
    self.outconv = nn.Conv2d(16, self.output_channel_num, kernel_size=1)

  def forward(self, x):
      x1 = self.inc(x)
      #print("x1: " + str(x1.size()))
      x2 = self.down1(x1)
      #print("x2: " + str(x2.size()))
      x3 = self.down2(x2)
      #print("x3: " + str(x3.size()))
      x4 = self.down3(x3)
      #print("x4: " + str(x4.size()))

      x5 = self.up1(x4, x3)
      #print("x5: " + str(x5.size()))
      x6 = self.up2(x5, x2)
      #print("x6: " + str(x6.size()))
      x7 = self.up3(x6, x1)
      #print("x7: " + str(x7.size()))
      logits = self.outconv(x7)

      return logits

In [None]:
model = Unet(3,3)
test_input = torch.randn((1,3,446,446))
output = model(test_input)
output.shape

In [None]:
#print(model)

# Training

In [None]:
def get_next_batch(img, cat, batch_indexes):
    assert len(img) == len(cat)
    
    imgs = img[batch_indexes, :]
    cats = cat[batch_indexes, :]

    imgs = np.asanyarray([cv2.resize(img, dsize=(450, 450), interpolation=cv2.INTER_CUBIC) for img in imgs])
    cats = np.asanyarray([cv2.resize(img, dsize=(450, 450), interpolation=cv2.INTER_CUBIC) for img in cats])

    imgs = np.swapaxes(imgs,1,3)
    cats = np.swapaxes(cats,1,3)

    imgs = torch.from_numpy(imgs)
    cats = torch.from_numpy(cats)
    return imgs, cats

In [None]:
#we have 3 different parameter-sets
epochs=2
batch_size=8
lr=0.001

In [None]:
epochs=2
batch_size=4
lr=0.0001

In [None]:
epochs=2
batch_size=2
lr=0.00001

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
net = Unet(3,3).cuda()

optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)

criterion = nn.MSELoss()

In [None]:
#prediction vs reality
losses = []

for epoch in range(epochs):
  epoch_loss = 0
  net.train()

  indexes = np.arange(0,len(train_img))
  np.random.shuffle(indexes)
  

  for iter in range(len(train_img)//batch_size):

    print("Iter: {0}/{1}".format(str(iter),str(len(train_img)//batch_size)),end='\r')

    batch_indexes = indexes[iter*batch_size:iter*batch_size+batch_size]
    imgs, true_cat = get_next_batch(train_img, train_cat, batch_indexes)

    assert imgs.shape[1] == net.input_channel_num, \
      f'Network has been defined with {net.input_channel_num} input channels, ' \
      f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
      'the images are loaded correctly.'
    
    imgs = imgs.to(device=device, dtype=torch.float32)
    cat_type = torch.float32 #if net.input_channel_num == 1 else torch.long
    true_masks = true_masks.to(device=device, dtype=cat_type)

    cat_pred = net(imgs)

    loss = criterion(cat_pred, true_cat)
    epoch_loss += loss.item()

    optimizer.zero_grad()
    loss.backward()
    nn.utils.clip_grad_value_(net.parameters(), 0.1)
    optimizer.step()

  if epoch % 1 == 0:
    print("loss: " + str(epoch_loss))
    losses.append(epoch_loss)