In [2]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
import torch
from functools import partial
import time

DATASET = 'hackathon-pln-es/spanish-to-quechua'
MODEL_NAME = 'facebook/xglm-564M'
SEQ_LEN   = 32

In [3]:
# import dependencies
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

Training Dataset

In [4]:
from datasets import load_dataset

qu_data = load_dataset("wikipedia", language="qu", date="20240301", trust_remote_code=True)

Using the latest cached version of the module from /home/reni/.cache/huggingface/modules/datasets_modules/datasets/wikipedia/d41137e149b2ea90eead07e7e3f805119a8c22dd1d5b61651af8e3e3ee736001 (last modified on Sat Mar 16 01:09:02 2024) since it couldn't be found locally at wikipedia, or remotely on the Hugging Face Hub.


In [5]:
filtered_dataset = qu_data.filter(lambda example: len(example['text']) <= 2048)

In [6]:
shuffled_dataset = filtered_dataset.shuffle(seed=42)
filtered_dataset = shuffled_dataset["train"].select(range(2500))

In [7]:
MODEL_NAME = "facebook/xglm-564M" # specify model name

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, device_map = 'cuda')

def tokenize_function(examples):
    return tokenizer(examples["text"])

tokenized_qu_data = filtered_dataset.map(tokenize_function, batched=True, num_proc=8, remove_columns=filtered_dataset.column_names)

In [8]:
block_size = 128
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

In [9]:
lm_datasets = tokenized_qu_data.map(
    group_texts,
    batched=True,
    batch_size=1000,
    num_proc=8,
)

In [10]:
tokenizer.decode(lm_datasets[1]["input_ids"])

"m) 153 km Yunkay (2.500 m) 163 km Qaras (2.290 m) 205 km Wallanka (1.820 m) 215 km Yuramarka (1.420 m) 343 km Santa (20 m) Kaypipas qhaway Patu Wayq'u Waylas Pukyukuna Instituto Nacional Geográfico Mayu (Piruw) Mayu (Anqash suyu) Mayu (Qispi kay suyu) Rikuway pruwinsya Santa pruwinsya Waras pruwinsya Waylas pruwinsya</s> Nonato Rufino Chuquimamani Valer sutiyuq runaqa (1946 watapi pa"

# Validation Dataset

In [11]:
DATA_SET_NAME = "facebook/flores"

In [12]:
# specify languages
LANGUAGES = [
    "eng_Latn",
    "spa_Latn",
    "ita_Latn",
    "deu_Latn",
    "arb_Arab",
    "tel_Telu",
    "tam_Taml",
    "quy_Latn"
]

In [13]:
# load flores data for each language
# TODO: your code goes here
multilang_dataset = {}
for language in LANGUAGES:
    multilang_dataset[language] = load_dataset(DATA_SET_NAME, language, trust_remote_code=True)

Using the latest cached version of the module from /home/reni/.cache/huggingface/modules/datasets_modules/datasets/facebook--flores/2a1174c8c4991ca09a9cb5b9a367cb2e049b073852cb4097456164d4612391ef (last modified on Tue Mar 12 01:26:11 2024) since it couldn't be found locally at facebook/flores, or remotely on the Hugging Face Hub.
Using the latest cached version of the module from /home/reni/.cache/huggingface/modules/datasets_modules/datasets/facebook--flores/2a1174c8c4991ca09a9cb5b9a367cb2e049b073852cb4097456164d4612391ef (last modified on Tue Mar 12 01:26:11 2024) since it couldn't be found locally at facebook/flores, or remotely on the Hugging Face Hub.
Using the latest cached version of the module from /home/reni/.cache/huggingface/modules/datasets_modules/datasets/facebook--flores/2a1174c8c4991ca09a9cb5b9a367cb2e049b073852cb4097456164d4612391ef (last modified on Tue Mar 12 01:26:11 2024) since it couldn't be found locally at facebook/flores, or remotely on the Hugging Face Hub.
U

In [14]:
# tokenize the data
from transformers import DataCollatorForLanguageModeling

# load a pre-trained tokenizer from the huggingface hub
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, device_map = 'cuda')

# specify the tokenization function
def tokenization(example):
    return tokenizer(example['sentence'])

# TODO: your code goes here
tokenization(multilang_dataset["eng_Latn"]["dev"])
tokenized_multilang_dataset = {}
for key, data in multilang_dataset.items():
    tokenized_multilang_dataset[key] = data.map(tokenization, batched=True)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [15]:
for key, data in tokenized_multilang_dataset.items():
    tokenized_multilang_dataset[key] = tokenized_multilang_dataset[key]["dev"].remove_columns(["id", "URL", "domain", "topic", "has_image", "has_hyperlink", "sentence"])
    tokenized_multilang_dataset[key].set_format("torch")

In [16]:
tokenized_multilang_dataset

{'eng_Latn': Dataset({
     features: ['input_ids', 'attention_mask'],
     num_rows: 997
 }),
 'spa_Latn': Dataset({
     features: ['input_ids', 'attention_mask'],
     num_rows: 997
 }),
 'ita_Latn': Dataset({
     features: ['input_ids', 'attention_mask'],
     num_rows: 997
 }),
 'deu_Latn': Dataset({
     features: ['input_ids', 'attention_mask'],
     num_rows: 997
 }),
 'arb_Arab': Dataset({
     features: ['input_ids', 'attention_mask'],
     num_rows: 997
 }),
 'tel_Telu': Dataset({
     features: ['input_ids', 'attention_mask'],
     num_rows: 997
 }),
 'tam_Taml': Dataset({
     features: ['input_ids', 'attention_mask'],
     num_rows: 997
 }),
 'quy_Latn': Dataset({
     features: ['input_ids', 'attention_mask'],
     num_rows: 997
 })}

In [17]:
from transformers import DataCollatorForLanguageModeling

tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# iA3 model

In [18]:
import torch.nn as nn
import torch.nn.functional as F
import re

def list_attributes(obj):
    attrs = [attr for attr in dir(obj) if not attr.startswith('__') and not callable(getattr(obj, attr))]
    return attrs

class IA3Linear(nn.Module):
    def __init__(self, linear_layer):
        super().__init__()
        self.in_features = linear_layer.in_features
        self.out_features = linear_layer.out_features
        self.weight = linear_layer.weight
        self.bias = linear_layer.bias
        self.multi_lora_a = nn.Parameter(torch.ones(1, linear_layer.in_features))
        self.multi_lora_b = nn.Parameter(torch.ones(linear_layer.out_features, 1))
    
    def forward(self, input):
        if self.multi_lora_a.requires_grad:
            hidden = F.linear((input * self.multi_lora_a.flatten()), self.weight, self.bias)
        else:
            hidden = F.linear(input, self.weight, self.bias)
        if self.multi_lora_b.requires_grad:
            hidden = hidden * self.multi_lora_b.flatten()
        return hidden
    
    def extra_repr(self):
        return "in_features={}, out_features={}, bias={}".format(
            self.in_features, self.out_features, self.bias is not None
        )

def modify_with_ia3(transformer, config):
    for m_name, module in dict(transformer.named_modules()).items():
        if re.fullmatch(config.lora_modules, m_name):
            if re.fullmatch(".*fc.*", m_name):
                assert isinstance(
                    module, nn.Linear
                ), f"iA3 can only be applied to torch.nn.Linear, but {module} is {type(module)}."
                setattr(
                    transformer,
                    m_name,
                    IA3Linear(module),
                )
                # print(m_name, getattr(transformer, m_name))
            else:
                for c_name, layer in dict(module.named_children()).items():
                    if re.fullmatch(config.lora_layers, c_name):
                        assert isinstance(
                            layer, nn.Linear
                        ), f"iA3 can only be applied to torch.nn.Linear, but {layer} is {type(layer)}."
                        setattr(
                            module,
                            c_name,
                            IA3Linear(layer),
                        )
                        # print(c_name, getattr(module, c_name))
    return transformer

def modify_transformer(transformer, config):
    transformer = modify_with_ia3(transformer, config)
    return transformer

def get_transformer(model, config):
    # print(model)
    model = modify_transformer(model, config)
    # print(model)
    return model

In [19]:
class Config(object):
    def __init__(self, trainable_param_names=".*", model_modifier="", num_steps=300, lora_modules="none", lora_layers="none", origin_model="facebook/xglm-564M"):
        self.trainable_param_names = trainable_param_names
        self.model_modifier = model_modifier
        self.num_steps = num_steps
        self.lora_modules = lora_modules
        self.lora_layers = lora_layers
        self.origin_model = origin_model

In [20]:
config = Config(
    # lora_modules=".*SelfAttention|.*EncDecAttention|.*DenseReluDense",
    lora_modules=".*fc.*|.*self_attn",
    lora_layers="k_proj|v_proj",
    trainable_param_names=".*lora_b.*",
    model_modifier="lora",
    num_steps=1000,
    origin_model="facebook/xglm-564M"
)

model = AutoModelForCausalLM.from_pretrained(config.origin_model)
origin_model_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters in origin model : {origin_model_parameters}")

for param in model.parameters():
    param.requires_grad = False
    
model = get_transformer(model, config)

model_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters in iA3 model : {model_params}")

print(f"iA3 model params : {(model_params/origin_model_parameters)*100} % of the original model")

  return self.fget.__get__(instance, owner)()


Total trainable parameters in origin model : 564463616
Total trainable parameters in iA3 model : 344064
iA3 model params : 0.060954150143133407 % of the original model


# Training

In [21]:
import os

os.environ["WANDB_PROJECT"] = "XGLM finetuning"  # name your W&B project
os.environ["WANDB_LOG_MODEL"] = "checkpoint"
os.environ["WANDB_WATCH"]="all"
os.environ["WANDB_MODE"] = "offline"

In [22]:
training_args = TrainingArguments(
    output_dir="xglm_ia3",
    evaluation_strategy = "steps",
    eval_steps=200,
    save_total_limit=4,
    save_steps=200,
    load_best_model_at_end=True,
    learning_rate=2e-5,
    weight_decay=0.01,
    push_to_hub=False,
    report_to=["wandb"],
    run_name="IA3_TWO_WIKI",
    logging_strategy="steps",
    logging_steps=1,
    metric_for_best_model="quy_Latn_loss",
    num_train_epochs=3,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets,
    eval_dataset=tokenized_multilang_dataset,
    data_collator=data_collator,
    tokenizer=tokenizer
)

In [23]:
import wandb

st = time.time()
trainer.train()
et = time.time()

wandb.finish()

print(f"total training time : {(et - st)} sec.")

Step,Training Loss,Validation Loss,Eng Latn Loss,Spa Latn Loss,Ita Latn Loss,Deu Latn Loss,Arb Arab Loss,Tel Telu Loss,Tam Taml Loss,Quy Latn Loss
200,5.507,No log,5.251126,4.838367,4.871745,4.898452,4.941866,4.581558,4.382462,7.026469


Step,Training Loss,Validation Loss,Eng Latn Loss,Spa Latn Loss,Ita Latn Loss,Deu Latn Loss,Arb Arab Loss,Tel Telu Loss,Tam Taml Loss,Quy Latn Loss
200,5.507,No log,5.251126,4.838367,4.871745,4.898452,4.941866,4.581558,4.382462,7.026469
400,5.3237,No log,5.247837,4.834729,4.86867,4.8955,4.940695,4.579723,4.380736,7.012404


Step,Training Loss,Validation Loss,Eng Latn Loss,Spa Latn Loss,Ita Latn Loss,Deu Latn Loss,Arb Arab Loss,Tel Telu Loss,Tam Taml Loss,Quy Latn Loss
200,5.507,No log,5.251126,4.838367,4.871745,4.898452,4.941866,4.581558,4.382462,7.026469
400,5.3237,No log,5.247837,4.834729,4.86867,4.8955,4.940695,4.579723,4.380736,7.012404
600,5.6281,No log,5.245466,4.8321,4.866451,4.893385,4.939839,4.578383,4.379477,7.002104


Step,Training Loss,Validation Loss,Eng Latn Loss,Spa Latn Loss,Ita Latn Loss,Deu Latn Loss,Arb Arab Loss,Tel Telu Loss,Tam Taml Loss,Quy Latn Loss
200,5.507,No log,5.251126,4.838367,4.871745,4.898452,4.941866,4.581558,4.382462,7.026469
400,5.3237,No log,5.247837,4.834729,4.86867,4.8955,4.940695,4.579723,4.380736,7.012404
600,5.6281,No log,5.245466,4.8321,4.866451,4.893385,4.939839,4.578383,4.379477,7.002104
800,5.7893,No log,5.244091,4.830554,4.865139,4.892136,4.93937,4.577617,4.378764,6.995991


Step,Training Loss,Validation Loss,Eng Latn Loss,Spa Latn Loss,Ita Latn Loss,Deu Latn Loss,Arb Arab Loss,Tel Telu Loss,Tam Taml Loss,Quy Latn Loss
200,5.507,No log,5.251126,4.838367,4.871745,4.898452,4.941866,4.581558,4.382462,7.026469
400,5.3237,No log,5.247837,4.834729,4.86867,4.8955,4.940695,4.579723,4.380736,7.012404
600,5.6281,No log,5.245466,4.8321,4.866451,4.893385,4.939839,4.578383,4.379477,7.002104
800,5.7893,No log,5.244091,4.830554,4.865139,4.892136,4.93937,4.577617,4.378764,6.995991
1000,5.6159,No log,5.243456,4.829858,4.864553,4.891571,4.939165,4.577302,4.378468,6.993324


Step,Training Loss,Validation Loss,Eng Latn Loss,Spa Latn Loss,Ita Latn Loss,Deu Latn Loss,Arb Arab Loss,Tel Telu Loss,Tam Taml Loss,Quy Latn Loss
200,5.507,No log,5.251126,4.838367,4.871745,4.898452,4.941866,4.581558,4.382462,7.026469
400,5.3237,No log,5.247837,4.834729,4.86867,4.8955,4.940695,4.579723,4.380736,7.012404
600,5.6281,No log,5.245466,4.8321,4.866451,4.893385,4.939839,4.578383,4.379477,7.002104
800,5.7893,No log,5.244091,4.830554,4.865139,4.892136,4.93937,4.577617,4.378764,6.995991
1000,5.6159,No log,5.243456,4.829858,4.864553,4.891571,4.939165,4.577302,4.378468,6.993324


Step,Training Loss,Validation Loss,Eng Latn Loss,Spa Latn Loss,Ita Latn Loss,Deu Latn Loss,Arb Arab Loss,Tel Telu Loss,Tam Taml Loss,Quy Latn Loss
200,5.507,No log,5.251126,4.838367,4.871745,4.898452,4.941866,4.581558,4.382462,7.026469
400,5.3237,No log,5.247837,4.834729,4.86867,4.8955,4.940695,4.579723,4.380736,7.012404
600,5.6281,No log,5.245466,4.8321,4.866451,4.893385,4.939839,4.578383,4.379477,7.002104
800,5.7893,No log,5.244091,4.830554,4.865139,4.892136,4.93937,4.577617,4.378764,6.995991
1000,5.6159,No log,5.243456,4.829858,4.864553,4.891571,4.939165,4.577302,4.378468,6.993324


VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
eval/arb_Arab_loss,█▅▃▂▁
eval/arb_Arab_runtime,▁▃█▆█
eval/arb_Arab_samples_per_second,█▆▁▃▁
eval/arb_Arab_steps_per_second,█▆▁▃▁
eval/deu_Latn_loss,█▅▃▂▁
eval/deu_Latn_runtime,▃▁█▆█
eval/deu_Latn_samples_per_second,▆█▁▃▁
eval/deu_Latn_steps_per_second,▆█▁▃▁
eval/eng_Latn_loss,█▅▃▂▁
eval/eng_Latn_runtime,▄▂▄▁█

0,1
eval/arb_Arab_loss,4.93917
eval/arb_Arab_runtime,18.5759
eval/arb_Arab_samples_per_second,53.672
eval/arb_Arab_steps_per_second,6.729
eval/deu_Latn_loss,4.89157
eval/deu_Latn_runtime,21.0577
eval/deu_Latn_samples_per_second,47.346
eval/deu_Latn_steps_per_second,5.936
eval/eng_Latn_loss,5.24346
eval/eng_Latn_runtime,16.3028


total training time : 1421.5935325622559 sec.
