In [None]:
!pip install accelerate

In [None]:
!pip install transformers pandas datasets

In [None]:
import torch
if not torch.cuda.is_available():
    raise

In [None]:
from transformers import MT5ForConditionalGeneration, AutoTokenizer, Trainer, TrainingArguments
import pandas as pd
from datasets import Dataset
from torch.utils.tensorboard import SummaryWriter

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-small")

In [None]:
train_df = pd.read_csv('/kaggle/input/michael-ivrit-dataset/train_data.csv', index_col=None).dropna()
eval_df = pd.read_csv('/kaggle/input/michael-ivrit-dataset/train_data.csv', index_col=None).dropna()

In [None]:
# Tokenize the data
def tokenize_data(examples):
    inputs = examples["orig_text"]
    targets = examples["text"]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length", return_tensors="pt")
    labels = tokenizer(targets, max_length=512, truncation=True, padding="max_length", return_tensors="pt").input_ids
    return {"input_ids": model_inputs["input_ids"], "attention_mask": model_inputs["attention_mask"], "labels": labels}

In [None]:
# Convert the DataFrames to Huggingface Datasets
train_dataset = Dataset.from_pandas(train_df)
eval_dataset = Dataset.from_pandas(eval_df)

In [None]:
# Tokenize and create Huggingface Datasets
tokenized_train_dataset = train_dataset.map(tokenize_data, batched=True, remove_columns=["orig_text", "text", "uuid"])
tokenized_eval_dataset = eval_dataset.map(tokenize_data, batched=True, remove_columns=["orig_text", "text", "uuid"])

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

# Load tokenizer and model
# tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
# model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-small")

# Define a custom loss function
error_weight = 5.0  # Higher weight for grammar/spelling errors


def custom_loss_fn(labels, logits, **kwargs):
    loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
    loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))

    # Get the error positions
    error_positions = (labels != tokenizer.pad_token_id) & (labels != tokenizer.eos_token_id)
    error_positions = error_positions.view(loss.size())  # Reshape error_positions to match loss

    # Apply higher weight to the errors
    loss = torch.where(error_positions, loss * error_weight, loss)

    return loss.mean()


# Create data collator
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

# Define training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./mt5-finetuned",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=5,
    prediction_loss_only=True,
#     predict_with_generate=True,
    report_to = 'none'
)

# Define a custom training step
class CustomTrainer(Seq2SeqTrainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss = custom_loss_fn(labels, logits)
        return (loss, outputs) if return_outputs else loss

# Create Trainer instance with custom training step
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=None,  # Replace ... with None or your custom function
    callbacks=None,  # Replace ... with None or your custom callbacks
)

# Train
trainer.train()

# ... (rest of the code remains the same)

In [None]:
# Save the fine-tuned model
trainer.save_model("./mt5-small-finetuned_model")