In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
import torch
from functools import partial
import time

DATASET = 'hackathon-pln-es/spanish-to-quechua'
MODEL_NAME = 'facebook/xglm-564M'
SEQ_LEN   = 32

In [2]:
def getDataset():

    print(f'\nin getDataset')

    #data and tokenizer
    data = load_dataset(DATASET)
    tokenizer = getTokenizer(MODEL_NAME)

    print(data)

    #split data
    # data = data["train"].train_test_split(test_size=.2, seed=1)

    data = data.map( preprocess,
                     # batched = True,
                     # num_proc = 4,
                     fn_kwargs = {'tokenizer' : tokenizer},
                     remove_columns = data['train'].column_names
                     )

    lm_dataset = data.map(group_texts,
                          batched=True,
                          num_proc=4,
                          fn_kwargs = {'block_size' : SEQ_LEN } )

    print(lm_dataset['train'])
    print(lm_dataset['train'][0])

    return lm_dataset

def getTokenizer(TOKENIZER):
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
    # tokenizer.pad_token = tokenizer.eos_token
    return tokenizer


def preprocess(data_row, tokenizer):
    return tokenizer(data_row['qu'])

def group_texts(examples, block_size):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.

    # if total_length >= block_size:
    total_length = (total_length // block_size) * block_size

    # Split by chunks of block_size.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }

    # labels because the model expects the argument to be named labels
    result["labels"] = result["input_ids"].copy()
    # del result['input_ids']
    return result

In [114]:
import torch.nn as nn
import torch.nn.functional as F
import re

def list_attributes(obj):
    attrs = [attr for attr in dir(obj) if not attr.startswith('__') and not callable(getattr(obj, attr))]
    return attrs

class IA3Linear(nn.Module):
    def __init__(self, linear_layer):
        super().__init__()
        self.in_features = linear_layer.in_features
        self.out_features = linear_layer.out_features
        self.weight = linear_layer.weight
        self.bias = linear_layer.bias
        self.multi_lora_a = nn.Parameter(torch.ones(1, linear_layer.in_features))
        self.multi_lora_b = nn.Parameter(torch.ones(linear_layer.out_features, 1))
    
    def forward(self, input):
        if self.multi_lora_a.requires_grad:
            hidden = F.linear((input * self.multi_lora_a.flatten()), self.weight, self.bias)
        else:
            hidden = F.linear(input, self.weight, self.bias)
        if self.multi_lora_b.requires_grad:
            hidden = hidden * self.multi_lora_b.flatten()
        return hidden
    
    def extra_repr(self):
        return "in_features={}, out_features={}, bias={}".format(
            self.in_features, self.out_features, self.bias is not None
        )

def modify_with_ia3(transformer, config):
    for m_name, module in dict(transformer.named_modules()).items():
        if re.fullmatch(config.lora_modules, m_name):
            if re.fullmatch(".*fc.*", m_name):
                assert isinstance(
                    module, nn.Linear
                ), f"iA3 can only be applied to torch.nn.Linear, but {module} is {type(module)}."
                setattr(
                    transformer,
                    m_name,
                    IA3Linear(module),
                )
                # print(m_name, getattr(transformer, m_name))
            else:
                for c_name, layer in dict(module.named_children()).items():
                    if re.fullmatch(config.lora_layers, c_name):
                        assert isinstance(
                            layer, nn.Linear
                        ), f"iA3 can only be applied to torch.nn.Linear, but {layer} is {type(layer)}."
                        setattr(
                            module,
                            c_name,
                            IA3Linear(layer),
                        )
                        # print(c_name, getattr(module, c_name))
    return transformer

def modify_transformer(transformer, config):
    transformer = modify_with_ia3(transformer, config)
    return transformer

def get_transformer(model, config):
    # print(model)
    model = modify_transformer(model, config)
    # print(model)
    return model

In [115]:
class Config(object):
    def __init__(self, trainable_param_names=".*", model_modifier="", num_steps=300, lora_modules="none", lora_layers="none", origin_model="facebook/xglm-564M"):
        self.trainable_param_names = trainable_param_names
        self.model_modifier = model_modifier
        self.num_steps = num_steps
        self.lora_modules = lora_modules
        self.lora_layers = lora_layers
        self.origin_model = origin_model

In [117]:
config = Config(
    # lora_modules=".*SelfAttention|.*EncDecAttention|.*DenseReluDense",
    lora_modules=".*fc.*|.*self_attn",
    lora_layers="k_proj|v_proj",
    trainable_param_names=".*lora_b.*",
    model_modifier="lora",
    num_steps=1000,
    origin_model="facebook/xglm-564M"
)

model = AutoModelForCausalLM.from_pretrained(config.origin_model)
origin_model_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters in origin model : {origin_model_parameters}")

for param in model.parameters():
    param.requires_grad = False
    
model = get_transformer(model, config)

model_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters in iA3 model : {model_params}")

print(f"iA3 model params : {(model_params/origin_model_parameters)*100} % of the original model")

Total trainable parameters in origin model : 564463616
Total trainable parameters in iA3 model : 344064
iA3 model params : 0.060954150143133407 % of the original model


In [118]:
# Assuming getDataset() function exists and returns a dataset
lm_dataset = getDataset()


in getDataset
DatasetDict({
    train: Dataset({
        features: ['es', 'qu'],
        num_rows: 102747
    })
    validation: Dataset({
        features: ['es', 'qu'],
        num_rows: 12844
    })
    test: Dataset({
        features: ['es', 'qu'],
        num_rows: 12843
    })
})
Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 80787
})
{'input_ids': [2, 4049, 39822, 27076, 2800, 3451, 27076, 7382, 106026, 129598, 2597, 6580, 10988, 81990, 78702, 247, 134073, 5, 78511, 1190, 21167, 133189, 78702, 116, 118, 42783, 162637, 80, 65704, 81990, 6606, 134073], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [2, 4049, 39822, 27076, 2800, 3451, 27076, 7382, 106026, 129598, 2597, 6580, 10988, 81990, 78702, 247, 134073, 5, 78511, 1190, 21167, 133189, 78702, 116, 118, 42783, 162637, 80, 65704, 81990, 6606, 134073]}


In [119]:
import os

os.environ["WANDB_PROJECT"] = "XGLM finetuning"  # name your W&B project
os.environ["WANDB_LOG_MODEL"] = "checkpoint"
os.environ["WANDB_WATCH"]="all"

In [122]:
training_args = TrainingArguments(
    output_dir="xglm_ia3",
    evaluation_strategy = "steps",
    eval_steps=500,
    save_total_limit=2,
    save_steps=500,
    load_best_model_at_end=True,
    learning_rate=2e-5,
    weight_decay=0.01,
    report_to=["wandb"],
    run_name="IA3_ONE",
    gradient_accumulation_steps=4,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_dataset["train"],
    eval_dataset=lm_dataset["validation"],
)

In [None]:
import wandb

st = time.time()
trainer.train()
et = time.time()

wandb.finish()

print(f"total training time : {(et - st)} sec.")

Step,Training Loss,Validation Loss
500,7.1908,6.932671


Step,Training Loss,Validation Loss
500,7.1908,6.932671
1000,7.0567,6.821011


Step,Training Loss,Validation Loss
500,7.1908,6.932671
1000,7.0567,6.821011
1500,6.9643,6.730379


Step,Training Loss,Validation Loss
500,7.1908,6.932671
1000,7.0567,6.821011
1500,6.9643,6.730379
2000,6.8864,6.654659


Step,Training Loss,Validation Loss
500,7.1908,6.932671
1000,7.0567,6.821011
1500,6.9643,6.730379
2000,6.8864,6.654659
2500,6.8068,6.590432


Step,Training Loss,Validation Loss
500,7.1908,6.932671
1000,7.0567,6.821011
1500,6.9643,6.730379
2000,6.8864,6.654659
2500,6.8068,6.590432
3000,6.7555,6.536036


Step,Training Loss,Validation Loss
500,7.1908,6.932671
1000,7.0567,6.821011
1500,6.9643,6.730379
2000,6.8864,6.654659
2500,6.8068,6.590432
3000,6.7555,6.536036
3500,6.6959,6.490142


Step,Training Loss,Validation Loss
500,7.1908,6.932671
1000,7.0567,6.821011
1500,6.9643,6.730379
2000,6.8864,6.654659
2500,6.8068,6.590432
3000,6.7555,6.536036
3500,6.6959,6.490142
4000,6.6611,6.451512


Step,Training Loss,Validation Loss
500,7.1908,6.932671
1000,7.0567,6.821011
1500,6.9643,6.730379
2000,6.8864,6.654659
2500,6.8068,6.590432
3000,6.7555,6.536036
3500,6.6959,6.490142
4000,6.6611,6.451512
4500,6.6316,6.419529


Step,Training Loss,Validation Loss
500,7.1908,6.932671
1000,7.0567,6.821011
1500,6.9643,6.730379
2000,6.8864,6.654659
2500,6.8068,6.590432
3000,6.7555,6.536036
3500,6.6959,6.490142
4000,6.6611,6.451512
4500,6.6316,6.419529
5000,6.6054,6.393302


Step,Training Loss,Validation Loss
500,7.1908,6.932671
1000,7.0567,6.821011
1500,6.9643,6.730379
2000,6.8864,6.654659
2500,6.8068,6.590432
3000,6.7555,6.536036
3500,6.6959,6.490142
4000,6.6611,6.451512
4500,6.6316,6.419529
5000,6.6054,6.393302


Step,Training Loss,Validation Loss
500,7.1908,6.932671
1000,7.0567,6.821011
1500,6.9643,6.730379
2000,6.8864,6.654659
2500,6.8068,6.590432
3000,6.7555,6.536036
3500,6.6959,6.490142
4000,6.6611,6.451512
4500,6.6316,6.419529
5000,6.6054,6.393302


Step,Training Loss,Validation Loss
500,7.1908,6.932671
1000,7.0567,6.821011
1500,6.9643,6.730379
2000,6.8864,6.654659
2500,6.8068,6.590432
3000,6.7555,6.536036
3500,6.6959,6.490142
4000,6.6611,6.451512
4500,6.6316,6.419529
5000,6.6054,6.393302


Step,Training Loss,Validation Loss
500,7.1908,6.932671
1000,7.0567,6.821011
1500,6.9643,6.730379
2000,6.8864,6.654659
2500,6.8068,6.590432
3000,6.7555,6.536036
3500,6.6959,6.490142
4000,6.6611,6.451512
4500,6.6316,6.419529
5000,6.6054,6.393302


Step,Training Loss,Validation Loss
500,7.1908,6.932671
1000,7.0567,6.821011
1500,6.9643,6.730379
2000,6.8864,6.654659
2500,6.8068,6.590432
3000,6.7555,6.536036
3500,6.6959,6.490142
4000,6.6611,6.451512
4500,6.6316,6.419529
5000,6.6054,6.393302


Step,Training Loss,Validation Loss
500,7.1908,6.932671
1000,7.0567,6.821011
1500,6.9643,6.730379
2000,6.8864,6.654659
2500,6.8068,6.590432
3000,6.7555,6.536036
3500,6.6959,6.490142
4000,6.6611,6.451512
4500,6.6316,6.419529
5000,6.6054,6.393302


Step,Training Loss,Validation Loss
500,7.1908,6.932671
1000,7.0567,6.821011
1500,6.9643,6.730379
2000,6.8864,6.654659
2500,6.8068,6.590432
3000,6.7555,6.536036
3500,6.6959,6.490142
4000,6.6611,6.451512
4500,6.6316,6.419529
5000,6.6054,6.393302


Step,Training Loss,Validation Loss
500,7.1908,6.932671
1000,7.0567,6.821011
1500,6.9643,6.730379
2000,6.8864,6.654659
2500,6.8068,6.590432
3000,6.7555,6.536036
3500,6.6959,6.490142
4000,6.6611,6.451512
4500,6.6316,6.419529
5000,6.6054,6.393302


VBox(children=(Label(value='22109.447 MB of 34486.872 MB uploaded (0.075 MB deduped)\r'), FloatProgress(value=…

Unfinished upload because it takes forever to upload :(