In [1]:
import torch
import torch.optim as optim
from torchinfo import summary

from KSI_models import KSI, ModifiedKSI, CNN
from KSI_utils import load_KSI_data, train_model, test_model

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
n_embedding = 100
batch_size = 32
n_epochs = 25
save = True
profile = False

In [None]:
dir = 'data/original/'
loaders, wikivec, word_to_ix = load_KSI_data(dir=dir, 
                                             batch_size=batch_size, 
                                             train=True, 
                                             val=True, 
                                             test=True, 
                                             device=DEVICE)
train_dataloader = loaders['train']
val_dataloader = loaders['val']
test_dataloader = loaders['test']

n_wiki, n_vocab = wikivec.shape
n_words = len(word_to_ix)

In [7]:
# note_lengths = []
# for data in train_dataloader:
#     n, _, _ = data
#     note_lengths.append(n.shape[1])
# avg_note_size = np.round(np.array(note_lengths).mean()).astype(int)

avg_note_size = 2455

In [9]:
base_model = CNN(n_words, n_wiki, n_embedding)
base_model = base_model.to(DEVICE)
base_summary = summary(base_model, [(batch_size, avg_note_size), (batch_size, n_vocab)], dtypes=[torch.int, torch.float])

base_summary

Layer (type:depth-idx)                   Output Shape              Param #
CNN                                      --                        --
├─Embedding: 1-1                         [32, 2455, 100]           4,796,200
├─Dropout: 1-2                           [32, 2455, 100]           --
├─Conv1d: 1-3                            [32, 100, 2453]           30,100
├─Conv1d: 1-4                            [32, 100, 2452]           40,100
├─Conv1d: 1-5                            [32, 100, 2451]           50,100
├─Linear: 1-6                            [32, 344]                 103,544
Total params: 5,020,044
Trainable params: 5,020,044
Non-trainable params: 0
Total mult-adds (G): 9.60
Input size (MB): 1.87
Forward/backward pass size (MB): 251.25
Params size (MB): 20.08
Estimated Total Size (MB): 273.20

In [10]:
optimizer = optim.Adam(base_model.parameters())
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 
                                          max_lr=0.01, 
                                          steps_per_epoch=len(train_dataloader), 
                                          epochs=n_epochs)
prof_base = train_model(base_model, 
                        train_dataloader=train_dataloader,
                        val_dataloader=val_dataloader,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        n_epochs=n_epochs,
                        profile=profile, 
                        log_path='./log/CNN',
                        device=DEVICE)

Epoch: 001, Train Recall@10: 0.5913, Val Recall@10: 0.5895, Train Micro F1: 0.4216, Val Micro F1: 0.4183, Train Macro F1: 0.0302, Val Macro F1: 0.0360, Train Micro AUC: 0.9480, Val Micro AUC: 0.9364, Train Macro AUC: 0.6611, Val Macro AUC: 0.6603
Epoch: 002, Train Recall@10: 0.6999, Val Recall@10: 0.6923, Train Micro F1: 0.5550, Val Micro F1: 0.5478, Train Macro F1: 0.0756, Val Macro F1: 0.0895, Train Micro AUC: 0.9595, Val Micro AUC: 0.9492, Train Macro AUC: 0.7896, Val Macro AUC: 0.7591
Epoch: 003, Train Recall@10: 0.7343, Val Recall@10: 0.7240, Train Micro F1: 0.5750, Val Micro F1: 0.5656, Train Macro F1: 0.1008, Val Macro F1: 0.1178, Train Micro AUC: 0.9657, Val Micro AUC: 0.9552, Train Macro AUC: 0.8509, Val Macro AUC: 0.7840
Epoch: 004, Train Recall@10: 0.7554, Val Recall@10: 0.7375, Train Micro F1: 0.6057, Val Micro F1: 0.5822, Train Macro F1: 0.1481, Val Macro F1: 0.1534, Train Micro AUC: 0.9734, Val Micro AUC: 0.9603, Train Macro AUC: 0.8971, Val Macro AUC: 0.7915
Epoch: 005, 

In [11]:
if save:
    torch.save(base_model, f'{dir}CNN_model.pt')
if profile:
    print(prof_base.key_averages(group_by_stack_n=5).table(sort_by='self_cuda_time_total'))

In [12]:
tt_recall_at_k, tt_micro_f1, tt_macro_f1, tt_micro_auc, tt_macro_auc, label_aucs_base = test_model(base_model, 
                                                                                                   test_dataloader, 
                                                                                                   wikivec,
                                                                                                   by_label=False,
                                                                                                   device=DEVICE)
print(f'Test Recall@10: {tt_recall_at_k:.4f}, Test Micro F1: {tt_micro_f1:.4f}, Test Macro F1: {tt_macro_f1:.4f}' +
      f', Test Micro AUC: {tt_micro_auc:.4f}, Test Macro AUC: {tt_macro_auc:.4f}')
del base_model
if DEVICE == 'cuda':
    torch.cuda.empty_cache()

Test Recall@10: 0.7453, Test Micro F1: 0.6051, Test Macro F1: 0.1598, Test Micro AUC: 0.9624, Test Macro AUC: 0.7581


In [13]:
ksi = KSI(n_embedding, n_vocab)
ksi.to(DEVICE)
model = CNN(n_words, n_wiki, n_embedding, ksi=ksi)
model = model.to(DEVICE)
ksi_summary = summary(model, [(batch_size, avg_note_size), 
                              (batch_size, n_vocab),
                              (n_wiki, n_vocab)], 
                      dtypes=[torch.int, torch.float, torch.float])

ksi_summary

Layer (type:depth-idx)                   Output Shape              Param #
CNN                                      --                        --
├─KSI: 1-1                               --                        --
│    └─Linear: 2-1                       --                        (recursive)
│    └─Linear: 2-2                       --                        (recursive)
│    └─Linear: 2-3                       --                        (recursive)
├─Embedding: 1-2                         [32, 2455, 100]           4,796,200
├─Dropout: 1-3                           [32, 2455, 100]           --
├─Conv1d: 1-4                            [32, 100, 2453]           30,100
├─Conv1d: 1-5                            [32, 100, 2452]           40,100
├─Conv1d: 1-6                            [32, 100, 2451]           50,100
├─Linear: 1-7                            [32, 344]                 103,544
├─KSI: 1-1                               --                        --
│    └─Linear: 2-4                

In [14]:
optimizer = optim.Adam(model.parameters())
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 
                                          max_lr=0.01, 
                                          steps_per_epoch=len(train_dataloader), 
                                          epochs=n_epochs)
prof_ksi = train_model(model, 
                       train_dataloader=train_dataloader,
                       val_dataloader=val_dataloader,
                       wikivec=wikivec,
                       optimizer=optimizer,
                       scheduler=scheduler,
                       n_epochs=n_epochs, 
                       profile=profile, 
                       log_path='./log/CNN_KSI',
                       device=DEVICE)

Epoch: 001, Train Recall@10: 0.6860, Val Recall@10: 0.6777, Train Micro F1: 0.4977, Val Micro F1: 0.4913, Train Macro F1: 0.0719, Val Macro F1: 0.0844, Train Micro AUC: 0.9688, Val Micro AUC: 0.9607, Train Macro AUC: 0.8290, Val Macro AUC: 0.8214
Epoch: 002, Train Recall@10: 0.7488, Val Recall@10: 0.7337, Train Micro F1: 0.5747, Val Micro F1: 0.5618, Train Macro F1: 0.1345, Val Macro F1: 0.1537, Train Micro AUC: 0.9737, Val Micro AUC: 0.9652, Train Macro AUC: 0.8951, Val Macro AUC: 0.8605
Epoch: 003, Train Recall@10: 0.7740, Val Recall@10: 0.7558, Train Micro F1: 0.6019, Val Micro F1: 0.5818, Train Macro F1: 0.1905, Val Macro F1: 0.1917, Train Micro AUC: 0.9775, Val Micro AUC: 0.9684, Train Macro AUC: 0.9123, Val Macro AUC: 0.8674
Epoch: 004, Train Recall@10: 0.7981, Val Recall@10: 0.7692, Train Micro F1: 0.6313, Val Micro F1: 0.5986, Train Macro F1: 0.2742, Val Macro F1: 0.2200, Train Micro AUC: 0.9826, Val Micro AUC: 0.9714, Train Macro AUC: 0.9327, Val Macro AUC: 0.8596
Epoch: 005, 

In [15]:
if save:
    torch.save(model, f'{dir}CNN_KSI_model.pt')
if profile:
    print(prof_ksi.key_averages(group_by_stack_n=5).table(sort_by='self_cuda_time_total'))

In [16]:
tt_recall_at_k, tt_micro_f1, tt_macro_f1, tt_micro_auc, tt_macro_auc, label_aucs_ksi = test_model(model, 
                                                                                                  test_dataloader, 
                                                                                                  wikivec,
                                                                                                  by_label=True,
                                                                                                  device=DEVICE)
print(f'Test Recall@10: {tt_recall_at_k:.4f}, Test Micro F1: {tt_micro_f1:.4f}, Test Macro F1: {tt_macro_f1:.4f}' +
      f', Test Micro AUC: {tt_micro_auc:.4f}, Test Macro AUC: {tt_macro_auc:.4f}')
del model
if DEVICE == 'cuda':
    torch.cuda.empty_cache()

Test Recall@10: 0.7589, Test Micro F1: 0.5992, Test Macro F1: 0.2207, Test Micro AUC: 0.9691, Test Macro AUC: 0.8238


In [None]:
# run modified KSI using frequency vectors rather than binary vectors
dir = 'data/original_freqs/'
loaders, wikivec, word_to_ix = load_KSI_data(dir=dir, 
                                             batch_size=batch_size, 
                                             train=True, 
                                             val=True, 
                                             test=True, 
                                             device=DEVICE)
train_dataloader = loaders['train']
val_dataloader = loaders['val']
test_dataloader = loaders['test']

n_wiki, n_vocab = wikivec.shape
n_words = len(word_to_ix)

In [None]:
mod_ksi = ModifiedKSI(n_embedding, n_vocab)
mod_ksi.to(DEVICE)
mod_model = CNN(n_words, n_wiki, n_embedding, ksi=mod_ksi)
mod_model = mod_model.to(DEVICE)
mod_summary = summary(mod_model, [(batch_size, avg_note_size), 
                                  (batch_size, n_vocab),
                                  (n_wiki, n_vocab)], 
                      dtypes=[torch.int, torch.float, torch.float])

mod_summary

Layer (type:depth-idx)                   Output Shape              Param #
CNN                                      --                        --
├─ModifiedKSI: 1-1                       --                        --
│    └─Linear: 2-1                       --                        (recursive)
│    └─Linear: 2-2                       --                        (recursive)
│    └─Linear: 2-3                       --                        (recursive)
│    └─Linear: 2-4                       --                        (recursive)
├─Embedding: 1-2                         [32, 2455, 100]           4,796,200
├─Dropout: 1-3                           [32, 2455, 100]           --
├─Conv1d: 1-4                            [32, 100, 2453]           30,100
├─Conv1d: 1-5                            [32, 100, 2452]           40,100
├─Conv1d: 1-6                            [32, 100, 2451]           50,100
├─Linear: 1-7                            [32, 344]                 103,544
├─ModifiedKSI: 1-1       

In [None]:
optimizer = optim.Adam(mod_model.parameters())
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 
                                          max_lr=0.01, 
                                          steps_per_epoch=len(train_dataloader), 
                                          epochs=n_epochs)
prof_mod_ksi = train_model(mod_model, 
                           train_dataloader=train_dataloader,
                           val_dataloader=val_dataloader,
                           wikivec=wikivec,
                           optimizer=optimizer,
                           scheduler=scheduler,
                           n_epochs=n_epochs, 
                           profile=profile, 
                           log_path='./log/CNN_ModifiedKSI',
                           device=DEVICE)

Epoch: 001, Train Recall@10: 0.7014, Val Recall@10: 0.6963, Train Micro F1: 0.5071, Val Micro F1: 0.4998, Train Macro F1: 0.0824, Val Macro F1: 0.0988, Train Micro AUC: 0.9703, Val Micro AUC: 0.9634, Train Macro AUC: 0.8481, Val Macro AUC: 0.8548
Epoch: 002, Train Recall@10: 0.7532, Val Recall@10: 0.7482, Train Micro F1: 0.5805, Val Micro F1: 0.5721, Train Macro F1: 0.1324, Val Macro F1: 0.1587, Train Micro AUC: 0.9744, Val Micro AUC: 0.9682, Train Macro AUC: 0.8968, Val Macro AUC: 0.8909
Epoch: 003, Train Recall@10: 0.7652, Val Recall@10: 0.7596, Train Micro F1: 0.5838, Val Micro F1: 0.5720, Train Macro F1: 0.1571, Val Macro F1: 0.1732, Train Micro AUC: 0.9760, Val Micro AUC: 0.9690, Train Macro AUC: 0.9051, Val Macro AUC: 0.8782
Epoch: 004, Train Recall@10: 0.7964, Val Recall@10: 0.7782, Train Micro F1: 0.6242, Val Micro F1: 0.6033, Train Macro F1: 0.2452, Val Macro F1: 0.2187, Train Micro AUC: 0.9826, Val Micro AUC: 0.9740, Train Macro AUC: 0.9285, Val Macro AUC: 0.8738
Epoch: 005, 

In [None]:
if save:
    torch.save(mod_model, f'{dir}CNN_ModifiedKSI_model.pt')
if profile:
    print(prof_mod_ksi.key_averages(group_by_stack_n=5).table(sort_by='self_cuda_time_total'))

In [None]:
tt_recall_at_k, tt_micro_f1, tt_macro_f1, tt_micro_auc, tt_macro_auc, label_aucs_mod = test_model(mod_model, 
                                                                                                  test_dataloader, 
                                                                                                  wikivec,
                                                                                                  by_label=True,
                                                                                                  device=DEVICE)
print(f'Test Recall@10: {tt_recall_at_k:.4f}, Test Micro F1: {tt_micro_f1:.4f}, Test Macro F1: {tt_macro_f1:.4f}' +
      f', Test Micro AUC: {tt_micro_auc:.4f}, Test Macro AUC: {tt_macro_auc:.4f}')
del mod_model
if DEVICE == 'cuda':
    torch.cuda.empty_cache()

Test Recall@10: 0.7571, Test Micro F1: 0.6018, Test Macro F1: 0.2452, Test Micro AUC: 0.9703, Test Macro AUC: 0.8288
