In [57]:
import os 
from datasets import load_dataset
import json

data = "/Users/marisa/clausal-coordinate-ellipsis/german-common-crawl/de_de_pairs.jsonl"

de_de_dataset = load_dataset("json", data_files=data, split='train')

In [58]:
from responses import target
from transformers import AutoTokenizer

checkpoint = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
#
source_lang = "de"
target_lang = "de"
prefix = "translate German to German: "

In [59]:
de_de_dataset = de_de_dataset.filter(lambda example: len(example['text']) >= 20)
de_de_dataset = de_de_dataset.filter(lambda example: len(example['gold_sentence']) >= 20)

Filter:   0%|          | 0/1000001 [00:00<?, ? examples/s]

Filter:   0%|          | 0/930753 [00:00<?, ? examples/s]

In [62]:
def preprocess_function(examples):
    inputs = [prefix + examples['text'][indx] for indx in range(len(examples))]
    targets = [examples["gold_sentence"][indx] for indx in range(len(examples))]
    model_inputs = tokenizer(inputs, text_target=targets, max_length=128, truncation=True)
    return model_inputs

In [61]:
tokenized_dataset = de_de_dataset.map(preprocess_function, batched=False)

Map:   0%|          | 0/930753 [00:00<?, ? examples/s]

In [63]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

In [64]:
import evaluate
import numpy as np 

metric = evaluate.load("sacrebleu")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [65]:
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer

model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

In [66]:
training_args = Seq2SeqTrainingArguments(
    output_dir="de_de_040124",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=2,
    predict_with_generate=True,
    fp16=False, # set true when cuda available
    push_to_hub=False,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=de_de_dataset,
    eval_dataset=de_de_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

  0%|          | 0/116346 [00:00<?, ?it/s]

  return table.fast_gather(key % table.num_rows)


IndexError: list index out of range