In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import time
import torch
import gensim
import numpy as np

from torch import optim
from torchtext.data import utils

from sklearn.metrics import precision_score, recall_score, f1_score

In [None]:
from dataset import AbstractDataset
from snn_util import AbstractSNN_2, SpikingBCELoss, forward_pass, train_model
from util import batch_predict, load_from_cnn

In [None]:
DEVICE = (f'cuda:0' if torch.cuda.is_available() else 'cpu')
HOME = '/home/hice1/khom9/CSE 8803 BMI Final Project'
CNN_VERSION = 6
# Path to the saved word embeddings file
EMBED_KEYS_PATH = f'{HOME}/wordvectors/abstracts200_trained_normalized_{CNN_VERSION}.wordvectors'

In [None]:
print(f'Using device {DEVICE}')
if torch.cuda.is_available():
    print(torch.cuda.get_device_name())

In [None]:
tk = utils.get_tokenizer('spacy') # use spacy tokenizer
wv = gensim.models.KeyedVectors.load(EMBED_KEYS_PATH, mmap='r') # Get the saved word vectors matrix
null_word = '\0' # Will be used to pad all abstracts to the same length (669 words)
d = AbstractDataset(f'{HOME}/CleanedAVdata.csv', 'Abstract', 'IPCR Classifications', tk, wv.key_to_index,
                    null_word=null_word, min_len=30, verbose=False) # PyTorch dataset for abstracts & their labels

In [None]:
batch_size = 48
beta = 1.0
lr = 1e-4
T = 45

snn_version = 8
save_path = f'{HOME}/models/snn_model-{snn_version}.pth'
cnn_path = f'{HOME}/models/cnn_model-{CNN_VERSION}.pth'
cnn_act_path = f'{HOME}/models/cnn_model-{CNN_VERSION}-max-activations.pkl'

model = AbstractSNN_2(T, EMBED_KEYS_PATH, null_word=null_word, beta=beta).to(DEVICE)

# Create a positive weight, such that we punish the model heavily for guessing 0 all the time.
num_pos = d.labels.sum(axis=0, keepdim=True).to_dense()
pos_weight = (d.labels.shape[0] - num_pos) / num_pos
loss_fn = SpikingBCELoss(pos_weight=pos_weight.squeeze().to(DEVICE))

In [None]:
model = load_from_cnn(model, cnn_path, cnn_act_path, num_lif=5, wv=wv)
optimizer = optim.NAdam(model.parameters(), lr=lr)
# model, optimizer = load_model_and_opt(model, optimizer, save_path)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=150, gamma=0.1)

In [None]:
epochs = 250
train_model(model, optimizer, d, loss_fn, T=T, epochs=epochs, batch_size=batch_size, save_freq=25, save_path=save_path,
            scheduler=scheduler, device=DEVICE)

In [None]:
pred = (batch_predict(model, d.abst_data, T=T, fw_pass_fn=forward_pass, device=DEVICE))
pred_spk = ((pred.mean(dim=0)) > 0.5).type(torch.float)
true = d.labels.to_dense()

In [None]:
start = time.time()
total_loss = []
loss_fn_cpu = loss_fn.to('cpu')
for i in range(len(d)):
    total_loss.append(loss_fn_cpu(pred_spk[i].unsqueeze(0), true[i]).item())
    
print(f'Total avg loss: {np.mean(total_loss)}')

In [None]:
# Print a sample prediction and true label
i =18722
txt, label = d[i]
# txt = txt.unsqueeze(0)
loss_fn = loss_fn.to(DEVICE)
label = label.unsqueeze(0)

print(loss_fn(forward_pass(model, T, txt.to(DEVICE)).detach(), label.to(DEVICE)).item())
print(torch.cat([forward_pass(model, T, txt.to(DEVICE)).mean(dim=0).detach(), label.to(DEVICE)]).T)


In [None]:
print(precision_score(true, pred_spk, average=None))
print(f'Total precision: {precision_score(true, pred_spk, average="weighted")}')

In [None]:
print(recall_score(true, pred_spk, average=None))
print(f'Total recall: {recall_score(true, pred_spk, average="weighted")}')

In [None]:
print(f'Total F1 score: {f1_score(true, pred_spk, average="weighted")}')

In [None]:
from snntorch import spikeplot as splt
from matplotlib import pyplot as plt
spk, mem = forward_pass(model, T, txt.to(DEVICE), return_mem=True)
spk, mem = spk.squeeze(), mem.squeeze()
splt.traces(mem, spk=spk, dim=(10,9))
fig = plt.gcf()
fig.set_size_inches(8, 6)

# 