In [1]:
%load_ext autoreload
%autoreload 2
import os
from os.path import splitext
from os import listdir
from glob import glob
import gc
import random
import logging
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import Dataset, DataLoader, random_split

import albumentations as A
from torchvision.models.segmentation import fcn_resnet101
from torchvision.models.segmentation.fcn import FCNHead
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from scipy.ndimage import morphology

from classes import Material, BasicDataset

import pandas as pd




In [2]:
inputs = {
"materials" :[
             Material("background", [85,85,85], 30, 0.5),
             Material("epidermis", [170,170,170], 150, 0.5),
             Material("mesophyll", [255,255,255], 255, 0.5),
             Material("air_space", [0,0,0], 1, 0.5),
             Material("bundle_sheath_extension", [103,103,103], 100, 0.5),
             Material("vein", (35,35,35), 180, 0.5)
            ],
#Various input/output directories
"training_image_directory" : "train/train_images/",
"training_mask_directory" : "train/train_masks/",
#Fraction of total annotations you want to leave for validating the model.
"validation_fraction": 0.2,
#Model Performance varies, make multiple models to have the best chance at success.
"num_models" : 1,
#Model Performance improves with increasing epochs, to a point.
"num_epochs" : 100,
"batch_size" : 1,
#Decrease scale to decrease VRAM usage; if you run out of VRAM during traing, restart your runtime and down scale your images
"scale" : 1,
"seed" : 0,
"models_directory" : "best_models/",
"model_group" : 'test/',
"current_model_name" : 'test',
"test_images" : "test/test_images/",
"test_masks": "test/test_masks/",
"csv_directory" : "other/",
#Input the directory of the data you want to segment here.
"inference_directory": "other/",
#Input the 5 alpha-numeric characters proceding the file number of your images
  #EX. Jmic3111_S0_GRID image_0.tif ----->mage_
"proceeding":"lice_",
#Input the 4 or mor alpha-numeric characters following the file number
  #EX. Jmic3111_S0_GRID image_0.tif ----->.tif
"following" : ".png",
"output_directory": "out/"
}


In [3]:
class SegModel():
    def __init__(self, init_dict):
        self.__dict__ = init_dict
        self.num_materials = len(init_dict["materials"])
        self.dir_checkpoint = self.models_directory
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        os.makedirs(self.dir_checkpoint+self.model_group, exist_ok=True )
        
    def set_up_model(self):
        self.model = fcn_resnet101(pretrained=True, progress=True)
        self.model.classifier=FCNHead(2048, self.num_materials)
        self.criterion = nn.BCEWithLogitsLoss()
        return self.model
    
    def get_data(self):
        return BasicDataset(self.training_image_directory, self.training_mask_directory, self.materials, scale=self.scale, transform=False)
    
    def get_loader(self, dataset):
        return DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=0, pin_memory=True)
    
    def trainval_split(self, dataset):
        validation_size = int(len(dataset) * self.validation_fraction)
        train_size = len(dataset) - validation_size
        train, val = torch.utils.data.random_split(dataset, [train_size, validation_size], generator=torch.Generator().manual_seed(self.seed))
 
        return train, val

    def setup_data(self):
        dataset = self.get_data()
        train_loader = self.get_loader(dataset)
        nimages = 0
        mean = 0. 
        std = 0.
        for batch, _ in train_loader:
            # Rearrange batch to be the shape of [B, C, W * H]
            batch = batch.view(batch.size(0), batch.size(1), -1)
            # Update total number of images
            nimages += batch.size(0)
            # Compute mean and std here
            mean += batch.mean(2).sum(0) 
            std += batch.std(2).sum(0)

        # Final step
        mean /= nimages
        std /= nimages

        print(mean)
        print(std)

        dataset.means=mean
        dataset.stds=std 
        
        return dataset
    
    def train_dataloader(self):
        return self.get_loader(self.dataset_train)

    def val_dataloader(self):
        return self.get_loader(self.dataset_val)

    def test_dataloader(self):
        return self.get_loader(self.dataset_val)

    def forward(self, x):
        return self.model(x)
    
    def backward(self, loss, optimizer, optimizer_idx):
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3)
        return optimizer
    
    def training_step(self, batch, batch_idx):
        img, mask = batch
        pred = self.model(img)
        loss = self.criterion(pred, mask)
        return loss
    
    def validation_step(self, batch, batch_idx):
        img, mask = batch
        pred = self.model(img)
        loss = self.criterion(pred, mask)
        return loss
    
    def test_step(self, batch, batch_idx):
        img, mask = batch
        pred = self.model(img)
        loss = self.criterion(pred, mask)
        return loss
    
    def predict_step(self, batch, batch_idx):
        img, mask = batch
        pred = self.model(img)
        return pred
    
    def train(self,verbose = True):
        self.model = self.set_up_model() 
        self.dataset = self.get_data()
        self.dataset_train, self.dataset_val = self.trainval_split(self.dataset)

        self.train_loader = self.get_loader(self.dataset_train)
        self.val_loader = self.get_loader(self.dataset_val)

        self.model.to(self.device)

        num_epochs= self.num_epochs
        optimizer = self.configure_optimizers()

        best_loss=999

        criterion = nn.BCEWithLogitsLoss()

        #this is the train loop
        for epoch in range(num_epochs):
            #print(psutil.virtual_memory().percent)
            if verbose:
                print('Epoch: ', str(epoch))
        #add back if doing fractional training
            self.train_loader.dataset.dataset.transform=True
            self.model.train()
            for images, masks in self.train_loader:

                images = images.to(device=self.device, dtype=torch.float32)
                masks = masks.to(device=self.device, dtype=torch.float32)

                #forward pass
                preds=self.model(images)['out'].cuda()

                #compute loss
                loss=criterion(preds, masks)

                #reset the optimizer gradients to 0
                optimizer.zero_grad()

                #backward pass (compute gradients)
                loss.backward()

                #use the computed gradients to update model weights
                optimizer.step()

            if verbose:
                print('Train loss: '+str(loss.to('cpu').detach()))

        self.val_loader.dataset.dataset.transform=False
        current_loss=0

        #test on val set and save the best checkpoint
        self.model.eval()
        with torch.no_grad():
            for images, masks in self.val_loader:
                images = images.to(device=self.device, dtype=torch.float32)
                masks = masks.to(device=self.device, dtype=torch.float32)
                preds = self.model(images)['out'].cuda()

                loss = criterion(preds, masks)
                current_loss+=loss.to('cpu').detach()
                del images, masks, preds, loss
 
        if best_loss>current_loss:
            best_loss=current_loss
            print('Best Model Saved!, loss: '+ str(best_loss))
            torch.save(self.model.state_dict(), self.dir_checkpoint + self.model_group + self.current_model_name+".pth")
        else:
            print('Model is bad!, Current loss: '+ str(current_loss) + ' Best loss: '+str(best_loss))
            print('\n')
    def validation(self): 
        prop_list = []
        for mat in self.materials:
            prop_list.append([[],[],[],[]])

        for images, target in self.val_loader:
            images = images.to(device=self.device, dtype=torch.float32)
            target = target.to(device=self.device, dtype=torch.float32)

            with torch.no_grad():
                pred=self.model(images)['out'].cuda()
                pred=nn.Sigmoid()(pred)

            for i, mat in enumerate(self.materials):
                material_target = target[:,i,:,:]
                material_pred = pred[:, i, :, :]
                material_pred[material_pred >=mat.confidence_threshold] = 1
                material_pred[material_pred <=mat.confidence_threshold] = 0
                pred[:, i, :, :]=material_pred

                material_tp=torch.sum(material_target*material_pred, (1,2))
                material_fp=torch.sum((1-material_target)*material_pred, (1,2))
                material_fn=torch.sum(material_target*(1-material_pred), (1,2))
                material_tn=torch.sum((1-material_target)*(1-material_pred), (1,2))

                material_precision=torch.mean((material_tp+0.000000001)/(material_tp+material_fp+0.000000001))
                material_recall=torch.mean((material_tp+0.000000001)/(material_tp+material_fn+0.000000001))
                material_accuracy=torch.mean((material_tp+material_tn+0.000000001)/(material_tp+material_tn+material_fp+material_fn+0.000000001))
                material_f1=torch.mean(((material_tp+0.000000001))/(material_tp++0.000000001+0.5*(material_fp+material_fn)))

                prop_list[i][0].append(material_precision.cpu().detach().numpy())
                prop_list[i][1].append(material_recall.cpu().detach().numpy())
                prop_list[i][2].append(material_accuracy.cpu().detach().numpy())
                prop_list[i][3].append(material_f1.cpu().detach().numpy())

        properties = {"name" : [mat.name for mat in self.materials],
                "precision" : [str(np.mean(prop_list[i][0])) for i in range(self.num_materials)],
                "recall" : [str(np.mean(prop_list[i][1])) for i in range(self.num_materials)],
                "accuracy" : [str(np.mean(prop_list[i][2])) for i in range(self.num_materials)],
                "f1" : [str(np.mean(prop_list[i][3])) for i in range(self.num_materials)]}
        self.modeldata = pd.DataFrame(properties, columns = ["name", "precision", "recall", "accuracy", "f1"])



In [4]:
%%timeit
Leaf = SegModel(inputs)
Leaf.num_epochs = 5
Leaf.train(verbose = True)
Leaf.validation()

Epoch:  0
Train loss: tensor(0.1634)
Epoch:  1
Train loss: tensor(0.1381)
Epoch:  2
Train loss: tensor(0.1311)
Epoch:  3
Train loss: tensor(0.1219)
Epoch:  4
Train loss: tensor(0.1494)
Best Model Saved!, loss: tensor(4.2514)
Epoch:  0
Train loss: tensor(0.1693)
Epoch:  1
Train loss: tensor(0.1390)
Epoch:  2
Train loss: tensor(0.1708)
Epoch:  3
Train loss: tensor(0.1475)
Epoch:  4
Train loss: tensor(0.1258)
Best Model Saved!, loss: tensor(2.6964)
Epoch:  0
Train loss: tensor(0.1659)
Epoch:  1
Train loss: tensor(0.1527)
Epoch:  2
Train loss: tensor(0.1554)
Epoch:  3
Train loss: tensor(0.1194)
Epoch:  4
Train loss: tensor(0.1191)
Best Model Saved!, loss: tensor(3.3628)
Epoch:  0
Train loss: tensor(0.1680)
Epoch:  1
Train loss: tensor(0.1510)
Epoch:  2
Train loss: tensor(0.1427)
Epoch:  3
Train loss: tensor(0.1180)
Epoch:  4
Train loss: tensor(0.1267)
Best Model Saved!, loss: tensor(8.9043)
Epoch:  0
Train loss: tensor(0.1499)
Epoch:  1
Train loss: tensor(0.2032)
Epoch:  2
Train loss: tens

In [5]:
Leaf.modeldata

NameError: name 'Leaf' is not defined