In [None]:
import torch
import torch.optim as optim
from torchinfo import summary

from KSI_models import KSI, ModifiedKSI, CAML
from KSI_utils import load_KSI_data, train_model, test_model

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
n_embedding = 100
n_hidden = 100
batch_size = 32
n_epochs = 1
save = False
profile = False

In [None]:
dir = 'data/original/'
loaders, wikivec, word_to_ix = load_KSI_data(dir=dir, 
                                             batch_size=batch_size, 
                                             train=True, 
                                             val=True, 
                                             test=True, 
                                             device=DEVICE)
train_dataloader = loaders['train']
val_dataloader = loaders['val']
test_dataloader = loaders['test']

n_wiki, n_vocab = wikivec.shape
n_words = len(word_to_ix)

In [None]:
# note_lengths = []
# for data in train_dataloader:
#     n, _, _ = data
#     note_lengths.append(n.shape[1])
# avg_note_size = np.round(np.array(note_lengths).mean()).astype(int)

avg_note_size = 2455

In [None]:
base_model = CAML(n_words, n_wiki, n_embedding, n_hidden)
base_model = base_model.to(DEVICE)
base_summary = summary(base_model, [(batch_size, avg_note_size), 
                                    (batch_size, n_vocab)], 
                       dtypes=[torch.int, torch.float])

base_summary

In [None]:
optimizer = optim.Adam(base_model.parameters())
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 
                                          max_lr=0.01, 
                                          steps_per_epoch=len(train_dataloader), 
                                          epochs=n_epochs)
prof_base = train_model(base_model, 
                        train_dataloader=train_dataloader,
                        val_dataloader=val_dataloader,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        n_epochs=n_epochs,
                        profile=profile, 
                        log_path='./log/CAML',
                        device=DEVICE)

In [None]:
if save:
    torch.save(base_model, f'{dir}CAML_model.pt')
if profile:
    print(prof_base.key_averages(group_by_stack_n=5).table(sort_by='self_cuda_time_total'))

In [None]:
tt_recall_at_k, tt_micro_f1, tt_macro_f1, tt_micro_auc, tt_macro_auc, label_aucs_base = test_model(base_model, 
                                                                                                   test_dataloader, 
                                                                                                   wikivec,
                                                                                                   by_label=False,
                                                                                                   device=DEVICE)
print(f'Test Recall@10: {tt_recall_at_k:.4f}, Test Micro F1: {tt_micro_f1:.4f}, Test Macro F1: {tt_macro_f1:.4f}' +
      f', Test Micro AUC: {tt_micro_auc:.4f}, Test Macro AUC: {tt_macro_auc:.4f}')
del base_model
if DEVICE == 'cuda':
    torch.cuda.empty_cache()

In [None]:
ksi = KSI(n_embedding, n_vocab)
ksi.to(DEVICE)
model = CAML(n_words, n_wiki, n_embedding, n_hidden, ksi=ksi)
model = model.to(DEVICE)
ksi_summary = summary(model, [(batch_size, avg_note_size), 
                              (batch_size, n_vocab),
                              (n_wiki, n_vocab)], 
                      dtypes=[torch.int, torch.float, torch.float])

ksi_summary

In [None]:
optimizer = optim.Adam(model.parameters())
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 
                                          max_lr=0.01, 
                                          steps_per_epoch=len(train_dataloader), 
                                          epochs=n_epochs)
prof_ksi = train_model(model, 
                       train_dataloader=train_dataloader,
                       val_dataloader=val_dataloader,
                       wikivec=wikivec,
                       optimizer=optimizer,
                       scheduler=scheduler,
                       n_epochs=n_epochs, 
                       profile=profile, 
                       log_path='./log/CAML_KSI',
                       device=DEVICE)

In [None]:
if save:
    torch.save(model, f'{dir}CAML_KSI_model.pt')
if profile:
    print(prof_ksi.key_averages(group_by_stack_n=5).table(sort_by='self_cuda_time_total'))

In [None]:
tt_recall_at_k, tt_micro_f1, tt_macro_f1, tt_micro_auc, tt_macro_auc, label_aucs_ksi = test_model(model, 
                                                                                                  test_dataloader, 
                                                                                                  wikivec,
                                                                                                  by_label=True,
                                                                                                  device=DEVICE)
print(f'Test Recall@10: {tt_recall_at_k:.4f}, Test Micro F1: {tt_micro_f1:.4f}, Test Macro F1: {tt_macro_f1:.4f}' +
      f', Test Micro AUC: {tt_micro_auc:.4f}, Test Macro AUC: {tt_macro_auc:.4f}')
del model
if DEVICE == 'cuda':
    torch.cuda.empty_cache()

In [None]:
# run modified KSI using frequency vectors rather than binary vectors
dir = 'data/original_freqs/'
loaders, wikivec, word_to_ix = load_KSI_data(dir=dir, 
                                             batch_size=batch_size, 
                                             train=True, 
                                             val=True, 
                                             test=True, 
                                             device=DEVICE)
train_dataloader = loaders['train']
val_dataloader = loaders['val']
test_dataloader = loaders['test']

n_wiki, n_vocab = wikivec.shape
n_words = len(word_to_ix)

In [None]:
mod_ksi = ModifiedKSI(n_embedding, n_vocab)
mod_ksi.to(DEVICE)
mod_model = CAML(n_words, n_wiki, n_embedding, n_hidden, ksi=mod_ksi)
mod_model = mod_model.to(DEVICE)
mod_summary = summary(mod_model, [(batch_size, avg_note_size), 
                                  (batch_size, n_vocab),
                                  (n_wiki, n_vocab)], 
                      dtypes=[torch.int, torch.float, torch.float])

mod_summary

In [None]:
optimizer = optim.Adam(mod_model.parameters())
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, 
                                          max_lr=0.01, 
                                          steps_per_epoch=len(train_dataloader), 
                                          epochs=n_epochs)
prof_mod_ksi = train_model(mod_model, 
                           train_dataloader=train_dataloader,
                           val_dataloader=val_dataloader,
                           wikivec=wikivec,
                           optimizer=optimizer,
                           scheduler=scheduler,
                           n_epochs=n_epochs, 
                           profile=profile, 
                           log_path='./log/CAML_ModifiedKSI',
                           device=DEVICE)

In [None]:
if save:
    torch.save(mod_model, f'{dir}CAML_ModifiedKSI_model.pt')
if profile:
    print(prof_mod_ksi.key_averages(group_by_stack_n=5).table(sort_by='self_cuda_time_total'))

In [None]:
tt_recall_at_k, tt_micro_f1, tt_macro_f1, tt_micro_auc, tt_macro_auc, label_aucs_mod = test_model(mod_model, 
                                                                                                  test_dataloader, 
                                                                                                  wikivec,
                                                                                                  by_label=True,
                                                                                                  device=DEVICE)
print(f'Test Recall@10: {tt_recall_at_k:.4f}, Test Micro F1: {tt_micro_f1:.4f}, Test Macro F1: {tt_macro_f1:.4f}' +
      f', Test Micro AUC: {tt_micro_auc:.4f}, Test Macro AUC: {tt_macro_auc:.4f}')
del mod_model
if DEVICE == 'cuda':
    torch.cuda.empty_cache()