<a href="https://colab.research.google.com/github/Noy-Bo/Cancer-Cell-Segmentation/blob/master/Cancer_Cell_Segmentation_Nir.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Mounting Google Drive

mounting drive, setting root path to PanNuke dataset, setting device to cuda, checking out GPU


In [2]:

from google.colab import drive
drive.mount('/content/gdrive')


Mounted at /content/gdrive


In [3]:
import os
root_path = '/content/gdrive/MyDrive/PanNuke' 
os.chdir(root_path)


In [4]:
import torch
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
print(device)

cuda


In [5]:
!nvidia-smi

Thu Jul 22 13:08:20 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.42.01    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   34C    P0    27W / 250W |      2MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# **Dataset**
loading files, creating dataloaders, etc..



In [6]:
import os
import numpy as np
import matplotlib.pyplot as plt
import random
from torch.utils.data import Dataset
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision.transforms import transforms

# calc normalization values of given data loader
def calc_normalization(data_loader):
    pop_mean = []
    pop_std0 = []
    pop_std1 = []
    for idx_batch, (images, masks) in enumerate(data_loader, 0):
     # shape (batch_size, 3, height, width)
     numpy_image = images.numpy()

     # shape (3,)
     batch_mean = np.mean(numpy_image, axis=(0, 2, 3))
     batch_std0 = np.std(numpy_image, axis=(0, 2, 3))
     batch_std1 = np.std(numpy_image, axis=(0, 2, 3), ddof=1)

     pop_mean.append(batch_mean/255)
     pop_std0.append(batch_std0/255)
     pop_std1.append(batch_std1/255)

    # shape (num_iterations, 3) -> (mean across 0th axis) -> shape (3,)
    pop_mean = np.array(pop_mean).mean(axis=0)
    print(pop_mean)
    pop_std0 = np.array(pop_std0).mean(axis=0)
    print(pop_std0)
    pop_std1 = np.array(pop_std1).mean(axis=0)
    print(pop_std1)


# VISUALIZING SAMPLES - designed for input that is the dataloader output, for raw input cancel moveaxis
def vis_sample(image, masks):
    fig, axs = plt.subplots(1,7)
    fig.set_size_inches(18.5, 3.2)
    fig.suptitle('Ground Truth', fontsize=18)
    #fig.tight_layout()
    axs[0].imshow(np.moveaxis(image.numpy().astype(np.uint8),0,-1)); axs[0].set_title("Sample")
    masks = np.moveaxis(masks.numpy().astype(np.uint8),0,-1)
    axs[1].imshow(masks[:,:,0], cmap="gray"); axs[1].set_title("Neoplastic cells")
    axs[2].imshow(masks[:,:,1], cmap="gray"); axs[2].set_title("Inflammatory")
    axs[3].imshow(masks[:,:,2], cmap="gray"); axs[3].set_title("Connective/Soft tissue cells")
    axs[4].imshow(masks[:,:,3], cmap="gray"); axs[4].set_title("Dead Cells")
    axs[5].imshow(masks[:,:,4], cmap="gray"); axs[5].set_title("Epithelial")
    #axs[6].imshow(masks[:,:,5], cmap="gray"); axs[6].set_title("Background")
    plt.show()

# VISUALIZING PREDICTIONS - visualizing network output on a random sample.
def vis_predictions():
    image_gt_batch, masks_gt_batch = iter(training_loader_1).next()
    pred_masks_batch = model(image_gt_batch)
    for i in range(image_gt_batch.shape[0]): # looping on batch_size
      print("\n\n")
      print("\t\t\t\t\t\t\t\t SAMPLE: {}".format(str(i)))
      image = image_gt_batch[i,...] # 3,256,256
      pred_masks = pred_masks_batch[i,...] # 6,256,256

      fig, axs = plt.subplots(1,7)
      fig.set_size_inches(18.5, 3.3)
      fig.suptitle('Prediction', fontsize=18)
      #fig.tight_layout()
      axs[0].imshow(np.moveaxis(image.numpy().astype(np.uint8),0,-1)); axs[0].set_title("Sample")
      axs[1].imshow(pred_masks[0,:,:].cpu().detach().numpy(), cmap="gray"); axs[1].set_title("Neoplastic cells")
      axs[2].imshow(pred_masks[1,:,:].cpu().detach().numpy(), cmap="gray"); axs[2].set_title("Inflammatory")
      axs[3].imshow(pred_masks[2,:,:].cpu().detach().numpy(), cmap="gray"); axs[3].set_title("Connective/Soft tissue cells")
      axs[4].imshow(pred_masks[3,:,:].cpu().detach().numpy(), cmap="gray"); axs[4].set_title("Dead Cells")
      axs[5].imshow(pred_masks[4,:,:].cpu().detach().numpy(), cmap="gray"); axs[5].set_title("Epithelial")
      #axs[6].imshow(pred_masks[5,:,:].cpu().detach().numpy(), cmap="gray"); axs[6].set_title("Background")
      plt.show()

      vis_sample(image,masks_gt_batch[i,...]) # visualizing gt (ground truth)


# LOADING DATASET
def load_dataset(dir_root, dir_images, dir_masks, training_size=0.8):

    train_set = PanNuke(dir_root, dir_images, dir_masks, train=True)

    # Splitting train into train/val
    permutations = torch.randperm(len(train_set))
    split = int(np.floor(training_size * len(train_set)))
    training_subset = SubsetRandomSampler(permutations[:split])
    validation_subset = SubsetRandomSampler(permutations[split:])

    # Apply DataLoader over train val and test data
    train_loader = DataLoader(train_set, sampler=training_subset, batch_size=1, num_workers=4)
    validation_loader = DataLoader(train_set, sampler=validation_subset, batch_size=1, num_workers=4)

    return train_loader, validation_loader



# DATASET CLASS
class PanNuke(Dataset):
    def __init__(self, dir_root, dir_images, dir_masks, val=False, train=False, test=False):
        self.images = np.load(dir_root+dir_images, mmap_mode='r')
        self.images = np.moveaxis(self.images, -1, 1)
        #self.images = np.copy(self.images)
        self.masks = np.load(dir_root+dir_masks, mmap_mode='r')
        self.masks = np.moveaxis(self.masks, -1, 1)
        #self.masks = np.copy(self.masks)
        self.masks = self.masks[:,:5,...]

        self.transforms = transforms.Compose([
            transforms.ToTensor(),
            #transforms.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
            #transforms.Normalize(mean=[0.73952988, 0.5701334, 0.702605], std=[0.18024648, 0.21097612, 0.16465892 ])
        ])

    def __len__(self):
        return self.images.shape[0]

    def __getitem__(self, idx):
        img = np.copy(self.images[idx, ...])
        img = self.transforms(img.transpose())
        masks = np.copy(self.masks[idx, ...])
        masks = np.ceil(masks/1000)
        return img, masks


# **Model**
basic UNet model configuration

In [7]:
from collections import OrderedDict

import torch
import torch.nn as nn


def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True)
    )


class Unet(nn.Module):
    def __init__(self):
        super().__init__()
        self.dblock1 = double_conv(3, 128)
        self.dblock2 = double_conv(128, 256)
        self.dblock3 = double_conv(256, 512)
        self.dblock4 = double_conv(512, 1024)

        self.pool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.dblock5 = double_conv(512 + 1024, 512)
        self.dblock6 = double_conv(256 + 512, 256)
        self.dblock7 = double_conv(256 + 128, 128)
        self.relu = nn.ReLU()
        self.last_layer = nn.Conv2d(128, 512, 1)
        self.last_layer_rly = nn.Conv2d(512,5,1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        conv1 = self.dblock1(x)
        x = self.pool(conv1)

        conv2 = self.dblock2(x)
        x = self.pool(conv2)

        conv3 = self.dblock3(x)
        x = self.pool(conv3)

        conv4 = self.dblock4(x)

        x = self.upsample(conv4)

        x = torch.cat([x, conv3], dim=1)

        x = self.dblock5(x)
        x = self.upsample(x)
        x = torch.cat([x, conv2], dim=1)

        x = self.dblock6(x)
        x = self.upsample(x)
        x = torch.cat([x, conv1], dim=1)

        x = self.dblock7(x)

        x = self.last_layer(x)
        out = self.last_layer_rly(x)
        # out = self.sigmoid(x)
        return out


class UNet(nn.Module):

    def __init__(self, in_channels=3, out_channels=1, init_features=32):
        super(UNet, self).__init__()

        features = init_features
        self.dropout = nn.Dropout2d()
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        self.conv_masks = nn.Conv2d(in_channels=features, out_channels=5, kernel_size=1)

        # self.conv = nn.Conv2d(
        #     in_channels=features, out_channels=out_channels, kernel_size=1
        # )


    def forward(self, x):
        x = x.to(device)
        enc1 = self.encoder1(x)
        enc1 = self.dropout(enc1)
        enc2 = self.encoder2(self.pool1(enc1))
        enc2 = self.dropout(enc2)
        enc3 = self.encoder3(self.pool2(enc2))
        enc3 = self.dropout(enc3)
        enc4 = self.encoder4(self.pool3(enc3))
        enc4 = self.dropout(enc4)

        bottleneck = self.bottleneck(self.pool4(enc4))
        bottleneck = self.dropout(bottleneck)

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec4 = self.dropout(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec3 = self.dropout(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec2 = self.dropout(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        dec1 = self.dropout(dec1)
        #return torch.sigmoid(self.conv(dec1))
        return torch.sigmoid(self.conv_masks(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

# **Train Utilities**
loss functions configurations, train epoch function ..


In [8]:
import sys
import torch.nn as nn
import torch
import tqdm

device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
print(device)

def dice_loss(predict, target):
    smooth = 1.
    loss = 0.
    for c in range(predict.shape[1]):
        iflat = predict[:, c, ...].contiguous().view(-1)
        tflat = target[:, c, ...].contiguous().view(-1)
        intersection = (iflat * tflat).sum()

        loss +=  (1 - ((2. * intersection + smooth) /
                          (iflat.sum() + tflat.sum() + smooth)))
    return loss
def BCE(predict, target):
    loss_func = nn.BCELoss()
    softmax = torch.nn.Softmax(dim=0)
    assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
    loss = 0
    predict = predict.permute(0,2,3,1)
    predict = predict.flatten()
    predict = predict.reshape(5,-1)
    predict = softmax(predict)
    target = target.permute(0,2,3,1)
    target = target.flatten()
    target = target.reshape(5,-1)

    loss += loss_func(predict,target)
    return loss
class BinaryDiceLoss(nn.Module):
    """Dice loss of binary class
    Args:
        smooth: A float number to smooth loss, and avoid NaN error, default: 1
        p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
        predict: A tensor of shape [N, *]
        target: A tensor of shape same with predict
        reduction: Reduction method to apply, return mean over batch if 'mean',
            return sum if 'sum', return a tensor of shape [N,] if 'none'
    Returns:
        Loss tensor according to arg reduction
    Raise:
        Exception if unexpected reduction
    """
    def __init__(self, smooth=1, p=2, reduction='sum'):
        super(BinaryDiceLoss, self).__init__()
        self.smooth = smooth
        self.p = p
        self.reduction = reduction

    def forward(self, predict, target):
        assert predict.shape[0] == target.shape[0], "predict & target batch size don't match"
        loss = 0
        for b_idx in range(predict.shape[0]):
            for m_idx in range(predict.shape[1]):
                predict_i = predict[b_idx,m_idx,...].contiguous().view(-1)
                target_i = target[b_idx,m_idx,...].contiguous().view(-1)

                num = torch.sum(torch.mul(predict_i, target_i), dim=0) + self.smooth
                den = torch.sum(predict_i.pow(self.p) + target_i.pow(self.p), dim=0) + self.smooth

                loss = loss +  1 - num / den

        # predict = predict.contiguous().view(predict.shape[0], -1)
        # target = target.contiguous().view(target.shape[0], -1)

        # num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth
        # den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth
        #
        # loss = 1 - num / den

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        elif self.reduction == 'none':
            return loss
        else:
            raise Exception('Unexpected reduction {}'.format(self.reduction))


def train_epoch(training_loader, model, optimizer, loss_function):

    losses = []
    model.train()
   
    with tqdm.tqdm(total=len(training_loader), file=sys.stdout) as pbar:
        for idx_batch, (images, masks) in enumerate(training_loader, start=1):

            images = images.to(device)
            masks = masks.to(device)

            # calculate output
            y_hat = model(images)

            # calculate loss now:
            optimizer.zero_grad()
            loss = loss_function(y_hat, masks)
            loss.backward()

            # optimizing weights
            optimizer.step()

            # update loss bar
            losses.append(loss.detach())
            pbar.update();
            pbar.set_description(f'train loss={losses[-1]:.3f}')
        mean_loss = torch.mean(torch.FloatTensor(losses))
        pbar.set_description(f'train loss={mean_loss:.3f}')

        images = None
        masks = None
    return [mean_loss]

    
def eval_loss_epoch(training_loader, model, loss_function):

    losses = []
    model.eval()
    with tqdm.tqdm(total=len(training_loader), file=sys.stdout) as pbar:
        for idx_batch, (images, masks) in enumerate(training_loader, start=1):

            images = images.to(device)
            masks = masks.to(device)

            # calculate output
            y_hat = model(images)

            # calculate loss now:
            loss = loss_function(y_hat, masks)

            # optimizing weights

            # update loss bar
            losses.append(loss.detach())
            pbar.update();
            pbar.set_description(f'val loss={losses[-1]:.3f}')

            images = None
            masks = None
        mean_loss = torch.mean(torch.FloatTensor(losses))
        pbar.set_description(f'val loss={mean_loss:.3f}')

    return [mean_loss]


cuda


# **Main (training)**


In [19]:
import sys

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as TNF
import torch.nn.modules
from torch import optim
import time
#import dataset as pnk
from torch.utils.data import DataLoader
from torchsummary import summary
import tqdm

# Small useful components:
# this class is to be able to use TNF.interpole within nn.Sequential()
class Interpolate(nn.Module):
    def __init__(self, size, mode):
        super(Interpolate, self).__init__()
        self.interp = TNF.interpolate
        self.size = size
        self.mode = mode

    def forward(self, x):
        return self.interp(x, size=self.size, mode=self.mode, align_corners=False)

# Create the Micro-Net components:
class Group1_B1(nn.Module):
    def __init__(self, in_channels=3):
        super(Group1_B1, self).__init__()

        self.sub_block1 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=5, stride=1, padding=0, bias=False), # kernel=5 since we start with 256 imgs, where in paper it's 252
            nn.BatchNorm2d(64), # out_channels=64
            nn.Tanh(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(64),  # REMOVE?
            nn.Tanh(),
            nn.MaxPool2d(kernel_size=2,stride=2)
        ) #124,124,64

        self.sub_block2 = nn.Sequential(
            Interpolate(size=(128,128),mode='bicubic'),
            nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=3, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(64),  # out_channels=64
            nn.Tanh(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0, bias=False),
            nn.Tanh(),
        )

    def forward(self, orig_input):
        sub_block1 = self.sub_block1(orig_input) # recall it outputs: B,CH,HEIGHT,WIDTH
        sub_block2 = self.sub_block2(orig_input)
        B1 = torch.cat((sub_block1,sub_block2), dim=1) # concat alongside channels dim'
        return B1

class Group1_B2(nn.Module):
    def __init__(self, in_channels=128):
        super(Group1_B2, self).__init__()
        self.sub_block1 = nn.Sequential(  # gets 124^2, ch=128,   outputs: 60^2, ch=128
            nn.Conv2d(in_channels=in_channels, out_channels=128, kernel_size=3, stride=1, padding=0, bias=True), # bias=True since no BN is applied.
            nn.Tanh(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=0, bias=True),
            nn.Tanh(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.sub_block2 = nn.Sequential( # gets 256^2, ch=3  outputs: 60^2, ch=128
            Interpolate(size=(64, 64), mode='bicubic'),
            nn.Conv2d(in_channels=3, out_channels=128, kernel_size=3, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(128),
            nn.Tanh(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=0, bias=False),
            nn.Tanh(),
        )

    def forward(self, B1_input, orig_input):
        sub_block1 = self.sub_block1(B1_input)
        sub_block2 = self.sub_block2(orig_input)
        B2 = torch.cat((sub_block1, sub_block2), dim=1)  # concat alongside channels dim'
        return B2

class Group1_B3(nn.Module):
    def __init__(self, in_channels=256):
        super(Group1_B3, self).__init__()
        self.sub_block1 = nn.Sequential(  # gets 60^2, ch=256,   outputs: 28^2, ch=256
            nn.Conv2d(in_channels=in_channels, out_channels=256, kernel_size=3, stride=1, padding=0, bias=True), # bias=True since no BN is applied.
            nn.Tanh(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=0, bias=True),
            nn.Tanh(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.sub_block2 = nn.Sequential( # gets 256^2, ch=3  outputs: 28^2, ch=256
            Interpolate(size=(32, 32), mode='bicubic'),
            nn.Conv2d(in_channels=3, out_channels=256, kernel_size=3, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(256),
            nn.Tanh(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=0, bias=False),
            nn.Tanh(),
        )

    def forward(self, B2_input, orig_input):
        sub_block1 = self.sub_block1(B2_input)
        sub_block2 = self.sub_block2(orig_input)
        B3 = torch.cat((sub_block1, sub_block2), dim=1)  # concat alongside channels dim'
        return B3

class Group1_B4(nn.Module):
    def __init__(self, in_channels=512):
        super(Group1_B4, self).__init__()
        self.sub_block1 = nn.Sequential(  # gets 28^2, ch=512,   outputs: 12^2, ch=512
            nn.Conv2d(in_channels=in_channels, out_channels=512, kernel_size=3, stride=1, padding=0, bias=True), # bias=True since no BN is applied.
            nn.Tanh(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=0, bias=True),
            nn.Tanh(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.sub_block2 = nn.Sequential( # gets 256^2, ch=3  outputs: 12^2, ch=512
            Interpolate(size=(16, 16), mode='bicubic'),
            nn.Conv2d(in_channels=3, out_channels=512, kernel_size=3, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(512),
            nn.Tanh(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=0, bias=False),
            nn.Tanh(),
        )

    def forward(self, B3_input, orig_input):
        sub_block1 = self.sub_block1(B3_input)
        sub_block2 = self.sub_block2(orig_input)
        B4 = torch.cat((sub_block1, sub_block2), dim=1)  # concat alongside channels dim'
        return B4

class Group2_B5(nn.Module):
    def __init__(self, in_channels=1024):
        super(Group2_B5, self).__init__()
        self.sub_block = nn.Sequential(  # gets 12^2, ch=1024,   outputs: 8^2, ch=2048
            nn.Conv2d(in_channels=in_channels, out_channels=2048, kernel_size=3, stride=1, padding=0, bias=True), # bias=True since no BN is applied.
            nn.Tanh(),
            nn.Conv2d(in_channels=2048, out_channels=2048, kernel_size=3, stride=1, padding=0, bias=True),
            nn.Tanh(),
        )

    def forward(self, B4_input):
        B5 = self.sub_block(B4_input)
        return B5

class Group3_Bi(nn.Module):
    def __init__(self, in_channels_prev_b, in_channels_g1):
        super(Group3_Bi, self).__init__()
        # ------------------- UNSURE HERE REGARDING THE 1ST SUBBLOCK, WE DECONV UP TO 16 THEN CONV TWICE TO 8 THEN AGAIN DECONV UP TO 16?????
        # First-part of the block:
        self.sub_block1 = nn.Sequential(  # gets size^2, #channels, outputs (size*2)^2, #channels/2
            nn.ConvTranspose2d(in_channels=in_channels_prev_b, out_channels=round(in_channels_prev_b / 2),kernel_size=2, stride=2, padding=0), # double x=h,w to 2x X 2x
            nn.Conv2d(in_channels=round(in_channels_prev_b / 2), out_channels=round(in_channels_prev_b / 2), kernel_size=3, stride=1, padding=0), # turn to 2x-2 X 2x-2
            # nn.BatchNorm2d(in_channels_prev_b/2),
            nn.Tanh(),
            nn.Conv2d(in_channels=round(in_channels_prev_b / 2), out_channels=round(in_channels_prev_b / 2), kernel_size=3, stride=1, padding=0),  # turn to 2x -4 X 2x-4
            # nn.BatchNorm2d(in_channels_prev_b/2),
            nn.Tanh(),
            nn.ConvTranspose2d(in_channels=round(in_channels_prev_b / 2), out_channels=round(in_channels_prev_b / 2), kernel_size=5, stride=1, padding=0),  # turn back to 2x X 2x,
        )

        # Mid-part of the block:
        self.sub_block2 = nn.ConvTranspose2d(in_channels=in_channels_g1, out_channels=in_channels_g1, kernel_size=5, stride=1, padding=0) # it upsamples by 4 only

        # Third-part of the block:
        self.sub_block3 = nn.Sequential(
            nn.Conv2d(in_channels=round(in_channels_g1*2), out_channels=in_channels_g1, kernel_size=3, stride=1, padding=1), # same conv
            nn.Tanh()
        )

    def forward(self, g1_input, prev_b_input):
        sub_block1 = self.sub_block1(prev_b_input)
        sub_block2 = self.sub_block2(g1_input)
        sub_block3 = torch.cat((sub_block1, sub_block2), dim=1)  # concat alongside channels dim'
        Bi = self.sub_block3(sub_block3)
        return Bi

class Group4_Pa1(nn.Module):
    def __init__(self):
        super(Group4_Pa1, self).__init__()
        self.sub_block1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2), #upsample by 2x
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),  #UNLIKE THE PAPER, we'll use same convs, in order to get final-output= 256x256 as out data imgs
            nn.Tanh(),
        )

        self.sub_block2 = nn.Sequential(
            nn.Dropout2d(p=0.5),
            nn.Conv2d(in_channels=64, out_channels=6, kernel_size=3, stride=1, padding=1), #unlike paper ^^
            #nn.Tanh(inplace=True), # Since it's output layer ^^..
        )

    def forward(self, b9_input):
        x1 = self.sub_block1(b9_input) # this also goes onwards to Group5
        pa1 = self.sub_block2(x1)
        return pa1, x1

class Group4_Pa2(nn.Module):
    def __init__(self):
        super(Group4_Pa2, self).__init__()
        self.sub_block1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=4), #upsample by 4x
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),  #UNLIKE THE PAPER, we'll use same convs, in order to get final-output= 256x256 as out data imgs
            nn.Tanh(),
        )

        self.sub_block2 = nn.Sequential(
            nn.Dropout2d(p=0.5),
            nn.Conv2d(in_channels=128, out_channels=6, kernel_size=3, stride=1, padding=1), #unlike paper ^^
            #nn.Tanh(inplace=True), # Since it's output layer ^^..
        )

    def forward(self, b8_input):
        x2 = self.sub_block1(b8_input) # this also goes onwards to Group5
        pa2 = self.sub_block2(x2)
        return pa2, x2

class Group4_Pa3(nn.Module):
    def __init__(self):
        super(Group4_Pa3, self).__init__()
        self.sub_block1 = nn.Sequential(
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=8, stride=8), #upsample by 8x
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),  #UNLIKE THE PAPER, we'll use same convs, in order to get final-output= 256x256 as out data imgs
            nn.Tanh(),
        )

        self.sub_block2 = nn.Sequential(
            nn.Dropout2d(p=0.5),
            nn.Conv2d(in_channels=256, out_channels=6, kernel_size=3, stride=1, padding=1), #unlike paper ^^
            #nn.Tanh(inplace=True), # Since it's output layer ^^..
        )

    def forward(self, b7_input):
        x3 = self.sub_block1(b7_input) # this also goes onwards to Group5
        pa3 = self.sub_block2(x3)
        return pa3, x3

class Group5(nn.Module):
    def __init__(self):
        super(Group5, self).__init__()
        self.sub_block = nn.Sequential( #unlike the paper, the input for this stage is 256x256, 448(after concat)
            nn.Dropout2d(p=0.5),
            nn.Conv2d(in_channels=448, out_channels=6, kernel_size=3, stride=1, padding=1), #unlike paper ^^
        )

    def forward(self, x1, x2, x3):
        x = torch.cat((x1, x2, x3), dim=1) # concat x1,x2,x3 alongside channels dim'
        p0 = self.sub_block(x)
        return p0

class MicroNet(nn.Module):
    def __init__(self):
        super(MicroNet, self).__init__()
        self.group1_b1 = Group1_B1()
        self.group1_b2 = Group1_B2()
        self.group1_b3 = Group1_B3()
        self.group1_b4 = Group1_B4()
        self.group2 = Group2_B5()
        self.group3_b6 = Group3_Bi(in_channels_prev_b=2048, in_channels_g1=1024)
        self.group3_b7 = Group3_Bi(in_channels_prev_b=1024, in_channels_g1=512)
        self.group3_b8 = Group3_Bi(in_channels_prev_b=512, in_channels_g1=256)
        self.group3_b9 = Group3_Bi(in_channels_prev_b=256, in_channels_g1=128)
        self.group4_pa1 = Group4_Pa1()
        self.group4_pa2 = Group4_Pa2()
        self.group4_pa3 = Group4_Pa3()
        self.group5 = Group5()

    def forward(self, x):
        # Propagate through G1:
        b1 = self.group1_b1(x)
        b2 = self.group1_b2(b1, x)
        b3 = self.group1_b3(b2, x)
        b4 = self.group1_b4(b3, x)

        # Propagate through G2:
        b5 = self.group2(b4)

        # Propagate through G3:
        b6 = self.group3_b6(b4, b5)
        b7 = self.group3_b7(b3, b6)
        b8 = self.group3_b8(b2, b7)
        b9 = self.group3_b9(b1, b8)

        # Propagate through G4:
        pa1, x1 = self.group4_pa1(b9)
        pa2, x2 = self.group4_pa2(b8)
        pa3, x3 = self.group4_pa3(b7)

        # Propagate through G5:
        p0 = self.group5(x1, x2, x3)

        return p0, pa1, pa2, pa3  # recall that p0 = main output, pa1,pa2,pa3 = auxiliary outputs


# def load_dataset(batch_size, shuffle_flag, num_workers, data_dir, transforms=None):
#     dataset = pnk.PanNukeDataset(data_dir)
#     data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle_flag)#, num_workers=num_workers)
#     return data_loader

def train(model, loader, opt, criterion,scheduler, epoch):
    epoch_loss = 0
    epoch_acc = 0

    # Train the model (turn training-mode on..)
    model.train()

    with tqdm.tqdm(total=len(loader), file=sys.stdout) as pbar:
        for (images, masks) in loader:
            images = images.to(device)
            masks = masks.to(device)

            # reinitialize gradients..
            opt.zero_grad()

            # Calculation of loss (according to paper):
            p0, pa1, pa2, pa3 = model(images) # Pi.shape, 256x256x5
            masks = torch.argmax(masks, dim=1)
            #loss = criterion(output, masks)
            l0 = criterion(p0, masks)
            l1 = criterion(pa1, masks)
            l2 = criterion(pa2, masks)
            l3 = criterion(pa3, masks)
            loss = l0+(l1+l2+l3)/epoch
            # Backpropagation
            loss.backward()

            # Calculate accuracy
            #acc = calculate_accuracy(output, labels)

            # update weights according to gradients
            opt.step()

            epoch_loss += loss.item()
            #epoch_acc += acc.item

            pbar.update()
            pbar.set_description(f'train loss={loss.item():.3f}')

        pbar.set_description(f'train loss={epoch_loss / len(loader):.3f}')
    scheduler.step()

    return epoch_loss / len(loader), epoch_acc / len(loader)


# ================= MAIN =====================
# Variables:
BATCH_SIZE = 6
NUM_WORKERS = 4
EPOCHS = 50
LEARNING_RATE=0.001

# train_dir = './/Final_Dataset/train_pickled_data'
# val_dir = './/Final_Dataset/val_pickled_data'
# test_dir = './/Final_Dataset/test_pickled_data'

# Set up device:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
print(device)

# Grab loaders:
# train_loader = load_dataset(BATCH_SIZE, shuffle_flag=True, num_workers=NUM_WORKERS, data_dir=train_dir)
# val_loader = load_dataset(BATCH_SIZE, shuffle_flag=True, num_workers=NUM_WORKERS, data_dir=val_dir)
# test_loader = load_dataset(BATCH_SIZE, shuffle_flag=True, num_workers=NUM_WORKERS, data_dir=test_dir)

# creating data loaders
images_dir = 'Images 1/images.npy'
masks_dir = 'Masks 1/masks.npy'
training_loader_1, _ = load_dataset(dir_root='/content/gdrive/MyDrive/PanNuke/', dir_masks=masks_dir, dir_images=images_dir, training_size=1)
images_dir = 'Images 3/images.npy'
masks_dir = 'Masks 3/masks.npy'
training_loader_2, _ = load_dataset(dir_root='/content/gdrive/MyDrive/PanNuke/', dir_masks=masks_dir, dir_images=images_dir, training_size=1)
images_dir = 'Images 2/images.npy'
masks_dir = 'Masks 2/masks.npy'
validation_loader, test_loader = load_dataset(dir_root='/content/gdrive/MyDrive/PanNuke/', dir_masks=masks_dir, dir_images=images_dir, training_size=0.5)



model = MicroNet().to(device)
summary(model, (3, 256, 256))
model.double()

# Define optimizer and criterion functions   - IMPROVE... LEARN HP'S
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999), eps=1e-08)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5, last_epoch=-1, verbose=False)
loss_func = nn.CrossEntropyLoss()

for epoch in range(EPOCHS):
    print("Epoch-%d: " % (epoch))

    train_start_time = time.monotonic()
    train_loss, train_acc = train(model, training_loader_1, optimizer, loss_func, scheduler, epoch+1)
    train_loss, train_acc = train(model, training_loader_2, optimizer, loss_func, scheduler, epoch+1)
    train_end_time = time.monotonic()


    print(f"train loss={train_loss}, epoch={epoch}")

    # val_start_time = time.monotonic()
    # val_loss, val_acc = evaluate(model, val_loader, loss_func)
    # val_end_time = time.monotonic()
    #
    # print(f"train loss={train_loss}, epoch={epoch}")






cuda


  cpuset_checked))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 252, 252]           4,800
       BatchNorm2d-2         [-1, 64, 252, 252]             128
              Tanh-3         [-1, 64, 252, 252]               0
            Conv2d-4         [-1, 64, 248, 248]         102,400
       BatchNorm2d-5         [-1, 64, 248, 248]             128
              Tanh-6         [-1, 64, 248, 248]               0
         MaxPool2d-7         [-1, 64, 124, 124]               0
       Interpolate-8          [-1, 3, 128, 128]               0
            Conv2d-9         [-1, 64, 126, 126]           1,728
      BatchNorm2d-10         [-1, 64, 126, 126]             128
             Tanh-11         [-1, 64, 126, 126]               0
           Conv2d-12         [-1, 64, 124, 124]          36,864
             Tanh-13         [-1, 64, 124, 124]               0
        Group1_B1-14        [-1, 128, 1

KeyboardInterrupt: ignored

# **Visualizing Predictions vs Ground Truth**

In [None]:
#saving model
# PATH = root_path + "/model_UNet.pt"
# torch.save(model, PATH)

# model = UNet().to(device)
# PATH = root_path + "/model_UNet.pt"
# model = torch.load(PATH)

vis_predictions()

  cpuset_checked))


RuntimeError: ignored

In [None]:
torch.cuda.empty_cache()
