In [21]:
import os 
from datasets import load_dataset
import json

data = "/Users/marisa/clausal-coordinate-ellipsis/german-common-crawl/de_de_pairs.jsonl"

de_de_dataset = load_dataset("json", data_files=data, split='train')

In [22]:
from responses import target
from transformers import AutoTokenizer

checkpoint = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
#
source_lang = "de"
target_lang = "de"
prefix = "translate German to German: "

In [23]:
de_de_dataset = de_de_dataset.filter(lambda example: len(example['text']) >= 20)
de_de_dataset = de_de_dataset.filter(lambda example: len(example['gold_sentence']) >= 20)

In [24]:
def preprocess_function(examples):
    inputs = prefix + examples['text']
    targets = examples["gold_sentence"]
    model_inputs = tokenizer(inputs, text_target=targets, max_length=128, truncation=True, padding='longest', return_tensors='pt')
    return model_inputs

In [25]:
tokenized_dataset = de_de_dataset.map(preprocess_function, batched=False)

In [26]:
def correct_inputs_masks_labels(examples):
    examples['input_ids'] = examples['input_ids'][0] 
    examples['attention_mask'] = examples['attention_mask'][0] 
    examples['labels'] = examples['labels'][0] 
    return examples

In [27]:
tokenized_dataset = tokenized_dataset.map(correct_inputs_masks_labels, batched=False)

Map:   0%|          | 0/930753 [00:00<?, ? examples/s]

In [28]:
tokenized_dataset = tokenized_dataset.train_test_split(test_size=0.2)

In [29]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

In [30]:
import evaluate
import numpy as np 

metric = evaluate.load("sacrebleu")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [31]:
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer

model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

In [32]:
training_args = Seq2SeqTrainingArguments(
    output_dir="de_de_040124",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=2,
    predict_with_generate=True,
    fp16=False, # set true when cuda available
    push_to_hub=False,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['test'],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [33]:
tokenizer.decode(trainer.train_dataset[0]["input_ids"])

'translate German to German: Die Gegenstrategie für diese Dilemma sieht denkbar simpel aus: Die Definition von Diskriminierung und Repression wird ins Absurde gesteigert und an die Gesellschaft werden unerfüllbare Anforderungen gestellt</s>'

In [34]:
trainer.train_dataset[0]["labels"]

[316,
 5959,
 21889,
 218,
 637,
 2043,
 109,
 635,
 9,
 10262,
 177,
 157,
 1047,
 108,
 51,
 4343,
 403,
 10,
 316,
 15476,
 193,
 2678,
 10648,
 1109,
 3194,
 64,
 419,
 8243,
 551,
 16,
 7,
 891,
 3042,
 221,
 873,
 31939,
 64,
 46,
 67,
 11580,
 404,
 73,
 49,
 6335,
 195,
 5304,
 17707,
 3,
 5371,
 1]

In [35]:
trainer.train()

  0%|          | 0/93076 [00:00<?, ?it/s]

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


KeyboardInterrupt: 