In [None]:
from transformers import T5Tokenizer, DataCollatorForSeq2Seq
from transformers import T5ForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer
import nltk
import numpy as np

In [None]:
!python -c "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('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')"
!pip install --upgrade -q wandb

In [None]:
MODEL_NAME = "dumitrescustefan/t5-v1_1-base-romanian"

tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
model.generation_config.min_new_tokens = 0
model.generation_config.max_new_tokens = 64
model.resize_token_embeddings(len(tokenizer))
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [6]:
from datasets import load_dataset

dataset = load_dataset("mateiaassAI/MEID_v2", split=['train[:98%]', 'train[98%:100%]'])
ds_train = dataset[0]
ds_test = dataset[1]

punctuation_marks = ['.', '?', '!', ';', '...']

def filter_sentences(sentences):
    text = sentences['right']
    if any(text.endswith(punc) for punc in punctuation_marks):
      words = text.split()
      if len(words) >= 10:
        return True
    return False

ds_train = ds_train.filter(filter_sentences, batched=False)
ds_test = ds_test.filter(filter_sentences, batched=False)

print(len(ds_train))
print(len(ds_test))

1017040
21435


In [None]:
prefix = "Corectează: "

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["wrong"]]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(text_target=examples["right"],
                          max_length=256,
                          truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

ds_tok_train = ds_train.select(range(500000)).map(preprocess_function, batched=True)
ds_tok_test = ds_test.select(range(20000)).map(preprocess_function, batched=True)

Map:   0%|          | 0/500000 [00:00<?, ? examples/s]



Map:   0%|          | 0/20000 [00:00<?, ? examples/s]

In [None]:
import wandb
wandb.login(key='9fe6455e5e90c4d4fc958203e62ae4452ad676cf')

In [9]:
# Global Parameters
L_RATE = 3.5e-4
BATCH_SIZE = 4
PER_DEVICE_EVAL_BATCH = 4
WEIGHT_DECAY = 0.01
SAVE_TOTAL_LIM = 3
NUM_EPOCHS = 1

model.to("cuda")

# Set up training arguments
training_args = Seq2SeqTrainingArguments(
   output_dir="./kaggle/working/results",
   evaluation_strategy="steps",
   eval_steps=62500,
   learning_rate=L_RATE,
   per_device_train_batch_size=BATCH_SIZE,
   per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH,
   weight_decay=WEIGHT_DECAY,
   save_total_limit=SAVE_TOTAL_LIM,
   num_train_epochs=NUM_EPOCHS,
   predict_with_generate=True,
   push_to_hub=False,
    save_strategy="steps",
    save_steps=125000,
    logging_steps=100)  # Log training loss every 100 steps
    #logging_dir=None)# No logging directory, print to console

trainer = Seq2SeqTrainer(
   model=model,
   args=training_args,
   train_dataset=ds_tok_train,
   eval_dataset=ds_tok_test,
   tokenizer=tokenizer,
   data_collator=data_collator,
)


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)



In [10]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mmateiaass[0m ([33mmateiaass2[0m). Use [1m`wandb login --relogin`[0m to force relogin

[34m[1mwandb[0m: Tracking run with wandb version 0.17.0

[34m[1mwandb[0m: Run data is saved locally in [35m[1m/kaggle/working/wandb/run-20240528_101416-m83gjlp2[0m

[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.

[34m[1mwandb[0m: Syncing run [33mancient-pyramid-12[0m

[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/mateiaass2/huggingface[0m

[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/mateiaass2/huggingface/runs/m83gjlp2[0m


Step,Training Loss,Validation Loss
62500,0.2409,0.227735
125000,0.2045,0.204753


TrainOutput(global_step=125000, training_loss=0.32045027952575683, metrics={'train_runtime': 23923.9361, 'train_samples_per_second': 20.9, 'train_steps_per_second': 5.225, 'total_flos': 6.434134137874022e+16, 'train_loss': 0.32045027952575683, 'epoch': 1.0})