# Check GPU 

In [None]:
!nvidia-smi

# Cellpose images as labels

In [None]:

import numpy as np 
import matplotlib.pyplot as plt
import skimage
from skimage import io
from PIL import Image
from sklearn.metrics import jaccard_score
import os

cellpose=os.listdir('/content/drive/MyDrive/DACO/CellPose')
cellpose=[x for x in cellpose if x.endswith('.npy')]
save_path='/content/drive/MyDrive/DACO/CellPose_Flow_Png'
for x in cellpose:
  seg=np.load(os.path.join('/content/drive/MyDrive/DACO/CellPose',x),allow_pickle=True).item()
  array=np.array(seg['flows'])[0][0]
  array=array.clip(min=0)
  array=array.clip(max=255) 
  seg_img=Image.fromarray(np.uint8(array))
  gray = seg_img.convert('L')  #conversion to gray scale    
  gray.save(os.path.join(save_path,x[:8]+'_flow.png'))
  
  del seg




# Dataset Class

In [None]:
import numpy as np
import csv
from PIL import Image
import os
csvfile=open(r"../input/sartorius-cell-instance-segmentation/train.csv")

csvreader=csv.reader(csvfile)
ids=[]
class_=[]
csv=[]
for row in csvreader:
  csv.append(row);
  ids.append(row[0])
  class_.append(row[4])

class_=np.array(class_).reshape(-1,1)
ids=np.array(ids).reshape(-1,1)

classes=np.concatenate([ids,class_],axis=1)


train=os.listdir('../input/sartorius-cell-instance-segmentation/train')

train=[x[:-4] for x in train]
cort_ids=[classes[x][0] for x in range(classes.shape[0]) if classes[x][1]=='cort']
astro_ids=[classes[x][0] for x in range(classes.shape[0]) if classes[x][1]=='astro']
shsy5y_ids=[classes[x][0] for x in range(classes.shape[0]) if classes[x][1]=='shsy5y']
cort_ids_f=[]
astro_ids_f=[]
shsy5y_ids_f=[]

[cort_ids_f.append(n) for n in cort_ids if n not in cort_ids_f] 
[astro_ids_f.append(n) for n in astro_ids if n not in astro_ids_f] 
[shsy5y_ids_f.append(n) for n in shsy5y_ids if n not in shsy5y_ids_f] 

cort=np.array(['cort' for x in range(len(cort_ids_f))]).reshape(-1,1)
astro=np.array(['astro' for x in range(len(astro_ids_f))]).reshape(-1,1)
shsy5y=np.array(['shsy5y' for x in range(len(shsy5y_ids_f))]).reshape(-1,1)

cort_ids_f=np.array(cort_ids_f).reshape(-1,1)
astro_ids_f=np.array(astro_ids_f).reshape(-1,1)
shsy5y_ids_f=np.array(shsy5y_ids_f).reshape(-1,1)


cort=np.concatenate([cort_ids_f,cort],axis=1)
astro=np.concatenate([astro_ids_f,astro],axis=1)
shsy5y=np.concatenate([shsy5y_ids_f,shsy5y],axis=1)

ids_classes=np.concatenate([cort,astro,shsy5y])




In [None]:
import os
import numpy as np
import torch
import skimage
from skimage import io
from PIL import Image

root_path = '../input/sartorius-cell-instance-segmentation'

#Class which allow iterate through images from the dataset
class SartoriusDataset(torch.utils.data.Dataset):
    def __init__(self, root,mask_root, transforms,id_classes):
        self.root = root
        self.mask_root=mask_root
        self.transforms = transforms
        self.id_classes=id_classes
        #Load all training images and masks, sorting them to ensure that they are aligned
        self.imgs = list(sorted(os.listdir(os.path.join(root, "train"))))
        self.imgs=self.imgs[:-1]
        self.masks = list(sorted(os.listdir(os.path.join(mask_root, "CellPose_Flow_Png"))))

    def __getitem__(self, idx):
        #Load images and masks
        
        img_path = os.path.join(self.root, "train", self.imgs[idx])
        mask_path = os.path.join(self.mask_root, "CellPose_Flow_Png", self.masks[idx])

        #Open image and mask
        img = io.imread(img_path)
        #img = skimage.color.gray2rgb(img)
        mask = io.imread(mask_path)
        
        zeros=np.expand_dims(np.zeros((520,704)),axis=-1)
        ones=np.expand_dims(np.ones((520,704)),axis=-1)
       
        if ids_classes[idx][1]=='cort':
            img = np.expand_dims(img, axis=-1)
            img=np.concatenate([img,ones,zeros,zeros],axis=-1)
        if ids_classes[idx][1]=='astro':
            img = np.expand_dims(img, axis=-1)
            img=np.concatenate([img,zeros,ones,zeros],axis=-1)
        if ids_classes[idx][1]=='shsy5y':
            img = np.expand_dims(img, axis=-1)
            img=np.concatenate([img,zeros,zeros,ones],axis=-1)
        
        if self.transforms is not None:
            img =self.transforms(img)
            mask = self.transforms(mask)
        
        
        #return img, target
        return img, mask

    def __len__(self):
        return len(self.imgs)

#Class which allow iterate through images from the dataset
class SartoriusTestDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms

        #Load all training images and masks, sorting them to ensure that they are aligned
        self.imgs = list(sorted(os.listdir(os.path.join(root, "test"))))

    def __getitem__(self, idx):
        #Load images and masks
        img_path = os.path.join(self.root, "test", self.imgs[idx])

        #Open image and mask
        img = io.imread(img_path)
        #img = skimage.color.gray2rgb(img)
        
        ones=np.expand_dims(np.ones((520,704)),axis=-1)
        zeros=np.expand_dims(np.zeros((520,704)),axis=-1)
        if ids_classes[idx][1]=='cort':
            img = np.expand_dims(img, axis=-1)
            img=np.concatenate([img,ones,zeros,zeros],axis=-1)
        if ids_classes[idx][1]=='astro':
            img = np.expand_dims(img, axis=-1)
            img=np.concatenate([img,zeros,ones,zeros],axis=-1)
        if ids_classes[idx][1]=='shsy5y':
            img = np.expand_dims(img, axis=-1)
            img=np.concatenate([img,zeros,zeros,ones],axis=-1)
        if self.transforms is not None:
            img =self.transforms(img)

        #return img, target
        return img

    def __len__(self):
        return len(self.imgs)

# U-Net Model

In [None]:
import torch
from torch import nn
from torchvision import models


class UNET(nn.Module):
    def __init__(self,in_channels,init_features):
        super().__init__()         
        
        features=init_features
        
        # Construct the encoder blocks
        self.enc1 = double_conv(in_channels, features)       
        self.enc2 =  double_conv(features, features * 2)      
        self.enc3 = double_conv(features * 2, features * 4)       
        self.enc4 =double_conv(features * 4, features * 8)       
        self.bottleneck = b_conv(features * 8, features * 16)   
        self.pool=nn.MaxPool2d(kernel_size=2, stride=2)

        # Construct the decoder blocks
        self.dec1=dec(features*8)
        self.upconv1=upscale(features*16,features*8)
        self.dec2=dec(features*4)
        self.upconv2=upscale(features*8,features*4)
        self.dec3=dec(features*2)
        self.upconv3=upscale(features*4,features*2)
        self.dec4=dec(features)
        self.upconv4=upscale(features*2,features)
        self.conv=nn.Conv2d(features,1, kernel_size=1) #out
    
    def forward(self,x): #x input image
        # encoder
        enc1 = self.enc1(x)
        
        enc2 = self.enc2(self.pool(enc1))
        
        enc3 = self.enc3(self.pool(enc2))
        
        enc4 = self.enc4(self.pool(enc3))
        
        bottleneck = self.bottleneck(self.pool(enc4))
        
        dec1=self.upconv1(bottleneck)
        dec1=copy_crop(dec1,enc4)
        dec1=self.dec1(dec1)
        
        dec2=self.upconv2(dec1)
        dec2=copy_crop(dec2,enc3)
        dec2=self.dec2(dec2)
        
        dec3=self.upconv3(dec2)
        dec3=copy_crop(dec3,enc2)
        dec3=self.dec3(dec3)
        
        dec4=self.upconv4(dec3)
        dec4=copy_crop(dec4,enc1)
        dec4=self.dec4(dec4)
        
        out=self.conv(dec4) 
        
        return out
    def to_class(output):
        output=[1 if (x>0.5) else 0 for x in output]
  
        return output
def dec(feat):
    return nn.Sequential(
            nn.Conv2d(feat*2,feat,kernel_size=3,padding=1,bias=False),
            nn.BatchNorm2d(feat),
            nn.ReLU(),
            nn.Conv2d(feat,feat,kernel_size=3,padding=1,bias=False),
            nn.BatchNorm2d(feat),
            nn.ReLU()
            )
    
def upscale (in_f,out_f): 
    return nn.ConvTranspose2d(in_f, out_f,kernel_size=2,stride=2)
    
def copy_crop (dec,enc): 
    return torch.cat((dec,enc),dim=1)

def double_conv(in_c, out_c):
    conv_=nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3,padding=1,bias=False),
        nn.BatchNorm2d(out_c),
        nn.ReLU(),
        nn.Conv2d(out_c, out_c, kernel_size=3,padding=1,bias=False),
        nn.BatchNorm2d(out_c),
        nn.ReLU()
        )
    return conv_



def b_conv(in_c, out_c):
    conv=nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3,padding=1),
        nn.ReLU()
        )
    return conv

# Training Loop

In [None]:
import numpy as np
import torch
from PIL import Image
from torch import nn
import skimage
from skimage import io
from skimage import transform
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from torch.utils.data import SubsetRandomSampler
from tqdm import tqdm
import copy

pil = transforms.ToPILImage()

def IoU(outputs, labels):
  intersection = np.logical_and(labels, outputs)
  union = np.logical_or(labels, outputs)
  iou_score = np.sum(intersection) / np.sum(union)

  return iou_score

#Path to dataset
root_path = '../input/sartorius-cell-instance-segmentation'
save_path='./'
mask_root='../input/cellpose-flow/content/drive/MyDrive/DACO'
#Create transforms and compose
composed_transform = transforms.Compose([transforms.ToTensor(),transforms.Resize((512,512)) ])

#Dataset settings
dataset = SartoriusDataset(root_path,mask_root, composed_transform,ids_classes)

idx = [*range(len(dataset))]
training_idx, validation_idx = train_test_split(idx, test_size = 0.2)

training_sampler = SubsetRandomSampler(training_idx)
validation_sampler = SubsetRandomSampler(validation_idx)

batch_size = 2
training_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=training_sampler, num_workers=0)
validation_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=validation_sampler, num_workers=0)

#Define model
model = UNET(4,32)

criterion =nn.MSELoss()

#Define one optimizer and scheduler for learning rate
optimizer = torch.optim.SGD(model.parameters(), 0.01 , momentum=0.01)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.5)
#Check if GPU is available for training
train_on_gpu = torch.cuda.is_available()

if train_on_gpu:
  print("Cuda is available! Training on GPU")
  model.cuda()
else:
  print("Cuda isn't available! Training on CPU")

#Number of epochs
n_epoch = 50;

training_IoU = []
training_loss = []

validation_IoU = []
validation_loss = []
lrs = []
min_valid_loss = 1e9

for epoch in range(n_epoch):
  print(f'[Epoch: {epoch+1}]')

  #Training loop
  model.train()
  print('Training model...')

  t_loss = 0
  t_IoU = 0

  for i, (img,mask) in enumerate(tqdm(training_dataloader)):
    
    if train_on_gpu:
      img, mask = img.cuda(), mask.cuda()
    
    #Reset the optimizer gradient
    optimizer.zero_grad()
    img=img.float()
    #Feedforward
    output = model(img)
    
    #Calculate the loss
    loss = criterion(torch.sigmoid(output), mask)

    #Backpropagation
    loss.backward()

    #Update the model
    optimizer.step()
    
    #Save loss
    t_loss += loss.item()

    #Save IoU
    for j in range(len(output)):
        t_IoU += IoU(np.round(torch.sigmoid(output)[j][0].detach().cpu().numpy()+0.3),np.round(mask[j][0].detach().cpu().numpy()+0.3))
        
    #Visualize Image:
    '''
    out=pil(torch.round(torch.sigmoid(output[0])))
    display(out)'''

  #Validation loop
  model.eval()
  print('Validating model...')

  v_loss = 0
  v_IoU = 0

  for i, (img,mask) in enumerate(tqdm(validation_dataloader)):
    
    if train_on_gpu:
      img, mask = img.cuda(), mask.cuda()

    #Feedforward
    img=img.float()
    output = model(img)
    
    #Calculate the loss
    loss = criterion(torch.sigmoid(output), mask)

    #Save loss
    v_loss += loss.item()

    #Save IoU
    for j in range(len(output)):
      v_IoU += IoU(np.round(torch.sigmoid(output)[j][0].detach().cpu().numpy()+0.3),np.round(mask[j][0].detach().cpu().numpy()+0.3))
    
  # learning rate update every 10 epochs:
  if (epoch%10==0): 
    scheduler.step()
    lrs.append(optimizer.param_groups[0]["lr"])
    print(f'Learning Rate Updated to {lrs[-1]}')
  #Average and save losses and IoU metric
  t_loss = t_loss/len(training_dataloader.sampler)
  v_loss = v_loss/len(validation_dataloader.sampler)
  t_IoU = t_IoU/len(training_dataloader.sampler)
  v_IoU = v_IoU/len(validation_dataloader.sampler)

  training_loss.append(t_loss)
  validation_loss.append(v_loss)
  training_IoU.append(t_IoU)
  validation_IoU.append(v_IoU)

  #Save the model state if validation loss has decreased
  if (v_loss < min_valid_loss):
    min_valid_loss = v_loss
    print("Validation loss decreased! Saving model...")
    best_model=copy.deepcopy(model.state_dict())
    torch.save(model.state_dict(),os.path.join(save_path,'model.pth'))
   
  
  print(f'Training loss: {t_loss}\tValidation loss: {v_loss}\tTraining IoU: {t_IoU}\tValidation IoU: {v_IoU}')
  np.save(os.path.join(save_path,'training_loss.npz'),training_loss);
  np.save(os.path.join(save_path,'val_loss.npz'),validation_loss);
  np.save(os.path.join(save_path,'training_IoU.npz'),training_IoU);
  np.save(os.path.join(save_path,'validation_IoU.npz'),validation_IoU);



np.save(os.path.join(save_path,'training_loss.npz'),training_loss);
np.save(os.path.join(save_path,'val_loss.npz'),validation_loss);
np.save(os.path.join(save_path,'training_IoU.npz'),training_IoU);
np.save(os.path.join(save_path,'validation_IoU.npz'),validation_IoU);
model.load_state_dict(best_model)
torch.save(model.state_dict(),os.path.join(save_path,'model.pth'))

