In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm

In [2]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [3]:
import sys
sys.path.append('/content/drive/MyDrive/siamese-registration')

In [4]:
from datasets import RandomTransformationDataset
from models import *

In [8]:
data_path = "/content/drive/MyDrive/data"
output_path = "/content/drive/MyDrive/outputs/16_sub_MSE_fc"

In [6]:
train_dataset = RandomTransformationDataset(
    transforms=transforms.Compose([
        transforms.ToTensor(),
    ]),
    path=os.path.join(data_path, "train.pkl"),
    path_prefix="/content/drive/MyDrive",
    tr_only=False
)

test_dataset = RandomTransformationDataset(
    transforms=transforms.Compose([
        transforms.ToTensor(),
    ]),
    path=os.path.join(data_path, "test.pkl"),
    path_prefix="/content/drive/MyDrive",
    tr_only=False
)


In [7]:
train_loader = DataLoader(train_dataset, batch_size=16, num_workers=10, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, num_workers=5, shuffle=True)
#train_loader = DataLoader(train_dataset, batch_size=1, num_workers=1, shuffle=True)
#test_loader = DataLoader(test_dataset, batch_size=1, num_workers=1, shuffle=True)



In [None]:
def mape_loss(output, target, c=0.0001):
    return torch.mean(torch.abs((target - output) / (target + c)))

In [None]:
def mse_loss(input, target):
    return ((input - target) ** 2).mean()

In [None]:
def norm_mse_loss(input, target):
    input_norm = nn.functional.normalize(input)
    target_norm = nn.functional.normalize(target)
    #print(f"{input=}, {input_norm=}, {target=}, {target_norm=}")
    #print(f"loss: {((input - target) ** 2).mean()}, nomr_loss: {((input_norm - target_norm) ** 2).mean()}")
    return ((input_norm - target_norm) ** 2).mean()

In [None]:
def weighted_mse_loss(input, target, weight):
    return (weight * (input - target) ** 2).mean()

In [9]:
pretrained = False
#model = siamese_resnet18(1, 7, "subtraction", channels=[512, 128, 64])
model = initial_siamese_resnet18(1, 7, "subtraction")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device=device)
criterion = nn.MSELoss()
#criterion = norm_mse_loss
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

if pretrained:
    checkpoint_path = "11_resnet18_wMSE_corr_reg_4/checkpoint-2.pt"
    checkpoint = torch.load(os.path.join("/content/drive/MyDrive/outputs", checkpoint_path), map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    #optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch_check = checkpoint['epoch']
    training_loss_check = checkpoint['training_loss']
    validation_loss_check = checkpoint['validation_loss']


print(f"Running on {device}")

Running on cuda


In [None]:
training_loss = training_loss_check if pretrained else []
validation_loss = validation_loss_check if pretrained else []
start = epoch_check+1 if pretrained else 0
end = epoch_check+25 if pretrained else 25

# add it to criterion() below
# mse_weights = torch.Tensor([0.1, 0.1, 100, 100, 50, 50, 30]).to(device=device)


for epoch in range(start, end):
    model.train()
    running_loss = 0.0
    with tqdm(train_loader, unit="batch") as progress:
        for img0, img1, params in progress:
            progress.set_description(f"Epoch {epoch} - train")
            img0, img1, params = img0.to(device=device), img1.to(device=device), params.to(device=device)
            optimizer.zero_grad()
            outputs = model(img0, img1)
            loss = criterion(outputs, params)
            loss_item = loss.item()
            running_loss += loss_item
            loss.backward()
            optimizer.step()
            progress.set_postfix(loss=loss_item)

    training_loss.append(running_loss / len(train_loader))

    model.eval()
    val_running_loss = 0.0
    with tqdm(test_loader, unit="batch") as validation_progress:
        for img0, img1, params in validation_progress:
            validation_progress.set_description(f"Epoch {epoch} - valid")
            img0, img1, params = img0.to(device=device), img1.to(device=device), params.to(device=device)
            outputs = model(img0, img1)
            loss = criterion(outputs, params)
            loss_item = loss.item()
            val_running_loss += loss_item
            validation_progress.set_postfix(loss=loss_item)

    validation_loss.append(val_running_loss / len(test_loader))

    #torch.save(model.state_dict(), os.path.join(output_path, f"model-{epoch}.pt"))
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'training_loss': training_loss,
            'validation_loss': validation_loss
            }, os.path.join(output_path, f"checkpoint-{epoch}.pt"))

    plt.figure()
    plt.plot(training_loss, label="training loss")
    plt.plot(validation_loss, label="validation loss")
    plt.title("Training loss")
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend()
    plt.savefig(os.path.join(output_path, "loss.png"))
    plt.close()

Epoch 0 - train: 100%|██████████| 1938/1938 [21:26<00:00,  1.51batch/s, loss=156]
Epoch 0 - valid: 100%|██████████| 1000/1000 [06:41<00:00,  2.49batch/s, loss=111]
Epoch 1 - train: 100%|██████████| 1938/1938 [20:29<00:00,  1.58batch/s, loss=32.7]
Epoch 1 - valid: 100%|██████████| 1000/1000 [02:37<00:00,  6.33batch/s, loss=17]
Epoch 2 - train: 100%|██████████| 1938/1938 [20:28<00:00,  1.58batch/s, loss=131]
Epoch 2 - valid: 100%|██████████| 1000/1000 [02:40<00:00,  6.25batch/s, loss=42.7]
Epoch 3 - train: 100%|██████████| 1938/1938 [20:29<00:00,  1.58batch/s, loss=54.6]
Epoch 3 - valid: 100%|██████████| 1000/1000 [02:36<00:00,  6.37batch/s, loss=18.5]
Epoch 4 - train: 100%|██████████| 1938/1938 [20:29<00:00,  1.58batch/s, loss=16.1]
Epoch 4 - valid: 100%|██████████| 1000/1000 [02:41<00:00,  6.21batch/s, loss=21.6]
Epoch 5 - train: 100%|██████████| 1938/1938 [20:30<00:00,  1.58batch/s, loss=32.4]
Epoch 5 - valid: 100%|██████████| 1000/1000 [02:38<00:00,  6.29batch/s, loss=7.25]
Epoch 6 -

In [None]:
plt.figure()
plt.plot(training_loss, label="training loss")
plt.plot(validation_loss, label="validation loss")
plt.title("Training loss")
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend()
plt.show()

In [None]:
import csv

with open(os.path.join(output_path, "losses.csv"), 'w') as f:
    writer = csv.writer(f)
    writer.writerow(['training_loss','validation_loss'])
    writer.writerows(zip(training_loss, validation_loss))