In [None]:
import json
import pandas as pd
import os
from datetime import datetime
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, TrainingArguments, Trainer, EarlyStoppingCallback
from train_seq import tokenize_sequence_classification
import numpy as np
from sklearn.metrics import classification_report

In [2]:
train_data_path = '/home/lgiordano/LUCA/checkthat_GITHUB/data/formatted/train_sentences.json'
with open(train_data_path, 'r', encoding='utf8') as f:
    data = json.load(f)
df = pd.DataFrame(data)

In [3]:
data_path_dict = {
'sl': '/home/lgiordano/LUCA/checkthat_GITHUB/data/train_sent_mt/sl/train_gold_sentences_translated_nllb-200-3.3B_eng_Latn-slv_Latn_tok_regex_en-sl/train_gold_sentences_translated_nllb-200-3.3B_eng_Latn-slv_Latn_tok_regex_en-sl_mdeberta-v3-base_mdeberta_xlwa_en-sl_ME3_2024-05-04-12-12-14_ls.json',
'ru': '/home/lgiordano/LUCA/checkthat_GITHUB/data/train_sent_mt/ru/train_gold_sentences_translated_nllb-200-3.3B_eng_Latn-rus_Cyrl_tok_regex_en-ru/train_gold_sentences_translated_nllb-200-3.3B_eng_Latn-rus_Cyrl_tok_regex_en-ru_mdeberta-v3-base_mdeberta_xlwa_en-ru_ME3_2024-05-04-12-09-20_ls.json',
'pt': '/home/lgiordano/LUCA/checkthat_GITHUB/data/train_sent_mt/pt/train_gold_sentences_translated_nllb-200-3.3B_eng_Latn-por_Latn_tok_regex_en-pt/train_gold_sentences_translated_nllb-200-3.3B_eng_Latn-por_Latn_tok_regex_en-pt_mdeberta-v3-base_mdeberta_xlwa_en-pt_ME3_2024-05-04-12-07-45_ls.json',
'it': '/home/lgiordano/LUCA/checkthat_GITHUB/data/train_sent_mt/it/train_gold_sentences_translated_nllb-200-3.3B_eng_Latn-ita_Latn_tok_regex_en-it/train_gold_sentences_translated_nllb-200-3.3B_eng_Latn-ita_Latn_tok_regex_en-it_mdeberta-v3-base_mdeberta_xlwa_en-it_ME3_2024-05-04-12-05-00_ls.json',
'es': '/home/lgiordano/LUCA/checkthat_GITHUB/data/train_sent_mt/es/train_gold_sentences_translated_nllb-200-3.3B_eng_Latn-spa_Latn_tok_regex_en-es/train_gold_sentences_translated_nllb-200-3.3B_eng_Latn-spa_Latn_tok_regex_en-es_mdeberta-v3-base_mdeberta_xlwa_en-es_ME3_2024-05-04-12-01-43_ls.json',
'bg': '/home/lgiordano/LUCA/checkthat_GITHUB/data/train_sent_mt/bg/train_gold_sentences_translated_nllb-200-3.3B_eng_Latn-bul_Cyrl_tok_regex_en-bg/train_gold_sentences_translated_nllb-200-3.3B_eng_Latn-bul_Cyrl_tok_regex_en-bg_mdeberta-v3-base_mdeberta_xlwa_en-bg_ME3_2024-05-04-11-58-52_ls.json',
'ar': '/home/lgiordano/LUCA/checkthat_GITHUB/data/aug_NEW/araieval24_all_bin_formatted.json'
}

dataset_aug = []
for key in data_path_dict:
    with open(data_path_dict[key], 'r', encoding='utf8') as f:
        dataset_aug_buffer = json.load(f)
        dataset_aug_buffer = [sample for sample in dataset_aug_buffer if sample['data'].get('type') != 'tweet'] #filter out tweets from ar
        
        for sample in dataset_aug_buffer:
            if 'text_en' in sample['data']:
                del sample['data']['text_en']
            if f'text_{key}' in sample['data']:
                sample['data']['text'] = sample['data'][f'text_{key}']
                del sample['data'][f'text_{key}']
            sample['data']['lang'] = key
            if 'labels' in sample['data']:
                sample['data']['label'] = sample['data'].pop('labels')
        dataset_aug += dataset_aug_buffer

semeval_24 = json.load(open('/home/lgiordano/LUCA/checkthat_GITHUB/data/aug_NEW/semeval24_all_bin_formatted.json', encoding='utf-8'))
dataset_aug += semeval_24

df_aug = pd.DataFrame(dataset_aug)
df = pd.concat([df, df_aug])

In [4]:
### This code balances positive and negative samples for each language by down-sampling the larger group ###
date_time = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")

langs = set(list(sample['data']['lang'] for sample in data) + (list(data_path_dict.keys())))

sampled_dfs = []
for lang in langs:
    df_lang = df[df['data'].apply(lambda x: x['lang'] == lang)]
    df_pos_lang = df_lang[df_lang['data'].apply(lambda x: x['label'] == 1)]
    df_neg_lang = df_lang[df_lang['data'].apply(lambda x: x['label'] == 0)]
    if len(df_neg_lang) > len(df_pos_lang):
        df_neg_lang = df_neg_lang.sample(len(df_pos_lang))
    df_lang_sampled = pd.concat([df_pos_lang, df_neg_lang])
    sampled_dfs.append(df_lang_sampled)
df_sampled = pd.concat(sampled_dfs, ignore_index=True)
df_sampled = df_sampled.sample(frac=1, random_state=42).reset_index(drop=True)
dataset = Dataset.from_pandas(df_sampled)

In [None]:
### This code splits the balanced dataset in train/test splits and tokenizes both with dynamic padding

#model_name = 'bert-base-multilingual-cased'
#model_name = 'xlm-roberta-base'
model_name = 'microsoft/mdeberta-v3-base'
tokenizer = AutoTokenizer.from_pretrained(model_name)

split_ratio = 0.2
split_seed = 42
batch_size = 16

datadict = dataset.train_test_split(split_ratio, seed=split_seed)
datadict = datadict.map(lambda x: tokenize_sequence_classification(x, tokenizer),
                            batch_size=batch_size,
                            batched=True
                            )

columns = [
            'input_ids',
            'token_type_ids', #non per xlm-roberta
            'attention_mask',
            'labels'
            ]
datadict.set_format('torch', columns = columns)

train_data = datadict['train'] # no aug: 33,345 samples, with aug: 90,720, + ARAIEVAL24(tweet+news): 97,892, + ARAIEVAL24(news) & SEMEVAL24: 108,471
val_data = datadict['test'] # no aug: 8,337 samples, with aug: 22,681, + ARAIEVAL24(tweet+news): 24,473, + ARAIEVAL24(news) & SEMEVAL24: 27,118

collate_fn = DataCollatorWithPadding(tokenizer=tokenizer, padding='longest')

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

training_args = TrainingArguments(output_dir='/home/lgiordano/LUCA/checkthat_GITHUB/models/M1/RUN_OTTOBRE/aug/aug, lr 2e-5, + ARAIEVAL(news) & SEMEVAL24',
                                  save_total_limit=1000,
                                  save_strategy='epoch',
                                  load_best_model_at_end=True,
                                  save_only_model=True,
                                  metric_for_best_model='eval_macro-f1',
                                  logging_strategy='epoch',
                                  evaluation_strategy='epoch',
                                  learning_rate=2e-5, #5e-5,
                                  optim='adamw_torch',
                                  num_train_epochs=10)

early_stopping = EarlyStoppingCallback(early_stopping_patience=2)

def compute_metrics(eval_preds):
    logits, labels = eval_preds
    preds = np.argmax(logits, axis=-1)
    results = classification_report(labels, preds, output_dict=True)
    results['macro-f1'] = results['macro avg']['f1-score']

    models_dir = '/home/lgiordano/LUCA/checkthat_GITHUB/models/M1/RUN_OTTOBRE/aug/aug, lr 2e-5, + ARAIEVAL(news) & SEMEVAL24'
    #model_name_simple = model_name.split('/')[-1]
    model_save_name = f'{date_time}_aug'
    model_save_dir = os.path.join(models_dir, model_save_name)
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)
    with open(os.path.join(model_save_dir, 'results.json'), 'w', encoding='utf8') as f:
        json.dump(results, f, ensure_ascii = False)

    return results

In [7]:
trainer = Trainer(
    model,
    training_args,
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=collate_fn,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[early_stopping]
)

In [None]:
trainer.train()

In [None]:
trainer.evaluate()