<a href="https://colab.research.google.com/github/dtdat16/Vess-net-Segmentation/blob/main/VessNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
import os

from torch.utils import data
from torchvision import transforms as T
from torchvision.transforms import functional as F
from PIL import Image
import torch
import numpy as np


class ImageFolder(data.Dataset):
    def __init__(self, root, mode='train'):
        """Initializes image paths and preprocessing module."""
        self.root = root
        self.image_dir = root + '/image'
        # GT : Ground Truth
        self.GT_dir = root + '/GT'

        self.image_idx = list(map(lambda x:  x, os.listdir(self.image_dir)))

        self.mode = mode
        
        print("image count in {} path :{}".format(
            self.mode, len(self.image_idx)))

    def __getitem__(self, index):
        """Reads an image from a file and preprocesses it and returns."""
        image_path = self.image_dir + '/' + self.image_idx[index] 


        GT_path = self.GT_dir + '/' + self.image_idx[index]

        image = Image.open(image_path)
        GT = self.get_gt(GT_path)
        Transform = []
        Transform.append(T.ToTensor())
        Transform = T.Compose(Transform)

        image = Transform(image)
        GT = Transform(GT) 
        GT = torch.squeeze(GT)
        GT = GT.type(torch.LongTensor)
        
        return image, GT

    def __len__(self):
        """Returns the total number of font files."""
        return len(self.image_idx)

    def get_gt(self, gt_path):
        img = Image.open(gt_path)
        img_array = np.array(img)
        h,w = img_array.shape
        result = np.zeros(img_array.shape, dtype= int)
        for i in range(0, h):
            for j in range(0, w):
                value = img_array[i][j]
                if value > 50:
                    result[i][j] = 1
                else:
                    result[i][j] = 0
      
        return result


def get_loader(image_path, batch_size, num_workers=1, mode='train'):
    """Builds and returns Dataloader."""

    dataset = ImageFolder(root=image_path, mode=mode)
    data_loader = data.DataLoader(
        dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    return data_loader

In [None]:
import torch
import torch.nn as nn

class conv_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(conv_block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            nn.BatchNorm2d(ch_out),
        )

    def forward(self,x):
        x = self.conv(x)
        return x

class IRSP(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(IRSP,self).__init__()
        self.irsp = nn.Sequential(
            nn.Conv2d(ch_in,ch_out, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(ch_out)
        )
    
    def forward(self,x):
        x = self.irsp(x)
        return x

class Vess_Net(nn.Module):
    def __init__(self,input_ch= 3, output_ch= 2):
        super(Vess_Net,self).__init__()
        self.pool = nn.MaxPool2d(kernel_size= 2, stride=2, return_indices=True)
        self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
        self.relu = nn.ReLU(inplace=True)

        #encoder block
        self.econv_1_1 = conv_block(input_ch,64)
        self.econv_1_2 = conv_block(64,64)
        
        self.econv_2_1 = conv_block(64,128)
        self.irps_1 = IRSP(64,128)
        self.econv_2_2 = conv_block(128,128)

        self.econv_3_1 = conv_block(128,256)
        self.irps_2 = IRSP(128,256)
        self.econv_3_2 = conv_block(256,256)

        self.econv_4_1 = conv_block(256,512)
        self.irps_3 = IRSP(256,512)
        self.econv_4_2 = conv_block(512,512)

        #decoder block
        self.dconv_4_2 = conv_block(512,512)
        self.irps_4 = IRSP(512,256)
        self.dconv_4_1 = conv_block(512,256)

        self.dconv_3_2 = conv_block(256,256)
        self.irps_5 = IRSP(256,128)
        self.dconv_3_1 = conv_block(256,128)

        self.dconv_2_2 = conv_block(128,128)
        self.irps_6 = IRSP(128,64)
        self.dconv_2_1 = conv_block(128,64)

        self.dconv_1_2 = conv_block(64,64)
        self.dconv_1_1 = conv_block(64,output_ch)
    
    def forward(self,x):
        #Encoder block 1
        x_e_1_1 = self.econv_1_1(x)
        x_e_1_1 = self.relu(x_e_1_1)
        x_e_1_2 = self.econv_1_2(x_e_1_1)
        x_e_1_2 = self.relu(x_e_1_2)
        pool_1_size = x_e_1_2.size()
        x_pool_1, indices_1 = self.pool(x_e_1_2)

        #Encoder block 2
        irsp1 = self.irps_1(x_pool_1)
        x_e_2_1 = self.econv_2_1(x_pool_1)
        x_e_2_1 = self.relu(x_e_2_1)
        x_e_2_2 = self.econv_2_2(x_e_2_1)
        x_e_2_2 = x_e_2_2 + irsp1
        x_e_2_2 = self.relu(x_e_2_2)
        pool_2_size = x_e_2_2.size()
        x_pool_2, indices_2 = self.pool(x_e_2_2)

        #Encoder block 3
        irsp2 = self.irps_2(x_pool_2)
        x_e_3_1 = self.econv_3_1(x_pool_2)
        x_e_3_1 = self.relu(x_e_3_1)
        x_e_3_2 = self.econv_3_2(x_e_3_1)
        x_e_3_2 = x_e_3_2 + irsp2
        x_e_3_2 = self.relu(x_e_3_2)
        pool_3_size = x_e_3_2.size()
        x_pool_3, indices_3 = self.pool(x_e_3_2)

        #Encoder block 4
        irsp3 = self.irps_3(x_pool_3)
        x_e_4_1 = self.econv_4_1(x_pool_3)
        x_e_4_1 = self.relu(x_e_4_1)
        x_e_4_2 = self.econv_4_2(x_e_4_1)
        x_e_4_2 = x_e_4_2 + irsp3
        x_e_4_2 = self.relu(x_e_4_2)
        pool_4_size = x_e_4_2.size()
        x_pool_4, indices_4 = self.pool(x_e_4_2)

        #Decoder block 4
        x_unpool_4 = self.unpool(x_pool_4, indices_4,output_size= pool_4_size)
        irsp4 = self.irps_4(x_unpool_4)
        x_d_4_2 = self.dconv_4_2(x_unpool_4)
        x_d_4_2 = self.relu(x_d_4_2)
        x_d_4_2 = x_d_4_2 + x_e_4_1
        x_d_4_1 = self.dconv_4_1(x_d_4_2)
        x_d_4_1 = x_d_4_1 + irsp4
        x_d_4_1 = self.relu(x_d_4_1)

        #Decoder block 3
        x_unpool_3 = self.unpool(x_d_4_1, indices_3, output_size= pool_3_size)
        irsp5 = self.irps_5(x_unpool_3)
        x_d_3_2 = self.dconv_3_2(x_unpool_3)
        x_d_3_2 = self.relu(x_d_3_2)
        x_d_3_2 = x_d_3_2 + x_e_3_1
        x_d_3_1 = self.dconv_3_1(x_d_3_2)
        x_d_3_1 = x_d_3_1 + irsp5
        x_d_3_1 = self.relu(x_d_3_1)

        #Decoder block 2 
        x_unpool_2 = self.unpool(x_d_3_1, indices_2, output_size= pool_2_size)
        irsp6 = self.irps_6(x_unpool_2)
        x_d_2_2 = self.dconv_2_2(x_unpool_2)
        x_d_2_2 = self.relu(x_d_2_2)
        x_d_2_2 = x_d_2_2 + x_e_2_1
        x_d_2_1 = self.dconv_2_1(x_d_2_2)
        x_d_2_1 = x_d_2_1 + irsp6
        x_d_2_1 = self.relu(x_d_2_1)

        #Decoder block 1
        x_unpool_1 = self.unpool(x_d_2_1, indices_1, output_size= pool_1_size)
        x_d_1_2 = self.dconv_1_2(x_unpool_1)
        x_d_1_2 = self.relu(x_d_1_2)
        x_d_1_2 = x_d_1_2 + x_e_1_1
        x_d_1_1 = self.dconv_1_1(x_d_1_2)
        x_d_1_1 = self.relu(x_d_1_1)

        return x_d_1_1

In [None]:
import torch

# SR : Segmentation Result
# GT : Ground Truth

def get_accuracy(SR,GT):

    corr = torch.sum(SR == GT)
    SR_dim = SR.dim()
    tensor_size = SR.size(0)
    for i in range(1, SR_dim):
      tensor_size = tensor_size*SR.size(i)
      
    acc = float(corr)/float(tensor_size)

    return acc

def get_sensitivity(SR,GT,threshold=0.5):
    # Sensitivity == Recall
    # SR = SR > threshold
    # GT = GT == torch.max(GT)

    # TP : True Positive
    # FN : False Negative
    TP = (SR==1)&(GT==1)
    FN = (SR==0)&(GT==1)

    SE = float(torch.sum(TP))/(float(torch.sum(TP+FN)))     
    
    return SE

def get_specificity(SR,GT,threshold=0.5):
    # SR = SR > threshold
    # GT = GT == torch.max(GT)

    # TN : True Negative
    # FP : False Positive
    TN = (SR==0)&(GT==0)
    FP = (SR==1)&(GT==0)

    SP = float(torch.sum(TN))/(float(torch.sum(TN+FP)))
    
    return SP

def get_JS(SR,GT,threshold=0.5):
    # JS : Jaccard similarity
    # SR = SR > threshold
    # GT = GT == torch.max(GT)
    
    Inter = torch.sum(SR&GT)
    Union = torch.sum(SR|GT)
    
    JS = float(Inter)/(float(Union))
    
    return JS

def get_DC(SR,GT,threshold=0.5):
    # DC : Dice Coefficient
    # SR = SR > threshold
    # GT = GT == torch.max(GT)

    Inter = torch.sum(SR&GT)
    DC = float(2*Inter)/(float(torch.sum(SR)+torch.sum(GT)))

    return DC

In [None]:
import os
import torch
import torchvision
from torch import optim
import csv
import pickle
from pathlib import Path
from PIL import Image
from matplotlib import cm
import numpy as np

def save_model(model, epoch, best_model_score,lr, optimizer, state_dict_dir, checkpoint_dir):
        print("save model in the each epoch")
        checkpoint = {
                'epoch': epoch, 
                'best_model_score': best_model_score,
                'lr': lr,
                'optimizer': optimizer.state_dict()
                }
        torch.save(model.state_dict(), f'{state_dict_dir}/state_dict_without_norm_50.pth')

        file_name = checkpoint_dir + '/' + 'checkpoint_without_norm_50.pickle'

        with open(file_name, 'wb+') as file:
          pickle.dump(checkpoint, file)
          
def load_model_checkpoint(file_name, checkpoint_dir):
    file_path = checkpoint_dir + '/' + file_name
    print('loading model contained in the file', file_path)
    if(Path(file_path).exists()):
      checkpoint = {}
      with open(file_path, 'rb') as file:
        checkpoint = pickle.load(file)
        return {
            'epoch': checkpoint['epoch'], 
            'best_model_score': checkpoint['best_model_score'],
            'lr': checkpoint['lr'],
            'optimizer': checkpoint['optimizer']
            }
    else:
        raise FileNotFoundError('Model file not found')
        
def optimizer_to(optim, device):
    for param in optim.state.values():
        # Not sure there are any global tensors in the state dict
        if isinstance(param, torch.Tensor):
            param.data = param.data.to(device)
            if param._grad is not None:
                param._grad.data = param._grad.data.to(device)
        elif isinstance(param, dict):
            for subparam in param.values():
                if isinstance(subparam, torch.Tensor):
                    subparam.data = subparam.data.to(device)
                    if subparam._grad is not None:
                        subparam._grad.data = subparam._grad.data.to(device)

def save_img(img_array, save_path):
    im = Image.fromarray(np.uint8(cm.gist_earth(img_array)*255))
    im.save(save_path)


In [None]:
import torch

def train(model, optimizer, train_loader,valid_loader, state_dict_dir,checkpoint_dir, lr, current_epoch, num_epochs, device, criterion, best_model_score, net_path, result_path):
        """Train encoder, generator and discriminator."""

        #====================================== Training ===========================================#
        #===========================================================================================#

        while current_epoch < num_epochs:
            
            epoch = current_epoch
            model.train(True)
            epoch_loss = 0
            epoch_acc = 0
            acc = 0.  # Accuracy
            SE = 0.		# Sensitivity (Recall)
            SP = 0.		# Specificity
            JS = 0.		# Jaccard Similarity
            DC = 0.		# Dice Coefficient
            length = 0

            for images, GT in train_loader:
                # GT : Ground Truth

                images = images.to(device)
                
                GT = GT.to(device)

                # SR : Segmentation Result
                SR = model(images)
                SR_prob = torch.softmax(SR, dim= 1)
                _,preds = torch.max(SR_prob, 1)
                loss = criterion(SR_prob, GT)
                epoch_loss += loss.item()

                # Backprop + optimize
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                preds = torch.squeeze(preds)
                GT = torch.squeeze(GT)
                acc += get_accuracy(preds, GT)
                SE += get_sensitivity(preds, GT)
                SP += get_specificity(preds, GT)
                JS += get_JS(preds, GT)
                DC += get_DC(preds, GT)
                length += images.size(0)

            acc = acc/length
            SE = SE/length
            SP = SP/length
            JS = JS/length
            DC = DC/length

            # Print the log info
            print('Epoch [%d/%d], Loss: %.4f, \n[Training] Acc: %.4f, SE: %.4f, SP: %.4f, JS: %.4f, DC: %.4f' % (
              epoch+1, num_epochs,
              epoch_loss,
              acc, SE, SP, JS, DC))

            #===================================== Validation ====================================#
            model.train(False)
            model.eval()

            acc = 0.  # Accuracy
            SE = 0.		# Sensitivity (Recall)
            SP = 0.		# Specificity
            JS = 0.		# Jaccard Similarity
            DC = 0.		# Dice Coefficient
            length = 0
            model_score = 0
            
            for i, (images, GT) in enumerate(valid_loader):

                images = images.to(device)
                GT = GT.to(device)
                
                SR = model(images)
                SR_prob = torch.softmax(SR, dim= 1)
                _,preds = torch.max(SR_prob, 1)

                preds = torch.squeeze(preds)
                GT = torch.squeeze(GT)
                acc += get_accuracy(preds, GT)
                SE += get_sensitivity(preds, GT)
                SP += get_specificity(preds, GT)
                JS += get_JS(preds, GT)
                DC += get_DC(preds, GT)
                length += images.size(0)

            acc = acc/length
            SE = SE/length
            SP = SP/length
            JS = JS/length
            DC = DC/length
            model_score = JS + DC

            print('[Validation] Acc: %.4f, SE: %.4f, SP: %.4f, JS: %.4f, DC: %.4f' % (
                    acc, SE, SP, JS, DC))

            # Save Best U-Net model
            if model_score > best_model_score:
                best_model_score = model_score # float type
                best_epoch = epoch
                best_net = model.state_dict()
                print('Best model score : %.4f' %
                      (best_model_score))
                torch.save(best_net, net_path)

                f = open(os.path.join(
                    result_path, 'result-validate.csv'), 'a', encoding='utf-8', newline='')
                wr = csv.writer(f)
                wr.writerow([best_epoch, num_epochs])
                f.close()
            
            save_model(model,current_epoch,best_model_score,lr,optimizer,state_dict_dir,checkpoint_dir)
            current_epoch += 1

In [None]:
import torch.optim as optim
import torch
model = Vess_Net(input_ch= 3, output_ch= 2)
lr = 0.0005
print('build_model')
optimizer = optim.Adam(model.parameters(), lr= lr, eps= 0.000001)
criterion = torch.nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

state_dict_dir= '/content/gdrive/MyDrive/train_data_hrf_vess_net/state_dicts'
result_path = '/content/gdrive/MyDrive/train_data_hrf_vess_net/result'
net_path =  '/content/gdrive/MyDrive/train_data_hrf_vess_net/models/best_model_without_norm_50.pth'
checkpoint_dir= '/content/gdrive/MyDrive/train_data_hrf_vess_net/checkpoints'

current_epoch= 0
num_epochs= 100
best_model_score= 0

train_path = '/content/gdrive/MyDrive/train_data_hrf_vess_net/training'
train_loader = get_loader(train_path, 1, num_workers= 4)
val_path = '/content/gdrive/MyDrive/train_data_hrf_vess_net/validate'
valid_loader = get_loader(val_path, 1, num_workers= 4, mode='valid')

try:
    checkpoint = load_model_checkpoint('checkpoint_without_norm_50.pickle', checkpoint_dir)
    print('checkpoint found. loading state...')
    current_epoch = checkpoint['epoch'] + 1
    best_model_score = checkpoint['best_model_score']
    lr = checkpoint['lr']
    optimizer.load_state_dict(checkpoint['optimizer'])
    optimizer_to(optimizer,device)
    state_dict_path = f'{state_dict_dir}/state_dict_without_norm_50.pth'
    model.load_state_dict(torch.load(state_dict_path))
    print('checkpoint state loaded successfully')

except FileNotFoundError:
    print('the first training')
print(f'current_epoch: {current_epoch}, best_unet_score: {best_model_score}, lr: {lr}')

model.to(device)
train(model, optimizer, train_loader,valid_loader, state_dict_dir,checkpoint_dir, lr, current_epoch, num_epochs, device, criterion, best_model_score, net_path, result_path)

In [None]:
def test(model, test_loader,result_path,device):
        
        model.train(False)
        model.eval()

        acc = 0.  # Accuracy
        SE = 0.		# Sensitivity (Recall)
        SP = 0.		# Specificity
        PC = 0. 	# Precision
        F1 = 0.		# F1 Score
        JS = 0.		# Jaccard Similarity
        DC = 0.		# Dice Coefficient
        length = 0
        for i, (images, GT) in enumerate(test_loader):

            images = images.to(device)
            GT = GT.to(device)
            SR = model(images)
            SR_prob = torch.softmax(SR, dim= 1)
            _,preds = torch.max(SR_prob, 1)
            preds = torch.squeeze(preds)
            GT = torch.squeeze(GT)
            acc += get_accuracy(preds, GT)
            SE += get_sensitivity(preds, GT)
            SP += get_specificity(preds, GT)
            PC += get_precision(preds, GT)
            F1 += get_F1(preds, GT)
            JS += get_JS(preds, GT)
            DC += get_DC(preds, GT)

            length += images.size(0)
            
            preds = preds.type(torch.FloatTensor)
            
            GT = GT.type(torch.FloatTensor)
            torchvision.utils.save_image(images.data.cpu(),os.path.join(result_path + f'/{i}',f'test_image_{i}_50.png'))
            save_img(preds.numpy(),os.path.join(result_path + f'/{i}',f'test_SR_{i}_50.png'))
            save_img(GT.numpy(),os.path.join(result_path + f'/{i}',f'test_GT_{i}_50.png'))

        acc = acc/length
        SE = SE/length
        SP = SP/length
        PC = PC/length
        F1 = F1/length
        JS = JS/length
        DC = DC/length
        print('[TEST] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % (
                    acc, SE, SP, PC, F1, JS, DC))
        

        f = open(os.path.join(result_path, 'result.csv'),
                 'a', encoding='utf-8', newline='')
        wr = csv.writer(f)
        wr.writerow([acc, SE, SP, PC, F1, JS, DC])
        f.close()

In [None]:
test_loader = get_loader(image_path='/content/gdrive/MyDrive/STARE/test',
                        batch_size=1,
                        num_workers=0,
                        )
model = Vess_Net(input_ch= 3, output_ch= 2)
model.load_state_dict(torch.load('/content/gdrive/MyDrive/STARE/state_dicts/state_dict_without_norm_50.pth', map_location=torch.device('cpu')))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
result_path = '/content/gdrive/MyDrive/STARE/result'
model.to(device)

test(model, test_loader, result_path,device)

image count in train path :20
[TEST] Acc: 0.9698, SE: 0.7632, SP: 0.9883, PC: 0.8843, F1: 0.8004, JS: 0.6803, DC: 0.8004
