<a href="https://colab.research.google.com/github/khered20/MTL-Dial2MSA/blob/main/MTLtrain.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install sacremoses sacrebleu  -q

In [2]:
import os

# Clone your GitHub repository
if not os.path.exists("MTL-Dial2MSA"):
    !git clone https://github.com/khered20/MTL-Dial2MSA.git
    %cd MTL-Dial2MSA
else:
    %cd MTL-Dial2MSA

/content/MTL-Dial2MSA


In [3]:
import os

# Clone your GitHub repository
if not os.path.exists("Dial2MSA-Verified"):
    !git clone https://github.com/khered20/Dial2MSA-Verified.git

In [4]:
import pandas as pd
import glob, os
import os

# Create the 'data' directory if it doesn't exist
if not os.path.exists("data"):
    os.makedirs("data")


# Define mapping of dialects to msa column name
msa_mapping = {
    "egy": "msa",
    "mgr": "msa",
    "glf": "msa_verified",
    "lev": "msa_verified"
}

def merge_csvs(folder):
    dfs = []
    files = glob.glob(os.path.join("Dial2MSA-Verified", folder, "*.csv"))
    for f in files:
        dialect = f.split("/")[-1].split("_")[0]   # e.g., "egy_train.csv" -> "egy"
        msa_col = msa_mapping[dialect]             # choose correct msa column
        df = pd.read_csv(f)

        # Extract required columns
        df = df[["cleanedtweet2", msa_col]].rename(
            columns={"cleanedtweet2": "dialect_sentence", msa_col: "msa_translation"}
        )
        df["dialect_label"] = dialect.upper()

        dfs.append(df)
    merged = pd.concat(dfs, ignore_index=True)
    merged = merged[["dialect_label", "dialect_sentence", "msa_translation"]]
    return merged

# Merge train & dev
train_df = merge_csvs("train")
dev_df   = merge_csvs("dev")

print("Train samples:", len(train_df))
print("Dev samples:", len(dev_df))
print(train_df.head())


Train samples: 23087
Dev samples: 800
  dialect_label                                   dialect_sentence  \
0           EGY  جروبس الواتس بتاعه العائلات دى عقاب من ربنا والله   
1           EGY  جروبس الواتس بتاعه العائلات دى عقاب من ربنا والله   
2           EGY  زاي لن محدش بيطلب الاهتمام محدش يستني من حد هي...   
3           EGY  زاي لن محدش بيطلب الاهتمام محدش يستني من حد هي...   
4           EGY        مبعرفش اكدب ديه ميزه بس فالعالم ده اكبر عيب   

                                     msa_translation  
0       جروب الواتس الخاص بهذه العائلات عقاب من الله  
1                     مجموعة الواتس هذة عقاب من اللة  
2  طالما لن يطلب شخص ما الاهتمام من أحد سيرتاح هذ...  
3            مثل اى احد بيطلب الاهتمام من احد هيرتاح  
4  لا أعرف الكذب مهذه ميزه لكن فى هذا العالم هذا ...  


In [5]:
print("Available dialect labels:", train_df['dialect_label'].unique())

Available dialect labels: ['EGY' 'MGR' 'LEV' 'GLF']


In [6]:
### Optional if you want augmenting the data with MSA pairs
# Duplicate train_df
train_df_msa = train_df.copy()

# Change dialect_label to 'MSA' in the duplicated DataFrame
train_df_msa['dialect_label'] = 'MSA'
train_df_msa['dialect_sentence'] = train_df_msa['msa_translation']

# Concatenate the original and duplicated DataFrames
train_df = pd.concat([train_df, train_df_msa], ignore_index=True)

train_df = train_df.drop_duplicates(["dialect_sentence", "msa_translation"], keep="first")
# Display the first few rows and the new length to verify
print("New length of train_df:", len(train_df))
print(train_df.head())
print("Available dialect labels:", train_df['dialect_label'].unique())

New length of train_df: 45760
  dialect_label                                   dialect_sentence  \
0           EGY  جروبس الواتس بتاعه العائلات دى عقاب من ربنا والله   
1           EGY  جروبس الواتس بتاعه العائلات دى عقاب من ربنا والله   
2           EGY  زاي لن محدش بيطلب الاهتمام محدش يستني من حد هي...   
3           EGY  زاي لن محدش بيطلب الاهتمام محدش يستني من حد هي...   
4           EGY        مبعرفش اكدب ديه ميزه بس فالعالم ده اكبر عيب   

                                     msa_translation  
0       جروب الواتس الخاص بهذه العائلات عقاب من الله  
1                     مجموعة الواتس هذة عقاب من اللة  
2  طالما لن يطلب شخص ما الاهتمام من أحد سيرتاح هذ...  
3            مثل اى احد بيطلب الاهتمام من احد هيرتاح  
4  لا أعرف الكذب مهذه ميزه لكن فى هذا العالم هذا ...  
Available dialect labels: ['EGY' 'MGR' 'LEV' 'GLF' 'MSA']


In [7]:
# Save
os.makedirs("data", exist_ok=True)
train_df.to_csv("data/All_train_mtl.csv", index=False)
dev_df.to_csv("data/All_dev_mtl.csv", index=False)

In [8]:
import sys
sys.path.append('./MTL-Dial2MSA')

from mtl.dataset import create_data_loaders
from mtl.models import MultiTaskT5, MultiTaskMBart
from mtl.train import train
from mtl.utils import cleanup
from transformers import AutoTokenizer, get_linear_schedule_with_warmup
from torch.optim import AdamW
import torch


In [9]:
import json

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#SAVE_PATH = "saved_models/mtl_AraT5"
#MODEL_NAME = "UBC-NLP/AraT5v2-base-1024"
#tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
#model = MultiTaskT5(num_labels=5, pretrained_model=MODEL_NAME).to(device)
batch_size = 16
max_length = 128
num_epochs = 3
alpha=0.5
SAVE_PATH = "fn/mtl_AraBART"
MODEL_NAME="moussaKam/AraBART"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = MultiTaskMBart(num_labels=5, pretrained_model=MODEL_NAME).to(device)


train_loader, val_loader, tokenizer = create_data_loaders(
    "data/All_train_mtl.csv", "data/All_dev_mtl.csv",
    tokenizer, batch_size=batch_size, max_length=max_length
)


tokenizer.save_pretrained(SAVE_PATH)
tokenizer.save_pretrained(SAVE_PATH+'/trns')
tokenizer.save_pretrained(SAVE_PATH+'/cls')
tokenizer.save_pretrained(SAVE_PATH+'/last')

config = {
    "base_model": MODEL_NAME,
    "num_labels": 5,
    "max_length": max_length,
    "batch_size": batch_size,
    "custom_parameters": {
        "alpha": alpha,
    }
}

config_path = SAVE_PATH+'/config.json'
with open(config_path, 'w') as json_file:
    json.dump(config, json_file, indent=4)

config_path = SAVE_PATH+'/trns'+'/config.json'
with open(config_path, 'w') as json_file:
    json.dump(config, json_file, indent=4)

config_path = SAVE_PATH+'/cls'+'/config.json'
with open(config_path, 'w') as json_file:
    json.dump(config, json_file, indent=4)

config_path = SAVE_PATH+'/last'+'/config.json'
with open(config_path, 'w') as json_file:
    json.dump(config, json_file, indent=4)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [10]:
optimizer = AdamW(model.parameters(), lr=5e-5)
num_epochs = 2
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=0, num_training_steps=len(train_loader) * num_epochs
)


In [11]:
best_bleu, best_f1 = train(
    model, train_loader, optimizer, scheduler, device,
    val_loader, tokenizer, epochs=num_epochs, save_path=SAVE_PATH,alpha=alpha
)
print("Training finished!")
print("Best BLEU:", best_bleu)
print("Best F1:", best_f1)

cleanup()


Epoch 1: 100%|██████████| 50/50 [00:30<00:00,  1.66it/s, loss=2.5]
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


dev Epoch 1, BLEU Score: 14.92817171180575, F1 Score: 0.21398593530239102, -best bleu: 14.92817171180575, best f1: 0.21398593530239102


Epoch 2: 100%|██████████| 50/50 [00:29<00:00,  1.69it/s, loss=2.04]


dev Epoch 2, BLEU Score: 16.62205440878057, F1 Score: 0.29885309400094473, -best bleu: 16.62205440878057, best f1: 0.29885309400094473
Training finished!
Best BLEU: 16.62205440878057
Best F1: 0.29885309400094473


In [12]:
from mtl.predict import predict

samples = ["إزيك عامل إيه؟", "شلونك يا خوي؟"]
outputs = predict(model, tokenizer, samples, device)

for o in outputs:
    print("\nInput:", o["input"])
    print("Predicted Dialect:", o["dialect"])
    print("Predicted Translation:", o["translation"])



Input: إزيك عامل إيه؟
Predicted Dialect: LEV
Predicted Translation: ماذا حدث؟

Input: شلونك يا خوي؟
Predicted Dialect: LEV
Predicted Translation: كيف حالك يا خوي؟
