In [1]:
import time
import torch
import gensim
import numpy as np
import snntorch as snn

from snntorch import functional as SF, utils as snnutils
import snntorch.functional.loss as snnloss

from torch.utils.data import DataLoader
from torch import nn, optim
from torch.nn import Parameter

from torchtext.data import utils
from collections import OrderedDict

from sklearn.metrics import precision_score, recall_score, f1_score

In [2]:
from dataset import AbstractDataset
from util import load_model_and_opt, save_model, batch_predict, load_from_cnn
from snn_util import AbstractHybrid, SpikingBCELoss, forward_pass, train_model

In [4]:
DEVICE = (f'cuda:0' if torch.cuda.is_available() else 'cpu')
HOME = '/home/hice1/khom9/CSE 8803 BMI Final Project'
CNN_VERSION = 3
# Path to the saved word embeddings file
EMBED_KEYS_PATH = f'{HOME}/abstracts200_trained_normalized_{CNN_VERSION}.wordvectors'

In [None]:
print(f'Using device {DEVICE}')
if torch.cuda.is_available():
    print(torch.cuda.get_device_name())

In [5]:
tk = utils.get_tokenizer('spacy') # use spacey tokenizer
wv = gensim.models.KeyedVectors.load(EMBED_KEYS_PATH, mmap='r') # Get the saved word vectors matrix
null_word = '\0' # Will be used to pad all abstracts to the same length (669 words)
d = AbstractDataset(f'{HOME}/CleanedAVdata.csv', 'Abstract', 'IPCR Classifications', tk, wv.key_to_index,
                    null_word=null_word, min_len=30) # PyTorch dataset for abstracts & their labels

  self.labels = torch.tensor(mlb.fit_transform(classes)).to_sparse_csr().type(torch.float)


In [11]:
batch_size = 48
beta = 1.0
lr = 1e-3
T = 45

save_path = f'{HOME}/hybrid.pth'
cnn_path = f'{HOME}/cnn_model-{CNN_VERSION}.pth'
cnn_act_path = f'{HOME}/cnn_model-{CNN_VERSION}-max-activations.pkl'

loader = DataLoader(d, batch_size=batch_size, shuffle=True)
model = AbstractSNN_2(T, EMBED_KEYS_PATH, null_word=null_word, beta=beta).to(DEVICE)

# Create a positive weight, such that we punish the model heavily for guessing 0 all the time.
num_pos = d.labels.sum(axis=0, keepdim=True).to_dense()
pos_weight = (d.labels.shape[0] - num_pos) / num_pos
# loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight.squeeze().to(DEVICE))
# loss_fn = snn.functional.ce_rate_loss()
# loss_fn = nn.BCELoss(weight=pos_weight.to(DEVICE))
loss_fn = SpikingBCELoss(pos_weight=pos_weight.squeeze().to(DEVICE), is_logits=True)

In [13]:
model = load_from_cnn(model, cnn_path, cnn_act_path)
optimizer = optim.Adam(model.parameters(), lr=lr)
model, optimizer = load_model_and_opt(model, optimizer, save_path)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
# optimizer.param_groups[0]['lr'] = 5e-5

In [14]:
epochs = 250
train_model(model, optimizer, loss_fn, T=T, epochs=epochs, batch_size=batch_size, save_freq=25, save_path=save_path, 
            scheduler=scheduler)

In [15]:
# save_path = f'.pth'
# save_model(save_path, model, optimizer, epochs)
# print(f'Saved to {save_path}')

In [16]:
pred = (batch_predict(model, d.abst_data, T))
pred_spk = (torch.sigmoid(pred.mean(dim=0)) > 0.5).type(torch.float)
true = d.labels.to_dense()

In [17]:
start = time.time()
total_loss = []
loss_fn_cpu = loss_fn.to('cpu')
for i in range(len(d)):
    total_loss.append(loss_fn_cpu(pred_spk[i].unsqueeze(0), true[i]).item())
    
print(f'Total avg loss: {np.mean(total_loss)}')

Total avg loss: 1.0346547530415238


In [18]:
# Print a sample prediction and true label
i = 8721 #8721 #18722
txt, label = d[i]
# txt = txt.unsqueeze(0)
loss_fn = loss_fn.to(DEVICE)
label = label.unsqueeze(0)

print(loss_fn(forward_pass(model, T, txt.to(DEVICE)).detach(), label.to(DEVICE)).item())
print(torch.cat([forward_pass(model, T, txt.to(DEVICE)).mean(dim=0).detach(), label.to(DEVICE)]).T)


0.06382208317518234
tensor([[-6.8264e+00,  0.0000e+00],
        [-4.8714e+01,  0.0000e+00],
        [-5.1413e+01,  0.0000e+00],
        [-5.1380e+01,  0.0000e+00],
        [-5.0150e+01,  0.0000e+00],
        [-5.3921e+01,  0.0000e+00],
        [-4.5920e+01,  0.0000e+00],
        [-8.4808e+00,  0.0000e+00],
        [-7.7649e+00,  0.0000e+00],
        [-1.9163e+01,  0.0000e+00],
        [-5.6534e+00,  0.0000e+00],
        [-3.6959e+01,  0.0000e+00],
        [-4.6629e+01,  0.0000e+00],
        [-1.6779e+01,  0.0000e+00],
        [-5.2093e+01,  0.0000e+00],
        [-4.9298e+01,  0.0000e+00],
        [-1.6542e+01,  0.0000e+00],
        [-4.6774e+01,  0.0000e+00],
        [-4.5695e+01,  0.0000e+00],
        [-3.1242e+01,  0.0000e+00],
        [-9.8940e+00,  0.0000e+00],
        [-4.8375e+01,  0.0000e+00],
        [-2.7642e+01,  0.0000e+00],
        [-4.7604e+01,  0.0000e+00],
        [-5.0324e+01,  0.0000e+00],
        [-4.7742e+01,  0.0000e+00],
        [-5.4373e+01,  0.0000e+00],
        

In [19]:
print(precision_score(true, pred_spk, average=None))
print(f'Total precision: {precision_score(true, pred_spk, average="weighted")}')

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


[0.41014493 0.         0.44       0.71428571 0.375      1.
 0.55555556 0.50834598 0.49651325 0.4137931  0.53977273 0.69230769
 0.2        0.66666667 0.83333333 0.         0.44144144 0.58823529
 0.33333333 0.58536585 0.51832461 0.         0.71428571 0.33333333
 0.56521739 0.4        0.53333333 0.56521739 0.33333333 0.5
 0.77800424 0.52747253 0.45326504 0.34020619 0.47452229 0.47986577
 0.45539906 0.57894737 0.         0.5        0.21428571 0.33333333
 0.30769231 1.         0.75       0.33333333 0.5        1.
 0.5        1.         0.66666667 0.5        0.61904762 0.47826087
 0.38709677 0.33333333 0.42201835 0.43258427 0.85714286 0.43478261
 0.57627119 0.50413223 0.4375     0.425      0.6        0.33935743
 0.75       0.6043956  0.525      0.41304348 0.         0.47619048
 0.41666667 0.42857143 0.62014013 0.47744361 0.63636364 0.52631579
 0.66647635 0.7346983  0.48343648 0.53342028 0.40345369 0.54029851
 0.47142857 0.52755906 0.44026549 0.41246291 0.42307692 0.66152614
 0.47474747]
Total

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [20]:
print(recall_score(true, pred_spk, average=None))
print(f'Total recall: {recall_score(true, pred_spk, average="weighted")}')

[0.99647887 0.         0.91666667 1.         1.         1.
 1.         0.98820059 0.99164345 0.92307692 1.         0.85714286
 0.33333333 1.         1.         0.         0.94230769 0.47619048
 1.         0.6        0.96116505 0.         0.83333333 1.
 0.54166667 1.         0.72727273 1.         1.         1.
 0.78724633 0.97959184 0.98951782 1.         0.95819936 0.97945205
 0.98979592 1.         0.         0.8        0.5        1.
 0.8        1.         1.         0.88888889 1.         1.
 1.         0.5        1.         1.         1.         1.
 0.92307692 1.         0.93877551 1.         1.         1.
 1.         0.98387097 0.875      1.         0.85714286 0.95480226
 1.         1.         0.95454545 1.         0.         1.
 0.90909091 1.         0.84342302 1.         1.         1.
 0.73171751 0.79321637 0.98809524 0.92612511 1.         1.
 1.         1.         1.         1.         1.         0.94832041
 0.98947368]
Total recall: 0.8333566056930841


In [21]:
print(f'Total F1 score: {f1_score(true, pred_spk, average="weighted")}')

Total F1 score: 0.7258079382167745


# 