### Import dependencies

In [None]:
# Import dependencies
import matplotlib.pyplot as plt
import numpy as np
import torch
import os
from torch.utils.data import DataLoader

# If running on Google Colab, install the required packages
# %pip install datasets transformers

from tqdm import tqdm
from datasets import load_dataset, load_dataset_builder, get_dataset_split_names
from transformers import AutoModelForCausalLM, AutoTokenizer

### Run the model on GPU if available

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"USING DEVICE: {device}")

### Explore dataset

In [None]:
# Specify dataset name
DATA_SET_NAME = "facebook/flores"

# Print the dataset description
ds_builder = load_dataset_builder("facebook/flores", "deu_Latn", trust_remote_code=True)
print(f"DESCRIPTION OF THE DATASET:\n {ds_builder.info.description}\n")

# Print the features (columns) of the dataset
print(f"FEATURRE COLUMNS OF THE DATASET:\n {ds_builder.info.features}\n")

# Get the available splits
AVAILABLE_SPLITS = get_dataset_split_names(DATA_SET_NAME, "deu_Latn", trust_remote_code=True)
print(f"AVAILABLE SPLITS:\n {AVAILABLE_SPLITS}\n")

In [None]:
# Specify languages
LANGUAGES = [
    "eng_Latn",
    "spa_Latn",
    "ita_Latn",
    "deu_Latn",
    "arb_Arab",
    "tel_Telu",
    "tam_Taml",
    "quy_Latn"
]


def load_flores_datasets(languages, splits):
    """ Loads the FLORES datasets for the specified languages and splits

    Args:
        languages (list): a list of languages
        splits (list): a list of splits

    Returns:
        dict: a dictionary of datasets for each language and split
    """
    flores_data = {}
    for language in languages:
        print(f"Loading dataset for language: {language}")
        flores_data[language] = {}
        for split in splits:
            flores_data[language][split] = {}
            flores_data[language][split] = load_dataset(
                "facebook/flores",
                language,
                split=split,
                trust_remote_code=True,
                cache_dir="../cache/languages"
            )
    return flores_data


flores_data = load_flores_datasets(LANGUAGES, AVAILABLE_SPLITS)

# Let's look at the English subset
data = flores_data["eng_Latn"]["devtest"].data
print(f"\nENGLISH SUBSET(DEVTEST):\n {data}\n")

# Let's look at an individual sample from the dataset
sample = flores_data["eng_Latn"]["devtest"][0]
print(f"SAMPLE FROM ENGLISH SUBSET(DEVTEST):\n {sample}\n")

### Define a tokenizer

In [None]:
class Tokenizer:
    """Tokenizer class to tokenize a given example for a given model
    """
    def __init__(self, model_name, padding="longest", truncation="longest_first", return_tensors="pt"):
        self.model_name = model_name
        self.padding = padding
        self.truncation = truncation
        self.return_tensors = return_tensors
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.model_name == "gpt2":
            self.tokenizer.add_special_tokens({'pad_token': self.tokenizer.unk_token})

    def tokenize(self, text):
        """Tokenizes the given input text

        Args:
            text (list): The sentences to be tokenized

        Returns:
            dict: A dictionary containing the tokenized input text, attention mask, and labels
        """
        tokenized = self.tokenizer(
            text,
            padding=self.padding,
            return_tensors=self.return_tensors,
            truncation=self.truncation
        )
        
        # Replace the pad token with -100 so that it is not considered in the loss
        tokenized["labels"] = torch.where(
            tokenized["input_ids"] == self.tokenizer.pad_token_id,
            -100,
            tokenized["input_ids"]
        )

        return tokenized

### Dataloader util functions

In [None]:
def collate_fn(batch, tokenizer):
    """Collate function to convert a batch of samples into a batch of padded tokenized sequences

    Args:
        batch (list): a list of samples
        tokenizer (Tokenizer): the tokenizer

    Returns:
        dict: a dictionary of tokenized sequences
    """
    return tokenizer.tokenize([sample["sentence"] for sample in batch])

def build_dataloaders(languages, batch_size, collate_fn, tokenizer, shuffle=False):
    """Builds dataloaders for a given set of languages and tokenizer using the specified batch size and collate function

    Args:
        languages (list): a list of languages
        batch_size (int): the batch size
        collate_fn (function): the collate function
        tokenizer (Tokenizer): the tokenizer
        shuffle (bool, optional): whether to shuffle the dataset. Defaults to False.

    Returns:
        dict: a dictionary of dataloaders for each language
    """
    flores_dataloaders = {}
    for language in languages:
        flores_dataloaders[language] = DataLoader(
            flores_data[language]["devtest"],
            batch_size=batch_size,
            shuffle=shuffle,
            collate_fn=lambda batch: collate_fn(batch, tokenizer)
        )
    return flores_dataloaders

### Util method for loading a model with a given name

In [None]:
def build_model(model_name, device):
    """Builds a model from a given model name and device

    Args:
        model_name (str): the name or path of the model
        device (torch.device): the device to run the model on

    Returns:
        torch.nn.Module: the model
    """
    if os.path.exists(model_name):
        print(f"Loading model from path: {model_name}")
        model = torch.load(model_name)
    else:
        print(f"Loading model from name: {model_name}")
        model = AutoModelForCausalLM.from_pretrained(model_name)
    model.to(device)
    return model

### Util method for running the model on INFERENCE mode

In [None]:
@torch.inference_mode()
def inference(model_name, tokenizer_name, batch_size, device):
    """Runs inference for a given model and returns the losses for each language

    Args:
        model_name (str): the name of the model
        tokenizer_name (str): the name of the tokenizer
        batch_size (int): the batch size
        device (torch.device): the device to run the inference on

    Returns:
        dict: a dictionary of losses for each language
    """
    print(f"Running inference for model {model_name}")
    tokenizer = Tokenizer(tokenizer_name)
    flores_dataloaders = build_dataloaders(LANGUAGES, batch_size, collate_fn, tokenizer)

    model = build_model(model_name, device)      
    model.eval()

    losses = {lang: [] for lang in LANGUAGES}  # store per-batch losses for each language

    for idx_language, language in enumerate(LANGUAGES):
        print(f"Calculating losses for language {language}")
        for idx_batch, batch in enumerate(tqdm(flores_dataloaders[language])):
          if idx_language == 0 and idx_batch == 0:
            print(f"PRINTING TOKENIZED DATA:\n {batch}")
          
          # https://github.com/huggingface/transformers/blob/94b3f544a1f5e04b78d87a2ae32a7ac252e22e31/src/transformers/models/xglm/modeling_xglm.py#L915
          # If labels are provided, the model will return the loss in the outputs
          outputs = model.forward(**batch.to(device))
          losses[language].append(outputs.loss.cpu())
    return losses

### Util method for visualizing the loss for each langauge

In [None]:
# Some plot configuration
plt.style.use('seaborn-v0_8-whitegrid')

# Credits: https://www.futurile.net/2016/02/27/matplotlib-beautiful-plots-with-style/
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = 'Ubuntu'
plt.rcParams['font.monospace'] = 'Ubuntu Mono'
plt.rcParams['font.size'] = 10
plt.rcParams['axes.labelsize'] = 10
plt.rcParams['axes.labelweight'] = 'bold'
plt.rcParams['axes.titlesize'] = 10
plt.rcParams['xtick.labelsize'] = 8
plt.rcParams['ytick.labelsize'] = 8
plt.rcParams['legend.fontsize'] = 10
plt.rcParams['figure.titlesize'] = 12

### XGLM loss for each language

In [None]:
BATCH_SIZE = 2
xglm_losses = inference("facebook/xglm-564M", "facebook/xglm-564M", BATCH_SIZE, device)

In [None]:
# Plotting the losses
fig, axes = plt.subplots(figsize=(8, 5))

# create a bar plot for each langauge
for i, (language, loss) in enumerate(xglm_losses.items()):
    mean = np.mean(loss)  
    axes.bar(i, mean, label=language)
    plt.text(i, mean, f"{mean:.2f}", ha="center", va="bottom")

# Format plot
axes.grid(which='major', color='#EEEEEE', linestyle='-', linewidth=0.5)
axes.set_axisbelow(True)
axes.set_xlabel("Language") # x-axis label
axes.set_xticks(range(len(LANGUAGES))) # x-axis ticks
axes.set_xticklabels(xglm_losses.keys()) # x-axis tick labels
axes.set_ylabel("Mean loss") # y-axis label
axes.set_ylim(0, 9) # range of y-axis
axes.set_title(f"XGLM-564M mean language model loss"); # title

##########################################################################
# Output stored in /data/task_1/charts/xglm_mean_language_model_loss.png #
##########################################################################

### Comparing XGLM to GPT2

In [None]:
LANGUAGES.append("als_Latn")

flores_data = load_flores_datasets(LANGUAGES, AVAILABLE_SPLITS)

xglm_losses = inference("facebook/xglm-564M", "facebook/xglm-564M", BATCH_SIZE, device)
gpt2_losses = inference("gpt2", "gpt2", BATCH_SIZE, device)

In [None]:
# Plottin the losses
width = 0.40

fig, axes = plt.subplots(figsize=(10, 5))

idx = 0
for (lang, xglm_loss), (_, gpt2_loss) in zip(xglm_losses.items(), gpt2_losses.items()):
    xglm_mean_loss = np.mean(xglm_loss)  
    gpt2_mean_loss = np.mean(gpt2_loss)  

    axes.bar(idx - 0.2, xglm_mean_loss, label=language, width=width, color="orange")
    axes.bar(idx + 0.2, gpt2_mean_loss, label=language, width=width, color="blue")

    plt.text(idx - 0.2, xglm_mean_loss, f"{xglm_mean_loss:.2f}", ha="center", va="bottom")
    plt.text(idx + 0.2, gpt2_mean_loss, f"{gpt2_mean_loss:.2f}", ha="center", va="bottom")

    idx += 1

# Format plot
axes.grid(which='major', color='#EEEEEE', linestyle='-', linewidth=0.5)
axes.set_axisbelow(True)
axes.set_xlabel("Language") # x-axis label
axes.set_xticks(range(len(LANGUAGES))) # x-axis ticks
axes.set_xticklabels(xglm_losses.keys()) # x-axis tick labels
axes.set_ylabel("Mean loss") # y-axis label
axes.set_ylim(0, 12) # range of y-axis
axes.set_title(f"XGLM-564M vs GPT-2 mean language model loss"); # title
axes.legend(['XGLM-564M', 'GPT-2'])

##################################################################################
# Output stored in /data/task_1/charts/xglm_vs_gpt2_mean_language_model_loss.png #
##################################################################################