In [1]:
import torch
import torch.nn as nn
import torchvision


from torch.utils.data import DataLoader
from torch.utils.data import Dataset, random_split
import torchvision.transforms.functional as TF
import torch.optim as optim

import sys
import datetime

from PIL import Image

import numpy as np
from matplotlib import pyplot as plt

In [2]:

should_train = True

path_to_trained_model = 'models_final/unet_test_local/trained_unet_model.pth'
path_to_train_loss = 'models_final/unet_test_local/unet_train_losses.txt'
path_to_val_loss = 'models_final/unet_test_local/unet_val_losses.txt'


In [3]:
class convolution(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
        self.batch_norm1 = nn.BatchNorm2d(out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
        self.batch_norm2 = nn.BatchNorm2d(out_c)
        
        self.relu = nn.ReLU()
        
    def forward(self, data):
        x = self.conv1(data)
        x = self.batch_norm1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.batch_norm2(x)
        x = self.relu(x)
        return x
    
class encoder(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv = convolution(in_c, out_c)
        self.pool = nn.MaxPool2d((2,2))
        
    def forward(self, data):
        x = self.conv(data)
        p = self.pool(x)
        return x, p

class decoder(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        self.conv = convolution(out_c + out_c, out_c)
        
    def forward(self, data, skip): # skip connections
        x = self.up(data)
        x = torch.cat([x, skip], axis=1)
        x = self.conv(x)

        return x
    
class unet(nn.Module):
    def __init__(self):
        super().__init__()
        
        """ Encoding """
        self.en1 = encoder(1, 64)
        self.en2 = encoder(64, 128)
        self.en3 = encoder(128, 256)
        self.en4 = encoder(256, 512)

        
        # """ Bottleneck """
        self.bottle = convolution(512, 1024)
        
        # """ Decoding """
        self.de1 = decoder(1024, 512)
        self.de2 = decoder(512, 256)
        self.de3 = decoder(256, 128)
        self.de4 = decoder(128, 64)
        
        """ Classifier """
        self.last = nn.Conv2d(64, 1, kernel_size=1, padding=0)
    
    def forward(self, data):
        """ Encoding """
        s1, p1 = self.en1(data)
        s2, p2 = self.en2(p1)
        s3, p3 = self.en3(p2)
        s4, p4 = self.en4(p3)
        
        # """ Bottleneck """
        b = self.bottle(p4)
        
        # """ Decoding """
        d1 = self.de1(b, s4)
        d2 = self.de2(d1, s3)
        d3 = self.de3(d2, s2)
        d4 = self.de4(d3, s1)
        
        """ Classifier """
        outs = self.last(d4)
        
        return torch.sigmoid(outs)


class Dataset(Dataset):
    def __init__(self, ids):
        self.ids = ids

    def transform(self, train_data, train_labels):
        return TF.to_tensor(train_data), TF.to_tensor(train_labels)

    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, index):
        id = self.ids[index]

        X = TF.to_tensor(Image.open(f"/Users/leeannquynhdo/Datalogi/MSc_thesis/unet_implementation/line_images/img_{id}.png"))
        y = TF.to_tensor(Image.open(f"/Users/leeannquynhdo/Datalogi/MSc_thesis/unet_implementation/line_images/mask_{id}.png"))
        
        return X, y

In [4]:
params = {"batch_size": 1, # batch size should be one to avoid re-batching of already batched data. Dataset class returns batched data
          "shuffle": True,}
        #   "num_workers": 4,}

all_ids = range(100)

# Define the split lengths
train_len = int(len(all_ids) * 0.6)
val_len = int(len(all_ids) * 0.2)
test_len = len(all_ids) - train_len - val_len

# Use random_split to split the dataset
train_data, val_data, test_data = random_split(
    Dataset(all_ids),
    [train_len, val_len, test_len]
)

training_generator= DataLoader(train_data, **params)
validation_generator = DataLoader(val_data, **params)
test_generator = DataLoader(test_data, **params)

In [5]:
print("GPU?:", torch.cuda.is_available())

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet_model = unet().to(device)
loss_func = nn.BCELoss().to(device)
optimizer = optim.Adam(unet_model.parameters(), lr=0.001)


# ### Functions to train model
def training_step(model, dataset):
    model.train()
    running_loss = 0

    for batch_pair in dataset:
        optimizer.zero_grad()
        train_images = batch_pair[0].to(device)
        train_labels = batch_pair[1].to(device)
        outputs = model(train_images)
        # print(outputs)
        loss = loss_func(outputs, train_labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(dataset)

def validation_step(model, dataset):
    model.eval()
    validation_loss = 0
    with torch.no_grad():
        for batch_pair in dataset:
            val_images = batch_pair[0].to(device)
            val_labels = batch_pair[1].to(device)
            val_outputs = model(val_images)
            validation_loss += loss_func(val_outputs, val_labels).item()
    return validation_loss / len(dataset)

def train_until_convergence(model, train_set, val_set, epochs, patience):
    best_val_loss = np.inf
    no_improvement = 0
    time_diff = 0
    train_loss_list = []
    val_loss_list = []

    total_start = datetime.datetime.now()

    for epoch in range(epochs):
        start_time = datetime.datetime.now()
        sys.stdout.write("\rCurrently at epoch: " + str(epoch+1) + ". Estimated time remaining: {}\n".format(time_diff*(epochs - epoch)))

        train_loss = training_step(model, train_set)
        val_loss = validation_step(model, val_set)
        train_loss_list.append(train_loss)
        val_loss_list.append(val_loss)

        print(f"Epoch {epoch+1}: \t Training Loss: {train_loss}, \t Validation Loss: {val_loss}")

        end_time = datetime.datetime.now()
        time_diff = end_time - start_time

        if (epoch%50==0): #save every model every 50th iter
            torch.save(model.state_dict(), f'models_final/unet_test_local/unet_model_mse_epoch_{epoch}.pth' )

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            no_improvement = 0
        else:
            no_improvement += 1

        if no_improvement >= patience:
            print(f"No improvement in validation loss for {patience} epochs. Stopping training...")
            break

    total_end = datetime.datetime.now()
    total_time = total_end-total_start
    print("Total running time for unet_lon", total_time)
    
    return model, train_loss_list, val_loss_list


GPU?: False


In [6]:

if should_train:

    # # Train the model LET's FuCKING SKRRRrrRT

    trained_model, train_losses, val_losses = train_until_convergence(unet_model, training_generator, validation_generator, epochs=500, patience=20)
    # save the trained model
    torch.save(trained_model.state_dict(), path_to_trained_model)
    # print('train losses:', train_losses)


    # write loss to file
    with open(path_to_train_loss, 'w') as f:
        for loss in train_losses:
            f.write("%s\n" % loss)
    
    with open(path_to_val_loss, 'w') as f:
        for loss in val_losses:
            f.write("%s\n" % loss)

Currently at epoch: 1. Estimated time remaining: 0
Epoch 1: 	 Training Loss: 0.2217104136943817, 	 Validation Loss: 0.09844397492706776
Currently at epoch: 2. Estimated time remaining: 8:32:46.173841
Epoch 2: 	 Training Loss: 0.06296202093362809, 	 Validation Loss: 0.09465408977121115
Currently at epoch: 3. Estimated time remaining: 7:54:24.233808
Epoch 3: 	 Training Loss: 0.030356917437165974, 	 Validation Loss: 0.038785685785114765
Currently at epoch: 4. Estimated time remaining: 8:13:17.548764
Epoch 4: 	 Training Loss: 0.018020158726722003, 	 Validation Loss: 0.04369545471854508
Currently at epoch: 5. Estimated time remaining: 7:57:11.907520
Epoch 5: 	 Training Loss: 0.012019593253110845, 	 Validation Loss: 0.019240303430706263
Currently at epoch: 6. Estimated time remaining: 7:48:43.910100
Epoch 6: 	 Training Loss: 0.008378676325082778, 	 Validation Loss: 0.06860718303360044
Currently at epoch: 7. Estimated time remaining: 7:58:05.393906
Epoch 7: 	 Training Loss: 0.0064939437356467

In [7]:
# load the trained model
trained_model = unet().to(device)
trained_model.load_state_dict(torch.load(path_to_trained_model, map_location=device)) # path to trained model

<All keys matched successfully>

In [8]:
def test_step(model, dataset):
    model.eval()
    test_images = []
    test_labels = []
    test_outputs = []
    with torch.no_grad():
        for batch_pair in dataset:
            test_image = batch_pair[0].to(device)
            # print(test_image)
            test_label = batch_pair[1].to(device)
            # print(test_label)
            test_output = model(test_image)
            # print(test_output)
            
            test_images.append(test_image.detach().cpu())
            test_labels.append(test_label.detach().cpu())
            test_outputs.append(test_output.detach().cpu())
    return test_images, test_labels, test_outputs

In [9]:
test_images, test_labels, test_outputs = test_step(trained_model, test_generator)

In [10]:
for i in range(len(test_images)):
    fig, ax = plt.subplots(1,3,figsize=(15,5))
    ax[0].imshow(test_images[i].squeeze(0).permute(1,2,0), cmap='gray')
    ax[0].set_title('image')
    ax[1].imshow(test_labels[i].squeeze(0).permute(1,2,0), cmap='gray')
    ax[1].set_title('ground truth')
    ax[2].imshow(test_outputs[i].squeeze(0).permute(1,2,0), cmap='gray')
    ax[2].set_title('prediction')
    # plt.show()
    plt.savefig(f'/Users/leeannquynhdo/Datalogi/MSc_thesis/unet_implementation/models_final/unet_test_local/lon_unet_test_result_{i}.png', dpi=500, bbox_inches='tight')
    plt.close()

In [11]:
train_losses = np.loadtxt(path_to_train_loss)
val_losses = np.loadtxt(path_to_val_loss)

loss_x = list(range(len(train_losses)))
plt.plot(loss_x, train_losses, label='Train loss')
plt.plot(loss_x, val_losses, label='Validation loss')
plt.title('Train / Validation loss')
plt.legend()
plt.savefig('/Users/leeannquynhdo/Datalogi/MSc_thesis/unet_implementation/models_final/unet_test_local/unet_train_val_loss.png', dpi=500, bbox_inches='tight')
plt.close()

In [12]:
predictions = torch.stack(test_outputs)
labels = torch.stack(test_labels)

# assuming predictions and labels are PyTorch tensors with shape [4, 4096, 4096, 1]
predictions = predictions.squeeze().view(-1)
labels = labels.squeeze().view(-1)


from sklearn.metrics import roc_curve, roc_auc_score
fpr, tpr, ts = roc_curve(labels.int().view(-1).numpy(), predictions.view(-1).numpy())
auc = roc_auc_score(labels.int().view(-1).numpy(), predictions.view(-1).numpy())

plt.title('U-Net ROC curve')
plt.plot(fpr, tpr, label=f'AUC score = {auc}')
plt.legend()
plt.savefig('/Users/leeannquynhdo/Datalogi/MSc_thesis/unet_implementation/models_final/unet_test_local/roc_curve_mse.png', dpi=500, bbox_inches='tight')
plt.close()

predictions = (predictions > 0.5).float()

accuracy = torch.eq(predictions, labels).sum().item() / len(predictions)
print("accuracy:", accuracy)

def calculate_iou(pred, target):
    intersection = torch.logical_and(pred, target).sum()
    union = torch.logical_or(pred, target).sum()
    iou = intersection.float() / union.float()
    return iou

iou = calculate_iou(predictions, labels)
print("iou:", iou)

from sklearn.metrics import f1_score

f1 = f1_score(labels.int(), predictions.int())
print("f1:", f1)

accuracy: 0.9956756591796875
iou: tensor(0.0985)
f1: 0.17926440776136696
