# EN-AR Model Training

## Objective
- Train a bidirectional EN <-> AR encoder-decoder model from random initialization.
- Use the cleaned combined dataset exported by Notebook 01.

## Scope
- Notebook-first, micro-step implementation (1-2 short cells per step).


In [1]:
# Setup: imports and training constants
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path

import numpy as np
import pandas as pd
import torch

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

candidate_roots = [Path.cwd(), Path.cwd().parent]
PROJECT_ROOT = next((r for r in candidate_roots if (r / "artifacts").exists()), Path.cwd())
DATA_PATH = PROJECT_ROOT / "artifacts" / "eda" / "final_cleaned_combined_dataset.parquet"

@dataclass
class TrainConfig:
    max_seq_len: int = 128
    vocab_size: int = 32_000
    train_ratio: float = 0.90
    val_ratio: float = 0.05
    test_ratio: float = 0.05

config = TrainConfig()
print(f"Project root: {PROJECT_ROOT}")
print(f"Dataset path: {DATA_PATH}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(config)


Project root: c:\My Projects\en-ar-translation
Dataset path: c:\My Projects\en-ar-translation\artifacts\eda\final_cleaned_combined_dataset.parquet
CUDA available: True
TrainConfig(max_seq_len=128, vocab_size=32000, train_ratio=0.9, val_ratio=0.05, test_ratio=0.05)


In [2]:
# Micro-step 2: load cleaned dataset and validate schema
assert DATA_PATH.exists(), f"Cleaned dataset not found: {DATA_PATH}"
df = pd.read_parquet(DATA_PATH)

required_columns = ["en", "ar"]
missing = [c for c in required_columns if c not in df.columns]
assert not missing, f"Missing required columns: {missing}"

print(f"Loaded dataset: {DATA_PATH}")
print(f"Shape: {df.shape}")
print(f"Columns: {list(df.columns)}")


Loaded dataset: c:\My Projects\en-ar-translation\artifacts\eda\final_cleaned_combined_dataset.parquet
Shape: (827576, 2)
Columns: ['en', 'ar']


In [3]:
# Micro-step 3: remove invalid/empty rows and report before/after
rows_before = len(df)

df = df.dropna(subset=["en", "ar"]).copy()
df["en"] = df["en"].astype(str).str.strip()
df["ar"] = df["ar"].astype(str).str.strip()
df = df[(df["en"] != "") & (df["ar"] != "")].reset_index(drop=True)

rows_after = len(df)
rows_removed = rows_before - rows_after
print(f"Rows before cleaning: {rows_before:,}")
print(f"Rows after cleaning: {rows_after:,}")
print(f"Rows removed: {rows_removed:,}")


Rows before cleaning: 827,576
Rows after cleaning: 827,546
Rows removed: 30


In [13]:
# Quick check: 10 random rows from current dataset
df.sample(n=10)[["en", "ar"]].reset_index(drop=True)

Unnamed: 0,en,ar
0,he was killed in the plane crash,.. مات في حادثة تحطم الطائرة
1,where did you get it,من أين حصلتم عليه ؟
2,an illustration of a gorilla flying through th...,« مثل غوريلا تطير في الهواء وخلفه سماء المدينة ».
3,He narrowly lost the final to Dieter Baumann o...,خسر المباراة بفارق ضئيل أمام ديتر بومان من ألم...
4,oh i see what you mean sir,فهمت ما تقصد
5,It was the last album recorded by the group be...,كان هذا آخر ألبوم قامت المجموعة بتسجيله قبل تف...
6,Now you can not worry that the Christmas tree ...,الآن لا يمكنك أن تقلق بشأن شجرة عيد الميلاد لت...
7,former world champions,بطولات وطنية سابقا.
8,Third-party tools enable one to build a variet...,أدوات الطرف الثالث تمكن المرء من بناء مجموعة م...
9,not a one,ولا واحد


In [14]:
# Micro-step 4: deterministic hash split (90/5/5) with leakage guard
assert abs((config.train_ratio + config.val_ratio + config.test_ratio) - 1.0) < 1e-9, "Split ratios must sum to 1.0"

pair_hash = pd.util.hash_pandas_object(df[["en", "ar"]], index=False).astype("uint64")
u = pair_hash / np.float64(2**64)

train_cut = config.train_ratio
val_cut = config.train_ratio + config.val_ratio
df["split"] = np.where(u < train_cut, "train", np.where(u < val_cut, "val", "test"))

leak_count = int((df.groupby(["en", "ar"])["split"].nunique() > 1).sum())
assert leak_count == 0, f"Leakage detected across splits for {leak_count} pairs"

split_counts = df["split"].value_counts().rename_axis("split").reset_index(name="rows")
split_counts["ratio"] = (split_counts["rows"] / len(df)).round(4)
split_counts


Unnamed: 0,split,rows,ratio
0,train,744926,0.9002
1,val,41445,0.0501
2,test,41175,0.0498


In [15]:
# Micro-step 5: build bidirectional rows with direction tokens
required_split_cols = ["en", "ar", "split"]
missing_split_cols = [c for c in required_split_cols if c not in df.columns]
assert not missing_split_cols, f"Missing columns before bidirectional build: {missing_split_cols}"

df_en_to_ar = pd.DataFrame({
    "source_text": "<2ar> " + df["en"],
    "target_text": df["ar"],
    "direction": "en_to_ar",
    "split": df["split"],
})

df_ar_to_en = pd.DataFrame({
    "source_text": "<2en> " + df["ar"],
    "target_text": df["en"],
    "direction": "ar_to_en",
    "split": df["split"],
})

df_bi = pd.concat([df_en_to_ar, df_ar_to_en], ignore_index=True)

print(f"Base rows: {len(df):,}")
print(f"Bidirectional rows: {len(df_bi):,}")

direction_split_counts = (
    df_bi.groupby(["split", "direction"]).size().reset_index(name="rows")
)
direction_split_counts


Base rows: 827,546
Bidirectional rows: 1,655,092


Unnamed: 0,split,direction,rows
0,test,ar_to_en,41175
1,test,en_to_ar,41175
2,train,ar_to_en,744926
3,train,en_to_ar,744926
4,val,ar_to_en,41445
5,val,en_to_ar,41445


In [21]:
# Micro-step 6: create train/val/test views from bidirectional dataset
required_bi_cols = ["source_text", "target_text", "direction", "split"]
missing_bi_cols = [c for c in required_bi_cols if c not in df_bi.columns]
assert not missing_bi_cols, f"Missing columns in df_bi: {missing_bi_cols}"

train_df = df_bi[df_bi["split"] == "train"].reset_index(drop=True)
val_df = df_bi[df_bi["split"] == "val"].reset_index(drop=True)
test_df = df_bi[df_bi["split"] == "test"].reset_index(drop=True)

assert len(train_df) + len(val_df) + len(test_df) == len(df_bi), "Split size mismatch"
assert len(train_df) > 0 and len(val_df) > 0 and len(test_df) > 0, "One split is empty"

print(f"train rows: {len(train_df):,}")
print(f"val rows: {len(val_df):,}")
print(f"test rows: {len(test_df):,}")


train rows: 1,489,852
val rows: 82,890
test rows: 82,350


In [23]:
# Micro-step 7: tokenizer setup (train split only)
try:
    from tokenizers import Tokenizer, decoders, models, pre_tokenizers, trainers
except ImportError as e:
    raise ImportError("`tokenizers` is required. Install with: pip install tokenizers") from e

TOKENIZER_DIR = PROJECT_ROOT / "artifacts" / "tokenizer"
TOKENIZER_DIR.mkdir(parents=True, exist_ok=True)
TOKENIZER_PATH = TOKENIZER_DIR / "en_ar_bpe_tokenizer.json"

SPECIAL_TOKENS = ["<pad>", "<s>", "</s>", "<unk>", "<2ar>", "<2en>"]

def train_corpus_iterator(df_in):
    for row in df_in.itertuples(index=False):
        yield row.source_text
        yield row.target_text

print(f"Tokenizer output path: {TOKENIZER_PATH}")
print(f"Special tokens: {SPECIAL_TOKENS}")
print("Tokenizer mode: ByteLevel BPE")
print(f"Train rows for tokenizer: {len(train_df):,}")
print(f"Approx lines seen by tokenizer iterator: {len(train_df) * 2:,}")


Tokenizer output path: c:\My Projects\en-ar-translation\artifacts\tokenizer\en_ar_bpe_tokenizer.json
Special tokens: ['<pad>', '<s>', '</s>', '<unk>', '<2ar>', '<2en>']
Tokenizer mode: ByteLevel BPE
Train rows for tokenizer: 1,489,852
Approx lines seen by tokenizer iterator: 2,979,704


In [24]:
# Micro-step 8: train and save shared ByteLevel BPE tokenizer
tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
tokenizer.decoder = decoders.ByteLevel()

trainer = trainers.BpeTrainer(
    vocab_size=config.vocab_size,
    special_tokens=SPECIAL_TOKENS,
    min_frequency=2,
    initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
)

tokenizer.train_from_iterator(train_corpus_iterator(train_df), trainer=trainer)
tokenizer.save(str(TOKENIZER_PATH))

vocab_size_trained = tokenizer.get_vocab_size()
print(f"Trained tokenizer vocab size: {vocab_size_trained:,}")
print(f"Saved tokenizer to: {TOKENIZER_PATH}")


Trained tokenizer vocab size: 32,000
Saved tokenizer to: c:\My Projects\en-ar-translation\artifacts\tokenizer\en_ar_bpe_tokenizer.json


In [26]:
# Micro-step 9: tokenizer audit (critical checks)
from collections import Counter

# Ensure tokenizer object is available (reload from disk if needed)
if "tokenizer" not in globals():
    tokenizer = Tokenizer.from_file(str(TOKENIZER_PATH))

# 1) Special-token integrity
special_token_ids = {tok: tokenizer.token_to_id(tok) for tok in SPECIAL_TOKENS}
missing_special = [tok for tok, tid in special_token_ids.items() if tid is None]
assert not missing_special, f"Missing special tokens in tokenizer vocab: {missing_special}"

id_2ar = special_token_ids["<2ar>"]
id_2en = special_token_ids["<2en>"]
probe_2ar = tokenizer.encode("<2ar> this is a test")
probe_2en = tokenizer.encode("<2en> english direction probe")
assert len(probe_2ar.ids) > 0 and probe_2ar.ids[0] == id_2ar, "<2ar> is not preserved as first token"
assert len(probe_2en.ids) > 0 and probe_2en.ids[0] == id_2en, "<2en> is not preserved as first token"

# 2) Sample-based token-length and truncation audit (real tokenizer lengths)
audit_sample_size = min(20000, len(df_bi))
audit_df = df_bi.sample(n=audit_sample_size, random_state=RANDOM_SEED) if len(df_bi) > audit_sample_size else df_bi

src_lengths = []
tgt_lengths = []
unk_counter = Counter()
unk_id = special_token_ids["<unk>"]

for row in audit_df.itertuples(index=False):
    src_ids = tokenizer.encode(row.source_text).ids
    tgt_ids = tokenizer.encode(row.target_text).ids
    src_lengths.append(len(src_ids))
    tgt_lengths.append(len(tgt_ids))
    unk_counter["src_unk_tokens"] += sum(1 for i in src_ids if i == unk_id)
    unk_counter["tgt_unk_tokens"] += sum(1 for i in tgt_ids if i == unk_id)

src_lengths = np.array(src_lengths, dtype=np.int32)
tgt_lengths = np.array(tgt_lengths, dtype=np.int32)

tokenizer_audit = pd.DataFrame({
    "metric": [
        "sample_rows",
        "trained_vocab_size",
        "src_len_p50", "src_len_p90", "src_len_p95", "src_len_p99",
        "tgt_len_p50", "tgt_len_p90", "tgt_len_p95", "tgt_len_p99",
        "src_trunc_ratio_gt_max_len",
        "tgt_trunc_ratio_gt_max_len",
        "either_trunc_ratio_gt_max_len",
        "src_unk_tokens",
        "tgt_unk_tokens",
    ],
    "value": [
        int(len(audit_df)),
        int(tokenizer.get_vocab_size()),
        int(np.percentile(src_lengths, 50)), int(np.percentile(src_lengths, 90)), int(np.percentile(src_lengths, 95)), int(np.percentile(src_lengths, 99)),
        int(np.percentile(tgt_lengths, 50)), int(np.percentile(tgt_lengths, 90)), int(np.percentile(tgt_lengths, 95)), int(np.percentile(tgt_lengths, 99)),
        float((src_lengths > config.max_seq_len).mean()),
        float((tgt_lengths > config.max_seq_len).mean()),
        float(((src_lengths > config.max_seq_len) | (tgt_lengths > config.max_seq_len)).mean()),
        int(unk_counter["src_unk_tokens"]),
        int(unk_counter["tgt_unk_tokens"]),
    ]
})

print("Special token IDs:", special_token_ids)
print("Direction token probes OK (<2ar>/<2en> preserved).")
tokenizer_audit


Special token IDs: {'<pad>': 0, '<s>': 1, '</s>': 2, '<unk>': 3, '<2ar>': 4, '<2en>': 5}
Direction token probes OK (<2ar>/<2en> preserved).


Unnamed: 0,metric,value
0,sample_rows,20000.0
1,trained_vocab_size,32000.0
2,src_len_p50,9.0
3,src_len_p90,18.0
4,src_len_p95,21.0
5,src_len_p99,28.0
6,tgt_len_p50,8.0
7,tgt_len_p90,17.0
8,tgt_len_p95,20.0
9,tgt_len_p99,28.0


In [None]:
# Micro-step 10: build and save Hugging Face compatible fast tokenizer
from transformers import PreTrainedTokenizerFast

HF_TOKENIZER_DIR = TOKENIZER_DIR / "hf_tokenizer"
HF_TOKENIZER_DIR.mkdir(parents=True, exist_ok=True)

hf_tokenizer = PreTrainedTokenizerFast(
    tokenizer_file=str(TOKENIZER_PATH),
    bos_token="<s>",
    eos_token="</s>",
    unk_token="<unk>",
    pad_token="<pad>",
    additional_special_tokens=["<2ar>", "<2en>"],
)

# Verify essential token ids exist
required_tokens = ["<pad>", "<s>", "</s>", "<unk>", "<2ar>", "<2en>"]
token_id_map = {tok: hf_tokenizer.convert_tokens_to_ids(tok) for tok in required_tokens}
missing_ids = [tok for tok, tid in token_id_map.items() if tid is None or tid < 0]
assert not missing_ids, f"Missing token ids in hf_tokenizer: {missing_ids}"

hf_tokenizer.save_pretrained(str(HF_TOKENIZER_DIR))

print(f"Saved HF tokenizer to: {HF_TOKENIZER_DIR}")
print("Token IDs:", token_id_map)
print(f"HF vocab size: {hf_tokenizer.vocab_size:,}")


Saved HF tokenizer to: c:\My Projects\en-ar-translation\artifacts\tokenizer\hf_tokenizer
Token IDs: {'<pad>': 0, '<s>': 1, '</s>': 2, '<unk>': 3, '<2ar>': 4, '<2en>': 5}
HF vocab size: 32,000


In [28]:
# Micro-step 11: tokenize train/val/test (truncation only, no static padding)
from datasets import Dataset

def tokenize_batch(batch):
    src = hf_tokenizer(
        batch["source_text"],
        truncation=True,
        max_length=config.max_seq_len,
        padding=False,
    )
    tgt = hf_tokenizer(
        batch["target_text"],
        truncation=True,
        max_length=config.max_seq_len,
        padding=False,
    )
    src["labels"] = tgt["input_ids"]
    return src

train_ds = Dataset.from_pandas(train_df[["source_text", "target_text", "direction", "split"]], preserve_index=False)
val_ds = Dataset.from_pandas(val_df[["source_text", "target_text", "direction", "split"]], preserve_index=False)
test_ds = Dataset.from_pandas(test_df[["source_text", "target_text", "direction", "split"]], preserve_index=False)

train_tok = train_ds.map(tokenize_batch, batched=True, desc="Tokenizing train")
val_tok = val_ds.map(tokenize_batch, batched=True, desc="Tokenizing val")
test_tok = test_ds.map(tokenize_batch, batched=True, desc="Tokenizing test")

print(f"Tokenized train rows: {len(train_tok):,}")
print(f"Tokenized val rows: {len(val_tok):,}")
print(f"Tokenized test rows: {len(test_tok):,}")
print(train_tok[0].keys())


Tokenizing train:   0%|          | 0/1489852 [00:00<?, ? examples/s]

Tokenizing val:   0%|          | 0/82890 [00:00<?, ? examples/s]

Tokenizing test:   0%|          | 0/82350 [00:00<?, ? examples/s]

Tokenized train rows: 1,489,852
Tokenized val rows: 82,890
Tokenized test rows: 82,350
dict_keys(['source_text', 'target_text', 'direction', 'split', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'])


In [32]:
# Micro-step 12: dynamic padding collator + one-batch sanity check
from transformers import DataCollatorForSeq2Seq

model_input_cols = ["input_ids", "attention_mask", "labels"]
train_tok_model = train_tok.remove_columns([c for c in train_tok.column_names if c not in model_input_cols])

data_collator = DataCollatorForSeq2Seq(
    tokenizer=hf_tokenizer,
    model=None,
    padding="longest",
    label_pad_token_id=-100,
)

batch_size_probe = min(32, len(train_tok_model))
probe_ds = train_tok_model.shuffle(seed=RANDOM_SEED).select(range(batch_size_probe))
probe_features = [probe_ds[i] for i in range(len(probe_ds))]
probe_batch = data_collator(probe_features)

print(f"Probe batch size: {batch_size_probe}")
print(f"input_ids shape: {tuple(probe_batch['input_ids'].shape)}")
print(f"attention_mask shape: {tuple(probe_batch['attention_mask'].shape)}")
print(f"labels shape: {tuple(probe_batch['labels'].shape)}")
print(f"Dynamic padded source length (batch max): {probe_batch['input_ids'].shape[1]}")
print(f"Dynamic padded target length (batch max): {probe_batch['labels'].shape[1]}")
print(f"Configured truncation cap: {config.max_seq_len}")


Probe batch size: 32
input_ids shape: (32, 74)
attention_mask shape: (32, 74)
labels shape: (32, 37)
Dynamic padded source length (batch max): 74
Dynamic padded target length (batch max): 37
Configured truncation cap: 128


In [38]:
# Micro-step 13: manual audit of random final tokenized examples
sample_n = min(5, len(train_tok))
audit_ds = train_tok.shuffle().select(range(sample_n))

audit_rows = []
for item in audit_ds:
    src_ids = item["input_ids"]
    lbl_ids = item["labels"]
    audit_rows.append({
        "direction": item.get("direction", ""),
        "source_text": item["source_text"],
        "target_text": item["target_text"],
        "source_len": len(src_ids),
        "target_len": len(lbl_ids),
        "source_ids": src_ids,
        "target_ids": lbl_ids,
        "decoded_source_from_ids": hf_tokenizer.decode(src_ids, skip_special_tokens=False),
        "decoded_target_from_ids": hf_tokenizer.decode(lbl_ids, skip_special_tokens=False),
    })

tokenized_audit_df = pd.DataFrame(audit_rows)
tokenized_audit_df


Unnamed: 0,direction,source_text,target_text,source_len,target_len,source_ids,target_ids,decoded_source_from_ids,decoded_target_from_ids
0,ar_to_en,<2en> وقد أصيب جمهور نيويورك بالصدمة ولكنهم ما...,New York audiences were shocked but still atte...,20,15,"[5, 3113, 15546, 11104, 5004, 457, 2635, 596, ...","[11837, 4674, 19580, 3875, 690, 25314, 883, 17...",<2en> وقد أصيب جمهور نيويورك بالصدمة ولكنهم ما...,New York audiences were shocked but still atte...
1,ar_to_en,<2en> أظن أن هذا ما يهم توقعتم العكس,you know im gonna just assume thats implied,9,9,"[5, 3755, 404, 504, 546, 7117, 20641, 606, 18926]","[781, 673, 587, 1354, 798, 26047, 1241, 14061,...",<2en> أظن أن هذا ما يهم توقعتم العكس,you know im gonna just assume thats implied
2,ar_to_en,<2en> مرحبا (بيل ),hey bill,5,2,"[5, 3295, 564, 3674, 2153]","[2715, 5751]",<2en> مرحبا (بيل ),hey bill
3,en_to_ar,<2ar> He also helped the girl to remember her ...,كما ساعد الفتاة على تذكر اسمها الحقيقي: إميلي ...,18,14,"[4, 1050, 778, 8449, 295, 1217, 342, 3660, 693...","[1941, 6114, 5954, 378, 7557, 8747, 9087, 31, ...",<2ar> He also helped the girl to remember her ...,كما ساعد الفتاة على تذكر اسمها الحقيقي: إميلي ...
4,en_to_ar,<2ar> theyre even in his butt crack,وهم حتى في بلده بعقب الكراك.,7,10,"[4, 3530, 1850, 326, 632, 15474, 12257]","[17421, 1658, 325, 1387, 2412, 3444, 1744, 563...",<2ar> theyre even in his butt crack,وهم حتى في بلده بعقب الكراك.


In [40]:
for item in audit_ds:
    print("=" * 60)
    print(f"Direction : {item.get('direction', '')}")
    print(f"Source    : {item['source_text']}")
    print(f"Target    : {item['target_text']}")
    print(f"Input IDs : {item['input_ids']}")
    print(f"Labels    : {item['labels']}")
    print(f"Attn Mask : {item['attention_mask']}")
    print(f"Decoded S : {hf_tokenizer.decode(item['input_ids'], skip_special_tokens=False)}")
    print(f"Decoded T : {hf_tokenizer.decode(item['labels'], skip_special_tokens=False)}")
print("=" * 60)

Direction : ar_to_en
Source    : <2en> وقد أصيب جمهور نيويورك بالصدمة ولكنهم ما زالوا يحضرون المسرحية وجعلوا العرض مشهورا.
Target    : New York audiences were shocked but still attended and made the play popular.
Input IDs : [5, 3113, 15546, 11104, 5004, 457, 2635, 596, 29822, 546, 9927, 523, 13691, 13528, 1081, 576, 523, 2886, 24474, 19]
Labels    : [11837, 4674, 19580, 3875, 690, 25314, 883, 1762, 5712, 355, 1502, 295, 848, 3777, 19]
Attn Mask : [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Decoded S : <2en> وقد أصيب جمهور نيويورك بالصدمة ولكنهم ما زالوا يحضرون المسرحية وجعلوا العرض مشهورا.
Decoded T : New York audiences were shocked but still attended and made the play popular.
Direction : ar_to_en
Source    : <2en> أظن أن هذا ما يهم توقعتم العكس
Target    : you know im gonna just assume thats implied
Input IDs : [5, 3755, 404, 504, 546, 7117, 20641, 606, 18926]
Labels    : [781, 673, 587, 1354, 798, 26047, 1241, 14061, 1319]
Attn Mask : [1, 1, 1, 1, 1, 1, 1, 1, 1]
De

In [41]:
# Micro-step 14: define and instantiate random-init BART model
from transformers import BartConfig, BartForConditionalGeneration

bart_config = BartConfig(
    vocab_size=hf_tokenizer.vocab_size,
    max_position_embeddings=config.max_seq_len + 2,
    d_model=512,
    encoder_layers=6,
    decoder_layers=6,
    encoder_attention_heads=8,
    decoder_attention_heads=8,
    encoder_ffn_dim=2048,
    decoder_ffn_dim=2048,
    dropout=0.1, # Main residual dropout (applied after attention & FFN blocks)
    attention_dropout=0.1, # Dropout applied to attention probabilities
    activation_dropout=0.0, # Dropout after FFN activation (kept 0 for stability)
    pad_token_id=hf_tokenizer.pad_token_id,
    bos_token_id=hf_tokenizer.bos_token_id,
    eos_token_id=hf_tokenizer.eos_token_id,
    decoder_start_token_id=hf_tokenizer.bos_token_id, # the decoder will start with <bos> to predict the first token
)

model = BartForConditionalGeneration(bart_config)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Initialized random BART model (no pretrained weights).")
print(f"Vocab size: {bart_config.vocab_size:,}")
print(f"d_model: {bart_config.d_model}, enc_layers: {bart_config.encoder_layers}, dec_layers: {bart_config.decoder_layers}")
print(f"Total params: {total_params:,}")
print(f"Trainable params: {trainable_params:,}")


Initialized random BART model (no pretrained weights).
Vocab size: 32,000
d_model: 512, enc_layers: 6, dec_layers: 6
Total params: 60,659,712
Trainable params: 60,659,712
