In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import gensim
import numpy as np

from torch.utils.data import DataLoader
from torch import optim

from torchtext.data import utils

from sklearn.metrics import precision_score, recall_score, f1_score

In [3]:
from dataset import AbstractDataset
from snn_util import AbstractSNN_1, SpikingBCELoss, train_model, forward_pass
from util import load_model_and_opt, save_model, batch_predict, load_from_cnn

In [4]:
DEVICE = (f'cuda:0' if torch.cuda.is_available() else 'cpu')
HOME = '/home/hice1/khom9/CSE 8803 BMI Final Project'
CNN_VERSION = 2
SNN_VERSION = 6
# Path to the saved word embeddings file
EMBED_KEYS_PATH = f'{HOME}/wordvectors/abstracts200_trained_normalized_{CNN_VERSION}.wordvectors'

In [5]:
print(f'Using device {DEVICE}')
if torch.cuda.is_available():
    print(torch.cuda.get_device_name())

Using device cuda:0
NVIDIA L40S


In [6]:
tk = utils.get_tokenizer('spacy') # use spacey tokenizer
wv = gensim.models.KeyedVectors.load(EMBED_KEYS_PATH, mmap='r') # Get the saved word vectors matrix
null_word = '\0' # Will be used to pad all abstracts to the same length (669 words)
d = AbstractDataset(f'{HOME}/CleanedAVdata.csv', 'Abstract', 'IPCR Classifications', tk, wv.key_to_index,
                    null_word=null_word, min_len=30, verbose=True) # PyTorch dataset for abstracts & their labels

100%|██████████| 23250/23250 [00:15<00:00, 1502.13it/s]


In [7]:
batch_size = 48
beta = 1.0
lr = 1e-5
T = 45

save_path = f'{HOME}/models/snn_model-{SNN_VERSION}.pth'
cnn_path = f'{HOME}/models/cnn_model-{CNN_VERSION}.pth'
cnn_act_path = f'{HOME}/models/cnn_model-{CNN_VERSION}-max-activations.pkl'

loader = DataLoader(d, batch_size=batch_size, shuffle=True)
model = AbstractSNN_1(T, EMBED_KEYS_PATH, null_word=null_word, beta=beta).to(DEVICE)

# Create a positive weight, such that we punish the model heavily for guessing 0 all the time.
num_pos = d.labels.sum(axis=0, keepdim=True)
pos_weight = (d.labels.shape[0] - num_pos) / num_pos
loss_fn = SpikingBCELoss(pos_weight=pos_weight.squeeze().to(DEVICE))

In [8]:
# Uncomment this line if training from scratch:
# model = load_from_cnn(model, cnn_path, cnn_act_path)

optimizer = optim.NAdam(model.parameters(), lr=lr)
# Uncomment the following lines if resuming training:
# model, optimizer = load_model_and_opt(model, optimizer, save_path)
# for g in optimizer.param_groups:
#     g['lr'] = lr
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=1.)

In [9]:
# epochs = 350
# train_model(model, optimizer, d, loss_fn, T=T, epochs=epochs, batch_size=batch_size, save_freq=10, save_path=save_path,
#             scheduler=scheduler, device=DEVICE)

In [10]:
# save_model(save_path, model, optimizer, epochs)
# print(f'Saved to {save_path}')

In [11]:
pred = batch_predict(model, d.abst_data, T=T, fw_pass_fn=forward_pass, device=DEVICE)
pred_spk = ((pred.mean(dim=0)) > 0.5).type(torch.float)
true = d.labels

In [12]:
total_loss = []
loss_fn_cpu = loss_fn.to('cpu')
for i in range(len(d)):
    total_loss.append(loss_fn_cpu(pred_spk[i].unsqueeze(0), true[i]).item())
    
print(f'Total avg loss: {np.mean(total_loss)}')

Total avg loss: 2.063510634458193


In [13]:
# Print a sample prediction and true label
i =18722
txt, label = d[i]
# txt = txt.unsqueeze(0)
loss_fn = loss_fn.to(DEVICE)
label = label.unsqueeze(0)

print(loss_fn(forward_pass(model, T, txt.to(DEVICE)).detach(), label.to(DEVICE)).item())
print(torch.cat([torch.tensor([list(range(91))]).to(DEVICE),                  # Index
                 forward_pass(model, T, txt.to(DEVICE)).mean(dim=0).detach(), # Prediction 
                 label.to(DEVICE)]).T)                                        # True label


1.5958837270736694
tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.0000e+00, 0.0000e+00, 0.0000e+00],
        [2.0000e+00, 0.0000e+00, 0.0000e+00],
        [3.0000e+00, 0.0000e+00, 0.0000e+00],
        [4.0000e+00, 0.0000e+00, 0.0000e+00],
        [5.0000e+00, 0.0000e+00, 0.0000e+00],
        [6.0000e+00, 0.0000e+00, 0.0000e+00],
        [7.0000e+00, 0.0000e+00, 0.0000e+00],
        [8.0000e+00, 0.0000e+00, 0.0000e+00],
        [9.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.0000e+01, 0.0000e+00, 0.0000e+00],
        [1.1000e+01, 0.0000e+00, 0.0000e+00],
        [1.2000e+01, 0.0000e+00, 0.0000e+00],
        [1.3000e+01, 0.0000e+00, 0.0000e+00],
        [1.4000e+01, 0.0000e+00, 0.0000e+00],
        [1.5000e+01, 0.0000e+00, 0.0000e+00],
        [1.6000e+01, 0.0000e+00, 0.0000e+00],
        [1.7000e+01, 0.0000e+00, 0.0000e+00],
        [1.8000e+01, 0.0000e+00, 0.0000e+00],
        [1.9000e+01, 0.0000e+00, 0.0000e+00],
        [2.0000e+01, 0.0000e+00, 0.0000e+00],
        [2.1000

In [14]:
print(precision_score(true, pred_spk, average=None))
print(f'Total precision: {precision_score(true, pred_spk, average="weighted")}')

[0.63926941 0.16666667 0.48       0.83333333 1.         0.5
 0.66666667 0.76566125 0.68250951 0.52       0.78512397 0.44680851
 0.5        0.63829787 0.55555556 0.14285714 0.43697479 0.53846154
 1.         0.5        0.70138889 0.5        0.85714286 0.5
 0.48979592 0.5        0.57894737 0.72222222 0.5        0.33333333
 0.88837438 0.67123288 0.69288577 0.57391304 0.65938865 0.64976959
 0.74242424 0.30555556 0.4        0.71428571 0.375      1.
 0.26315789 1.         0.75       0.40909091 0.5        1.
 0.5        0.28571429 0.28571429 1.         0.56521739 0.64179104
 0.62903226 0.5        0.5        0.50657895 0.46153846 0.47619048
 0.72340426 0.74846626 0.61538462 0.62962963 0.4375     0.59459459
 0.75       0.64705882 0.44       0.59259259 1.         0.37037037
 0.35483871 0.5        0.79847278 0.69577465 0.46226415 0.74074074
 0.82828729 0.88735949 0.74204545 0.77896696 0.61835749 0.74789916
 0.55       0.58878505 0.65131579 0.71428571 0.64705882 0.83638211
 0.71653543]
Total precis

In [15]:
print(recall_score(true, pred_spk, average=None))
print(f'Total recall: {recall_score(true, pred_spk, average="weighted")}')

[0.98591549 1.         1.         1.         1.         1.
 1.         0.97345133 1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         0.98058252 1.         1.         1.
 1.         1.         1.         1.         1.         1.
 0.74385415 1.         0.96645702 1.         0.97106109 0.96575342
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         1.
 1.         1.         1.         1.         1.         0.97727273
 1.         1.         0.95918367 1.         1.         1.
 1.         0.98387097 1.         1.         1.         0.99435028
 1.         1.         1.         0.84210526 1.         1.
 1.         1.         0.89872945 0.97244094 1.         1.
 0.78311743 0.90479532 0.97172619 0.94763657 0.99610895 0.98342541
 1.         0.94029851 0.99497487 0.98920863 1.         0.94516222
 0.95789474]
Total recall: 0.8590492908579521


In [16]:
print(f'Total F1 score: {f1_score(true, pred_spk, average="weighted")}')

Total F1 score: 0.8345685388269651


# 