In [None]:
%load_ext autoreload
%autoreload 2
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchmetrics import ROC, AUROC
from persim import plot_diagrams, PersistenceImager
import numpy as np
import matplotlib.pyplot as plt
import pickle
import glob
import os
import pandas as pd
import sys
sys.path.append("../src")
from datasets import *
from models import *

In [None]:
data_dir = "../preprocessed"
metadata_path = "../Data/UCSF-PDGM-metadata_v2.csv"

#"""
filtration_type = "alpha"
imgr = PersistenceImager(pixel_size=2, birth_range=(0, 24), pers_range=(0, 24), 
                          kernel_params={'sigma':2}, weight_params={'n':1.5})
layers = np.array([])
lr = 1e-3
l2_lam = 1e-2
#"""

"""
filtration_type = "cubical"
imgr = PersistenceImager(pixel_size=0.1, birth_range=(-1,1), pers_range=(0, 2), 
                          kernel_params={'sigma':0.1}, weight_params={'n':1.5})
layers = np.array([])
lr = 2e-4
l2_lam = 5e-3
#layers = np.arange(18) # Exclude the random convolutions
"""


dataset_train = PImageTumorDataset(data_dir, metadata_path, imgr, filtration_type, is_training=True, layers=layers)
dataset_test = PImageTumorDataset(data_dir, metadata_path, imgr, filtration_type, is_training=False, layers=layers)

model = PImgCNNBinary(dataset_train, 3, 4)
#model = PImgShallowCNNBinary(dataset_train, 1)

print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

total_params = 0
for p in model.parameters():
    p = np.prod(list(p.shape))
    total_params += p
    print(p)
print("Total parameters:", total_params)

In [None]:
n_epochs = 1000
loss_fn = torch.nn.BCEWithLogitsLoss()
roc = ROC(task="binary")
losses = []
accs_train = []
accs_test = []
aurocs_train = []
aurocs_test = []
for epoch in range(n_epochs):
    model.train(True)
    training_loader = DataLoader(dataset_train, batch_size=16, shuffle=True)
    total_loss = 0
    for i, data in enumerate(training_loader):
        inputs, labels = data
        inputs = inputs[:, 0, :, :, :]
        
        optimizer.zero_grad()
        outputs = model(inputs)

        loss = loss_fn(outputs[:, 0], labels)
        for p in model.parameters():
            loss += l2_lam*torch.sum(p*p)
        loss.backward()

        optimizer.step()
        total_loss += loss.item()
    losses.append(total_loss)
    
    acc_train = get_sigmoid_accuracy(model, dataset_train)    
    acc_test = get_sigmoid_accuracy(model, dataset_test)
    accs_test.append(acc_test)
    accs_train.append(acc_train)
    
    auroc_train = get_auroc(model, dataset_train)    
    auroc_test = get_auroc(model, dataset_test)
    aurocs_test.append(auroc_test)
    aurocs_train.append(auroc_train)
    
    if epoch > 0 and epoch%20 == 0:
        print("loss {:.3f}".format(total_loss), end=", ")
        print("train: {:.3f}".format(acc_train), end=", ")
        print("test: {:.3f}".format(acc_test))
        plt.figure(figsize=(12, 12))
        plt.subplot(311)
        plt.plot(losses)
        plt.title("Losses")
        plt.xlabel("Epoch")
        
        plt.subplot(312)
        plt.plot(accs_train)
        plt.plot(accs_test)
        plt.legend(["Training ({:.3f})".format(accs_train[-1]), "Validation ({:.3f})".format(accs_test[-1])])
        plt.title("Accuracy")
        plt.xlabel("Epoch")
        plt.tight_layout()
        
        plt.subplot(313)
        plt.plot(aurocs_train)
        plt.plot(aurocs_test)
        plt.legend(["Training ({:.3f})".format(aurocs_train[-1]), "Validation ({:.3f})".format(aurocs_test[-1])])
        plt.title("AUROC")
        plt.xlabel("Epoch")
        plt.tight_layout()
        
        plt.show()