# Model training example

In [1]:
import os
import pandas as pd
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"

import sys
import torch
import torch.nn as nn
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import csv
import pandas as pd
import numpy as np

from torch.optim.lr_scheduler import ReduceLROnPlateau
from pathlib import Path
from torch_geometric.data import DataLoader
from sentence_transformers import SentenceTransformer, util

In [2]:
jtnames_dir = Path('../../')
src_dir = jtnames_dir / 'src'
sys.path.append(src_dir.as_posix())

from models.tgeometric.dataset import TorchGeometricDataset
from models.tgeometric.vocabulary import Vocabulary
from models.ggcn_lightning import GGCNLightning
from models.tgeometric.ggcn import GGCN
# from models.tgeometric.simple_decoder import SimpleDecoder

from data.scripts_for_filtration.method_name_tokenizer import tokenize
from utility.metrics import set_sentence_model, get_sentence_similarity

gpus=1

### Vocabulary

In [3]:
jtnames = str(jtnames_dir.absolute())
data_train = jtnames  + '/data/processed/fn_graphs_article/train'
data_val = jtnames  + '/data/processed/fn_graphs_article/val'
data_test = jtnames  + '/data/processed/fn_graphs_article/test'

# vocabulary = Vocabulary([Path(data_train), Path(data_val), Path(data_test)], tokenize)
# vocabulary.save(Path(vocabulary_path))

vocabulary_path = jtnames + '/data/processed/fn_graphs_article/vocab.pickle'

vocabulary = Vocabulary.load(Path(vocabulary_path)) # load vocab that waws used while training
train_dataset = TorchGeometricDataset(data_train, vocabulary)
val_dataset = TorchGeometricDataset(data_val, vocabulary)
test_dataset = TorchGeometricDataset(data_test, vocabulary)

pad_index = vocabulary.token_encoder['<pad>']

### Simple decoder

In [8]:
hparams = {
    # Train hparams
    'max_epochs': 50,
    'batch_size': 32,
    'es_mode': 'max',
    'es_patience': 10,
    'es_metric': 'val_acc',
    'lr': .005,
    'optimizer': 'Adam',
    'lr_scheduler': 'ReduceLROnPlateau',
    'lr_scheduler_metric': 'val_loss',
    # Model hparams
    'num_agg_steps': 5,
    'encoder_output_dim': 128,
    'encoder_step_norm': False,
    'node_type_embed_dim': 10,
    'node_attr_embed_dim': 10,
    'decoder': 'simple',
    'decoder_hidden_size': 20,
    'decoder_num_layers': 1,
    # Metrics configuration
    'sentence_encoder': 'msmarco-distilbert-base-v2',
}

model = GGCNLightning(vocabulary, pad_index, train_df = train_dataset, 
                      val_df = val_dataset, 
                      test_df = test_dataset, hparams = hparams)

In [None]:
patience=hparams['es_patience']
max_epochs=hparams['max_epochs']

logger = pl.loggers.CSVLogger("logs", name="article_training_simple_old_shuffled")
estop = pl.callbacks.early_stopping.EarlyStopping(monitor=hparams['es_metric'], patience=patience, mode = hparams['es_mode'])
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='article_training_simple_old_shuffled/',
    filename='article_training_simple_old_shuffled-{epoch:02d}-{val_acc:.2f}',
    save_top_k=3,
    mode='max',
)

trainer = pl.Trainer(
    gpus=gpus,
    max_epochs=max_epochs,
    logger=logger,
    callbacks=[estop, checkpoint_callback]
)
trainer.fit(model)

In [None]:
trainer.test()

### RNN 

In [4]:
hparams = {
    # Train hparams
    'max_epochs': 100,
    'batch_size': 32,
    'es_mode': 'max',
    'es_patience': 10,
    'es_metric': 'val_acc',
    'lr': .005,
    'optimizer': 'Adam',
    'lr_scheduler': 'ReduceLROnPlateau',
    'lr_scheduler_metric': 'val_acc_by_token',
    # Model hparams
    'num_agg_steps': 5,
    'encoder_output_dim': 128,
    'encoder_step_norm': False,
    'node_type_embed_dim': 10,
    'node_attr_embed_dim': 10,
    'decoder': 'rnn',
    'decoder_hidden_size': 20,
    'decoder_num_layers': 1,
    # Metrics configuration
    'sentence_encoder': 'msmarco-distilbert-base-v2',
}

model = GGCNLightning(vocabulary, pad_index, train_df = train_dataset, 
                      val_df = val_dataset, 
                      test_df = test_dataset, hparams = hparams)

In [None]:
patience=hparams['es_patience']
gpus=1
max_epochs=hparams['max_epochs']
logger = pl.loggers.CSVLogger("logs", name="article_training_rnn_old_shuffled")
estop = pl.callbacks.early_stopping.EarlyStopping(monitor=hparams['es_metric'], patience=patience, mode = hparams['es_mode'])
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath='article_training_rnn_old_shuffled/',
    filename='article_training_rnn_old_shuffled-07-04-{epoch:02d}-{val_acc:.2f}',
    save_top_k=3,
    mode='max',
)

trainer = pl.Trainer(
    gpus=gpus,
    max_epochs=max_epochs,
    logger=logger,
    callbacks=[estop, checkpoint_callback]
)
trainer.fit(model)

In [None]:
trainer.test()