In [None]:
import numpy as np
import torch
import os
import pandas as pd
from torch import Tensor, nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

from thesisproject.models import UNet
from thesisproject.train import training_loop
from thesisproject.data import ImageTKRDataset, SliceLoader

In [None]:
class EncodeTrain(nn.Module):
    def __init__(self, unet):
        super().__init__()
        self.unet = unet
        self.encoding_size = unet.encoding_size
        self.fc = nn.Linear(self.encoding_size, 2)
        
    def forward(self, x):
        x = self.unet.forward(x, encode=True)
        x = self.fc(x)
        return x

In [None]:
class SquarePad:
    def __call__(self, image: Tensor):
        imsize = image.shape
        max_edge = np.argmax(imsize)
        pad_amounts = [imsize[max_edge] - imsize[0], imsize[max_edge] - imsize[1], imsize[max_edge] - imsize[2]]

        padding = [int(np.floor(pad_amounts[0] / 2)),
                   int(np.ceil(pad_amounts[0] / 2)),
                   int(np.floor(pad_amounts[1] / 2)),
                   int(np.ceil(pad_amounts[1] / 2)),
                   int(np.floor(pad_amounts[2] / 2)),
                   int(np.ceil(pad_amounts[2] / 2)),] #left, right, top, bottom, front, back
        padding = tuple(padding[::-1])

        padded_im = F.pad(image, padding, "constant", 0)
        return padded_im
    
def collate(data):
    return data

In [None]:
unet_path = "model_saves/toy_unet.pt"

"""
label_keys = ["Lateral femoral cart.",
                "Lateral meniscus",
                "Lateral tibial cart.",
                "Medial femoral cartilage",
                "Medial meniscus",
                "Medial tibial cart.",
                "Patellar cart.",
                "Tibia"]
"""
label_keys = ["Sphere"]

#unet = UNet(1, 9, 384, class_names=label_keys)
unet = UNet(1, 2, 100, class_names=label_keys)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
unet.to(device)

unet_state_dict = torch.load(unet_path)
unet.load_state_dict(unet_state_dict)

net = EncodeTrain(unet)
net.to(device)
print("Loaded encoder")

In [None]:
rows = []
for filename in os.listdir("../toy-data/train/images"):
    split = filename.split("-")
    subject_id_and_knee = "-".join([split[0], split[1], "L"])
    TKR = split[2][0] == "S"
    visit = int(split[1][1:])
    rows.append({
        "filename": filename,
        "subject_id_and_knee": subject_id_and_knee,
        "is_right": False,
        "TKR": TKR,
        "visit": visit
    })
    
subjects_df = pd.DataFrame(rows)
subjects_df.to_csv("../toy_train_subjects.csv")

rows = []
for filename in os.listdir("../toy-data/val/images"):
    split = filename.split("-")
    subject_id_and_knee = "-".join([split[0], split[1], "L"])
    TKR = split[2][0] == "S"
    visit = int(split[1][1:])
    rows.append({
        "filename": filename,
        "subject_id_and_knee": subject_id_and_knee,
        "is_right": False,
        "TKR": TKR,
        "visit": visit
    })
    
subjects_df = pd.DataFrame(rows)
subjects_df.to_csv("../toy_val_subjects.csv")

In [None]:
train = ImageTKRDataset("../toy-data/train/images/", "../toy_train_subjects.csv", predict_mode=False, image_transform=SquarePad())
val = ImageTKRDataset("../toy-data/val/images/", "../toy_val_subjects.csv", predict_mode=False, image_transform=SquarePad())

train_slicer = SliceLoader(train, slices_per_epoch=1000)
val_slicer = SliceLoader(val, slices_per_epoch=500)

train_loader = DataLoader(train_slicer, batch_size=16, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_slicer, batch_size=16, num_workers=8, pin_memory=True)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=5e-5)

In [None]:
num_epochs = 30

layout = {
        "Loss": {"loss": ["Multiline", ["loss/train", "loss/validation"]]}
    }

writer = SummaryWriter()
writer.add_custom_scalars(layout)

scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=2)

# Load checkpoint if continuing
checkpoint_path = os.path.join("model_saves", "toy_encoder_checkpoint.pt")
model_path = os.path.join("model_saves", f"toy_encoder.pt")
start_epoch = 0

cont = False
if cont:
    checkpoint = torch.load(checkpoint_path)
    net.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    start_epoch = checkpoint["epoch"] + 1
    scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
    print(f"Continuing training from epoch {start_epoch}")


# Training
for epoch in range(start_epoch, num_epochs):
    net.train()
    pbar = tqdm(total=len(train_loader) + len(val_loader), position=0, leave=True)
    pbar.set_description(f"Epoch {epoch} training")
    train_loss = 0.0
    num_batches = 0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data[0].to(device), data[0].to(device)
        print(inputs.shape)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        batch_samples = inputs.shape[0]
        current_loss = loss.item() / batch_samples
        train_loss += current_loss

        num_batches += 1

        #pbar.update(inputs.shape[0])
        pbar.update(1)

    pbar.set_description(f"Epoch {epoch} validation")
    net.eval()
    # Validation
    with torch.no_grad():
        val_loss = 0.0
        val_accuracy = 0.0
        val_precision = 0.0
        val_recall = 0.0
        val_specificity = 0.0
        val_dice = 0.0
        #val_per_class_dice = np.zeros(net.n_classes - 1)
        

        num_val_batches = 0

        for i, data in enumerate(val_loader, 0):
            inputs, labels = data[0].image.to(device), data[1].to(device)

            outputs = net(inputs)

            loss = criterion(outputs, labels)

            # save statistics
            batch_samples = inputs.shape[0]
            current_loss = loss.item() / batch_samples
            metrics = get_multiclass_metrics(outputs.detach().cpu(), labels.detach().cpu(), remove_bg=True)

            val_loss += current_loss
            val_accuracy += np.mean(metrics["accuracy"])
            val_precision += np.mean(metrics["precision"])
            val_recall += np.mean(metrics["recall"])
            val_specificity += np.mean(metrics["specificity"])
            val_dice += np.mean(metrics["dice"])
            #val_per_class_dice += metrics["dice"]

            num_val_batches += 1

            pbar.update(1)

        # Write to tensorboard
        #overlay_fig, _ = create_overlay_figure(inputs, labels, outputs, images_per_batch=4)

        #writer.add_figure("images/val", overlay_fig, epoch)

        writer.add_scalar("validation_metrics/accuracy", val_accuracy/num_val_batches, epoch)
        writer.add_scalar("validation_metrics/precision", val_precision/num_val_batches, epoch)
        writer.add_scalar("validation_metrics/recall", val_recall/num_val_batches, epoch)
        writer.add_scalar("validation_metrics/specificity", val_specificity/num_val_batches, epoch)
        writer.add_scalar("validation_metrics/dice", val_dice/num_val_batches, epoch)

        writer.add_scalar("loss/validation", val_loss/num_val_batches, epoch)

        writer.add_scalar("loss/train", train_loss/num_batches, epoch)

        writer.add_scalar("learning rate", optimizer.param_groups[0]['lr'], epoch)

    # Save model checkpoints
    torch.save({
        "epoch": epoch,
        "model_state_dict": net.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict()
    }, checkpoint_path)

    # Step learning rate scheduler
    scheduler.step(val_dice/num_val_batches)

# Save final model
torch.save(net.state_dict(), model_path)
pbar.close()

print('Finished Training')