In [None]:
%load_ext autoreload
%autoreload 2
from autoencoder import *
from varnet import VARnet
import numpy as np
import torch
from torch.utils.data import DataLoader
import pandas as pd
import plotly.graph_objects as go
from IPython.display import display, clear_output
from dataloader import Trainer

# torch.multiprocessing.set_start_method('spawn')
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(device)
torch.set_default_dtype(torch.float32)

In [None]:
training_data = pd.read_parquet('/home/mpaz/neovar/secondary/data/training_data.parquet')
training_data = training_data[training_data["type"] != "rscvn"]
print(len(training_data))
print(training_data.columns)

In [None]:
traind = []
validd = []
for type_ in training_data['type'].unique():
    indices = training_data['type'] == type_
    oftype = training_data[indices]
    traind.append(oftype.sample(frac=0.9))
    validd.append(oftype.drop(traind[-1].index))

traind = pd.concat(traind)
validd = pd.concat(validd)

print(training_data.value_counts('type'))
print(traind.value_counts('type'))
print(validd.value_counts('type'))

# trainer = Trainer(traind, 64, 4096, False, True, multithread=True, bin_overlap_frac=0.25)
# print(len(trainer))
valid = Trainer(validd, 64, 4096, False, True, multithread=True, bin_overlap_frac=0.25)
print(len(valid))

print(torch.sum(torch.isnan(trainer.tensor)))
print(torch.sum(torch.isinf(trainer.tensor)))

In [None]:
import pickle
with open("trainer.pkl", "wb") as f:
    pickle.dump(trainer, f)
with open("valid.pkl", "wb") as f:
    pickle.dump(valid, f)

In [None]:
import pickle
with open("trainer.pkl", "rb") as f:
    trainer = pickle.load(f)
with open("valid.pkl", "rb") as f:
    valid = pickle.load(f)

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, f1_score
bins = 64
# model = VARnet(128, len(trainer.types), learnsamples=False, wavelet="haar", infeatures=2)
model = Morphologic(bins, len(trainer.types), 2)
model.load_state_dict(torch.load("model/morpho.pth"))
print(sum(p.numel() for p in model.parameters()))
epochs = 1500
loss = torch.nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00015, weight_decay=10**-6)
# optimizer = torch.optim.SGD(model.parameters(), lr=0.0015, weight_decay=10**-6)

losses = []
validlosses = []
f1s = []
model.train()
iterlim = 50
for e in range(epochs):
    model.train()
    epochlosses = []
    validations = []
    i=0
    for data, target in trainer:
        if i > iterlim:
            break
        i += 1
        out = model(data.cuda())
        if torch.sum(torch.isinf(out)) > 0:
            print(out)
            raise ValueError("NaN in output")
        if torch.sum(torch.isinf(data)) > 0:
            print(data)
            raise ValueError("NaN in input")
        if torch.sum(torch.isinf(target)) > 0:
            print(target)
            raise ValueError("NaN in target")

        l = loss(out, target.cuda())
        l.backward()
        optimizer.step()
        optimizer.zero_grad()
        epochlosses.append(l.item())
    

    # print a confusion matrix
    epoch_validlosses = []
    model.eval()
    for data, target in valid:
        out = model(data.cuda())
        epochlosses.append(loss(out, target.cuda()).item())
        pred = torch.argmax(out, dim=1).squeeze()
        target = torch.argmax(target, dim=1).squeeze()
        validations.append((target.cpu(), pred.cpu()))
        epoch_validlosses.append((target.cpu(), pred.cpu()))
    
    validlosses.append(np.mean(epoch_validlosses))
    true_y = torch.cat([x[0] for x in validations]).cpu().numpy()
    pred_y = torch.cat([x[1] for x in validations]).cpu().numpy()
    print("N Valids: ", len(true_y))
    f1 = f1_score(true_y, pred_y, average="macro")
    f1s.append(f1)
    cm = confusion_matrix(true_y, pred_y)
    disp = ConfusionMatrixDisplay(cm, display_labels=["ea", "ew", "lpv", "rot", "rr", "cep", "yso"])
    losses.append(np.mean(epochlosses))
    fig = go.Figure()
    fig.add_trace(go.Scatter(y=losses, line=dict(color="blue"), mode="lines", name="train"))
    fig.add_trace(go.Scatter(y=f1s, line=dict(color="red"), mode="lines", name="f1"))
    # fig.add_trace(go.Scatter(y=validlosses, line=dict(color="orange"), mode="lines", name="valid"))

    if e % 1 == 0:
        clear_output(wait=True)
        print(f"Epoch {e} loss: {np.mean(epochlosses)}")
        print(f"Epoch {e} valid loss: {np.mean(epoch_validlosses)}")
        print(f"Epoch {e} f1: {f1}")
        display(fig)
        print(confusion_matrix(true_y, pred_y))
    if e % 20 == 0:
        disp.plot()

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, f1_score
bins = 64
modelb = VARnet(512, len(trainer.types), learnsamples=False, wavelet="db8", infeatures=2)
# model = Morphologic(bins, len(trainer.types), 2)
print(sum(p.numel() for p in model.parameters()))
epochs = 300
loss = torch.nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.AdamW(modelb.parameters(), lr=0.0015)

losses = []
validlosses = []
f1s = []
modelb.train()
iterlim = 50
for e in range(epochs):
    modelb.train()
    epochlosses = []
    validations = []
    i=0
    for data, target in trainer:
        if i > iterlim:
            break
        data = torch.repeat_interleave(data, 2, dim=1)
        i += 1
        out = modelb(data.cuda())
        if torch.sum(torch.isinf(out)) > 0:
            print(out)
            raise ValueError("NaN in output")
        if torch.sum(torch.isinf(data)) > 0:
            print(data)
            raise ValueError("NaN in input")
        if torch.sum(torch.isinf(target)) > 0:
            print(target)
            raise ValueError("NaN in target")

        l = loss(out, target.cuda())
        l.backward()
        optimizer.step()
        optimizer.zero_grad()
        epochlosses.append(l.item())
    

    # print a confusion matrix
    epoch_validlosses = []
    modelb.eval()
    for data, target in valid:
        out = modelb(data.cuda())
        epochlosses.append(loss(out, target.cuda()).item())
        pred = torch.argmax(out, dim=1).squeeze()
        target = torch.argmax(target, dim=1).squeeze()
        validations.append((target.cpu(), pred.cpu()))
        epoch_validlosses.append((target.cpu(), pred.cpu()))
    
    validlosses.append(np.mean(epoch_validlosses))
    true_y = torch.cat([x[0] for x in validations]).cpu().numpy()
    pred_y = torch.cat([x[1] for x in validations]).cpu().numpy()
    print("N Valids: ", len(true_y))
    f1 = f1_score(true_y, pred_y, average="macro")
    f1s.append(f1)
    cm = confusion_matrix(true_y, pred_y)
    disp = ConfusionMatrixDisplay(cm, display_labels=trainer.types)
    losses.append(np.mean(epochlosses))
    fig = go.Figure()
    fig.add_trace(go.Scatter(y=losses, line=dict(color="blue"), mode="lines", name="train"))
    fig.add_trace(go.Scatter(y=f1s, line=dict(color="red"), mode="lines", name="f1"))
    # fig.add_trace(go.Scatter(y=validlosses, line=dict(color="orange"), mode="lines", name="valid"))

    if e % 1 == 0:
        clear_output(wait=True)
        print(f"Epoch {e} loss: {np.mean(epochlosses)}")
        print(f"Epoch {e} valid loss: {np.mean(epoch_validlosses)}")
        print(f"Epoch {e} f1: {f1}")
        display(fig)
        print(confusion_matrix(true_y, pred_y))
        # disp.plot()
    if e % 20 == 0:
        disp.plot()

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, f1_score, classification_report
import matplotlib.pyplot as plt
model = Morphologic(64, len(valid.types), 2)
model.load_state_dict(torch.load("/home/mpaz/neovar/secondary/subclassifier/model/morpho.pth", map_location=device))
model.eval()

preds = []
trues = []
for data, label in valid:
    out = model(data.cuda()).detach().cpu()
    pred = torch.argmax(out, dim=1).squeeze().cpu()
    label = torch.argmax(label, dim=1).squeeze().cpu()
    preds.append(pred)
    trues.append(label)

pred_y = torch.cat(preds).numpy()
true_y = torch.cat(trues).numpy()

cm = confusion_matrix(true_y, pred_y)

disp = ConfusionMatrixDisplay(cm, display_labels=valid.types)
disp.plot(cmap='Blues')
disp.im_.colorbar.remove()
plt.savefig("confusion_matrix.png", dpi=500)
rep = classification_report(true_y, pred_y, target_names=valid.types)
print(rep)
stats = classification_report(true_y, pred_y, target_names=valid.types, output_dict=True)

In [None]:
cm = confusion_matrix(true_y, pred_y)
disp = ConfusionMatrixDisplay(cm, display_labels=valid.types)
labels = valid.types
disp.plot(cmap='Blues')
disp.im_.colorbar.remove()
plt.show()

stats = classification_report(true_y, pred_y, target_names=valid.types, output_dict=True)
theta1 = np.linspace(0, 2*np.pi, len(valid.types), endpoint=False) - np.pi/14 + 0.025
theta2 = np.linspace(0, 2*np.pi, len(valid.types), endpoint=False) + np.pi/14 + 0.025
categories = valid.types

ax = plt.subplot(polar=True)
ax.bar(theta1, [stats[type_]['precision'] for type_ in categories], width=2*np.pi/len(categories) - np.pi/7 - 0.025, align='center', alpha=1, edgecolor='k', color="mediumblue", linewidth=1, label="Precision")
ax.bar(theta2, [stats[type_]['recall'] for type_ in categories], width=2*np.pi/len(categories) - np.pi/7 - 0.025, align='center', alpha=1, edgecolor='k', color="cornflowerblue", linewidth=1, label="Recall")
ax.set_xticks(theta1 + np.pi/14, minor=False)
ax.set_xticklabels(categories)
# make the x tick marks invisible but keep the labels
ax.tick_params(axis='x', which='both', length=0)
# each pi/7 degrees, draw a dotted spoke
for i in range(len(categories)):
    ax.plot([theta1[i] - np.pi/14, theta1[i]- np.pi/14], [0, 0.75], color="grey", linestyle="--", label=None)


r = (np.linspace(0,2*np.pi, 10000, endpoint=False) + 0.1) % (2*np.pi)
get_current_f1 = lambda r: stats[labels[int(np.floor(7*(r - 0.025 + np.pi/7) / (2*np.pi))) % len(labels)]]["f1-score"]
y = np.vectorize(get_current_f1)(r)
ax.plot(r, y, color="red", linewidth=1.25)

# set legend, cornflower blue = recall, medium blue = precision in very upper right

ax.legend(loc=(0.9, 0.9))

ax.set_yticks([0.5, 0.75])
ax.set_ylim(0, 0.75)
ax.yaxis.grid(True, linestyle='--', color='black')
ax.xaxis.grid(False)
plt.savefig("precision.png", dpi=500)
plt.show()