In [None]:
!wget https://ml.gan4x4.ru/msu/students/dubrovin/dataset.zip

In [None]:
!unzip dataset.zip

In [None]:
!pip install segmentation-models-pytorch lightning torchmetrics

In [None]:
from math import nan
from torch.utils.data import Dataset
from glob import glob
import os
from pathlib import Path
import torch
import pandas as pd
import matplotlib.pyplot as plt
import json
import cv2
import numpy as np
import warnings
import ast
import pickle
import torchvision.transforms.functional as F
from statistics import mode
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
from torchvision.transforms import ToTensor


class CustomDataset2chdm(Dataset):
    def __init__(self, root_dir, transform=None, target_transform = None, exclude = [], cache = None):
      images = glob(f"{root_dir}{os.sep}Images{os.sep}*")
      labels = glob(f"{root_dir}{os.sep}Masks{os.sep}_*") # read only continuous heigths
      # extract id from paths
      im = set(map(lambda x: Path(x).stem,images))
      lab = set(map(lambda x: Path(x).stem[1:],labels))
      img_without_masks = im - lab
      if len(img_without_masks) > 0:
        warn_text = f"Found images without masks {','.join(img_without_masks)}"
        warnings.warn(warn_text)
      self.items =  (list((im & lab) - set(exclude)))
      self.items.sort()
      self.root_dir = root_dir
      self.transform = transform
      self.target_transform = target_transform
      self.max_height = 100
      if cache == None:
          self.cache = {}
      else:
        self.cache = cache

    def get_all_scales(self):
        scales = []
        for name in self.items:
            path = self.get_im_path(name)
            image, real_w = self.txt2pil(path)
            scales.append(self.get_scale(image, real_w))
        return np.array(scales)

    def get_scale(self, img, real_w):
      return  real_w / img.shape[1]

    def get_im_path(self,name):
      path = f"{self.root_dir}{os.sep}Images{os.sep}{name}.txt"
      return path

    def get_mask_path(self,name):
      path = f"{self.root_dir}{os.sep}Masks{os.sep}_{name}.txt"
      return path

    def save_cache(self,filename = "cache.pickle"):
      with open(filename, 'wb') as f:
        pickle.dump(self.cache, f, pickle.HIGHEST_PROTOCOL)

    def load_cache(self,filename = "cache.pickle"):
      if os.path.isfile(filename):
        with open(filename, 'rb') as f:
          self.cache =  pickle.load(f)
        print(f"Loaded cache from {filename}")

    def __len__(self):
      return len(self.items)

    def line2tensor(self,line):
        txt = line.strip()
        parts = txt.split("\t")
        parts = list(filter(len,parts)) #remove empty
        if len(parts) <= 2:
          return None
        numbers = list(map(float, parts))
        t = torch.tensor(numbers)
        return t


    def txt2pil(self, filename):
      if filename in self.cache and "image" in self.cache[filename]:
        return self.cache[filename]["image"].copy(), self.cache[filename]["real_w"]

      # convert list of relative heights to image
      with open(filename, encoding='unicode_escape') as file:
        x_line = file.readline() # X
        x = self.line2tensor(x_line[6:])
        real_w = (x.max()-x.min()).item()
        units = x_line[3:5]
        if units == "A°":
          real_w = real_w/10000
        if units == "nm":
          real_w = real_w/1000
        line = file.readline() # Y, Z skip it
        lines = []
        for line in file:
          if line != '\n':  #to exclude the last line
            pos = line.index('\t')#position of the first tabulation
            line2 = line[(pos + 2):]#exclude Y-coordinate and 2 tabulations after it
            t = self.line2tensor(line2)
            if t is not None:
              lines.append(t)
        t = torch.stack(lines)
        # Shift to zero
        # Because all heights just a difference between current and randomly sampled point
        t = t - t.min()
        t = t.numpy()
        self.cache[filename]= {"image": t, "real_w" : real_w}
      return t, real_w


    def load_heights(self, path):
      """
        get heights of some points marked by human
      """
      df = pd.read_excel(path)
      return self.fix_format(df)


    def get_height_map(self, path):
      if not (path in self.cache and "mask" in self.cache[path]):
          with open(path, 'r') as file:
            content = file.read()
          x = ast.literal_eval(content)
          x = np.array(x)
          self.cache[path] = { "mask" : x }
      return self.cache[path]["mask"].copy()
      #return x


    def __getitem__(self,n):
      """
        img - data(raw heights) from microscope
        masks - continious globules height map

        real_w - width of the image in microns
      """
      name = self.items[n]
      img = self.get_im_path(name)
      mask = self.get_mask_path(name)

      image, real_w = self.txt2pil(img)
      mask = self.get_height_map(mask)

      scale_factor = 0

      if self.transform:
        output = self.transform(image=image, mask=mask)
        image = output['image']
        mask = output['mask']
        if self.target_transform:
          mask = self.target_transform(mask)
      meta = {"w": real_w, 'name' : name, "scale_factor": scale_factor}
      binary_mask = torch.where(mask != 0, 1, 0)
      mask2 = torch.unsqueeze(binary_mask, 0)
      im_mask = torch.cat((image, mask2), 0) #creates two channel tensor, where 1 channel is image, 2nd channel is mask
      return im_mask, image, mask, meta

    def rescale(self,img, real_w):
      resize_coeff = 1
      h,w = img.shape[:2]
      original_size = (h,w)
      most_popular_scale = 0.00389862060546875
      scale = self.get_scale(img,real_w)
      if most_popular_scale != scale:
        resize_coeff = most_popular_scale/scale
      new_size = tuple((np.array(original_size) / resize_coeff).astype(int).tolist()) # '*' changed to '/'
      img = cv2.resize(img, new_size)
      return img, original_size, resize_coeff


In [None]:
from tqdm import tqdm
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [None]:
mean, std = [8.489298], [9.06547]

train_transforms = A.Compose(
    [
        A.Normalize(mean, std),
        A.RandomCrop(192, 192),
        ToTensorV2(),
    ]
)

val_transforms = A.Compose(
    [
        A.Normalize(mean, std),
        A.CenterCrop(192, 192),
        ToTensorV2(),
    ]
)

class NormalizeNonZero(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, x):
        mask = x == 0
        x -= self.mean
        x /= self.std
        x[mask] = 0
        return x.to(torch.float32)

    def denorm(self,x):
        mask = x == 0
        x *=  self.std
        x += self.mean
        x[mask] = 0
        return x

#Normalization by the whole (non-dotted) mask
#target_mean, target_std = 1.7816125217702499, 2.170761470939972
#nnz = NormalizeNonZero(target_mean, target_std)

#Normalization by the dotted mask
dot_target_mean, dot_target_std = 3.016509424749255, 2.452459479074767
nnz = NormalizeNonZero(dot_target_mean, dot_target_std)

target_transform = T.Compose([nnz])


In [None]:
%load_ext autoreload
%autoreload 2
def load_cache(filename = "cachedm.pickle"):
      if os.path.isfile(filename):
        with open(filename, 'rb') as f:
           cache =  pickle.load(f)
        print(f"Loaded cache from {filename}")
        return cache
      return None
cache = load_cache()

In [None]:
ds_train = CustomDataset2chdm("Splitted Dataset/Train", transform  = train_transforms, target_transform = target_transform, cache = cache)
loader_train = DataLoader(ds_train, batch_size=16, shuffle=True, num_workers=2)
#ds_train = CustomDataset("Splitted Dataset/Train")

ds_val = CustomDataset2chdm("Splitted Dataset/Val", transform  = val_transforms, target_transform=target_transform, cache = cache)
loader_val = DataLoader(ds_val, batch_size=4, shuffle=False, num_workers=2)
#ds_val = CustomDataset("Splitted Dataset/Val")

ds_test = CustomDataset2chdm("Splitted Dataset/Test", transform  = val_transforms, target_transform=target_transform, cache = cache)
loader_test = DataLoader(ds_test, batch_size=4, shuffle=False, num_workers=2)
#ds_test = CustomDataset("Splitted Dataset/Test")

In [None]:
from torchmetrics import MeanSquaredError
from typing import Any, Optional, Sequence, Union
from torch import Tensor

class ZeroAwareMSE(MeanSquaredError):
    def update(self, preds: Tensor, target: Tensor) -> None:
        target_sum = torch.sum(target).item()
        if target_sum == 0:#change preds and targets in such a way that the result would be zero tensor (not works directly)
            preds = preds - preds + 1
            target = target + 1
        mask = target != 0
        return super().update(preds[mask],target[mask])

In [None]:
def MSELoss_mask(pred, target): #MSE loss calculated only inside the masks (which are contained in the targets)
  mse_loss = MSELoss(reduction='sum')  #calculates sum of the squared errors (without division by n)
  loss = mse_loss(pred, target) #calculates sum of the squared errors (without division by n)
  target_binary = torch.where(target != 0, 1, 0) #writing 1 in each unmasked pixel and 0 in each masked pixel (for calculation the number of unmasked pixels)
  n_ummasked_pxls = target_binary.sum()#.item() #calculating the number of unmasked pixels
  if n_ummasked_pxls == 0:
    return 0

  return loss/n_ummasked_pxls

import torch.nn as nn

class CustomLoss(nn.Module):
  def __init__(self):
    super(CustomLoss, self).__init__()

  def forward(self, pred, target):
    return MSELoss_mask(pred, target)

In [None]:
import lightning as L
from torchmetrics import MeanSquaredError
from torch.nn import MSELoss
import torch

class Lit(L.LightningModule):
    def __init__(self, model, lr=0.0025):
        super().__init__()
        self.model = model
        self.lr = lr
        self.criterion = CustomLoss()
        self.save_hyperparameters()
        self.metric_train = ZeroAwareMSE()
        self.metric_train_dn = ZeroAwareMSE()
        self.metric_val = ZeroAwareMSE()
        self.metric_val_dn = ZeroAwareMSE()
        self.metric_test = ZeroAwareMSE()
        self.metric_test_dn = ZeroAwareMSE()  #??

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)
        return optimizer

    def training_step(self, batch, batch_idx):
        _, imgs, masks, meta = batch
        predicted_masks = self.get_prediction(batch)
        #print(predicted_masks, masks)
        loss = self.criterion(predicted_masks, masks)
        #loss_dn = self.criterion(nnz.denorm(predicted_masks), nnz.denorm(masks))
        self.log("Loss/", loss.item(), prog_bar=False)
        #self.log("Loss_dn/", loss_dn.item(), prog_bar=False)
        self.metric_train.update(predicted_masks, masks)
        #self.metric_train_dn.update(nnz.denorm(predicted_masks), nnz.denorm(masks))
        return loss

    def get_prediction(self,batch):
        im_masks, imgs, masks, meta = batch
        predicted_masks = self.model(im_masks).squeeze(1)
        predicted_masks = self.postprocess(predicted_masks, masks)
        return predicted_masks

    def postprocess(self, pred, mask):
        pred = pred.squeeze(1)
        pred[mask == 0] = 0
        return pred

    def validation_step(self, batch, batch_idx):
        _, imgs, masks, meta = batch
        predicted_masks = self.get_prediction(batch)
        self.metric_val.update(predicted_masks, masks)
        self.metric_val_dn.update(nnz.denorm(predicted_masks), nnz.denorm(masks))

    def on_validation_epoch_end(self):
        self.log("MSE/val", self.metric_val.compute(), prog_bar=True)
        self.log("MSE/val_dn", self.metric_val_dn.compute(), prog_bar=True)
        self.metric_val.reset()
        self.metric_val_dn.reset()

    def on_train_epoch_end(self):
        self.log("MSE/train", self.metric_train.compute())
        #self.log("MSE_dn/train", self.metric_train_dn.compute())
        self.metric_train.reset()

    def test_step(self, batch, batch_idx):
        im_masks, imgs, masks, _ = batch
        predicted_masks = self.get_prediction(batch)
        self.log("MSE/test", self.metric_test.compute(), prog_bar=True)
        self.metric_test.update(predicted_masks, masks)
        self.metric_test_dn.update(nnz.denorm(predicted_masks), nnz.denorm(masks))#??

    def on_test_epoch_end(self):
        self.log("MSE/test", self.metric_test.compute(), prog_bar=True)
        self.log("MSE/test_dn", self.metric_test_dn.compute(), prog_bar=True)
        self.metric_test.reset()
        self.metric_test_dn.reset()

In [None]:
import segmentation_models_pytorch as smp

smp_unet = smp.Unet(
    encoder_name="efficientnet-b0",  # choose encoder
    encoder_weights=None,
    in_channels=2,
    classes=1,  # model output channels (number of classes in mask)
)

In [None]:
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    monitor='MSE/val',
    dirpath='lightning_logs',
    save_top_k=5,
    mode='min',
)

lit_model = Lit(smp_unet)

#To continue from checkpoint:
#lit_model = Lit.load_from_checkpoint("/home/jupyter/datasphere/project/lightning_logs/epoch=361-step=4344.ckpt")
#model = lit_model.model
#model.train()

logger = TensorBoardLogger("lightning_logs", name="SMPUnet")
trainer = L.Trainer(
    max_epochs=1000,
    logger=logger,
    log_every_n_steps=5,
    callbacks=[checkpoint_callback]
)

In [None]:
trainer.fit(model=lit_model,
            train_dataloaders=loader_train,
            val_dataloaders=loader_val
           )

In [None]:
trainer.test(model=lit_model, dataloaders=loader_test)

In [None]:
trainer.save_checkpoint(f"lightning_logs{os.sep}last_epoch_checkpoint.ckpt")