<a href="https://colab.research.google.com/github/heitingv/Masters_project/blob/master/Weight_Initialization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Code for Transfer Learning: Weight Initialization

Install pydicom and barbar

In [0]:
!pip install pydicom
!pip install barbar

Collecting pydicom
[?25l  Downloading https://files.pythonhosted.org/packages/d3/56/342e1f8ce5afe63bf65c23d0b2c1cd5a05600caad1c211c39725d3a4cc56/pydicom-2.0.0-py3-none-any.whl (35.4MB)
[K     |████████████████████████████████| 35.5MB 89kB/s 
[?25hInstalling collected packages: pydicom
Successfully installed pydicom-2.0.0
Collecting barbar
  Downloading https://files.pythonhosted.org/packages/48/1f/9b69ce144f484cfa00feb09fa752139658961de6303ea592487738d0b53c/barbar-0.2.1-py3-none-any.whl
Installing collected packages: barbar
Successfully installed barbar-0.2.1


Link to drive for data 

In [0]:
from google.colab import drive
drive.mount('/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /gdrive


Import all libraries

In [0]:
import os
import matplotlib
import numpy as np
import cv2
import matplotlib.pyplot as plt
import pydicom
from tqdm import tqdm_notebook as tqdm
from random import randint

import time
from barbar import Bar 
import progressbar
from sklearn import metrics
from sklearn.metrics import roc_curve
from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import auc
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
from tqdm import trange
from time import sleep
from torch.utils.data.sampler import SubsetRandomSampler 
use_gpu = torch.cuda.is_available()

Class for data only with mass:
returns image/mask padded (for patch extraction) and image/mask resized to 250x250

In [0]:
class Dataset(BaseDataset):

    CLASSES = ['non tumor','tumor']

    def __init__(self, images_dir, masks_dir, classes=None):
        self.ids_f=[]
        self.ids_m_f=[]
        self.ids = os.listdir(images_dir)
        self.ids_m = os.listdir(masks_dir)
        for i in range(len(self.ids)):
          self.ids[i]=self.ids[i].rstrip(".dcm")
          for i in range(len(self.ids_m)):
            self.ids_m[i]=self.ids_m[i].rstrip(".png")

        for temp in self.ids_m:
          if temp in self.ids:
            self.ids_f.append(temp+'.dcm')
            self.ids_m_f.append(temp+'.png')

        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids_f]
        self.masks_fps = [os.path.join(masks_dir, mask_id) for mask_id in self.ids_m_f]
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
    
    def __len__(self):
        return len(self.ids_f)
    
    def breast_left_or_right(self, image_array):
      position=None
      image=(image_array>0).float() #transform image into binary for easier analysis
      coordinates_breast_tissue=(image==1).nonzero() #look at coordinatex where there is '1'
      min_coordinates=torch.min(coordinates_breast_tissue,0)[0][1].item() #find the minimum column of where breast, if breast on the left then 0/1, if right then high number
      
      if  min_coordinates<=100:
        position='left'
      else:
        position='right'
      
      return(position)

    def image_padding(self, position, image_array, mask_array):

      if image_array.shape[0]==4084:
        if position=='left':
          image_tensor=torch.nn.functional.pad(image_array, (0,(3500-3328),0,(4250-4084)))
          mask_tensor=torch.nn.functional.pad(mask_array, (0,(3500-3328),0,(4250-4084)))
        else:
          image_tensor=torch.nn.functional.pad(image_array, ((3500-3328),0,0,(4250-4084)))
          mask_tensor=torch.nn.functional.pad(mask_array, ((3500-3328),0,0,(4250-4084)))
      
      else:
        if position=='left':
          image_tensor=torch.nn.functional.pad(image_array, (0,(2750-2560),0,(3500-3328)))
          mask_tensor=torch.nn.functional.pad(mask_array, (0,(2750-2560),0,(3500-3328)))
        else:
          image_tensor=torch.nn.functional.pad(image_array, ((2750-2560),0,0,(3500-3328)))
          mask_tensor=torch.nn.functional.pad(mask_array, ((2750-2560),0,0,(3500-3328)))
      
      return(image_tensor,mask_tensor)


    def __getitem__(self, i):
        
        # read data
        image = pydicom.dcmread(self.images_fps[i])
        image = image.pixel_array.astype('float')
        image_re = cv2.resize(image,(250,250))
        image = torch.from_numpy(image)
      
        mask = cv2.imread(self.masks_fps[i])
        mask_re = cv2.resize(mask,(250,250))
        mask_re = torch.from_numpy(mask_re)
        mask_re = mask_re.long()
        mask_re = abs((mask_re.sum(2)/3)-1)
        mask_re = (mask_re>0).float()
        mask = torch.from_numpy(mask)
        mask = mask.long()
        mask = abs((mask.sum(2)/3)-1)
        mask = (mask>0).float()

        position=self.breast_left_or_right(image)
        image_pad, mask_pad = self.image_padding(position,image,mask)
        
        return image_pad, mask_pad, image_re, mask_re

Class for data without mass: returns image/mask padded (for patch extraction) and image/mask resized to 250x250

The distinction between the two classes is be able to create a balanced dataset for training

In [0]:
class Dataset_with_NonMass(BaseDataset): #original data non padded or changed in size 

    CLASSES = ['non tumor','tumor']

    def __init__(self, images_dir, masks_dir, classes=None):
        self.ids_f=[]
        self.ids_m_f=[]
        self.ids = os.listdir(images_dir)
        self.ids_m = os.listdir(masks_dir)
        self.images_fps=[]
        self.masks_fps=[]
        for i in range(len(self.ids)):
          self.ids[i]=self.ids[i].rstrip(".dcm")
          for i in range(len(self.ids_m)):
            self.ids_m[i]=self.ids_m[i].rstrip(".png")

        for i in range(len(self.ids)):
          temp=self.ids[i][0:8]
          if temp in self.ids_m:
            self.ids_f.append(self.ids[i]+'.dcm')
            self.ids_m_f.append(temp+'.png')

        for i in range(len(self.ids_m_f)):
          mask_id=self.ids_m_f[i]
          image_id=self.ids_f[i]
          temp=os.path.join(masks_dir, mask_id)
          presence=self.test_mass(temp)

          if presence==False:
            self.images_fps.append(os.path.join(images_dir, image_id))
            self.masks_fps.append(os.path.join(masks_dir, mask_id))

          temp=None
          mask_id=None
          image_id=None

        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]

    def test_mass(self,name):
      presence=None

      mask = cv2.imread(name)
      mask = torch.from_numpy(mask)
      mask=mask.long()
      mask=abs((mask.sum(2)/3)-1)
      mask = (mask>0).float()

      if 1 in mask:
        presence=True
      else:
        presence=False

      return(presence)

    def __len__(self):
        return len(self.images_fps)

    def breast_left_or_right(self, image_array):
      position=None
      image=(image_array>0).float() #transform image into binary for easier analysis
      coordinates_breast_tissue=(image==1).nonzero() #look at coordinatex where there is '1'
      min_coordinates=torch.min(coordinates_breast_tissue,0)[0][1].item() #find the minimum column of where breast, if breast on the left then 0/1, if right then high number
      
      if  min_coordinates<=100:
        position='left'
      else:
        position='right'
      
      return(position)

    def image_padding(self, position, image_array, mask_array):

      if image_array.shape[0]==4084:
        if position=='left':
          image_tensor=torch.nn.functional.pad(image_array, (0,(3500-3328),0,(4250-4084)))
          mask_tensor=torch.nn.functional.pad(mask_array, (0,(3500-3328),0,(4250-4084)))
        else:
          image_tensor=torch.nn.functional.pad(image_array, ((3500-3328),0,0,(4250-4084)))
          mask_tensor=torch.nn.functional.pad(mask_array, ((3500-3328),0,0,(4250-4084)))
      
      else:
        if position=='left':
          image_tensor=torch.nn.functional.pad(image_array, (0,(2750-2560),0,(3500-3328)))
          mask_tensor=torch.nn.functional.pad(mask_array, (0,(2750-2560),0,(3500-3328)))
        else:
          image_tensor=torch.nn.functional.pad(image_array, ((2750-2560),0,0,(3500-3328)))
          mask_tensor=torch.nn.functional.pad(mask_array, ((2750-2560),0,0,(3500-3328)))
      
      return(image_tensor,mask_tensor)

    def __getitem__(self, i):
        
        # read data
        image = pydicom.dcmread(self.images_fps[i])
        image = image.pixel_array.astype('float')
        image_re = cv2.resize(image,(250,250))
        image = torch.from_numpy(image)
      
        mask = cv2.imread(self.masks_fps[i])
        mask_re = cv2.resize(mask,(250,250))
        mask_re = torch.from_numpy(mask_re)
        mask_re = mask_re.long()
        mask_re = abs((mask_re.sum(2)/3)-1)
        mask_re = (mask_re>0).float()
        mask = torch.from_numpy(mask)
        mask = mask.long()
        mask = abs((mask.sum(2)/3)-1)
        mask = (mask>0).float()

        position=self.breast_left_or_right(image)
        image_pad, mask_pad = self.image_padding(position,image,mask)

        return image_pad, mask_pad, image_re, mask_re

Load train, validation and test dataset from drive that was previously saved so that all trainings/testings are done on exactly the same images

In [0]:
test_loader=(torch.load('/gdrive/My Drive/test_loader.pth'))
train_loader=(torch.load('/gdrive/My Drive/train_loader.pth'))
val_loader=(torch.load('/gdrive/My Drive/val_loader.pth'))

Add the code for the necessary network 
Here the standard UNet is used as example

In [0]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Dropout(0.25)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)

        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
      
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 512)
        self.up1 = Up(1024, 256, bilinear)
        self.up2 = Up(512, 128, bilinear)
        self.up3 = Up(256, 64, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

Function to calculate IoU, Dice coeff etc

In [0]:
def metrics(predicted,truth,word):
  ####### metrics for tumour
  TP=0
  FN=0
  FP=0
  TN=0
  for i in range(truth.squeeze().size()[0]):
    for j in range(truth.squeeze().size()[1]):
      if truth[i,j]==1 and predicted[i,j]==1:
        TP+=1
      elif truth[i,j]==1 and predicted[i,j]==0:
        FN+=1
      elif truth[i,j]==0 and predicted[i,j]==1:
        FP+=1
      else:
        TN+=1
  
  if TP==0 and FP==0 and FN==0:
    iou_tumour=0
    dice_tumour=0
  else:
    iou_tumour = TP/(TP+FP+FN)
    dice_tumour = (2*TP)/(TP+FP+TP+FN)

  if TN==0 and FP==0:
    spec_tumour=0
  else:
    spec_tumour = TN/(TN+FP) #specificity #true negative rate
 
  if TP==0 and FN==0:
    sens_tumour=0
  else:
    sens_tumour = TP/(TP+FN) #sensitivity #true positive rate

  acc_tumour = (TP+TN)/(TP+TN+FP+FN) #accuracy

  ####### metrics for background
  truth_b=abs(truth-1)
  predicted_b=abs(predicted-1)
  TP_b=0
  FN_b=0
  FP_b=0
  TN_b=0
  for i in range(truth_b.squeeze().size()[0]):
    for j in range(truth_b.squeeze().size()[1]):
      if truth_b[i,j]==1 and predicted_b[i,j]==1:
        TP_b+=1
      elif truth_b[i,j]==1 and predicted_b[i,j]==0:
        FN_b+=1
      elif truth_b[i,j]==0 and predicted_b[i,j]==1:
        FP_b+=1
      else:
        TN_b+=1
  
  if TP_b==0 and FP_b==0 and FN_b==0:
    iou_background=0
    dice_background=0
  else:
    iou_background = TP_b/(TP_b+FP_b+FN_b)
    dice_background = (2*TP_b)/(TP_b+FP_b+TP_b+FN_b)

  if TN_b==0 and FP_b==0:
    spec_background=0
  else:
   spec_background = TN_b/(TN_b+FP_b) #specificity #true negative rate
 
  if TP_b==0 and FN_b==0:
    sens_background=0
  else:
    sens_background = TP_b/(TP_b+FN_b) #sensitivity #true positive rate

  acc_background = (TP_b+TN_b)/(TP_b+TN_b+FP_b+FN_b) #accuracy

  ####### metrics for mean of tumour & background
  object_nb=0
  if 1 not in truth:
    object_nb=1
  else:
    object_nb=2
  
  mean_iou=(iou_tumour+iou_background)/object_nb
  mean_dice=(dice_tumour+dice_background)/object_nb
  mean_spec=(spec_tumour+spec_background)/object_nb
  mean_sens=(sens_tumour+sens_background)/object_nb
  mean_acc=(acc_tumour+acc_background)/2
  
  if word=='iou':
    return(iou_tumour,iou_background,mean_iou)
  elif word=='AllTumour':
    return(iou_tumour,dice_tumour,spec_tumour,sens_tumour,acc_tumour)
  elif word=='AllBackground':
    return(iou_background,dice_background,spec_background,sens_background,acc_background)
  elif word=='AllMean':
    return(mean_iou,mean_dice,mean_spec,mean_sens,mean_acc)


Functions for ROC curve etc

In [0]:
#################### ROC curve and AUC ################################
def ROC(iou_tumour_list,ground):
  prediction=[]
  threshold=0.1
  for i in range(len(iou_tumour_list)):
    if iou_tumour_list[i]<threshold:
      prediction.append(0)
    else:
      prediction.append(1)

  fpr, tpr, thresholds = roc_curve(ground, prediction, pos_label=1)
  AUC = roc_auc_score(fpr, tpr)

  fig=plt.figure()
  plt.plot(fpr, tpr, color='orange', label='ROC')
  plt.plot([0, 1], [0, 1], color='darkblue', linestyle='--')
  plt.xlabel('False Positive Rate')
  plt.ylabel('True Positive Rate')
  plt.title('Receiver Operating Characteristic (ROC) Curve')
  plt.legend()
  plt.show()
  fig.savefig('/gdrive/My Drive'+path+name+'_ROC_curve.png')

  return(AUC)

#################### Precision-Recall Curve, F1 score & AUPRC ################################
def Recall_Precision(iou_tumour_list,ground):
  precision, recall, thresholds = precision_recall_curve(ground, iou_tumour_list)
  no_skill = len(ground[ground==1]) / len(ground)
  fig=plt.figure()
  plt.plot([0, 1], [no_skill, no_skill], linestyle='--')
  plt.plot(recall, precision, marker='.', label='Logistic')
  plt.xlabel('Recall')
  plt.ylabel('Precision')
  plt.legend()
  plt.show()
  fig.savefig('/gdrive/My Drive'+path+name+'_RecallPrecision_curve.png')

  prediction=[]
  for i in range(len(iou_tumour_list)):
    if iou_tumour_list[i]<threshold:
      prediction.append(0)
    else:
      prediction.append(1)

  f1 = f1_score(ground, prediction, pos_label=1)
  auprc = auc(recall, precision)

  return(f1, auprc)

#################### Accuracy ################################
def AccuracyGraph(iou_list,val_list):
  f, (ax1) = plt.subplots(1, 1, figsize=(7, 4))
  epoch_list = list(range(0,30))
  ax1.plot(epoch_list, iou_list, label='Train Accuracy')
  ax1.plot(epoch_list, val_iou_list, label='Validation Accuracy')
  ax1.set_xticks(np.arange(0, 30, 1))
  ax1.set_ylabel('Accuracy Value')
  ax1.set_xlabel('Epoch')
  ax1.set_title('Accuracy')
  l1 = ax1.legend(loc="best")
  ax1.figure.savefig('/gdrive/My Drive/'+path+name+'_AccuracyGraph.png')



Model: link the path to the network that has been pre-trained with on one of the two patch extracted database

In [0]:
path_transfer='/networks/FULL_TRANSFER/UNet_WholeMass_full/'
name_transfer='UNet_WholeMass_full'
model = UNet(n_classes=2, n_channels=1)
model.load_state_dict(torch.load('/gdrive/My Drive/'+path_transfer+name_transfer+'.pth'))

<All keys matched successfully>

Optimizer & criterion

In [0]:
criterion = nn.CrossEntropyLoss() 
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [0]:
path='/networks/FULL_TRANSFER/UNet_WholeMass_full/'
name='UNet_WholeMass_full'

Training: Wieght Initialization

In [0]:
epochs=20
best=0
best_tumour=0.0
train_iou = 0.0
val_iou = 0.0
val_tumour_iou=[]
counter=0
iou_list=[]
val_iou_list=[]
nb_patches_analyzed=0
nb_patches_analyzed_v=0

file = open('/gdrive/My Drive'+path+'training.txt','w')
file.write('Training: 20 epochs full transfer from 1connected patches\n')

time.sleep(5)
for epoch in range(epochs):  

    ###################### Training
    time.sleep(5)
    Bar = progressbar.ProgressBar(max_value=len(train_loader))
    for i, data in enumerate(Bar(train_loader), 0):
        image, mask, image_re, mask_re = data

        optimizer.zero_grad()
        outputs = model(image_re.unsqueeze(dim=0).float()) 
        outputs_final_probabilities, outputs_final = torch.max(outputs.squeeze(), axis=0) 
        loss = criterion(outputs.float(), mask_re.long())
        loss.backward() #backward propagation
        optimizer.step() #optimize
        iou_tumour,iou_background,mean_iou = metrics(outputs_final,mask_re.squeeze(),'iou')
        train_iou+=mean_iou

    ############### Validation
    time.sleep(5)
    BarTwo = progressbar.ProgressBar(max_value=len(val_loader))   
    for i, data_v in enumerate(BarTwo(val_loader), 0):
        image_v, mask_v, image_re_v, mask_re_v = data_v
        outputs_v=model(image_re_v.unsqueeze(dim=0).float())
        outputs_final_probabilities_v, outputs_final_v = torch.max(outputs_v.squeeze(), axis=0)
        iou_tumour,iou_background,mean_iou = metrics(outputs_final_v,mask_re_v.squeeze(),'iou')
        val_iou+=mean_iou
        if 1 in mask_re_v:
          val_tumour_iou.append(iou_tumour)


    ################ save best model
    if val_iou > best:
      best=val_iou
      torch.save(model.state_dict(),'/gdrive/My Drive/'+path+name+'.pth')
    
    if sum(val_tumour_iou) > best_tumour:
      best_tumour=sum(val_tumour_iou)
      torch.save(model.state_dict(),'/gdrive/My Drive/'+path+name+'_tumour.pth')
      
    ################ print for each epoch iou
    print('Epoch %d :iou train: %.3f' % (epoch + 1, train_iou/ len(train_loader)),'; iou val: %.3f' % (val_iou/ len(val_loader)),'; tumour iou val: %.3f' % (sum(val_tumour_iou)/ len(val_tumour_iou)),'\n')
    file.write('\n\nEpoch %d : mean iou train: %.3f' % (epoch + 1, train_iou/ len(train_loader)))
    file.write(' ;mean iou validation: %.3f' % (val_iou/ len(val_loader)))  
    file.write(' ;tumour iou validation: %.3f' % (sum(val_tumour_iou)/ len(val_tumour_iou)))     

    iou_list.append(train_iou/ len(train_loader))
    val_iou_list.append(val_iou/ len(val_loader))

    #print for every 20 epoch
    if (counter%1)==0:
      fig = plt.figure()
      plt.subplot(1, 3, 1)
      plt.imshow(image_re.squeeze())
      plt.subplot(1, 3, 2)
      plt.imshow(mask_re.squeeze())
      plt.subplot(1, 3, 3)
      plt.imshow(outputs_final.detach().numpy())

    counter+=1
    train_iou = 0.0
    val_tumour_iou = []
    nb_patches_analyzed=0
    nb_patches_analyzed_v=0
    val_iou=0.0

print('Finished Training')

# file.write('\n\nTrain iou list')
# file.wirte(iou_list)
# file.write('Validation iou_list')
# file.write(val_iou_list)
file.close()

# AccuracyGraph(iou_list,val_iou_list)

Testing

In [0]:
model.load_state_dict(torch.load('/gdrive/My Drive/'+path+name+'.pth'))

file = open('/gdrive/My Drive'+path+'testing.txt','w')
file.write('Testing \n')

iou_tumour_list=[] #iou_tumour_list: for each of the 28 images, this list contains the corresponding iou value
ground=[] #ground: for each of the 28 image analzed, 1 correspond to images with a tumour
test_iou = 0.0
test_dice = 0.0
test_spec = 0.0
test_sens = 0.0
test_acc = 0.0
test_iou_tumour = 0.0
test_dice_tumour = 0.0
test_spec_tumour = 0.0
test_sens_tumour = 0.0
test_acc_tumour = 0.0

false_positive=0
true_positive=0
false_negative=0
true_negative=0
FPR=0
TPR=0
FNR=0
TNR=0
tum_count=0

################ test model
for i, data_t in enumerate(test_loader, 0):
  image_t, mask_t, image_re_t, mask_re_t = data_t
  outputs_t=model(image_re_t.unsqueeze(dim=0).float())
  outputs_final_probabilities_t, outputs_final_t = torch.max(outputs_t.squeeze(), axis=0)
  iou_tumour,iou_background,mean_iou = metrics(outputs_final_t,mask_re_t.squeeze(),'iou')
  test_iou += mean_iou 
  mean_iou,mean_dice,mean_spec,mean_sens,mean_acc= metrics(outputs_final_t,mask_re_t.squeeze(),'AllMean')
  test_dice += mean_dice
  test_spec += mean_spec
  test_sens += mean_sens
  test_acc += mean_acc
  iou_tumour_list.append(iou_tumour)
  print('\n\nTest Image %d : Mean iou %.3f' % (i+1, mean_iou))
  file.write('\n\nTest image %d : Mean iou %.3f, mean dice %.3f, mean spec %.3f, mean sens %.3f, mean acc %.3f' % (i+1,mean_iou,mean_dice,mean_spec,mean_sens,mean_acc))

  if 1 in mask_re_t:
    print('Tumour Iou: %.3f' % (iou_tumour))
    iou_tumour,dice_tumour,spec_tumour,sens_tumour,acc_tumour = metrics(outputs_final_t,mask_re_t.squeeze(),'AllTumour')
    test_iou_tumour += iou_tumour
    test_dice_tumour += dice_tumour
    test_spec_tumour += spec_tumour
    test_sens_tumour += sens_tumour
    test_acc_tumour += acc_tumour
    tum_count+=1
    file.write('\n tumour iou %.3f, tumour dice %.3f, tumour spec %.3f, tumour sens %.3f, tumour acc %.3f' % (iou_tumour,dice_tumour,spec_tumour,sens_tumour,acc_tumour))

  if 1 in mask_re_t:
    ground.append(1)
  else:
    ground.append(0)

  if 1 in mask_re_t.squeeze() and 1 in outputs_final_t.squeeze():
    true_positive=true_positive+1
  elif 1 not in mask_re_t.squeeze() and 1 in outputs_final_t.squeeze():
    false_positive=false_positive+1
  elif 1 in mask_re_t.squeeze() and 1 not in outputs_final_t.squeeze():
    false_negative=false_negative+1
  else:
    true_negative=true_negative+1

  fig = plt.figure()
  plt.subplot(1, 3, 1)
  plt.imshow(image_re_t.squeeze())
  plt.subplot(1, 3, 2)
  plt.imshow(mask_re_t.squeeze())
  plt.subplot(1, 3, 3)
  plt.imshow(outputs_final_t.detach().numpy())
  matplotlib.image.imsave('/gdrive/My Drive'+path+'test2/'+name+'_image_'+str(i)+'.png',image_t.squeeze())
  matplotlib.image.imsave('/gdrive/My Drive'+path+'test2/'+name+'_mask_'+str(i)+'.png',mask_t.squeeze())
  matplotlib.image.imsave('/gdrive/My Drive'+path+'test2/'+name+'_prediction_'+str(i)+'.png',outputs_final_t.detach().numpy())


#print('\nOverall Test Images: test loss: %.3f' % (test_loss / len(test_loader)))
print('\n\n\nOver entire Test data set: average mean iou: %.3f' % (test_iou/ len(test_loader)))
file.write('\n\n\n\nOverall iou %.3f, overall dice %.3f, overall spec %.3f, overall sens %.3f, overall acc %.3f' % (test_iou/len(test_loader),test_dice/len(test_loader),test_spec/len(test_loader),test_sens/len(test_loader),test_acc/len(test_loader)))
print('\nAverage tumour iou (for mask that have tumour): %.3f' % (test_iou_tumour/tum_count), '\n')
file.write('\n\nOverall tumour iou %.3f, overall tumour dice %.3f, overall tumour spec %.3f, overall tumour sens %.3f, tumour overall acc %.3f' % (test_iou_tumour/tum_count,test_dice_tumour/tum_count,test_spec_tumour/tum_count,test_sens_tumour/tum_count,test_acc_tumour/tum_count))

####### from testing
FPR=false_positive/(false_positive+true_negative) #False Positive Rate 
TPR=true_positive/(true_positive+false_negative) #True Positive Rate ###Sens
FNR=false_negative/(false_negative+true_positive) #False Negative Rate
TNR=true_negative/(true_negative+false_positive) #True Negative Rate ###Spec

Acc=(true_negative+true_positive)/(false_positive+false_negative+true_negative+true_positive)

auc=ROC(iou_tumour_list,ground)

# f1, auprc = Recall_Precision(iou_tumour_list,ground)

# file.write('\n\n-------------------')
# file.write('\nAUC: ')
# file.write(auc)
# file.write('\nF1: ',f1,' ;AUPRC: ',auprc)
file.close()

# print('False Positive Rate:', FPR, '\n')
# print('True Positive Rate:', TPR, '\n')
# print('False Negative Rate:', FNR, '\n')
# print('True Negative Rate:', TNR, '\n')
# print('Acc:', acc, '\n')
# print('AUC:', auc, '\n')
# print('F1:', f1, 'AUPRC:', auprc, '\n')

