In [None]:
!pip install -U transformers datasets evaluate accelerate sacrebleu --quiet

In [None]:
from datasets import Dataset

def load_parallel_text_data(en_path, gn_path):
    with open(en_path, encoding="utf-8") as f_en:
        en_lines = [line.strip() for line in f_en if line.strip()]
    with open(gn_path, encoding="utf-8") as f_gn:
        gn_lines = [line.strip() for line in f_gn if line.strip()]
    if len(en_lines) != len(gn_lines):
        raise ValueError(f"Line mismatch: {len(en_lines)} English vs {len(gn_lines)} Guarani.")
    return Dataset.from_list([{"en": en, "gn": gn} for en, gn in zip(en_lines, gn_lines)])

def load_flores_for_bleu(english_path, guarani_path):
    with open(english_path, encoding="utf-8") as f_en:
        en_lines = [line.strip() for line in f_en if line.strip()]
    with open(guarani_path, encoding="utf-8") as f_gn:
        gn_lines = [line.strip() for line in f_gn if line.strip()]
    if len(en_lines) != len(gn_lines):
        raise ValueError(f"Line mismatch: {len(en_lines)} English vs {len(gn_lines)} Guarani.")
    data = [{"en": en, "gn": gn} for en, gn in zip(en_lines, gn_lines)]
    return Dataset.from_list(data)

base_path = "CompLingBen"
train_en_path = f"{base_path}/NLLB.en-gn.en"
train_gn_path = f"{base_path}/NLLB.en-gn.gn"
en_devtest_path = f"{base_path}/eng_Latn.devtest"
gn_devtest_path = f"{base_path}/grn_Latn.devtest"

train_dataset = load_parallel_text_data(train_en_path, train_gn_path)
test_dataset = load_flores_for_bleu(en_devtest_path, gn_devtest_path)

def normalize_glottal_stop(text):
    return text.replace("\\'", "ʼ").replace("'", "ʼ")

def normalize_dataset(dataset):
    return dataset.map(lambda x: {
        "en": x["en"],
        "gn": normalize_glottal_stop(x["gn"])
    })

train_dataset = normalize_dataset(train_dataset)
test_dataset = normalize_dataset(test_dataset)

train_dataset = train_dataset.shuffle().select(range(10000))
test_dataset = test_dataset.select(range(500))

In [None]:
from transformers import BertTokenizer, EncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
import torch

# load tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")

# initialize model
model = EncoderDecoderModel.from_encoder_decoder_pretrained(
    "bert-base-multilingual-cased", "gpt2"
)

# configure model for seq2seq
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.sep_token_id
model.decoder.resize_token_embeddings(len(tokenizer))

In [None]:
def preprocess_function(examples):
    inputs = tokenizer(examples["en"], padding="max_length", truncation=True, max_length=64)
    targets = tokenizer(examples["gn"], padding="max_length", truncation=True, max_length=64)
    inputs["labels"] = targets["input_ids"]
    return inputs

tokenized_train = train_dataset.map(preprocess_function, batched=True, remove_columns=["en", "gn"])
tokenized_test = test_dataset.map(preprocess_function, batched=True, remove_columns=["en", "gn"])

tokenized_train.set_format("torch")
tokenized_test.set_format("torch")

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir=f"{base_path}/bert2gpt2_en2gn",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=4,
    max_steps=2000,
    weight_decay=0.01,
    save_total_limit=1,
    prediction_loss_only=True,
    fp16=torch.cuda.is_available(),
    logging_steps=100,
    eval_steps=500,
    save_steps=1000,
    report_to="none"
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, pad_to_multiple_of=8)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    tokenizer=tokenizer,
    data_collator=data_collator
)

trainer.train()

# save model
trainer.save_model(f"{base_path}/bert2gpt2_en2gn")
tokenizer.save_pretrained(f"{base_path}/bert2gpt2_en2gn")
print(f" Model saved to: {base_path}/bert2gpt2_en2gn")

In [None]:
from torch.utils.data import DataLoader
from evaluate import load

# clear gpu memory
torch.cuda.empty_cache()

# use a subset of test data (100)
subset = tokenized_test.select(range(100))
subset.set_format("torch")
eval_dataloader = DataLoader(subset, batch_size=4)

model.eval()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

generated_preds = []
reference_labels = []

for batch in eval_dataloader:
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    with torch.no_grad():
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=64,
            num_beams=1,
            do_sample=False
        )
    decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    labels = tokenizer.batch_decode(batch["labels"], skip_special_tokens=True)
    generated_preds.extend(decoded)
    reference_labels.extend([[label] for label in labels])

chrf = load("chrf")
score = chrf.compute(predictions=generated_preds, references=reference_labels)
print(f"\n ChrF++: {score['score']:.2f}")

In [None]:
import ipywidgets as widgets
from IPython.display import display

def translate(text, max_new_tokens=64):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    inputs = tokenizer(text, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        output_tokens = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            num_beams=4,
            do_sample=False
        )
    return tokenizer.batch_decode(output_tokens, skip_special_tokens=True)[0]

text_input = widgets.Text(
    value='this worked',
    placeholder='type a sentence...',
    description='english:',
    layout=widgets.Layout(width='90%')
)
output_box = widgets.Output()
translate_button = widgets.Button(
    description='translate to guarani',
    button_style='success'
)

def on_translate_clicked(b):
    output_box.clear_output()
    sentence = text_input.value.strip()
    with output_box:
        if not sentence:
            print("please enter a sentence.")
        else:
            translation = translate(sentence)
            print("guarani:", translation)

translate_button.on_click(on_translate_clicked)
display(widgets.VBox([text_input, translate_button, output_box]))