In [2]:
import torch

# Import your custom modules
from config.Experiment_Config import (
    ExperimentConfig,
    ModelSettings,
    TrainingSettings,
    DataSettings,
    ModelSelection,
    TokenizerSettings
)
from models.LSTM.BiLSTM import BiLSTM
from models.LSTM.CNNBiLSTM import CNNBiLSTM
from utils.load_split import loader
from exp.trainer import Trainer
import matplotlib.pyplot as plt
import joblib


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model_settings = ModelSettings(
    input_dim=768,  # For BERT base
    output_dim=1,   # Number of classes
    hidden_dims=[512, 256, 128],
    embedding_dim=768,
    embedding_type='BERT_base_uncased',
    lags=5,
    dropout_rate=0.2
)

In [4]:
training_settings = TrainingSettings(
    num_epochs=100,
    batch_size=256,
    learning_rate=0.001,
    weight_decay=0.01,
    gradient_clip=1.0,
    early_stopping_patience=5,
    save_model_dir='saved_models',
    save_results_dir='results'
)


In [5]:
data_settings = DataSettings(
    which="question",
    train_size=0.8,
    val_size=0.1,
    test_size=0.1,
    batch_size=32,
    shuffle=True,
    drop_last=True
)

In [6]:
model_selection = ModelSelection(
    model_type='BiLSTM',
    use_cnn=True,
    cnn_layers=2
)

In [7]:
tokenizer_settings = TokenizerSettings(
    name="BERT_base_uncased",
    embedding_type="bert",
    truncation="ratio",  # or "equal"
    max_length=512,
    padding=True,
    add_special_tokens=True
)

In [8]:
config = ExperimentConfig(
    model_settings=model_settings,
    training_settings=training_settings,
    data_settings=data_settings,
    model_selection=model_selection,
    tokenizer_settings=tokenizer_settings
)

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config.training_settings.device = device

print(f"Using device: {device}")
print(f"Model type: {config.model_selection.model_type}")

Using device: cpu
Model type: BiLSTM


In [10]:
train_loader, val_loader, test_loader, label_encoder = loader(config)
joblib.dump(label_encoder, 'label_encoder.pkl')
num_classes = len(label_encoder.classes_)
config.model_settings.output_dim = num_classes

Token indices sequence length is longer than the specified maximum sequence length for this model (1024 > 512). Running this sequence through the model will result in indexing errors


In [11]:
if config.model_selection.model_type == 'BiLSTM':
    model = BiLSTM(config)
elif config.model_selection.model_type == 'CNNBiLSTM':
    model = CNNBiLSTM(config)

model = model.to(device)

In [12]:
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    test_loader=test_loader,
    config=config
)

# Train model
print("Starting training...")
metrics_history = trainer.train()

Starting training...
Training on device: cpu

Epoch 1/100 - train


Training:   0%|          | 0/127 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


torch.Size([32, 512]) torch.Size([32])


Training:   1%|          | 1/127 [00:11<23:44, 11.31s/it]

torch.Size([32, 512]) torch.Size([32])


Training:   2%|▏         | 2/127 [00:21<22:46, 10.93s/it]

torch.Size([32, 512]) torch.Size([32])


Training:   2%|▏         | 3/127 [00:32<22:34, 10.92s/it]

torch.Size([32, 512]) torch.Size([32])


Training:   3%|▎         | 4/127 [00:44<22:34, 11.02s/it]

torch.Size([32, 512]) torch.Size([32])


Training:   4%|▍         | 5/127 [00:55<22:38, 11.13s/it]

torch.Size([32, 512]) torch.Size([32])


Training:   5%|▍         | 6/127 [01:06<22:31, 11.17s/it]

torch.Size([32, 512]) torch.Size([32])


Training:   6%|▌         | 7/127 [01:17<22:19, 11.16s/it]

torch.Size([32, 512]) torch.Size([32])


Training:   6%|▋         | 8/127 [01:28<22:09, 11.17s/it]

torch.Size([32, 512]) torch.Size([32])


Training:   7%|▋         | 9/127 [01:40<21:59, 11.18s/it]

torch.Size([32, 512]) torch.Size([32])


Training:   8%|▊         | 10/127 [01:51<21:46, 11.17s/it]

torch.Size([32, 512]) torch.Size([32])


Training:   9%|▊         | 11/127 [02:02<21:38, 11.19s/it]

torch.Size([32, 512]) torch.Size([32])


Training:   9%|▉         | 12/127 [02:13<21:26, 11.18s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  10%|█         | 13/127 [02:25<21:18, 11.22s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  11%|█         | 14/127 [02:36<21:13, 11.27s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  12%|█▏        | 15/127 [02:47<21:02, 11.28s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  13%|█▎        | 16/127 [02:59<20:52, 11.29s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  13%|█▎        | 17/127 [03:10<20:41, 11.29s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  14%|█▍        | 18/127 [03:21<20:29, 11.28s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  15%|█▍        | 19/127 [03:32<20:21, 11.31s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  16%|█▌        | 20/127 [03:44<20:12, 11.33s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  17%|█▋        | 21/127 [03:55<20:05, 11.37s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  17%|█▋        | 22/127 [04:07<19:56, 11.39s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  18%|█▊        | 23/127 [04:18<19:49, 11.44s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  19%|█▉        | 24/127 [04:30<19:38, 11.44s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  20%|█▉        | 25/127 [04:41<19:27, 11.45s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  20%|██        | 26/127 [04:53<19:12, 11.41s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  21%|██▏       | 27/127 [05:04<19:01, 11.41s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  22%|██▏       | 28/127 [05:15<18:52, 11.43s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  23%|██▎       | 29/127 [05:27<18:42, 11.45s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  24%|██▎       | 30/127 [05:38<18:34, 11.49s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  24%|██▍       | 31/127 [05:50<18:25, 11.51s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  25%|██▌       | 32/127 [06:02<18:16, 11.54s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  26%|██▌       | 33/127 [06:13<18:07, 11.57s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  27%|██▋       | 34/127 [06:25<18:01, 11.63s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  28%|██▊       | 35/127 [06:37<17:52, 11.66s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  28%|██▊       | 36/127 [06:49<17:47, 11.73s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  29%|██▉       | 37/127 [07:00<17:37, 11.75s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  30%|██▉       | 38/127 [07:12<17:26, 11.76s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  31%|███       | 39/127 [07:24<17:16, 11.78s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  31%|███▏      | 40/127 [07:36<17:05, 11.79s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  32%|███▏      | 41/127 [07:48<16:54, 11.80s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  33%|███▎      | 42/127 [08:00<16:44, 11.82s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  34%|███▍      | 43/127 [08:11<16:34, 11.84s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  35%|███▍      | 44/127 [08:23<16:24, 11.86s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  35%|███▌      | 45/127 [08:35<16:17, 11.92s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  36%|███▌      | 46/127 [08:47<16:04, 11.91s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  37%|███▋      | 47/127 [08:59<15:56, 11.95s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  38%|███▊      | 48/127 [09:11<15:45, 11.97s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  39%|███▊      | 49/127 [09:23<15:35, 11.99s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  39%|███▉      | 50/127 [09:35<15:25, 12.01s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  40%|████      | 51/127 [09:47<15:12, 12.01s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  41%|████      | 52/127 [09:59<15:00, 12.01s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  42%|████▏     | 53/127 [10:11<14:47, 12.00s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  43%|████▎     | 54/127 [10:23<14:36, 12.01s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  43%|████▎     | 55/127 [10:35<14:24, 12.00s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  44%|████▍     | 56/127 [10:48<14:12, 12.01s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  45%|████▍     | 57/127 [11:00<14:01, 12.03s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  46%|████▌     | 58/127 [11:12<13:50, 12.04s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  46%|████▋     | 59/127 [11:24<13:43, 12.10s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  47%|████▋     | 60/127 [11:36<13:31, 12.11s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  48%|████▊     | 61/127 [11:48<13:17, 12.08s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  49%|████▉     | 62/127 [12:00<13:05, 12.08s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  50%|████▉     | 63/127 [12:12<12:52, 12.07s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  50%|█████     | 64/127 [12:24<12:41, 12.09s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  51%|█████     | 65/127 [12:36<12:28, 12.06s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  52%|█████▏    | 66/127 [12:48<12:16, 12.07s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  53%|█████▎    | 67/127 [13:01<12:05, 12.09s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  54%|█████▎    | 68/127 [13:13<11:57, 12.16s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  54%|█████▍    | 69/127 [13:25<11:42, 12.12s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  55%|█████▌    | 70/127 [13:37<11:36, 12.22s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  56%|█████▌    | 71/127 [13:50<11:28, 12.29s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  57%|█████▋    | 72/127 [14:02<11:20, 12.36s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  57%|█████▋    | 73/127 [14:15<11:07, 12.37s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  58%|█████▊    | 74/127 [14:27<10:54, 12.34s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  59%|█████▉    | 75/127 [14:41<11:09, 12.88s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  60%|█████▉    | 76/127 [14:54<10:49, 12.74s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  61%|██████    | 77/127 [15:06<10:35, 12.70s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  61%|██████▏   | 78/127 [15:19<10:22, 12.70s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  62%|██████▏   | 79/127 [15:32<10:10, 12.72s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  63%|██████▎   | 80/127 [15:44<09:53, 12.63s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  64%|██████▍   | 81/127 [15:57<09:43, 12.69s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  65%|██████▍   | 82/127 [16:09<09:30, 12.67s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  65%|██████▌   | 83/127 [16:22<09:15, 12.63s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  66%|██████▌   | 84/127 [16:35<09:02, 12.61s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  67%|██████▋   | 85/127 [16:47<08:51, 12.65s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  68%|██████▊   | 86/127 [17:00<08:40, 12.69s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  69%|██████▊   | 87/127 [17:13<08:27, 12.68s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  69%|██████▉   | 88/127 [17:25<08:09, 12.55s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  70%|███████   | 89/127 [17:38<07:57, 12.57s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  71%|███████   | 90/127 [17:50<07:41, 12.47s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  72%|███████▏  | 91/127 [18:02<07:25, 12.38s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  72%|███████▏  | 92/127 [18:15<07:15, 12.44s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  73%|███████▎  | 93/127 [18:27<07:01, 12.38s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  74%|███████▍  | 94/127 [18:39<06:46, 12.31s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  75%|███████▍  | 95/127 [18:51<06:33, 12.31s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  76%|███████▌  | 96/127 [19:04<06:23, 12.36s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  76%|███████▋  | 97/127 [19:16<06:09, 12.33s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  77%|███████▋  | 98/127 [19:28<05:55, 12.28s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  78%|███████▊  | 99/127 [19:40<05:41, 12.20s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  79%|███████▊  | 100/127 [19:53<05:31, 12.28s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  80%|███████▉  | 101/127 [20:05<05:21, 12.36s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  80%|████████  | 102/127 [20:17<05:07, 12.31s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  81%|████████  | 103/127 [20:30<05:00, 12.53s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  82%|████████▏ | 104/127 [20:43<04:45, 12.42s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  83%|████████▎ | 105/127 [20:55<04:31, 12.34s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  83%|████████▎ | 106/127 [21:07<04:18, 12.31s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  84%|████████▍ | 107/127 [21:19<04:05, 12.26s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  85%|████████▌ | 108/127 [21:31<03:52, 12.22s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  86%|████████▌ | 109/127 [21:44<03:40, 12.22s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  87%|████████▋ | 110/127 [21:56<03:27, 12.21s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  87%|████████▋ | 111/127 [22:08<03:15, 12.23s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  88%|████████▊ | 112/127 [22:20<03:02, 12.19s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  89%|████████▉ | 113/127 [22:32<02:50, 12.20s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  90%|████████▉ | 114/127 [22:44<02:38, 12.19s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  91%|█████████ | 115/127 [22:57<02:26, 12.18s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  91%|█████████▏| 116/127 [23:09<02:14, 12.22s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  92%|█████████▏| 117/127 [23:22<02:04, 12.42s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  93%|█████████▎| 118/127 [23:34<01:51, 12.41s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  94%|█████████▎| 119/127 [23:46<01:38, 12.37s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  94%|█████████▍| 120/127 [23:59<01:26, 12.33s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  95%|█████████▌| 121/127 [24:11<01:14, 12.39s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  96%|█████████▌| 122/127 [24:24<01:02, 12.47s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  97%|█████████▋| 123/127 [24:36<00:49, 12.47s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  98%|█████████▊| 124/127 [24:49<00:37, 12.41s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  98%|█████████▊| 125/127 [25:01<00:24, 12.40s/it]

torch.Size([32, 512]) torch.Size([32])


Training:  99%|█████████▉| 126/127 [25:14<00:12, 12.50s/it]

torch.Size([32, 512]) torch.Size([32])


Training: 100%|██████████| 127/127 [25:26<00:00, 12.02s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
Evaluating (val): 100%|██████████| 15/15 [02:10<00:00,  8.70s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 1 Results:

Training Metrics:
accuracy: 0.4924
precision_macro: 0.0154
precision_micro: 0.4924
recall_macro: 0.0266
recall_micro: 0.4924
f1_macro: 0.0185
f1_micro: 0.4924
loss: 2.5382

Validation Metrics:
accuracy: 0.4958
precision_macro: 0.0130
precision_micro: 0.4958
recall_macro: 0.0263
recall_micro: 0.4958
f1_macro: 0.0174
f1_micro: 0.4958
loss: 2.3221

Epoch 2/100 - train


Training:   0%|          | 0/127 [00:00<?, ?it/s]

torch.Size([32, 512]) torch.Size([32])


Training:   1%|          | 1/127 [00:12<25:55, 12.34s/it]

torch.Size([32, 512]) torch.Size([32])


Training:   2%|▏         | 2/127 [00:24<25:45, 12.36s/it]

torch.Size([32, 512]) torch.Size([32])


Training:   2%|▏         | 2/127 [00:28<30:06, 14.45s/it]


KeyboardInterrupt: 

In [None]:
%matplotlib inline

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(trainer.train_losses, label='Training Loss')
plt.plot(trainer.val_losses, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

# Print final test results
print("\nTest Results:")
test_metrics, _, _ = trainer.evaluate(test_loader, 'test')
for metric, value in test_metrics.items():
    print(f"{metric}: {value:.4f}")