In [None]:
import warnings
import helpers
import pandas as pd
import os
from transformers import AutoTokenizer, AutoConfig, DataCollatorWithPadding, Trainer, TrainingArguments, AutoModelForSequenceClassification
warnings.filterwarnings("ignore")

In [None]:
device = helpers.get_device()
model_ckpt = "distilbert-base-multilingual-cased"

train_path = "data/SemEval2024-Task8/SubtaskA/subtaskA_train_multilingual.jsonl"
val_path = "data/SemEval2024-Task8/SubtaskA/subtaskA_dev_multilingual.jsonl"

addon_directory = "./backtranslation_data_multi/"
addon_paths = []

for filename in os.listdir(addon_directory):
    if filename.endswith(".jsonl"):
        addon_paths.append(os.path.join(addon_directory, filename))

tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
config = AutoConfig.from_pretrained(model_ckpt)
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt).to(device)

In [None]:
train_df, val_df = helpers.get_pandas_dfs(train_path, val_path)
addon_dataframes = []

for addon_path in addon_paths:
    addon_df = helpers.get_pandas_atomic_dfs(addon_path)
    addon_dataframes.append(addon_df)

train_df = pd.concat([train_df] + addon_dataframes, axis=0, ignore_index=True)
print(train_df.shape)

In [None]:
train_df["text"] = train_df["text"].apply(lambda x: helpers.chunk_text(x, tokenizer))
train_df = train_df.explode("text").reset_index(drop=True)

In [None]:
train_ds, val_ds = helpers.prepare_datasets(train_df, val_df)

In [None]:
data_collator = DataCollatorWithPadding(tokenizer, return_tensors="pt")

def tokenize(batch):
    return tokenizer(batch["text"], padding=True, truncation=True, return_tensors="pt")

train_ds_encoded = train_ds.map(tokenize, batched=True)
val_ds_encoded = val_ds.map(tokenize, batched=True)

training_args = TrainingArguments(
    "SemEval-Trainer",
    num_train_epochs=15,
    save_strategy="epoch",
    save_total_limit=20,
    evaluation_strategy="epoch",
    metric_for_best_model="eval_loss",
)


trainer = Trainer(
    model,
    training_args,
    train_dataset=train_ds_encoded,
    eval_dataset=val_ds_encoded,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=helpers.compute_metrics
)

trainer.train()
trainer.save_model(f"fine_tuned_distilbert_for_multilingual.pt")