In [None]:
import warnings
warnings.filterwarnings('ignore')

import numpy as np
from datasets import (Dataset, 
                      DatasetDict,
                      load_dataset,
                      load_from_disk)
from transformers import (AutoTokenizer, 
                          AutoModelForSeq2SeqLM, 
                          T5ForConditionalGeneration, 
                          Seq2SeqTrainer, 
                          Seq2SeqTrainingArguments, 
                          BitsAndBytesConfig)

In [None]:
dataset = load_dataset("gudleifrr/text-correction-en-ru")

bnb_config = BitsAndBytesConfig(load_in_8_bit=True)

tokenizer = AutoTokenizer.from_pretrained("google/mt5-large")
specials = ['<исправить>', '<fix>', '<NULL>']
tokenizer.add_tokens(specials)

def preprocess(ex):
    inputs = [
        ('<исправить>: ' if lang=='ru' else '<fix>: ') + bad_text for lang, bad_text in zip(ex['lang'], ex['bad_text'])
    ]

    targets = [good_text for good_text in ex['text']]
    model_inputs = tokenizer(inputs, text_target=targets, max_length=256, truncation=True)
    return model_inputs

tokenized_dataset = dataset.map(preprocess, batched=True, remove_columns=dataset['train'].columns_names)

In [None]:
from peft import LoraConfig, TaskType, get_peft_model
import torch
from transformers import DataCollatorForSeq2Seq

model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-large")
model.resize_token_embeddings(len(tokenizer))

peft_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    r=16,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules='all_linear',
    modules_to_save=[
        'shared',
        'lm_head',
        'encoder.block.0.layer.0.SelfAttention.relative_attention_bias',
        'decoder.block.0.layer.0.SelfAttention.relative_attention_bias'
    ],
    bias='none',
    use_rslora=True

)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
training_args = Seq2SeqTrainingArguments(
    output_dir="./mt5-text-correction-enru",
    gradient_accumulation_steps=8,
    learning_rate=5e-4,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=4,
    weight_decay=0.01,
    save_total_limit=1,
    num_train_epochs=1,
    warmup_steps=100,
    logging_dir="./mt5_training_logs",
    logging_steps=1
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset['train'],
    tokenizer=tokenizer,
    data_collator=data_collator
)

In [None]:
trainer.train()

In [None]:
trainer.push_to_hub()