In [None]:
%load_ext autoreload
%autoreload 2
import pickle 
import torch
import numpy as np
import pandas as pd
from model import *
from datagen.genset import GenSet
from training_helpers import *
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score

torch.multiprocessing.set_start_method('spawn')
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(device)
torch.set_default_dtype(torch.float32)

In [None]:
train = GenSet(128*3, randseed=0)
traini = DataLoader(train, batch_size=None, num_workers=18)
valid = pickle.load(open("./valid/valid_data.pkl", "rb"))
valid = (valid[0].to(device), valid[1].to(device), valid[2], valid[3])

In [None]:
ITERATIONS_PER_EPOCH = 40
# MODEL_NAME = "3class"
MODEL_NAME = None

train = GenSet(128*3, randseed=0)
traini = DataLoader(train, batch_size=None, num_workers=24)
torch.random.manual_seed(10)

if MODEL_NAME:
    params = pickle.load(open(f"./models/{MODEL_NAME}.pkl", "rb"))
    model = WCNFourierModel(**params).to(device)
    model.load_state_dict(torch.load(f"./models/{MODEL_NAME}.pt"))
    model.train()
else:
    params = {"samples": 700, "out": 3, "wavelet": "bior2.2"}
    model = WCNFourierModel(**params).to(device)

optim = torch.optim.AdamW(model.parameters(), lr=0.000075, weight_decay=10**-5.5, amsgrad=True)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, factor=0.5, patience=15, threshold=0.0001, 
                        threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, mode="max")
lossfn = torch.nn.CrossEntropyLoss().to(device)
epoch_training_losses = []
epoch_validations = []

epoch = 0
for batch, label in traini:
    epoch_training_loss = []
    optim.zero_grad()
    out = model(batch)
    loss = lossfn(out, label)
    loss.backward()
    optim.step()

    epoch_training_loss.append(loss.item())

    if epoch % ITERATIONS_PER_EPOCH == 0:
        model.eval()
        data, label, types, names = valid
        out = model(data.cuda())

        loss = lossfn(out, label.cuda())

        true_y = torch.argmax(label, dim=1).flatten().cpu().numpy()
        pred_y = torch.argmax(out, dim=1).flatten().cpu().numpy()

        epoch_training_losses.append(np.mean(epoch_training_loss))
        epoch_validations.append({"true_y": true_y, "pred_y": pred_y, "true_type_y": types, "names": names, "loss": loss.cpu().item()})
        f1 = f1_score(true_y, pred_y, average="macro")
        # save model to {f1}.pt
        if f1 > 0.916:
            torch.save(model.state_dict(), f"./models/{f1}.pt")
            pickle.dump(params, open(f"./models/{f1}.pkl", "wb"))
        scheduler.step(f1)
        if epoch % 10 * ITERATIONS_PER_EPOCH == 0:
            training_plot(epoch_training_losses, epoch_validations, xrange=750)
    
    epoch += 1

