1. Load and prepare dataset

In [1]:
# All imports
from modules.device import best_device
from datasets import load_dataset
from data.dataset import build_tokenizer
from modules.lstm import LSTMConfig, LSTMLanguageModel
import torch
from modules.transformer import TransformerConfig, TransformerLanguageModel
from modules.training import train_streamed_lm
from modules.inference import generate_text
from modules.benchmark import measure_throughput
from modules.eval import evaluate_on_hf_or_file

In [2]:
print("Available device:", best_device())

Available device: mps


In [3]:
# Load tokenizer first (you can choose a different HF tokenizer if desired)
tokenizer = build_tokenizer("gpt2")

# Load dataset (streaming to avoid full download)
ds_stream = load_dataset("mikex86/stackoverflow-posts", split="train", streaming=True)

Resolving data files:   0%|          | 0/59 [00:00<?, ?it/s]

2. Initialize LSTM

In [4]:
# Initialize LSTM LM using HF tokenizer vocab size
vocab_size = tokenizer.vocab_size

lstm_cfg = LSTMConfig(
    vocab_size=vocab_size,
    emb_dim=256,
    hidden_size=512,
    num_layers=2,
    dropout=0.1,
    pad_token_id=tokenizer.pad_token_id,
)

lstm = LSTMLanguageModel(lstm_cfg)

# quick smoke test with tokenizer
dummy_texts = [
    "Hello StackOverflow!",
    "How do I calculate someone's age based on a DateTime type birthday?",
]
enc = tokenizer(dummy_texts, return_tensors="pt", padding=True)
_dummy = enc["input_ids"]
logits, loss = lstm(_dummy, targets=_dummy)
print("LSTM logits:", logits.shape, "loss:", loss.item())

LSTM logits: torch.Size([2, 15, 50257]) loss: 10.825050354003906


3. Initialize Transformer

In [5]:
# Initialize Transformer LM using HF tokenizer vocab size
vocab_size = tokenizer.vocab_size

tr_cfg = TransformerConfig(
    vocab_size=vocab_size,
    emb_dim=256,
    n_heads=8,
    n_layers=4,
    ff_dim=1024,
    dropout=0.1,
    max_seq_len=2048,
    pad_token_id=tokenizer.pad_token_id,
)

trans = TransformerLanguageModel(tr_cfg)

# quick smoke test with tokenizer
dummy_texts = [
    "Given a DateTime representing a person's birthday, how do I calculate their age in years?",
    "Calculate relative time in C#",
]
enc = tokenizer(dummy_texts, return_tensors="pt", padding=True)
_dummy = enc["input_ids"]
logits, loss = trans(_dummy, targets=_dummy)
print("Transformer logits:", logits.shape, "loss:", loss.item())

Transformer logits: torch.Size([2, 19, 50257]) loss: 10.836922645568848


4. Training LSTM

In [6]:
# Full training (LSTM) using shared utilities
lstm_final = train_streamed_lm(
    model=lstm,
    tokenizer=tokenizer,
    config=lstm_cfg.__dict__,
    ckpt_dir="checkpoints/lstm",
    final_dir="outputs/lstm",
    batch_size=16,
    max_length=256,
    steps_per_epoch=10000,
    num_epochs=5,
    save_every=1,
    lr=3e-4,
)

Training on device: mps
Epoch 1/5


Resolving data files:   0%|          | 0/59 [00:00<?, ?it/s]

step 100 loss 9.7970
step 200 loss 7.5953
step 200 loss 7.5953
step 300 loss 6.8217
step 300 loss 6.8217
step 400 loss 7.1827
step 400 loss 7.1827
step 500 loss 7.0969
step 500 loss 7.0969
step 600 loss 7.0844
step 600 loss 7.0844
step 700 loss 7.1052
step 700 loss 7.1052
step 800 loss 6.7242
step 800 loss 6.7242
step 900 loss 6.7620
step 900 loss 6.7620
step 1000 loss 7.0288
step 1000 loss 7.0288
step 1100 loss 7.0364
step 1100 loss 7.0364
step 1200 loss 7.0843
step 1200 loss 7.0843
step 1300 loss 6.9420
step 1300 loss 6.9420
step 1400 loss 6.9233
step 1400 loss 6.9233
step 1500 loss 6.0773
step 1500 loss 6.0773
step 1600 loss 6.4482
step 1600 loss 6.4482
step 1700 loss 6.3611
step 1700 loss 6.3611
step 1800 loss 6.1164
step 1800 loss 6.1164
step 1900 loss 6.1431
step 1900 loss 6.1431
step 2000 loss 6.1067
step 2000 loss 6.1067
step 2100 loss 6.0104
step 2100 loss 6.0104
step 2200 loss 5.8380
step 2200 loss 5.8380
step 2300 loss 5.7838
step 2300 loss 5.7838
step 2400 loss 6.0017
step 

'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 544393ad-48a6-4d2a-b87e-cd2381776dcc)')' thrown while requesting HEAD https://huggingface.co/datasets/mikex86/stackoverflow-posts/resolve/main/README.md
Retrying in 1s [Retry 1/5].
Retrying in 1s [Retry 1/5].


Resolving data files:   0%|          | 0/59 [00:00<?, ?it/s]

'(ReadTimeoutError("HTTPSConnectionPool(host='cas-bridge.xethub.hf.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: e7dde5df-f2f2-4d24-8002-3e5f7f9b6100)')' thrown while requesting GET https://huggingface.co/datasets/mikex86/stackoverflow-posts/resolve/9e791fe8997879cf127e2a0b006bad3484bbda32/stackoverflow-posts-00000-of-00058.parquet
Retrying in 1s [Retry 1/5].
Retrying in 1s [Retry 1/5].
'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: d2a6ce5d-34d2-449d-ba50-2dc3ed52cc4d)')' thrown while requesting GET https://huggingface.co/datasets/mikex86/stackoverflow-posts/resolve/9e791fe8997879cf127e2a0b006bad3484bbda32/stackoverflow-posts-00000-of-00058.parquet
Retrying in 1s [Retry 1/5].
'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: d2a6ce5d-34d2-449d-ba50-2dc3ed52cc4d)')' thrown while requesting GET https://huggingface.c

step 10100 loss 4.3417
step 10200 loss 5.0223
step 10200 loss 5.0223
step 10300 loss 4.4131
step 10300 loss 4.4131
step 10400 loss 4.9579
step 10400 loss 4.9579
step 10500 loss 4.9171
step 10500 loss 4.9171
step 10600 loss 4.6367
step 10600 loss 4.6367
step 10700 loss 4.4873
step 10700 loss 4.4873
step 10800 loss 4.1787
step 10800 loss 4.1787
step 10900 loss 4.4118
step 10900 loss 4.4118
step 11000 loss 4.8136
step 11000 loss 4.8136
step 11100 loss 4.7749
step 11100 loss 4.7749
step 11200 loss 4.7621
step 11200 loss 4.7621
step 11300 loss 4.7211
step 11300 loss 4.7211
step 11400 loss 4.5745
step 11400 loss 4.5745
step 11500 loss 4.1086
step 11500 loss 4.1086
step 11600 loss 4.2862
step 11600 loss 4.2862
step 11700 loss 4.5677
step 11700 loss 4.5677
step 11800 loss 4.4618
step 11800 loss 4.4618
step 11900 loss 4.6170
step 11900 loss 4.6170
step 12000 loss 4.7224
step 12000 loss 4.7224
step 12100 loss 4.7073
step 12100 loss 4.7073
step 12200 loss 4.5915
step 12200 loss 4.5915
step 12300 

Resolving data files:   0%|          | 0/59 [00:00<?, ?it/s]

step 20100 loss 3.8640
step 20200 loss 4.5480
step 20200 loss 4.5480
step 20300 loss 4.1126
step 20300 loss 4.1126
step 20400 loss 4.6029
step 20400 loss 4.6029
step 20500 loss 4.5897
step 20500 loss 4.5897
step 20600 loss 4.2783
step 20600 loss 4.2783
step 20700 loss 4.0984
step 20700 loss 4.0984
step 20800 loss 3.8557
step 20800 loss 3.8557
step 20900 loss 4.1106
step 20900 loss 4.1106
step 21000 loss 4.4771
step 21000 loss 4.4771
step 21100 loss 4.4711
step 21100 loss 4.4711
step 21200 loss 4.3898
step 21200 loss 4.3898
step 21300 loss 4.3745
step 21300 loss 4.3745
step 21400 loss 4.2211
step 21400 loss 4.2211
step 21500 loss 3.7827
step 21500 loss 3.7827
step 21600 loss 3.9383
step 21600 loss 3.9383
step 21700 loss 4.2640
step 21700 loss 4.2640
step 21800 loss 4.1424
step 21800 loss 4.1424
step 21900 loss 4.3557
step 21900 loss 4.3557
step 22000 loss 4.4400
step 22000 loss 4.4400
step 22100 loss 4.4084
step 22100 loss 4.4084
step 22200 loss 4.2986
step 22200 loss 4.2986
step 22300 

Resolving data files:   0%|          | 0/59 [00:00<?, ?it/s]

step 30100 loss 3.6207
step 30200 loss 4.3141
step 30200 loss 4.3141
step 30300 loss 3.9717
step 30300 loss 3.9717
step 30400 loss 4.4082
step 30400 loss 4.4082
step 30500 loss 4.3588
step 30500 loss 4.3588
step 30600 loss 4.0782
step 30600 loss 4.0782
step 30700 loss 3.8787
step 30700 loss 3.8787
step 30800 loss 3.6992
step 30800 loss 3.6992
step 30900 loss 3.9431
step 30900 loss 3.9431
step 31000 loss 4.2925
step 31000 loss 4.2925
step 31100 loss 4.2876
step 31100 loss 4.2876
step 31200 loss 4.2356
step 31200 loss 4.2356
step 31300 loss 4.1796
step 31300 loss 4.1796
step 31400 loss 4.0322
step 31400 loss 4.0322
step 31500 loss 3.5973
step 31500 loss 3.5973
step 31600 loss 3.7376
step 31600 loss 3.7376
step 31700 loss 4.0778
step 31700 loss 4.0778
step 31800 loss 3.9697
step 31800 loss 3.9697
step 31900 loss 4.1836
step 31900 loss 4.1836
step 32000 loss 4.2872
step 32000 loss 4.2872
step 32100 loss 4.2368
step 32100 loss 4.2368
step 32200 loss 4.1137
step 32200 loss 4.1137
step 32300 

Resolving data files:   0%|          | 0/59 [00:00<?, ?it/s]

step 40100 loss 3.4780
step 40200 loss 4.1811
step 40200 loss 4.1811
step 40300 loss 3.8793
step 40300 loss 3.8793
step 40400 loss 4.2770
step 40400 loss 4.2770
step 40500 loss 4.2328
step 40500 loss 4.2328
step 40600 loss 3.9499
step 40600 loss 3.9499
step 40700 loss 3.7529
step 40700 loss 3.7529
step 40800 loss 3.5976
step 40800 loss 3.5976
step 40900 loss 3.8396
step 40900 loss 3.8396
step 41000 loss 4.1874
step 41000 loss 4.1874
step 41100 loss 4.1573
step 41100 loss 4.1573
step 41200 loss 4.1213
step 41200 loss 4.1213
step 41300 loss 4.0728
step 41300 loss 4.0728
step 41400 loss 3.9332
step 41400 loss 3.9332
step 41500 loss 3.4960
step 41500 loss 3.4960
step 41600 loss 3.6121
step 41600 loss 3.6121
step 41700 loss 3.9728
step 41700 loss 3.9728
step 41800 loss 3.8773
step 41800 loss 3.8773
step 41900 loss 4.0792
step 41900 loss 4.0792
step 42000 loss 4.1740
step 42000 loss 4.1740
step 42100 loss 4.1563
step 42100 loss 4.1563
step 42200 loss 4.0042
step 42200 loss 4.0042
step 42300 

5. Training Transformer

In [7]:
 # Full training (Transformer) using shared utilities
trans_final = train_streamed_lm(
    model=trans,
    tokenizer=tokenizer,
    config=tr_cfg.__dict__,
    ckpt_dir="checkpoints/transformer",
    final_dir="outputs/transformer",
    batch_size=16,
    max_length=256,
    steps_per_epoch=10000,
    num_epochs=5,
    save_every=1,
    lr=3e-4,
)

Training on device: mps
Epoch 1/5
Epoch 1/5


Resolving data files:   0%|          | 0/59 [00:00<?, ?it/s]

step 100 loss 10.2686
step 200 loss 9.0030
step 200 loss 9.0030
step 300 loss 7.3714
step 300 loss 7.3714
step 400 loss 7.1246
step 400 loss 7.1246
step 500 loss 7.0368
step 500 loss 7.0368
step 600 loss 7.0333
step 600 loss 7.0333
step 700 loss 7.0521
step 700 loss 7.0521
step 800 loss 6.6674
step 800 loss 6.6674
step 900 loss 6.7038
step 900 loss 6.7038
step 1000 loss 6.9740
step 1000 loss 6.9740
step 1100 loss 6.9789
step 1100 loss 6.9789
step 1200 loss 7.0886
step 1200 loss 7.0886
step 1300 loss 7.0563
step 1300 loss 7.0563
step 1400 loss 7.1430
step 1400 loss 7.1430
step 1500 loss 6.4978
step 1500 loss 6.4978
step 1600 loss 6.8266
step 1600 loss 6.8266
step 1700 loss 6.8633
step 1700 loss 6.8633
step 1800 loss 6.5565
step 1800 loss 6.5565
step 1900 loss 6.4279
step 1900 loss 6.4279
step 2000 loss 6.3318
step 2000 loss 6.3318
step 2100 loss 6.2026
step 2100 loss 6.2026
step 2200 loss 6.0747
step 2200 loss 6.0747
step 2300 loss 5.9528
step 2300 loss 5.9528
step 2400 loss 6.2040
step

Resolving data files:   0%|          | 0/59 [00:00<?, ?it/s]

step 10100 loss 4.5663
step 10200 loss 5.2073
step 10200 loss 5.2073
step 10300 loss 4.6707
step 10300 loss 4.6707
step 10400 loss 5.2303
step 10400 loss 5.2303
step 10500 loss 5.1305
step 10500 loss 5.1305
step 10600 loss 4.9499
step 10600 loss 4.9499
step 10700 loss 4.6996
step 10700 loss 4.6996
step 10800 loss 4.4218
step 10800 loss 4.4218
step 10900 loss 4.6387
step 10900 loss 4.6387
step 11000 loss 5.0275
step 11000 loss 5.0275
step 11100 loss 5.0237
step 11100 loss 5.0237
step 11200 loss 4.9280
step 11200 loss 4.9280
step 11300 loss 4.9408
step 11300 loss 4.9408
step 11400 loss 4.7442
step 11400 loss 4.7442
step 11500 loss 4.2394
step 11500 loss 4.2394
step 11600 loss 4.5314
step 11600 loss 4.5314
step 11700 loss 4.7785
step 11700 loss 4.7785
step 11800 loss 4.6663
step 11800 loss 4.6663
step 11900 loss 4.8101
step 11900 loss 4.8101
step 12000 loss 4.9758
step 12000 loss 4.9758
step 12100 loss 4.8926
step 12100 loss 4.8926
step 12200 loss 4.8437
step 12200 loss 4.8437
step 12300 

Resolving data files:   0%|          | 0/59 [00:00<?, ?it/s]

'HTTPSConnectionPool(host='cas-bridge.xethub.hf.co', port=443): Read timed out.' thrown while requesting GET https://huggingface.co/datasets/mikex86/stackoverflow-posts/resolve/9e791fe8997879cf127e2a0b006bad3484bbda32/stackoverflow-posts-00000-of-00058.parquet
Retrying in 1s [Retry 1/5].
Retrying in 1s [Retry 1/5].
'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: c9a6aa9e-67f9-4f5b-af2f-077893df2b9a)')' thrown while requesting GET https://huggingface.co/datasets/mikex86/stackoverflow-posts/resolve/9e791fe8997879cf127e2a0b006bad3484bbda32/stackoverflow-posts-00000-of-00058.parquet
Retrying in 2s [Retry 2/5].
'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: c9a6aa9e-67f9-4f5b-af2f-077893df2b9a)')' thrown while requesting GET https://huggingface.co/datasets/mikex86/stackoverflow-posts/resolve/9e791fe8997879cf127e2a0b006bad3484bbda32/stacko

step 20100 loss 4.0195
step 20200 loss 4.6922
step 20200 loss 4.6922
step 20300 loss 4.2175
step 20300 loss 4.2175
step 20400 loss 4.7771
step 20400 loss 4.7771
step 20500 loss 4.6266
step 20500 loss 4.6266
step 20600 loss 4.4314
step 20600 loss 4.4314
step 20700 loss 4.2002
step 20700 loss 4.2002
step 20800 loss 3.9589
step 20800 loss 3.9589
step 20900 loss 4.1798
step 20900 loss 4.1798
step 21000 loss 4.5862
step 21000 loss 4.5862
step 21100 loss 4.5589
step 21100 loss 4.5589
step 21200 loss 4.4931
step 21200 loss 4.4931
step 21300 loss 4.4629
step 21300 loss 4.4629
step 21400 loss 4.3088
step 21400 loss 4.3088
step 21500 loss 3.8033
step 21500 loss 3.8033
step 21600 loss 4.0765
step 21600 loss 4.0765
step 21700 loss 4.3521
step 21700 loss 4.3521
step 21800 loss 4.2316
step 21800 loss 4.2316
step 21900 loss 4.3836
step 21900 loss 4.3836
step 22000 loss 4.5509
step 22000 loss 4.5509
step 22100 loss 4.4803
step 22100 loss 4.4803
step 22200 loss 4.4487
step 22200 loss 4.4487
step 22300 

Resolving data files:   0%|          | 0/59 [00:00<?, ?it/s]

'HTTPSConnectionPool(host='cas-bridge.xethub.hf.co', port=443): Read timed out.' thrown while requesting GET https://huggingface.co/datasets/mikex86/stackoverflow-posts/resolve/9e791fe8997879cf127e2a0b006bad3484bbda32/stackoverflow-posts-00000-of-00058.parquet
Retrying in 1s [Retry 1/5].
Retrying in 1s [Retry 1/5].
'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 2d781aea-3d25-4cc9-9abd-9adec03051e4)')' thrown while requesting GET https://huggingface.co/datasets/mikex86/stackoverflow-posts/resolve/9e791fe8997879cf127e2a0b006bad3484bbda32/stackoverflow-posts-00000-of-00058.parquet
Retrying in 2s [Retry 2/5].
'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: 2d781aea-3d25-4cc9-9abd-9adec03051e4)')' thrown while requesting GET https://huggingface.co/datasets/mikex86/stackoverflow-posts/resolve/9e791fe8997879cf127e2a0b006bad3484bbda32/stacko

step 30100 loss 3.6896
step 30200 loss 4.3520
step 30200 loss 4.3520
step 30300 loss 3.9746
step 30300 loss 3.9746
step 30400 loss 4.4095
step 30400 loss 4.4095
step 30500 loss 4.2988
step 30500 loss 4.2988
step 30600 loss 4.0426
step 30600 loss 4.0426
step 30700 loss 3.8768
step 30700 loss 3.8768
step 30800 loss 3.6687
step 30800 loss 3.6687
step 30900 loss 3.9113
step 30900 loss 3.9113
step 31000 loss 4.3490
step 31000 loss 4.3490
step 31100 loss 4.3274
step 31100 loss 4.3274
step 31200 loss 4.2079
step 31200 loss 4.2079
step 31300 loss 4.1481
step 31300 loss 4.1481
step 31400 loss 4.0659
step 31400 loss 4.0659
step 31500 loss 3.5830
step 31500 loss 3.5830
step 31600 loss 3.7826
step 31600 loss 3.7826
step 31700 loss 4.0555
step 31700 loss 4.0555
step 31800 loss 3.9577
step 31800 loss 3.9577
step 31900 loss 4.1347
step 31900 loss 4.1347
step 32000 loss 4.3105
step 32000 loss 4.3105
step 32100 loss 4.1993
step 32100 loss 4.1993
step 32200 loss 4.1876
step 32200 loss 4.1876
step 32300 

Resolving data files:   0%|          | 0/59 [00:00<?, ?it/s]

step 40100 loss 3.5020
step 40200 loss 4.1087
step 40200 loss 4.1087
step 40300 loss 3.8078
step 40300 loss 3.8078
step 40400 loss 4.1644
step 40400 loss 4.1644
step 40500 loss 3.9975
step 40500 loss 3.9975
step 40600 loss 3.8347
step 40600 loss 3.8347
step 40700 loss 3.6256
step 40700 loss 3.6256
step 40800 loss 3.5484
step 40800 loss 3.5484
step 40900 loss 3.7593
step 40900 loss 3.7593
step 41000 loss 4.1752
step 41000 loss 4.1752
step 41100 loss 4.1405
step 41100 loss 4.1405
step 41200 loss 4.0305
step 41200 loss 4.0305
step 41300 loss 3.9065
step 41300 loss 3.9065
step 41400 loss 3.8749
step 41400 loss 3.8749
step 41500 loss 3.3885
step 41500 loss 3.3885
step 41600 loss 3.5791
step 41600 loss 3.5791
step 41700 loss 3.8986
step 41700 loss 3.8986
step 41800 loss 3.7314
step 41800 loss 3.7314
step 41900 loss 3.9423
step 41900 loss 3.9423
step 42000 loss 4.1004
step 42000 loss 4.1004
step 42100 loss 4.0363
step 42100 loss 4.0363
step 42200 loss 4.0230
step 42200 loss 4.0230
step 42300 

5. Inference and Benchmarking

In [8]:
# Inference: quick generation using shared utility
prompt = "Explain what a Python generator is and provide a short example."
print("LSTM ->\n", generate_text(lstm, tokenizer, prompt, max_new_tokens=80))
print("\nTransformer ->\n", generate_text(trans, tokenizer, prompt, max_new_tokens=80))

LSTM ->
 code from using a time, you just to this can a line by 3, you [, I wrote a web. I have a previous page, an by text is by: is the time in a 2 line and have a one method (2, a.c.0.0, 2, for...,, on, -, and, it, a text is not to the old

Transformer ->
   It is very powerful for this purpose.  The reason that I'm aware of are in the past of what you're going to be able to do. 

I have seen a tool that is the best way for creating some Python functions for writing a library to emulate that.  Even if it's a lot of C and C++ code then I can understand it, because of the

Transformer ->
   It is very powerful for this purpose.  The reason that I'm aware of are in the past of what you're going to be able to do. 

I have seen a tool that is the best way for creating some Python functions for writing a library to emulate that.  Even if it's a lot of C and C++ code then I can understand it, because of the


6. Comparison

In [9]:
# Performance comparison using shared utility
lstm_tok_s = measure_throughput(lstm, tokenizer.vocab_size)
trans_tok_s = measure_throughput(trans, tokenizer.vocab_size)
print({"lstm_tok_per_s": lstm_tok_s, "transformer_tok_per_s": trans_tok_s})

{'lstm_tok_per_s': 42405.42756745674, 'transformer_tok_per_s': 69119.05862183005}


7. Test-set evaluation (perplexity)

In [11]:
# Evaluate both final models on the dataset test split (fallback to test.txt)
# Load the best available weights if final paths exist, else use in-memory models
try:
    lstm.load_state_dict(torch.load("outputs/lstm/final.pt", map_location="cpu"))
except Exception:
    pass
try:
    trans.load_state_dict(torch.load("outputs/transformer/final.pt", map_location="cpu"))
except Exception:
    pass

# Select best available device (CUDA > MPS > CPU)
device = best_device()
print("Using device:", device)
lstm.to(device).eval()
trans.to(device).eval()

lstm_ppl, lstm_loss, lstm_tokens, src1 = evaluate_on_hf_or_file(
    lstm, tokenizer, hf_split="test", fallback_path="test.txt", max_examples=5000, batch_size=16, max_length=256
)
trans_ppl, trans_loss, trans_tokens, src2 = evaluate_on_hf_or_file(
    trans, tokenizer, hf_split="test", fallback_path="test.txt", max_examples=5000, batch_size=16, max_length=256
)

print({
    "device": str(device),
    "source_lstm": src1,
    "lstm_ppl": lstm_ppl,
    "lstm_avg_loss": lstm_loss,
    "lstm_tokens": lstm_tokens,
    "source_trans": src2,
    "trans_ppl": trans_ppl,
    "trans_avg_loss": trans_loss,
    "trans_tokens": trans_tokens,
})

Using device: mps


Resolving data files:   0%|          | 0/59 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/59 [00:00<?, ?it/s]

{'device': 'mps', 'source_lstm': 'file:test.txt', 'lstm_ppl': 627.7102574173653, 'lstm_avg_loss': 6.442078686463452, 'lstm_tokens': 407, 'source_trans': 'file:test.txt', 'trans_ppl': 149.8652125502354, 'trans_avg_loss': 5.0097363071301055, 'trans_tokens': 407}
