In [20]:
import os
import cv2
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision import models
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
from torchvision.io import read_image, ImageReadMode
import segmentation_models_pytorch as smp
from sklearn.metrics import average_precision_score
%matplotlib inline

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'

In [21]:
data_path = 'dataset/tiff'
bs = 16

transform1 = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((224, 224)),
            transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
        ])

transform2 = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((224, 224)),
            transforms.Normalize(mean=0.5, std=0.25)
        ])

In [22]:
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
device

'cuda:1'

In [39]:
train_path = os.path.join(data_path, 'train')
train_labels_path = os.path.join(data_path, 'train_labels')

test_path = os.path.join(data_path, 'test')
test_labels_path = os.path.join(data_path, 'test_labels')

val_path = os.path.join(data_path, 'val')
val_labels_path = os.path.join(data_path, 'val_labels')

In [41]:
class CustomDataset(Dataset):
    def __init__(self, images, labels, transform=None, target_transform=None):
        self.images_path = images
        self.labels_path = labels
        
        self.images  = os.listdir(images)
        self.labels = os.listdir(labels)
        
        self.transform  = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.images_path, self.images[index])
        mask_path = os.path.join(self.labels_path, self.labels[index])
        image = Image.open(img_path)
        label = Image.open(mask_path)

        bw_image_array = np.array(label)
        color_image_array = np.array(image)

        top_left_x = random.randint(0, 1500 - 224)
        top_left_y = random.randint(0, 1500 - 224)
            
        bw_square = bw_image_array[top_left_y:top_left_y + 224, top_left_x:top_left_x + 224]
            
        color_square = color_image_array[top_left_y:top_left_y + 224, top_left_x:top_left_x + 224, :]

        white_pixels = np.sum(bw_square == 255)
        total_pixels = np.prod(bw_square.shape)
        while (white_pixels / total_pixels) < 0.02:
            top_left_x = random.randint(0, 1500 - 224)
            top_left_y = random.randint(0, 1500 - 224)
            bw_square = bw_image_array[top_left_y:top_left_y + 224, top_left_x:top_left_x + 224]
            color_square = color_image_array[top_left_y:top_left_y + 224, top_left_x:top_left_x + 224, :]
            white_pixels = np.sum(bw_square == 255)
            total_pixels = np.prod(bw_square.shape)

            
        bw_square_image = Image.fromarray(bw_square)
        color_square_image = Image.fromarray(color_square)
        
        if self.transform:
            image = self.transform(color_square_image)
        if self.target_transform:
            label = self.target_transform(bw_square_image)
            label = torch.cat([label], dim=0)
        
        return image, label

In [42]:
custom_dataset = CustomDataset(train_path, train_labels_path,  transform1, transform2)
train_loader = DataLoader(custom_dataset, batch_size=bs, shuffle=True)

custom_dataset = CustomDataset(test_path, test_labels_path, transform1, transform2)
test_loader = DataLoader(custom_dataset, batch_size=bs, shuffle=False)

custom_dataset = CustomDataset(val_path, val_labels_path, transform1, transform2)
val_loader = DataLoader(custom_dataset, batch_size=bs, shuffle=False)

In [44]:
model = smp.Unet(encoder_name="resnet34", in_channels=3, classes=1)

model = model.to(device)
# model.load_state_dict(torch.load('UNet_ResNet34_10_total.pth'))
print(model)

Unet(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

In [45]:
num_epochs = 30
len_train_loader = len(train_loader)
len_val_loader = len(val_loader)

train_loss_arr = []
val_loss_arr = []

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=8e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)

In [46]:
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    train_total_iou = 0.0
    train_total_precision = 0.0
    train_total_recall = 0.0
    train_iou_arr = []
    train_precision_arr = []
    train_recall_arr = []
    
    i = 1
    
    for inputs, targets in train_loader:
        optimizer.zero_grad()
        print(f'\rIteration: {i} / {len_train_loader}', end='', flush=True)
        targets[targets == -2] = 0
        targets[targets == 2] = 1

        inputs = inputs.to(device)
        
        targets = targets.to(device)
        
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        train_loss_arr.append(loss.item())
        
        predicted_masks = (outputs > 0.5).float()

        intersection = torch.logical_and(predicted_masks, targets).sum().item()
        union = torch.logical_or(predicted_masks, targets).sum().item()
        iou = intersection / union if union != 0 else 0
        train_total_iou += iou
        train_iou_arr.append(iou)

        
        TP = ((predicted_masks == 1) & (targets == 1)).sum().item()
        FP = ((predicted_masks == 1) & (targets == 0)).sum().item()
        FN = ((predicted_masks == 0) & (targets == 1)).sum().item()

        precision = TP / (TP + FP) if TP + FP > 0 else 0.0
        recall = TP / (TP + FN) if TP + FN > 0 else 0.0

        train_total_precision += precision
        train_total_recall += recall

        train_precision_arr.append(precision)
        train_recall_arr.append(recall)
        
        i += 1      

    scheduler.step()
     
    model.eval()
    val_loss = 0.0
    val_total_iou = 0.0
    val_total_precision = 0.0
    val_total_recall = 0.0
    val_precision_arr = []
    val_recall_arr = []
    j = 0
    
    with torch.no_grad():
        for inputs, targets in val_loader:
            targets[targets == -2] = 0
            targets[targets == 2] = 1
            
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item()
            val_loss_arr.append(loss.item())
            
            predicted_masks = (outputs > 0.5).float()
           
            intersection = torch.logical_and(predicted_masks, targets).sum().item()
            union = torch.logical_or(predicted_masks, targets).sum().item()
            iou = intersection / union if union != 0 else 0
            val_total_iou += iou

            TP = ((predicted_masks == 1) & (targets == 1)).sum().item()
            FP = ((predicted_masks == 1) & (targets == 0)).sum().item()
            FN = ((predicted_masks == 0) & (targets == 1)).sum().item()

            precision = TP / (TP + FP) if TP + FP > 0 else 0.0
            recall = TP / (TP + FN) if TP + FN > 0 else 0.0

            val_total_precision += precision
            val_total_recall += recall

            val_precision_arr.append(precision)
            val_recall_arr.append(recall)
    
    print(f'\nEpoch {epoch + 1}:')
    print(f'Training Loss: {train_loss / len_train_loader}') 
    print(f'Validation Loss: {val_loss / len_val_loader}')
    print(f'Training precision: {train_total_precision / len_train_loader}')
    print(f'Validation precision: {val_total_precision / len_val_loader}')
    print(f'Training recall: {train_total_recall / len_train_loader}')
    print(f'Validation recall: {val_total_recall / len_val_loader}')
    print(f'Training IOU: {train_total_iou / len_train_loader}')
    print(f'Validation IOU: {val_total_iou / len_val_loader}\n')
    
print("That's all")

Iteration: 70 / 70
Epoch 1:
Training Loss: 0.2968824950712068
Validation Loss: 0.25278058648109436
Training precision: 0.004010673533323268
Validation precision: 0.0
Training recall: 0.003901655208472513
Validation recall: 0.0
Training IOU: 0.001771931688056742
Validation IOU: 0.0

Iteration: 70 / 70
Epoch 2:
Training Loss: 0.2418670151914869
Validation Loss: 0.27864381670951843
Training precision: 0.0
Validation precision: 0.0
Training recall: 0.0
Validation recall: 0.0
Training IOU: 0.0
Validation IOU: 0.0

Iteration: 9 / 70

KeyboardInterrupt: 