## TODO
* 

### CURRENT OPTIMAL SETTINGS 
Who knows

In [1]:
# Library imports
import wandb
import os
import numpy as np
import torch
import pandas as pd

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint


import dataset 
import metrics
import etm
import data
import utls as utl
import model as my_models
import settings

# Settings

In [2]:
# Imported settings
emb_size = settings.emb_size
num_topics = settings.num_topics
rho_size = settings.rho_size
enc_drop = settings.enc_drop
t_hidden_size = settings.t_hidden_size
theta_act = settings.theta_act
batch_size = settings.batch_size
batch_size_test = settings.batch_size_test
emb_np_path = settings.emb_np_path
emb_path = settings.emb_path

checkpoint = None

# training settings
max_epochs = 20
accumulate_grad_batches = 1
log_every_n_steps = 1
val_check_interval = 1

device_num = 3
device = torch.device(device_num)
# lr = 25
# lr_w = 25

# model settings
gamma=0.7 #Learning rate step gamma (default: 0.7)
seed=42 #random seed (default: 42)
save_model=False #save the trained model (default: False)

# misc settings
no_cuda=False #disables CUDA training (default: True)
use_cuda = not no_cuda and torch.cuda.is_available()
kwargs = {'num_workers': 10, 'pin_memory': True}
torch.manual_seed(seed)

print("Device:", device)

Device: cuda:3


### Data Directories

In [3]:
project_path = '..'
os.chdir(project_path)
print(os.getcwd())

/home/text_tripletMarginLoss


## Make Data and Model

In [4]:
# Make Data
vocab,vocab_size = data.get_vocab()
train_loader, test_loader = data.make_data(vocab, 
                            vocab_size, kwargs)

# Make Embeddings
embeddings = data.load_embeddings('./Data/embeddings.npz.npy', device)

In [5]:
# Define hyperparameters
lr = 1
lr_w = 1250

In [7]:
# Make Model
# if checkpoint:
#     m = my_models.TripletNet.load_from_checkpoint(checkpoint, embeddings, vocab_size, device, lr, lr_w)
# else: 
m = my_models.TripletNet(embeddings, vocab_size, device, lr=lr, freeze_encoder=False, margin=1)

trained_model = torch.load(settings.dict_path)
m.etm.load_state_dict(trained_model.state_dict())

# Callbacks
early_stopping = EarlyStopping(monitor='val_loss', patience=10)
checkpoint = ModelCheckpoint(monitor='val_loss', dirpath='./models/checkpoints/')
callbacks = [early_stopping, checkpoint]

# logging
logger = WandbLogger(save_dir='./savedata/', project='qualitative-analysis')

trainer = Trainer(accelerator='ddp',
                  max_epochs=max_epochs,
                  accumulate_grad_batches=accumulate_grad_batches, 
                  gpus=[device_num],
                  callbacks=callbacks,
                  logger=logger,
                  reload_dataloaders_every_epoch=True,
                  log_every_n_steps=log_every_n_steps,
                  val_check_interval=val_check_interval)

trainer.fit(m, train_loader, test_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/1
[34m[1mwandb[0m: Currently logged in as: [33mwitw[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.33 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade



  | Name                | Type              | Params
----------------------------------------------------------
0 | triplet_margin_loss | TripletMarginLoss | 0     
1 | etm                 | ETM               | 26.7 M
2 | weigh               | GetWeightedTopics | 0     
----------------------------------------------------------
26.7 M    Trainable params
0         Non-trainable params
26.7 M    Total params
106.855   Total estimated model params size (MB)


Validation sanity check:   0%|          | 0/1 [00:00<?, ?it/s]



Epoch 0:  90%|█████████ | 9/10 [00:02<00:00,  4.44it/s, loss=0.96, v_num=c4fw] 
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/1 [00:00<?, ?it/s][A
Epoch 0: 100%|██████████| 10/10 [00:03<00:00,  2.67it/s, loss=0.96, v_num=c4fw]
Epoch 1:  90%|█████████ | 9/10 [00:02<00:00,  3.97it/s, loss=0.938, v_num=c4fw]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/1 [00:00<?, ?it/s][A
Epoch 1: 100%|██████████| 10/10 [00:03<00:00,  2.66it/s, loss=0.938, v_num=c4fw]
Epoch 2: 100%|██████████| 10/10 [00:02<00:00,  4.57it/s, loss=0.879, v_num=c4fw]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/1 [00:00<?, ?it/s][A
Epoch 2: 100%|██████████| 10/10 [00:03<00:00,  2.57it/s, loss=0.879, v_num=c4fw]
Epoch 3: 100%|██████████| 10/10 [00:01<00:00,  5.06it/s, loss=0.824, v_num=c4fw]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/1 [00:00<?, ?it/s][A
Epoch 3: 100%|██████████| 10/10 [00:03<00:00,  2.89it/s, loss=0.824, v_num=c4fw]
Epoch 4

1

In [None]:
m.weigh.W

# Training and Testing

In [None]:
torch.save(m.state_dict(), "triplet_model3.pt")

In [None]:
torch.save(m.etm.state_dict(),'triplet_etm3.pt')

# Debugging - run below cells

def printgradnorm(self, grad_input, grad_output):
    print(f"Input: 1-norm={torch.linalg.norm(grad_input[0], 1)}, shape={grad_input[0].shape}")
    print(f"Output: 1-norm={torch.linalg.norm(grad_output[0], 1)}, shape={grad_output[0].shape}")


#hook_handles.append(m.etm.q_theta[2].register_backward_hook(printgradnorm))

hook_handles = []

# Visualize Topic Embeddings

In [None]:
# get the trained topic weights
weights = m.distance.topic_map.lin2.weight.detach().cpu().numpy()
# weights = m.distance.topic_map.detach().cpu().numpy()
weights

In [None]:
# get the indices for the top topics
top_topics = np.argsort(-np.abs(weights))
top_topics, weights[top_topics]

In [None]:
# get the words for a topic
def get_words(topic_ix, num_words=10):
    word_distribution = m.etm.get_beta()[topic_ix].detach().cpu().numpy()
    top_words = np.argsort(-word_distribution)[:num_words]

    return [vocab[ix] for ix in top_words]

In [None]:
for topic in top_topics:
    print(f'Topic {topic} : {get_words(topic)}')