In [None]:
%load_ext autoreload
%autoreload 2

import os

# Set project root
repo_name = "\\thesis"
os.chdir(os.getcwd().split(repo_name)[0] + repo_name)
print(f'Changed working directory to: {os.getcwd()}')

#### 1. Corpus Example

In [None]:
import pickle

# Load data
examples = pickle.load(open("data/examples.pkl", "rb"))
vocab = pickle.load(open("data/vocab.pkl", "rb"))

len(examples)

#### 2. Run Experiment

In [None]:
from src.model_training import train_one_bucket, TrainConfig
from src.dataloaders import bucket_examples_by_distance

# Define distance buckets (inclusive ranges)
distance_buckets = [
    (32, 127),
    (128, 255),
    (256, 1023)
]

# Bucket all examples
bucketed = bucket_examples_by_distance(examples, distance_buckets)

# Print size by bucket dict keys
for bucket in distance_buckets:
    print(f"Bucket {bucket}: {len(bucketed[bucket])} examples")


In [None]:
import random
import time
from src.model_training import set_seed
from src.model_training import train_one_bucket, train_one_bucket_lstm

print("Starting training on distance buckets...\n")

vals_acc = {}

for i, bucket in enumerate(distance_buckets):

    print(f"Training on bucket {i}/{len(distance_buckets)}")
    # Get first bucket
    bucket = bucketed[distance_buckets[i]]
    examples = len(bucket)

    max_len = distance_buckets[i][1]

    cfg = TrainConfig(
        emb_dim=128,
        hidden_dim=256,
        bidirectional=True,
        max_len=max_len,
        epochs=30,
        alpha=0.97,
        batch_size=512,
        lr=0.0007,
        seed=42,
    )
    set_seed(cfg.seed)
    random.shuffle(bucket)
    print("="*100)
    print(f"Training on bucket {i} FOFENet - {max_len} tokens")
    print("="*100)
    start_time = time.time()
    val_acc_fofe = train_one_bucket(bucket, max_len, vocab, cfg)
    end_time = time.time()
    print(f"Training time: {end_time - start_time:.2f} seconds")
    # Save accuracy
    vals_acc[("fofe", max_len)] = val_acc_fofe

    print("="*100)
    print(f"Training on bucket {i} BiLSTM - {max_len} tokens")
    print("="*100)
    start_time = time.time()
    val_acc_lstm = train_one_bucket_lstm(bucket, max_len, vocab, cfg)
    end_time = time.time()
    print(f"Training time: {end_time - start_time:.2f} seconds")
    # Save accuracy
    vals_acc[("lstm", max_len)] = val_acc_lstm

    print("\n"*2)