In [None]:
!pip install einops


In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torchsummary import summary



from torchvision import models
import torch.nn as nn

import numpy as np
from tqdm import tqdm_notebook as tqdm
import cv2
from sklearn import metrics
from sklearn.metrics import jaccard_score as js
from PIL import Image, ImageOps
from torch.autograd import Variable as v
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
%matplotlib inline
import pandas as pd
import os
import pickle



In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
def randomHueSaturationValue(image, hue_shift_limit=(-180, 180),
                             sat_shift_limit=(-255, 255),
                             val_shift_limit=(-255, 255), u=0.5):
    if np.random.random() < u:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
        h, s, v = cv2.split(image)
        hue_shift = np.random.randint(hue_shift_limit[0], hue_shift_limit[1]+1)
        hue_shift = np.uint8(hue_shift)
        h += hue_shift
        sat_shift = np.random.uniform(sat_shift_limit[0], sat_shift_limit[1])
        s = cv2.add(s, sat_shift)
        val_shift = np.random.uniform(val_shift_limit[0], val_shift_limit[1])
        v = cv2.add(v, val_shift)
        image = cv2.merge((h, s, v))
        #image = cv2.merge((s, v))
        image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
    return image

def randomShiftScaleRotate(image, mask,
                           shift_limit=(-0.0, 0.0),
                           scale_limit=(-0.0, 0.0),
                           rotate_limit=(-0.0, 0.0), 
                           aspect_limit=(-0.0, 0.0),
                           borderMode=cv2.BORDER_CONSTANT, u=0.5):
    if np.random.random() < u:
        height, width, channel = image.shape
        angle = np.random.uniform(rotate_limit[0], rotate_limit[1])
        scale = np.random.uniform(1 + scale_limit[0], 1 + scale_limit[1])
        aspect = np.random.uniform(1 + aspect_limit[0], 1 + aspect_limit[1])
        sx = scale * aspect / (aspect ** 0.5)
        sy = scale / (aspect ** 0.5)
        dx = round(np.random.uniform(shift_limit[0], shift_limit[1]) * width)
        dy = round(np.random.uniform(shift_limit[0], shift_limit[1]) * height)
        cc = np.math.cos(angle / 180 * np.math.pi) * sx
        ss = np.math.sin(angle / 180 * np.math.pi) * sy
        rotate_matrix = np.array([[cc, -ss], [ss, cc]])
        box0 = np.array([[0, 0], [width, 0], [width, height], [0, height], ])
        box1 = box0 - np.array([width / 2, height / 2])
        box1 = np.dot(box1, rotate_matrix.T) + np.array([width / 2 + dx, height / 2 + dy])
        box0 = box0.astype(np.float32)
        box1 = box1.astype(np.float32)
        mat = cv2.getPerspectiveTransform(box0, box1)
        image = cv2.warpPerspective(image, mat, (width, height), flags = cv2.INTER_LINEAR, borderMode = borderMode,borderValue = (0, 0,0,))
        mask = cv2.warpPerspective(mask, mat, (width, height), flags = cv2.INTER_LINEAR, borderMode = borderMode, borderValue = (0, 0,0,))
    return image, mask

def randomFlip(image, mask, u=0.5):
    if np.random.random() < u:
        image = cv2.flip(image, 1)
        mask = cv2.flip(mask, 1)
    return image, mask

#def randomRotate90
#def randomMove

def randomRotate90(image, mask, u=0.5):
    if np.random.random() < u:
        image=np.rot90(image)
        mask=np.rot90(mask)
    return image, mask
      
def loader(img_path, mask_path, phase):
    #print('image: ', img_path)
    #print('mask: ', mask_path)
    
    img = cv2.imread(img_path)
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)  #, cv2.IMREAD_GRAYSCALE

    #print(img_path)
    #print(img.shape)
    #print(mask.shape)

    img = cv2.resize(img, (256,256),cv2.INTER_AREA)
    mask = cv2.resize(mask, (256,256),cv2.INTER_AREA)

    #img=Image.open(img_path)
    #img=img.resize((256,256), Image.ANTIALIAS)
    #mask=Image.open(mask_path)
    #mask=mask.resize((256,256), Image.ANTIALIAS)
    
    #if(phase == 'train'):
    #    img = randomHueSaturationValue(img,
    #                                  hue_shift_limit=(-30, 30),
    #                                  sat_shift_limit=(-5, 5),
    #                                  val_shift_limit=(-15, 15)
    #                                  )

    #    img, mask = randomShiftScaleRotate(img, mask,
    #                                      shift_limit=(-0.1, 0.1),
    #                                      scale_limit=(-0.1, 0.1),
    #                                      aspect_limit=(-0.1, 0.1),
    #                                      rotate_limit=(-0, 0))
    #    img, mask = randomFlip(img, mask)
    #    img, mask = randomRotate90(img, mask)

    img = np.array(img, np.float32).transpose(2,0,1)/255.0 
    mask = np.array(mask, np.float32)
    mask = mask[np.newaxis,:,:]
    #print(mask.sum())
    if(phase=='Train'):
      mask = mask/255.0
    else:
      mask=mask/255.0
    mask[mask >= 0.5] = 1
    mask[mask < 0.5] = 0
    #mask = 1 - mask
    #mask = mask[0, :, :, :]
    #print(mask[:, :, 0].shape)
    #t_mask = np.array([mask[:, :, 0], mask[:, :, 1], mask[:, :, 2]])
    

    #mask = t_mask
    #print(mask.shape, img.shape)
    return img, mask

def read_dataset(root_path, mode):
  images = []
  masks = []

  if(mode == 'Train'):
    image_root = os.path.join(root_path, 'Training/Images')
    gt_root = os.path.join(root_path, 'Training/ODMask')
  else :
    image_root = os.path.join(root_path, 'Test/Images')  
    gt_root = os.path.join(root_path, 'Test/ODMask')

  #print(os.listdir(image_root))
  #print(os.listdir(gt_root))
  for image_name in sorted(os.listdir(image_root)):
    image_path = os.path.join(image_root, image_name) #.split('.')[0] + '.jpg')
    images.append(image_path)
  for mask_name in sorted(os.listdir(gt_root)):
    label_path = os.path.join(gt_root, mask_name) #.split('.')[0] + '.tif')
    masks.append(label_path)

    # print(images, masks)
    #images.sort()
    #masks.sort()
  return images, masks

class Eye_Dataset(Dataset):

    def __init__(self, root_path, phase):
        self.root = root_path
        self.phase = phase
        self.images, self.labels = read_dataset(self.root, self.phase)
        print('images: ', self.images)
        print('labels: ', self.labels)

    def __getitem__(self, index):
        
        img, mask = loader(self.images[index], self.labels[index], self.phase)
        img = torch.tensor(img, dtype = torch.float32)
        mask = torch.tensor(mask, dtype = torch.float32)
        return img, mask

    def __len__(self):
        assert len(self.images) == len(self.labels), 'The number of images must be equal to labels'
        return len(self.images)

In [None]:
import torch
import torch.nn as nn


class FixedPositionalEncoding(nn.Module):
    def __init__(self, embedding_dim, max_length=5000):
        super(FixedPositionalEncoding, self).__init__()

        pe = torch.zeros(max_length, embedding_dim)
        position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, embedding_dim, 2).float()
            * (-torch.log(torch.tensor(10000.0)) / embedding_dim)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[: x.size(0), :]
        return x


class LearnedPositionalEncoding(nn.Module):
    def __init__(self, max_position_embeddings, embedding_dim, seq_length):
        super(LearnedPositionalEncoding, self).__init__()
        self.pe = nn.Embedding(max_position_embeddings, embedding_dim)
        self.seq_length = seq_length

        self.register_buffer(
            "position_ids",
            torch.arange(max_position_embeddings).expand((1, -1)),
        )

    def forward(self, x, position_ids=None):
        if position_ids is None:
            position_ids = self.position_ids[:, : self.seq_length]

        position_embeddings = self.pe(position_ids)
        return x + position_embeddings

In [None]:
class IntermediateSequential(nn.Sequential):
    def __init__(self, *args, return_intermediate=True):
        super().__init__(*args)
        self.return_intermediate = return_intermediate

    def forward(self, input):
        if not self.return_intermediate:
            return super().forward(input)

        intermediate_outputs = {}
        output = input
        for name, module in self.named_children():
            output = intermediate_outputs[name] = module(output)

        return output, intermediate_outputs

In [None]:
class SelfAttention(nn.Module):
    def __init__(
        self, dim, heads=8, qkv_bias=False, qk_scale=None, dropout_rate=0.0
    ):
        super().__init__()
        self.num_heads = heads
        head_dim = dim // heads
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(dropout_rate)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(dropout_rate)

    def forward(self, x):
        B, N, C = x.shape
        qkv = (
            self.qkv(x)
            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
            .permute(2, 0, 3, 1, 4)
        )
        q, k, v = (
            qkv[0],
            qkv[1],
            qkv[2],
        )  # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x):
        return self.fn(self.norm(x))


class PreNormDrop(nn.Module):
    def __init__(self, dim, dropout_rate, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.fn = fn

    def forward(self, x):
        return self.dropout(self.fn(self.norm(x)))


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout_rate):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(p=dropout_rate),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(p=dropout_rate),
        )

    def forward(self, x):
        return self.net(x)


class TransformerModel(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        heads,
        mlp_dim,
        dropout_rate=0.1,
        attn_dropout_rate=0.1,
    ):
        super().__init__()
        layers = []
        for _ in range(depth):
            layers.extend(
                [
                    Residual(
                        PreNormDrop(
                            dim,
                            dropout_rate,
                            SelfAttention(
                                dim, heads=heads, dropout_rate=attn_dropout_rate
                            ),
                        )
                    ),
                    Residual(
                        PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate))
                    ),
                ]
            )
        self.net = IntermediateSequential(*layers)

    def forward(self, x):
        return self.net(x)

In [None]:
class SegmentationTransformer(nn.Module):
    def __init__(
        self,
        img_dim,
        patch_dim,
        num_channels,
        embedding_dim,
        num_heads,
        num_layers,
        hidden_dim,
        dropout_rate=0.0,
        attn_dropout_rate=0.0,
        conv_patch_representation=False,
        positional_encoding_type="learned",
    ):
        super(SegmentationTransformer, self).__init__()

        assert embedding_dim % num_heads == 0
        assert img_dim % patch_dim == 0

        self.img_dim = img_dim
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.patch_dim = patch_dim
        self.num_channels = num_channels
        self.dropout_rate = dropout_rate
        self.attn_dropout_rate = attn_dropout_rate
        self.conv_patch_representation = conv_patch_representation

        self.num_patches = int((img_dim // patch_dim) ** 2)
        self.seq_length = self.num_patches
        self.flatten_dim = patch_dim * patch_dim * num_channels

        self.linear_encoding = nn.Linear(self.flatten_dim, embedding_dim)
        if positional_encoding_type == "learned":
            self.position_encoding = LearnedPositionalEncoding(
                self.seq_length, self.embedding_dim, self.seq_length
            )
        elif positional_encoding_type == "fixed":
            self.position_encoding = FixedPositionalEncoding(
                self.embedding_dim,
            )

        self.pe_dropout = nn.Dropout(p=self.dropout_rate)

        self.transformer = TransformerModel(
            embedding_dim,
            num_layers,
            num_heads,
            hidden_dim,
            self.dropout_rate,
            self.attn_dropout_rate,
        )
        self.pre_head_ln = nn.LayerNorm(embedding_dim)

        if self.conv_patch_representation:
            self.conv_x = nn.Conv2d(
                self.num_channels,
                self.embedding_dim,
                kernel_size=(self.patch_dim, self.patch_dim),
                stride=(self.patch_dim, self.patch_dim),
                padding=self._get_padding(
                    'VALID', (self.patch_dim, self.patch_dim),
                ),
            )
        else:
            self.conv_x = None

    def _init_decode(self):
        raise NotImplementedError("Should be implemented in child class!!")

    def encode(self, x):
        n, c, h, w = x.shape
        if self.conv_patch_representation:
            # combine embedding w/ conv patch distribution
            x = self.conv_x(x)
            x = x.permute(0, 2, 3, 1).contiguous()
            x = x.view(x.size(0), -1, self.embedding_dim)
        else:
            x = (
                x.unfold(2, self.patch_dim, self.patch_dim)
                .unfold(3, self.patch_dim, self.patch_dim)
                .contiguous()
            )
            x = x.view(n, c, -1, self.patch_dim ** 2)
            x = x.permute(0, 2, 3, 1).contiguous()
            x = x.view(x.size(0), -1, self.flatten_dim)
            x = self.linear_encoding(x)

        x = self.position_encoding(x)
        x = self.pe_dropout(x)

        # apply transformer
        x, intmd_x = self.transformer(x)
        x = self.pre_head_ln(x)

        return x, intmd_x

    def decode(self, x):
        raise NotImplementedError("Should be implemented in child class!!")

    def forward(self, x, auxillary_output_layers=None):
        encoder_output, intmd_encoder_outputs = self.encode(x)
        decoder_output = self.decode(
            encoder_output, intmd_encoder_outputs, auxillary_output_layers
        )

        if auxillary_output_layers is not None:
            auxillary_outputs = {}
            for i in auxillary_output_layers:
                val = str(2 * i - 1)
                _key = 'Z' + str(i)
                auxillary_outputs[_key] = intmd_encoder_outputs[val]

            return decoder_output, auxillary_outputs

        return decoder_output

    def _get_padding(self, padding_type, kernel_size):
        assert padding_type in ['SAME', 'VALID']
        if padding_type == 'SAME':
            _list = [(k - 1) // 2 for k in kernel_size]
            return tuple(_list)
        return tuple(0 for _ in kernel_size)

    def _reshape_output(self, x):
        x = x.view(
            x.size(0),
            int(self.img_dim / self.patch_dim),
            int(self.img_dim / self.patch_dim),
            self.embedding_dim,
        )
        x = x.permute(0, 3, 1, 2).contiguous()
        return x

In [None]:
class SegmentationTransformer(nn.Module):
    def __init__(
        self,
        img_dim,
        patch_dim,
        num_channels,
        embedding_dim,
        num_heads,
        num_layers,
        hidden_dim,
        dropout_rate=0.0,
        attn_dropout_rate=0.0,
        conv_patch_representation=False,
        positional_encoding_type="learned",
    ):
        super(SegmentationTransformer, self).__init__()

        assert embedding_dim % num_heads == 0
        assert img_dim % patch_dim == 0

        self.sig = nn.Sigmoid()

        self.img_dim = img_dim
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.patch_dim = patch_dim
        self.num_channels = num_channels
        self.dropout_rate = dropout_rate
        self.attn_dropout_rate = attn_dropout_rate
        self.conv_patch_representation = conv_patch_representation

        self.num_patches = int((img_dim // patch_dim) ** 2)
        self.seq_length = self.num_patches
        self.flatten_dim = patch_dim * patch_dim * num_channels

        self.linear_encoding = nn.Linear(self.flatten_dim, embedding_dim)
        if positional_encoding_type == "learned":
            self.position_encoding = LearnedPositionalEncoding(
                self.seq_length, self.embedding_dim, self.seq_length
            )
        elif positional_encoding_type == "fixed":
            self.position_encoding = FixedPositionalEncoding(
                self.embedding_dim,
            )

        self.pe_dropout = nn.Dropout(p=self.dropout_rate)

        self.transformer = TransformerModel(
            embedding_dim,
            num_layers,
            num_heads,
            hidden_dim,
            self.dropout_rate,
            self.attn_dropout_rate,
        )
        self.pre_head_ln = nn.LayerNorm(embedding_dim)

        if self.conv_patch_representation:
            self.conv_x = nn.Conv2d(
                self.num_channels,
                self.embedding_dim,
                kernel_size=(self.patch_dim, self.patch_dim),
                stride=(self.patch_dim, self.patch_dim),
                padding=self._get_padding(
                    'VALID', (self.patch_dim, self.patch_dim),
                ),
            )
        else:
            self.conv_x = None

    def _init_decode(self):
        raise NotImplementedError("Should be implemented in child class!!")

    def encode(self, x):
        n, c, h, w = x.shape
        if self.conv_patch_representation:
            # combine embedding w/ conv patch distribution
            x = self.conv_x(x)
            x = x.permute(0, 2, 3, 1).contiguous()
            x = x.view(x.size(0), -1, self.embedding_dim)
        else:
            x = (
                x.unfold(2, self.patch_dim, self.patch_dim)
                .unfold(3, self.patch_dim, self.patch_dim)
                .contiguous()
            )
            x = x.view(n, c, -1, self.patch_dim ** 2)
            x = x.permute(0, 2, 3, 1).contiguous()
            x = x.view(x.size(0), -1, self.flatten_dim)
            x = self.linear_encoding(x)

        x = self.position_encoding(x)
        x = self.pe_dropout(x)

        # apply transformer
        x, intmd_x = self.transformer(x)
        x = self.pre_head_ln(x)

        return x, intmd_x

    def decode(self, x):
        raise NotImplementedError("Should be implemented in child class!!")

    def forward(self, x, auxillary_output_layers=None):
        encoder_output, intmd_encoder_outputs = self.encode(x)
        decoder_output = self.decode(
            encoder_output, intmd_encoder_outputs, auxillary_output_layers
        )

        if auxillary_output_layers is not None:
            auxillary_outputs = {}
            for i in auxillary_output_layers:
                val = str(2 * i - 1)
                _key = 'Z' + str(i)
                auxillary_outputs[_key] = intmd_encoder_outputs[val]

            return decoder_output, auxillary_outputs

        return decoder_output

    def _get_padding(self, padding_type, kernel_size):
        assert padding_type in ['SAME', 'VALID']
        if padding_type == 'SAME':
            _list = [(k - 1) // 2 for k in kernel_size]
            return tuple(_list)
        return tuple(0 for _ in kernel_size)

    def _reshape_output(self, x):
        x = x.view(
            x.size(0),
            int(self.img_dim / self.patch_dim),
            int(self.img_dim / self.patch_dim),
            self.embedding_dim,
        )
        x = x.permute(0, 3, 1, 2).contiguous()
        return x


class SETR_Naive(SegmentationTransformer):
    def __init__(
        self,
        img_dim=256,
        patch_dim=16,
        num_channels=3,
        num_classes=1,
        embedding_dim=768,
        num_heads=8,
        num_layers=12,
        hidden_dim=3,
        dropout_rate=0.0,
        attn_dropout_rate=0.0,
        conv_patch_representation=False,
        positional_encoding_type="learned",
    ):
        super(SETR_Naive, self).__init__(
            img_dim=img_dim,
            patch_dim=patch_dim,
            num_channels=num_channels,
            embedding_dim=embedding_dim,
            num_heads=num_heads,
            num_layers=num_layers,
            hidden_dim=hidden_dim,
            dropout_rate=dropout_rate,
            attn_dropout_rate=attn_dropout_rate,
            conv_patch_representation=conv_patch_representation,
            positional_encoding_type=positional_encoding_type,
        )

        self.num_classes = num_classes
        self._init_decode()

    def _init_decode(self):
        self.conv1 = nn.Conv2d(
            in_channels=self.embedding_dim,
            out_channels=self.embedding_dim,
            kernel_size=1,
            stride=1,
            padding=self._get_padding('VALID', (1, 1),),
        )
        self.bn1 = nn.BatchNorm2d(self.embedding_dim)
        self.act1 = nn.ReLU()
        self.conv2 = nn.Conv2d(
            in_channels=self.embedding_dim,
            out_channels=self.num_classes,
            kernel_size=1,
            stride=1,
            padding=self._get_padding('VALID', (1, 1),),
        )
        self.upsample = nn.Upsample(
            scale_factor=self.patch_dim, mode='bilinear'
        )

    def forward(self, x, intmd_layers=None): #intmd_x, 
        #print(x.shape)
        x = self._reshape_output(x)
        #print("yo",x.shape)
        x = self.conv1(x)
        #print("yoa",x.shape)
        x = self.bn1(x)
        #x = self.act1(x)
        #x = self.conv2(x)
        #print("yob",x.shape)
        #x = self.upsample(x)
        #print("yoc",x.shape)
        #x= self.sig(x)
        #print("yo2",list(x.shape))
        return x


In [None]:
#RFB
class BasicConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
        super(BasicConv2d, self).__init__()
        
        
        self.conv = nn.Conv2d(in_planes, out_planes,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)
  
        self.convx = nn.Conv2d(in_channels = 64 , out_channels = 256, kernel_size = 1)
        

    def forward(self, x):
        #print("x is",list(x.shape))
        x = self.conv(x)
        x = self.bn(x)
        return x
class BasicConv2d1(nn.Module): #this is a lil customized to balance the tensor of my previous work feel  free to make changes
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
        super(BasicConv2d1, self).__init__()
        
        
        self.conv = nn.Conv2d(in_planes, out_planes,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)
        self.convx = nn.Conv2d(in_channels = 32 , out_channels = 64, kernel_size = 1)
        

    def forward(self, x):
        #print("x is",list(x.shape))
        #x= self.convx(x)
        x = self.conv(x)
        x = self.bn(x)
        #x = self.convx(x)
        #print("x is",list(x.shape))
        return x
class RFB_modified(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(RFB_modified, self).__init__()
        self.relu = nn.ReLU(True)
        self.branch0 = nn.Sequential(
            BasicConv2d(in_channel, out_channel, 1),
        )
        self.branch1 = nn.Sequential(
            BasicConv2d(in_channel, out_channel, 1),
            BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
            BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
            BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3)
        )
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channel, out_channel, 1),
            BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
            BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
            BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5)
        )
        self.branch3 = nn.Sequential(
            BasicConv2d(in_channel, out_channel, 1),
            BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
            BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
            BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7) 
        )
        self.conv_cat = BasicConv2d(4*out_channel, out_channel, 3, padding=1)
        self.conv_res = BasicConv2d1(in_channel, out_channel, 1)
        #self.convx = nn.Conv2d(in_channels = 6 , out_channels = 256, kernel_size = 1)
        #self.convx1 = nn.Conv2d(in_channels = 256 , out_channels = 64, kernel_size = 1)
    def forward(self, x):
      #print("fff",x.shape)
      x0 = self.branch0(x)
      x1 = self.branch1(x)
      x2 = self.branch2(x)
      x3 = self.branch3(x)
      x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1))
      #x_cat =self.convx(x_cat)
      #print("cat",x_cat.shape)
      x = self.relu(x_cat + self.conv_res(x))
      return x




In [None]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)


    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class BatchNorm(nn.Module):
  def init(self, out_channels):
    super(BatchNorm, self).init()
    #self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3)
    self.bn = nn.BatchNorm2d(out_channels)
    self.relu = nn.ReLU()

  def forward(self, x):
    #x = self.conv(x)
    x = self.bn(x)
    x = self.relu(x)
    return x



In [None]:

class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1)


class BiFusion_block(nn.Module):
    def __init__(self, ch_1, ch_2, r_2, ch_int, ch_out, drop_rate=0.):
        super(BiFusion_block, self).__init__()

        # channel attention for F_g, use SE Block
        self.fc1 = nn.Conv2d(ch_2, ch_2 // r_2, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(ch_2 // r_2, ch_2, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

        # spatial attention for F_l
        self.compress = ChannelPool()
        self.spatial = Conv(2, 1, 7, bn=True, relu=False, bias=False)

        # bi-linear modelling for both
        self.W_g = Conv(ch_1, ch_int, 1, bn=True, relu=False)
        self.W_x = Conv(ch_2, ch_int, 1, bn=True, relu=False)
        self.W = Conv(ch_int, ch_int, 3, bn=True, relu=True)

        self.relu = nn.ReLU(inplace=True)

        self.residual = Residual1(ch_1+ch_2+ch_int, ch_out)

        self.dropout = nn.Dropout2d(drop_rate)
        self.drop_rate = drop_rate

        
    def forward(self, g, x):
        # bilinear pooling
        W_g = self.W_g(g)
        W_x = self.W_x(x)
        bp = self.W(W_g*W_x)

        # spatial attention for cnn branch
        g_in = g
        g = self.compress(g)
        g = self.spatial(g)
        g = self.sigmoid(g) * g_in

        # channel attetion for transformer branch
        x_in = x
        x = x.mean((2, 3), keepdim=True)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x) * x_in
        fuse = self.residual(torch.cat([g, x, bp], 1))

        if self.drop_rate > 0:
            return self.dropout(fuse)
        else:
            return fuse

In [None]:

class Residual1(nn.Module):
    def __init__(self, inp_dim, out_dim):
        super(Residual1, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        self.bn1 = nn.BatchNorm2d(inp_dim)
        self.conv1 = Conv(inp_dim, int(out_dim/2), 1, relu=False)
        self.bn2 = nn.BatchNorm2d(int(out_dim/2))
        self.conv2 = Conv(int(out_dim/2), int(out_dim/2), 3, relu=False)
        self.bn3 = nn.BatchNorm2d(int(out_dim/2))
        self.conv3 = Conv(int(out_dim/2), out_dim, 1, relu=False)
        self.skip_layer = Conv(inp_dim, out_dim, 1, relu=False)
        if inp_dim == out_dim:
            self.need_skip = False
        else:
            self.need_skip = True
        
    def forward(self, x):
        if self.need_skip:
            residual = self.skip_layer(x)
        else:
            residual = x
        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn3(out)
        out = self.relu(out)
        out = self.conv3(out)
        out += residual
        return out 


class Conv(nn.Module):
    def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1, bn=False, relu=True, bias=True):
        super(Conv, self).__init__()
        self.inp_dim = inp_dim
        self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size-1)//2, bias=bias)
        self.relu = None
        self.bn = None
        if relu:
            self.relu = nn.ReLU(inplace=True)
        if bn:
            self.bn = nn.BatchNorm2d(out_dim)

    def forward(self, x):
        assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[2], self.inp_dim)
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

In [None]:
#just playing around a bit...don't mind me
class UTNet(nn.Module):
    def __init__(self, n_channels = 3, n_classes = 1, bilinear=False):
        super(UTNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        factor = 2 if bilinear else 1

        self.act = nn.Sigmoid()

        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024 // factor)

        self.up4 = Up(1024, 512 // factor, bilinear)
        self.up3 = Up(512, 256 // factor, bilinear)
        self.up2 = Up(256, 128, bilinear) 
        self.up1 = Up(128, 64, bilinear)

        self.inc = DoubleConv(3, 64)
        self.con1 = DoubleConv(128, 128)
        self.con2 = DoubleConv(256, 256)
        self.con3 = DoubleConv(512, 512)
        self.con4 = DoubleConv(1792, 1024)
        self.con5 = DoubleConv(512, 512)
        self.con6 = DoubleConv(256, 256)
        self.con7 = DoubleConv(128, 128)
        self.con8 = DoubleConv(64, 64)
        self.outconv = OutConv(64, 1)
        self.filter = OutConv(1,1)
        self.rectify = DoubleConv(768, 1024)

        self.rfb_u4 = RFB_modified(512, 512)
        self.rfb_u3 = RFB_modified(256, 256)
        self.rfb_u2 = RFB_modified(128, 128)
        self.rfb_u1 = RFB_modified(64, 64)

        self.tru = SETR_Naive()
        self.sig = nn.Sigmoid()

        self.up_c_2_1 = BiFusion_block(ch_1=1024, ch_2=768, r_2=1, ch_int=16, ch_out=1792, drop_rate=0/2)

    def forward(self, x):


        x1 = self.inc(x)
        x1_s = x1
        
        x2 = self.down1(x1)
        x2 = self.con1(x2)
        x2_s = x2

        x3 = self.down2(x2)
        x3 = self.con2(x3)
        x3_s = x3

        x4 = self.down3(x3)
        x4 = self.con3(x4)
        x4_s = x4


        xb = self.down4(x4)
        #print("xb is",list(xb.shape))
        xt = self.tru(x)
        #xt = self.rectify(xt)
        #print('xb: ', xb.shape)
        #print('xt: ', xt.shape)
        xb = self.up_c_2_1(xb, xt) 
        #xb = torch.cat((xb, xt),1)
        xb = self.con4(xb)
        xb_s = xb


        y4 = self.up4(xb_s, x4_s)
        y4 = self.rfb_u4(y4)
        y4 = self.con5(y4)
        y4_s = y4

        y3 = self.up3(y4_s, x3_s)
        y3 = self.rfb_u3(y3)
        y3 = self.con6(y3)
        y3_s = y3

        y2 = self.up2(y3_s, x2_s)
        y2 = self.rfb_u2(y2)
        y2 = self.con7(y2)
        y2_s = y2

        y1 = self.up1(y2_s, x1_s)
        y1 = self.rfb_u1(y1)
        y1 = self.con8(y1)

        out = self.outconv(y1)
        #out = self.filter(out)
        out = self.act(out)

        return out

In [None]:
root_path = '/content/drive/MyDrive/Datasets/DRISHTI-GS/Drishti-GS1_files/Drishti-GS1_files'
input_size = (3,256,256) #for kaggle 448
batch_size = 1
learning_rate = 0.0001
epochs = 300

INITAL_EPOCH_LOSS = 10000
NUM_EARLY_STOP = 20
NUM_UPDATE_LR = 5

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [None]:
def acc_sen_iou(pred, mask) :

    pred = torch.round(pred)
    #TP = (mask * pred).sum(1).sum(1).sum(1)
    TP = (mask * pred).sum()
    #TN = ((1 - mask) * (1 - pred)).sum(1).sum(1).sum(1)
    TN = ((1 - mask) * (1 - pred)).sum()
    #FP = pred.sum(1).sum(1).sum(1) - TP
    FP = pred.sum() - TP
    #FN = mask.sum(1).sum(1).sum(1) - TP
    FN = mask.sum() - TP
    acc = (TP + TN)/ (TP + TN + FP + FN)
    acc = torch.sum(acc)
    iou = (TP)/(TP + FN + FP)
    iou = torch.sum(iou)

    #iou = jsc(mask.cpu().numpy().reshape(-1), pred.cpu().numpy().reshape(-1))

    sen = TP / (TP + FN)
    sen = torch.sum(sen)
    return acc, sen, iou

def diceCoeff(pred, gt, smooth=1e-5, activation='none'):
    """ computational formula：
        dice = (2 * (pred ∩ gt)) / (pred ∪ gt)
    """
 
    if activation is None or activation == "none":
        activation_fn = lambda x: x
    elif activation == "sigmoid":
        activation_fn = nn.Sigmoid()
    elif activation == "softmax2d":
        activation_fn = nn.Softmax2d()
    else:
        raise NotImplementedError("Activation implemented for sigmoid and softmax2d activation function operation")
 
    pred = activation_fn(pred)
 
    N = gt.size(0)
    pred_flat = pred.view(N, -1)
    gt_flat = gt.view(N, -1)
 
    intersection = (pred_flat * gt_flat).sum(1)
    unionset = pred_flat.sum(1) + gt_flat.sum(1)
    loss = (2 * intersection + smooth) / (unionset + smooth)
 
    return loss.sum() / N

#calculation of precision and recall
def calc_prerec(mask, pred):
  pred = torch.round(pred)
  #TP = (mask * pred).sum(1).sum(1).sum(1)
  TP = (mask * pred).sum()
  #FP = pred.sum(1).sum(1).sum(1) - TP
  FP = pred.sum() - TP
  #FN = mask.sum(1).sum(1).sum(1) - TP
  FN = mask.sum() - TP
  prec = (TP)/ (TP + FP)
  prec = torch.sum(prec)
  recc = TP / (TP + FN)
  recc = torch.sum(recc)
  return prec, recc

#calculate DSC
def dice_score(mask,pred):
  pred = torch.round(pred)
  #TP = (mask * pred).sum(1).sum(1).sum(1)
  TP = (mask * pred).sum()
  #FP = pred.sum(1).sum(1).sum(1) - TP
  FP = pred.sum() - TP
  #FN = mask.sum(1).sum(1).sum(1) - TP
  FN = mask.sum() - TP
  dice=(2*TP)/(2*TP+FP+FN)
  dice=torch.sum(dice)
  return dice
'''
#calcuation of jaccard index between two random images
def dice_coef(img, img2):
        if img.shape != img2.shape:
            raise ValueError("Shape mismatch: img and img2 must have to be of the same shape.")
        else:

            lenIntersection=0

            for i in range(img.shape[0]):
                for j in range(img.shape[1]):
                    if ( np.array_equal(img[i][j],img2[i][j]) ):
                        lenIntersection+=1

            lenimg=img.shape[0]img.shape[1]
            lenimg2=img2.shape[0]img2.shape[1]
            value2 = (lenIntersection  / (lenimg + lenimg2 - lenIntersection))
        return value2
'''

'''
pred = (np.random.rand(1, 500,500))
mask = np.round_(np.random.randint(0,2,(1, 500,500)))
#print(.)
#print(mask)
a, s = acc_sen(torch.Tensor(pred), torch.Tensor(mask))
d = dice_score(torch.Tensor(mask), torch.Tensor(pred))
p, r = calc_prerec(torch.Tensor(mask), torch.Tensor(pred))
print(a, s)
print(d)
print(p, r)
print(pred)
print(mask)
'''

In [None]:
class MyFrame():
    def __init__(self, net, learning_rate, device, evalmode=False):
      self.net = net().to(device)
      self.optimizer = torch.optim.Adam(params=self.net.parameters(), lr=learning_rate, weight_decay=0.0001)
      self.loss = proposed_loss().to(device)
      self.lr = learning_rate
        
    def set_input(self, img_batch, mask_batch=None):
        self.img = img_batch
        self.mask = mask_batch
        
    def optimize(self):
      self.optimizer.zero_grad()
      pred = self.net.forward(self.img)
      #print(list(mask.shape),mask.dtype)
      #print(list(pred.shape),pred.dtype)
      loss = self.loss(self.mask, pred)
      loss.backward()
      self.optimizer.step()
      return loss, pred
        
    def save(self, path):
        torch.save(self.net.state_dict(), path)
        
    def load(self, path):
        self.net.load_state_dict(torch.load(path))

    def update_lr(self, new_lr, factor=False):

        if factor:
            new_lr = self.lr / new_lr
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = new_lr

        print ('update learning rate: %f -> %f' % (self.lr, new_lr))
        print ('update learning rate: %f -> %f' % (self.lr, new_lr))
        self.lr = new_lr

class proposed_loss(nn.Module):
    def __init__(self, batch=True):
        super(proposed_loss, self).__init__()
        self.batch = batch
        self.mae_loss = torch.nn.L1Loss()
        self.bce_loss = torch.nn.BCELoss()

    def soft_dice_coeff(self, y_true, y_pred):
        smooth = 0.0  # may change
        if self.batch:
            i = torch.sum(y_true)
            j = torch.sum(y_pred)
            intersection = torch.sum(y_true * y_pred)
        else:
            i = y_true.sum(1).sum(1).sum(1)
            j = y_pred.sum(1).sum(1).sum(1)
            intersection = (y_true * y_pred).sum(1).sum(1).sum(1)
        score = (2. * intersection + smooth) / (i + j + smooth)
        # score = (intersection + smooth) / (i + j - intersection + smooth)#iou
        return score.mean()

    def soft_dice_loss(self, y_true, y_pred):
        loss = 1 - self.soft_dice_coeff(y_true, y_pred)
        return loss

    def iou_loss(self, inputs, targets):
        smooth = 0.0
        #inputs = inputs.view(-1)
        #targets = targets.view(-1)
        
        intersection = (inputs * targets).sum(1).sum(1).sum(1)
        total = (inputs + targets).sum(1).sum(1).sum(1)
        union = total - intersection 
        
        IoU = (intersection + smooth)/(union + smooth)
                
        return (1 - IoU.mean())

    def forward(self, y_true, y_pred):
        a = self.mae_loss(y_pred, y_true)
        b = self.soft_dice_loss(y_true, y_pred)
        c = self.bce_loss(y_pred, y_true)
        d = self.iou_loss(y_pred, y_true)
        loss = 0.15*a + 0.4*b  + 0.15*c + 0.3*d
        return loss




In [None]:
def visualise(Loss, Accuracy, Dice, IOU, Precision, Recall, EP):
    #plt.subplots(figsize=(20, 3))
    figure, axis = plt.subplots(2, 3, figsize=(30, 6))

    # For Loss Function
    axis[0, 0].plot(EP, Loss, color='purple')
    axis[0, 0].set_title("Loss Function")
    
    # For Accuracy Function
    axis[0, 1].plot(EP, Accuracy, color='black')
    axis[0, 1].set_title("Accuracy")
    
    # For Dice
    axis[0, 2].plot(EP, Dice, color='cyan')
    axis[0, 2].set_title("Dice")
    
    # For IOU
    axis[1, 0].plot(EP, IOU, color='red')
    axis[1, 0].set_title("IOU")

    # For Precision
    axis[1, 1].plot(EP, Precision, color='blue')
    axis[1, 1].set_title("Precision")
    
    # For Recall
    axis[1, 2].plot(EP, Recall, color='green')
    axis[1, 2].set_title("Recall")

    #plt.figure(figsize=(30, 2))
    
    plt.show()

In [None]:
train_dataset = Eye_Dataset(root_path, 'Train')
train_loader = DataLoader(train_dataset, batch_size = 1, shuffle = True)

In [None]:
test_dataset = Eye_Dataset(root_path = root_path, phase = 'Test')
test_loader = DataLoader(test_dataset, batch_size =1, shuffle = False)

In [None]:
def trainer_chan(epoch, epochs, train_loader, solver, logfile):
    keep_training = True
    no_optim = 0
    train_epoch_best_loss = INITAL_EPOCH_LOSS
    prev_loss = 1
    print('Epoch {}/{}'.format(epoch, epochs))
    train_epoch_loss = 0
    train_epoch_dice = 0
    train_epoch_acc = 0
    train_epoch_sen = 0
    train_epoch_pre = 0
    train_epoch_rec = 0
    train_epoch_iou = 0
    p_loss = 10

    # index = 0
    length = len(train_loader)
    iterator = tqdm(enumerate(train_loader), total=length, leave=False, desc=f'Epoch {epoch}/{epochs}')
    for index, (img, mask) in iterator :

        img = img.to(device)
        mask = mask.to(device)
        #print(mask.shape)
        solver.set_input(img, mask)
        train_loss, pred = solver.optimize()
        #print(pred.shape, mask.shape)
        train_acc, train_sen, train_iou = acc_sen_iou(pred,mask)
        train_dice = dice_score(mask, pred)
        train_pre, train_rec = calc_prerec(mask, pred)
        
        train_loss = train_loss.detach().cpu().numpy()
        train_acc = train_acc.detach().cpu().numpy()
        train_sen = train_sen.detach().cpu().numpy()
        train_dice = train_dice.detach().cpu().numpy()
        train_pre = train_pre.detach().cpu().numpy()
        train_rec = train_rec.detach().cpu().numpy()
        train_iou = train_iou.detach().cpu().numpy()

        train_epoch_loss += train_loss
        train_epoch_acc += train_acc
        train_epoch_sen += train_sen
        train_epoch_dice += train_dice
        train_epoch_pre += train_pre
        train_epoch_rec += train_rec
        train_epoch_iou += train_iou
        
        # index = index + 1
        # print(index, end = ' ')

    train_epoch_loss = train_epoch_loss/len(train_dataset)
    train_epoch_acc = train_epoch_acc/len(train_dataset)
    train_epoch_sen = train_epoch_sen/len(train_dataset)
    train_epoch_dice = train_epoch_dice/len(train_dataset)
    train_epoch_pre = train_epoch_pre/len(train_dataset)
    train_epoch_rec = train_epoch_rec/len(train_dataset)
    train_epoch_iou = train_epoch_iou/len(train_dataset)

    print('train_loss:', train_epoch_loss)
    print('train_accuracy:', train_epoch_acc)
    print('train_sensitivity:', train_epoch_sen)
    print('train_dice:', train_epoch_dice)
    print('train_precision', train_epoch_pre)
    print('train_recall', train_epoch_rec)
    print('IOU:', train_epoch_iou)
    print('Learning rate: ', solver.lr)

    logfile.write('Epoch: '+str(epoch)+'/'+str(epochs)+'\n')
    logfile.write('train_loss: '+str(train_epoch_loss)+'\n')
    logfile.write('train_accuracy: '+str(train_epoch_acc)+'\n')
    logfile.write('train_sensitivity: '+str(train_epoch_sen)+'\n')
    logfile.write('train_dice: '+str(train_epoch_dice)+'\n')
    logfile.write('train_precision: '+str(train_epoch_pre)+'\n')
    logfile.write('train_recall: '+str(train_epoch_rec)+'\n')
    logfile.write('train_iou: '+str(train_epoch_iou)+'\n')
    logfile.write('Learning rate: '+str(solver.lr)+'\n')


    if train_epoch_loss >= train_epoch_best_loss:
        no_optim += 1
        prev_loss = train_epoch_loss
    else:
        no_optim = 0
        #train_epoch_best_loss = train_epoch_loss
        prev_loss = train_epoch_loss
        if train_epoch_loss < train_epoch_best_loss:
            solver.save('/content/drive/MyDrive/Saved Models/UTnet_300_Fuse.pth')
            train_epoch_best_loss = train_epoch_loss

    if no_optim > NUM_UPDATE_LR:
        if solver.lr < 1e-9: keep_training=False
        solver.load('/content/drive/MyDrive/Saved Models/UTnet_300_Fuse.pth')
        solver.update_lr(2, factor=True)
        no_optim = 0
    
    

    if no_optim > NUM_EARLY_STOP:
        print('early stop at %d epoch' % epoch)
        print('early stop at %d epoch' % epoch)
        keep_training = False
    

    print('---------------------------------------------')
    return train_epoch_loss, train_epoch_acc, train_epoch_sen, train_epoch_pre, train_epoch_rec, train_epoch_dice, train_epoch_iou, keep_training

def tester_chan(model,test_loader, logfile):
    test_acc = 0
    test_sen = 0
    test_prec= 0
    test_recc= 0
    dsc1 = 0
    test_iou = 0
    test_loss = 0
    test_dice = 0
    Loss_console = proposed_loss()
    with torch.no_grad() : 
        it = 0
        for index, (img, mask) in enumerate(test_loader) :
            it+=1
            img = img.to(device)
            mask = mask.to(device)
            pred = model.forward(img)

            #index = pred.cpu().numpy()
            #os.path.join(image_root, image_name.split('.')[0] + '.png')
            #torch.save(pred,os.path.join(output,image_name.split('.')[0]+'.png'))

            # 
            acc, sen, iou = acc_sen_iou(pred, mask)
            prec,recc = calc_prerec(mask,pred)
            loss = Loss_console.forward(mask, pred)
            dsc= dice_score(mask,pred)
            '''
            if (dsc>=0.8):
            dsc1 += dsc
            count+=1 
            '''
            test_acc += acc
            test_sen += sen
            test_prec += prec
            test_recc += recc
            test_iou += iou
            test_loss += loss
            test_dice += dsc

            # print(index, end = ' ')

        test_acc = test_acc.detach().cpu().numpy()
        test_sen = test_sen.detach().cpu().numpy()
        test_prec = test_prec.detach().cpu().numpy()
        test_recc = test_recc.detach().cpu().numpy()
        test_iou = test_iou.detach().cpu().numpy()
        test_loss = test_loss.detach().cpu().numpy()
        test_dice = test_dice.detach().cpu().numpy()

        test_acc = test_acc / len(test_dataset)
        test_sen = test_sen / len(test_dataset)
        test_prec = test_prec/len(test_dataset)
        test_iou = test_iou/len(test_dataset)
        test_recc = test_recc/len(test_dataset)
        test_loss = test_loss/len(test_dataset)
        test_dice = test_dice/len(test_dataset)


        print('Test Accuracy : ', test_acc)
        print('Test Sensitivity : ', test_sen)
        print('Test Pecision : ', test_prec)
        print('Test Recall : ', test_recc)
        print('Test loss : ', test_loss)
        print('Test IOU : ', test_iou)
        print('Test Dice : ', test_dice)

        logfile.write('----------------------------------------------------------\n')
        logfile.write('test_loss: '+str(test_loss)+'\n')
        logfile.write('test_accuracy: '+str(test_acc)+'\n')
        logfile.write('test_sensitivity: '+str(test_sen)+'\n')
        logfile.write('test_dice: '+str(test_dice)+'\n')
        logfile.write('test_precision: '+str(test_prec)+'\n')
        logfile.write('test_recall: '+str(test_recc)+'\n')
        logfile.write('test_iou: '+str(test_iou)+'\n')
        logfile.write('----------------------------------------------------------\n')
        logfile.write('----------------------------------------------------------\n')
        logfile.write('----------------------------------------------------------\n')
        logfile.close()

        return test_loss, test_acc, test_sen, test_prec, test_recc, test_dice, test_iou

solver = MyFrame(UTNet, learning_rate, device)
solver.load('/content/drive/MyDrive/Saved Models/UTnet_300_Fuse.pth')
#solver.load('/content/drive/MyDrive/Saved Models/UTnet_300_Refuge.pth',map_location=torch.device('cpu'))

tr_Loss = []
tr_Accuracy = []
tr_Sensitivity = []
tr_Dice = []
tr_IOU = []
tr_Precision = []
tr_Recall = []
tr_LR = []

EP = []

te_Loss = []
te_Accuracy = []
te_Sensitivity = []
te_Dice = []
te_IOU = []
te_Precision = []
te_Recall = []
te_LR = []

v = open('/content/drive/MyDrive/Log/Variables/Fuse_train.pkl', 'rb')
tr_Loss, tr_Accuracy, tr_Sensitivity, tr_Dice, tr_IOU, tr_Precision, tr_Recall, EP = pickle.load(v)
v = open('/content/drive/MyDrive/Log/Variables/Fuse_test.pkl', 'rb')
te_Loss, te_Accuracy, te_Sensitivity, te_Dice, te_IOU, te_Precision, te_Recall = pickle.load(v)

#logfile = open('/content/drive/MyDrive/Log/Fuse.txt', 'w')
#logfile.write('Training.Fuse \n')
#logfile.close()

for epoch in range(222, epochs + 1):
    logfile = open('/content/drive/MyDrive/Log/Fuse.txt', 'a')

    l,a,s,p,r,d,i,k = trainer_chan(epoch, epochs, train_loader, solver, logfile)
    tr_Loss += [l]
    tr_Accuracy += [a]
    tr_Sensitivity += [s]
    tr_Dice += [d]
    tr_IOU += [i]
    tr_Precision += [p]
    tr_Recall += [r]
    EP += [epoch]

    tr_comp = [tr_Loss, tr_Accuracy, tr_Sensitivity, tr_Dice, tr_IOU, tr_Precision, tr_Recall, EP]

    lt,at,st,pt,rt,dt,it = tester_chan(solver.net, test_loader, logfile)
    te_Loss += [lt]
    te_Accuracy += [at]
    te_Sensitivity += [st]
    te_Dice += [dt]
    te_IOU += [it]
    te_Precision += [pt]
    te_Recall += [rt]

    te_comp = [te_Loss, te_Accuracy, te_Sensitivity, te_Dice, te_IOU, te_Precision, te_Recall]

    var1 = open('/content/drive/MyDrive/Log/Variables/Fuse_train.pkl', 'wb')
    var2 = open('/content/drive/MyDrive/Log/Variables/Fuse_test.pkl', 'wb')
    pickle.dump(tr_comp, var1)
    pickle.dump(te_comp, var2)
    var1.close()
    var2.close()

    print('Train plots:')
    visualise(tr_Loss, tr_Accuracy, tr_Dice, tr_IOU, tr_Precision, tr_Recall, EP)
    print('Test plots:')
    visualise(te_Loss, te_Accuracy, te_Dice, te_IOU, te_Precision, te_Recall, EP)
    print('----------------------------------------------------------------------------------------')
    print('----------------------------------------------------------------------------------------')
    print('----------------------------------------------------------------------------------------')

    if k: continue
    else: break





In [None]:
model = UTNet()
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(pytorch_total_params)