In [32]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
import math
import numpy as np
import os
import cv2
import csv
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from CustomDataset import CustomDataset
import matplotlib.pyplot as plt
torch.manual_seed(0)
from GPUtil import showUtilization as gpu_usage

torch.cuda.empty_cache()

In [33]:
#Initializations
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
model = torchvision.models.segmentation.fcn_resnet50(pretrained=True, progress = True).to(device)

#has to normalize data the same way the pretrained images were
normalize = transforms.Normalize(mean=[0.485,0.456,0.406],
                        std=[0.229,0.224,0.225])
data_transforms = {
    'train':
    transforms.Compose([
        transforms.ToTensor(),
        normalize
    ]),
    'validation':
    transforms.Compose([
        transforms.ToTensor(),
        normalize
    ]),
}
dataset = CustomDataset(target_type = 'poly')

TRAIN_SIZE= math.floor(dataset.__len__()*0.75)
TEST_SIZE = dataset.__len__() - TRAIN_SIZE
trainset, testset = random_split(dataset,[TRAIN_SIZE,TEST_SIZE])


traindata_loader = DataLoader(trainset, batch_size=5, shuffle=True)
testdata_loader = DataLoader(testset, batch_size=5, shuffle=True)
LOADER_TRAIN_SIZE = traindata_loader.__len__()
LOADER_TEST_SIZE = testdata_loader.__len__()

image_datasets = {
    'train': 
        LOADER_TRAIN_SIZE,
    'validation':
        LOADER_TEST_SIZE
}
print(image_datasets)
dataloaders = {
    'train':
        traindata_loader,
    'validation':
        testdata_loader
}
# print(model)

for name,param in model.named_parameters():
    if 'classifier' not in name:
        param.requires_grad = False  
#print(model.classifier[4])

model.classifier[4] = nn.Sequential(
    nn.Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)).to(device),
    nn.Sigmoid().to(device)
)

#model.fc = nn.Sequential(
#               nn.Linear(2048, 128),
#               nn.ReLU(inplace=True),
#               nn.Linear(128, 2)).to(device)
params_to_update = []
for param in model.parameters():
    if param.requires_grad == True:
        params_to_update.append(param)
        
class DiceLoss(torch.nn.Module):
    def init(self):
        super(diceLoss, self).init()
    def forward(self,pred, target):
        smooth = 1.
        iflat = pred.contiguous().view(-1)
        tflat = target.contiguous().view(-1)
        intersection = (iflat * tflat).sum()
        A_sum = torch.sum(iflat * iflat)
        B_sum = torch.sum(tflat * tflat)
        return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth) )

    
    
    
class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)  
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = BCE + dice_loss
        
        return Dice_BCE
    
    
class FocalLoss(nn.CrossEntropyLoss):
    ''' Focal loss for classification tasks on imbalanced datasets '''

    def __init__(self, gamma, alpha=None, ignore_index=-100, reduction='none'):
        super().__init__(weight=alpha, ignore_index=ignore_index, reduction='none')
        self.reduction = reduction
        self.gamma = gamma

    def forward(self, input_, target):
        cross_entropy = super().forward(input_, target)
        # Temporarily mask out ignore index to '0' for valid gather-indices input.
        # This won't contribute final loss as the cross_entropy contribution
        # for these would be zero.
        target = target * (target != self.ignore_index).long()
        input_prob = torch.gather(F.softmax(input_, 1), 1, target.unsqueeze(1))
        loss = torch.pow(1 - input_prob, self.gamma) * cross_entropy
        return torch.mean(loss) if self.reduction == 'mean'else torch.sum(loss) if self.reduction == 'sum' else loss
            
            
#loss_fn = nn.CrossEntropyLoss()
# loss_fn = FocalLoss(gamma=0.7)
# loss_fn = torchvision.ops.sigmoid_focal_loss()
# loss_fn = DiceBCELoss()
# loss_fn = nn.NLLLoss()
optim = optim.Adam(params_to_update, lr=0.001, eps=1e-08, weight_decay=0, amsgrad=False)
# optim = optim.Adam((model.classifier[4].parameters()))
print('a')
print(model.classifier)

{'train': 174, 'validation': 58}
a
FCNHead(
  (0): Conv2d(2048, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): Dropout(p=0.1, inplace=False)
  (4): Sequential(
    (0): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
    (1): Sigmoid()
  )
)


In [3]:
def confusion(prediction, truth):
    """ Returns the confusion matrix for the values in the `prediction` and `truth`
    tensors, i.e. the amount of positions where the values of `prediction`
    and `truth` are
    - 1 and 1 (True Positive)
    - 1 and 0 (False Positive)
    - 0 and 0 (True Negative)
    - 0 and 1 (False Negative)
    """
    threshold = torch.tensor([0.5]).to(device)
    prediction = (prediction>threshold).float()*1
    
    confusion_vector = prediction / truth
    # Element-wise division of the 2 tensors returns a new tensor which holds a
    # unique value for each case:
    #   1     where prediction and truth are 1 (True Positive)
    #   inf   where prediction is 1 and truth is 0 (False Positive)
    #   nan   where prediction and truth are 0 (True Negative)
    #   0     where prediction is 0 and truth is 1 (False Negative)
    # .item()
    true_positives = torch.sum(confusion_vector == 1)
    false_positives = torch.sum(confusion_vector == float('inf'))
    true_negatives = torch.sum(torch.isnan(confusion_vector))
    false_negatives = torch.sum(confusion_vector == 0)

    return true_positives, false_positives, true_negatives, false_negatives

In [4]:
for batch, (inputs,labels) in enumerate(dataloaders['train']):
# for inputs,labels in enumerate(dataloaders[phase]):
    # inputs = inputs.to(device)
    # labels = labels.to(device)
    print(inputs.shape)
    print(labels.shape)
    break

torch.Size([5, 3, 480, 640])
torch.Size([5, 480, 640])


In [37]:
from sklearn.metrics import f1_score, roc_auc_score
import pandas as pd
def train_model(model, optimizer, save_dir, num_epochs=3, evaluate = True):
    # epoch_loss_list = []
    # epoch_f1_list = []
    # epoch_roc_list = []
    phases = ['train']
    if evaluate == True:
        phases = ['train','validation']
    train_performance = []
    test_performance = []
    for epoch in range(num_epochs):
        for phase in phases:
            print("Currently in the: ", phase," phase")
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            total = 0
            # running_corrects = 0
            running_tp = 0
            running_fp = 0
            running_tn = 0
            running_fn = 0
            conf_tp = 0
            conf_fp = 0
            conf_tn = 0
            conf_fn = 0
            conf_precision = 0
            conf_recall = 0
            metrics_f1 = 0
            metrics_roc = 0
            missing_roc = 0
            for batch, (inputs,labels) in enumerate(dataloaders[phase]):
            # for inputs,labels in enumerate(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)
                # print(inputs.dtype,labels.dtype)
                # DON'T USE ARGMAX HERE
                #labels = torch.argmax(labels, dim=0)
                outputs = model(inputs)["out"]
                # print(outputs.shape, outputs.max(), outputs.min())
                
                preds = torch.squeeze(outputs).float().requires_grad_()
                
                loss = torchvision.ops.sigmoid_focal_loss(preds, torch.squeeze(labels), reduction = 'mean')
                
                # loss = criterion(outputs, labels.long())
                # print(loss)

                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                # _, preds = torch.max(outputs, 1)
                running_loss += loss.item() * inputs.size(0)
                total += labels.nelement()
                # running_corrects += preds.eq(labels.data).sum().item()
                # acc += running_corrects/total
                
                
                #print(running_loss)
                if (epoch == (num_epochs-1)) and (phase != 'train'):
                    plt1 = plt.figure()
                    plt1 = plt.imshow(preds.cpu().detach().numpy()[0])

                    plt2 = plt.figure()
                    plt2 = plt.imshow(labels.data.cpu()[0])
                    
                tp, fp, tn, fn = confusion(preds, labels.data)
                # print(tp, fp, tn, fn)
                
                running_tp += tp
                running_fp += fp
                running_tn += tn
                running_fn += fn
                conf_tp += running_tp/total
                conf_fp += running_fp/total
                conf_tn += running_tn/total
                conf_fn += running_fn/total
                conf_precision += running_tp/(running_tp+running_fp)
                conf_recall += running_tp/(running_tp+running_fn)
                
                # print(total)
                # print(conf_tp, conf_fp, conf_tn, conf_fn, acc)
                
                metrics_f1 += f1_score(labels.cpu().ravel()>0, preds.cpu().ravel()>0.5)
                try:
                    metrics_roc += roc_auc_score(labels.cpu().int().ravel().detach().numpy(),preds.cpu().ravel().detach().numpy())
                except:
                    missing_roc += 1
                    continue
                # print(metrics_f1, metrics_roc)
            # Calculating loss and acc
            
            epoch_loss = running_loss / image_datasets[phase]
            # epoch_loss_list.append(epoch_loss)
            epoch_precision = conf_precision / image_datasets[phase]
            epoch_recall = conf_recall / image_datasets[phase]
            epoch_f1 = epoch_precision*epoch_recall/(epoch_precision+epoch_recall)
            epoch_metrics_f1 = metrics_f1 / image_datasets[phase]
            epoch_metrics_roc = metrics_roc / (image_datasets[phase] - missing_roc)
            # epoch_f1_list.append(epoch_metrics_f1)
            # epoch_roc_list.append(epoch_metrics_roc)
            # epoch_confmat = (conf_tp, conf_fp, conf_tn, conf_fn)/image_datasets[phase]
            # epoch_confmat_list.append(epoch_confmat)
            
            if phase == 'train':
                performance = [epoch, epoch_loss, epoch_metrics_f1, epoch_metrics_roc]
            else:
                performance.extend((epoch_loss, epoch_metrics_f1, epoch_metrics_roc))
                print(performance)
                with open(os.path.join(save_dir, 'fcn_log.csv'), 'a+', newline='') as f:
                    writer = csv.writer(f)
                    writer.writerow(performance)

            
            print('epoch: {} {} loss: {:.4f}, '.format(epoch,phase,
                                                        epoch_loss))
            print('f1 score: {:.4f}, precision: {:.4f}, recall: {:.4f}'.format(epoch_f1,
                                                                              epoch_precision,
                                                                              epoch_recall))
            print('TP: {:.4f}, FP: {:.4f}, TN: {:.4f}, FN: {:.4f}'. format(
                                                        conf_tp/image_datasets[phase],
                                                        conf_fp/image_datasets[phase],
                                                        conf_tn/image_datasets[phase],
                                                        conf_fn/image_datasets[phase])
                                                        )
            best_loss = 1
            if phase == 'validation' and epoch_loss < best_loss:
                best_loss = epoch_loss
                # best_model_wts = copy.deepcopy(model.state_dict())
                torch.save(model, os.path.join(save_dir, 'fcn_resnet50_200epoch.pt'))
    gpu_usage()
    return model


In [None]:
gpu_usage()
save_dir = os.path.join(os.getcwd(),'models/pytorch')
model_trained = train_model(model, optim, save_dir, num_epochs=200)
#!mkdir models
#!mkdir models/pytorch

# torch.save(model_trained.state_dict(), 'models/pytorch/weights_dec04_fcn_resnet50.h5')

| ID | GPU | MEM |
------------------
|  0 |  0% | 61% |
|  1 |  0% | 44% |
Currently in the:  train  phase


In [None]:
model = torchvision.models.segmentation.fcn_resnet50(pretrained=True)
model.classifier[4] = nn.Conv2d(512, 2, kernel_size=(1, 1), stride=(1, 1))
model.load_state_dict(torch.load('models/pytorch/weights6.h5'))
model.eval();

Evaluation

In [8]:
# model_trained, epoch_loss_list, epoch_f1_list = train_model(model, loss_fn, optim, num_epochs=1, evaluate = True)

In [59]:
performance = [0.1,0.2,0.3]
with open(os.path.join(save_dir, 'fcn_log.csv'), 'w', newline='') as f:
                writer = csv.writer(f)
                writer.writerow(performance)
    