# Imports

In [1]:
import warnings
import yaml
import pandas as pd

from algorithms.transformers.engine import Config, Trainer
from algorithms.transformers.dataset import get_dataloaders, get_datasets

warnings.filterwarnings('ignore')

# Config

In [2]:
with open("config.yaml", 'r') as file:
    file_config = yaml.safe_load(file)
    
config = Config(**file_config)
config.print_config()

Config
experiment_name                test2
root_dir                       /pscratch/sd/a/aryamaan
device                         cuda:0
train_batch_size               256
test_batch_size                256
train_split                    0.8
test_split                     0.1
primary_df                     ./FeynmanEquationsModified.csv
train_df                       ./data_new/train_df.csv
data_dir                       ./data_new
epochs                         1
seed                           42
use_half_precision             True
scheduler_type                 cosine_annealing
T_0                            10
T_mult                         1
T_max                          125000
optimizer_type                 adam
optimizer_lr                   5e-05
optimizer_momentum             0.9
optimizer_weight_decay         0.0001
clip_grad_norm                 -1
model_name                     seq2seq_transformer
hybrid                         True
embedding_size                 64
hidden_

# Dataset

In [3]:
df = pd.read_csv(config.primary_df)
input_df = pd.read_csv(config.train_df)

In [4]:
datasets, train_equations, test_equations = get_datasets(
    df,
    input_df,
    config.data_dir,
    [0.8, 0.1, 0.1] # train-val-test split
)

In [5]:
dataloaders = get_dataloaders(
    datasets,
    config.train_batch_size,
    config.train_batch_size,
    config.test_batch_size
)

# Train

In [6]:
trainer = Trainer(config, dataloaders)
trainer.train() # demonstration for 1 epoch only

[1/1] Train: 100%|██████████| 782/782 [06:42<00:00,  1.94it/s, loss=0.976]
[1/1] Valid: 100%|██████████| 98/98 [00:50<00:00,  1.94it/s]

==> Best Accuracy improved to 0.8779107 from -1





# Test

In [7]:
trainer.test_seq_acc()

[1/1] Test: 100%|██████████| 98/98 [00:42<00:00,  2.33it/s]


Calculating Sequence Accuracy for predictions (1 example per batch)


Test: 100%|██████████| 98/98 [01:19<00:00,  1.24it/s]

Test Accuracy: 0.7023745 | Valid Accuracy: -1
Test Sequence Accuracy: 0.0



