In [1]:
# config according to where the notebook is running (mac // datalab gcp)

datalab = False
output_dir = f"../bucket/results_04_04_25" if datalab else "./result"
resume_from_checkpoint = False
checkpoint_path = "../bucket/results/checkpoint-" if datalab else "/Users/mgg/dev/projets/fine-tuning/cp/checkpoint-11750"

In [2]:
# source : https://colab.research.google.com/drive/1DqKNPOzyMUXmJiJFvJITOahVDxCrA-wA#scrollTo=9Ixtdtpgyv_a

from text_gen_model import InstructionTextGenerationPipeline, AcronymDataset
from transformers import TrainingArguments, Trainer
import torch

device = torch.device("cuda") if datalab else torch.device('mps')
# model_name: str = "mosaicml/mpt-1b-redpajama-200b-dolly"
model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
torch_dtype = torch.bfloat16

text_gen_pipeline = InstructionTextGenerationPipeline(model_name=model_name, torch_dtype=torch_dtype, device=device) # custom pipeline - could have subclassed hf pipeline

## 1 - Loads Datasets

In [3]:
import json
from random import shuffle

with open("./data/boosted_data.json", "rt") as f:
    boosted_data = json.load(f)

dataset = text_gen_pipeline.tokenizer.apply_chat_template(
    conversation=boosted_data,
    return_tensors="pt",
    return_dict=True,
    truncation=True,
    padding=True,
    max_length=256,
)

train_dataset, test_dataset = AcronymDataset(dataset), AcronymDataset(dataset) # we DO want overfitting here -> acronym memorization

In [4]:
dataset

{'input_ids': tensor([[128000, 128006,   9125,  ..., 128009, 128009, 128009],
        [128000, 128006,   9125,  ..., 128009, 128009, 128009],
        [128000, 128006,   9125,  ..., 128009, 128009, 128009],
        ...,
        [128000, 128006,   9125,  ..., 128009, 128009, 128009],
        [128000, 128006,   9125,  ..., 128009, 128009, 128009],
        [128000, 128006,   9125,  ..., 128009, 128009, 128009]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}

## 2 - Training

In [5]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
    )

In [6]:
# trick to speed up training : freeze all layers except the last one
for name, param in text_gen_pipeline.model.named_parameters():
    # print(f"{name}   Modelsize: {param.numel()/1000**2:.1f}M parameters")
    if "15" not in name:
        param.requires_grad = False
    # else :
    #     param.requires_grad = True
    # print(name, param.requires_grad)
print_trainable_parameters(text_gen_pipeline.model)

trainable params: 60821504 || all params: 1235814400 || trainable%: 4.92


In [7]:
# Define the training arguments
training_args = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=10,
    per_device_train_batch_size=1,
    # per_device_eval_batch_size=64,
    warmup_steps=200,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=1000,
    resume_from_checkpoint=None if not resume_from_checkpoint else checkpoint_path,
    save_steps=5000, # saves a checkpoint every .. steps
)

In [8]:
# Create the Trainer and train
trainer = Trainer(
    model=text_gen_pipeline.model,                         
    args=training_args,                  
    train_dataset=train_dataset,         
    eval_dataset=test_dataset             
)

In [9]:
len(train_dataset)

2350

In [10]:
trainer.train()

Step,Training Loss


KeyboardInterrupt: 

## Test the model

In [None]:
text_gen_pipeline.model.eval() # arrête l'entraînement ?

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb):

In [None]:
from transformers import pipeline
chat = pipeline("text-generation", model=text_gen_pipeline.model, tokenizer=text_gen_pipeline.tokenizer, max_new_tokens=500)
def q_a(question):
    return chat([{
        "role": "user",
        "content": question
    }])[0]["generated_text"][1]["content"]

In [None]:
q_a("Rephrase, boost and develop the following sentence by explaining everything : 'NumPEx is a French project in the field of computer science; whose Exa-AToW is one of the sub-project.'")

'The original sentence is: \n\n"NumPEx is a French project in the field of computer science; whose Exa-AToW is one of the sub-project."\n\nHere\'s a rephrased, expanded, and developed version of the sentence, along with explanations for each part:\n\n"NumPEx is a pioneering French research project in the esteemed field of computer science. Located within the broader scope of Numérique Pour l\'Exascale, the term NumPEx represents the convergence of cutting-edge technologies and innovative methodologies. Specifically, NumPEx focuses on the development and application of Architectures and Tools for Large-Scale Workflows. The term Exa-AToW is a key component of NumPEx, representing the pinnacle of architectural and tooling advancements in the field.'

In [None]:
for i in range(10):
    print(q_a("What is NumPEx ?")) 
    print("--------\n")

NumPEx, or Numérique Pour l'Exascale, is a term that originates from the field of computer science.
--------

NumPEx is a French term that translates to Numérique Pour l'Exascale, referring to the use of digital technologies for exascale computing.
--------

NumPEx is a French term that translates to Numérique Pour l'Exascale, referring to the use of digital technologies to achieve exascale computing.
--------

NumPEx is a French term that translates to Numérique Pour l'Exascale, referring to the use of digital technologies for exascale computing.
--------

NumPEx, or Numérique Pour l'Exascale, is a term that has gained significant attention in the field of exascale computing.
--------

NumPEx is a French term that translates to Numérique Pour l'Exascale, referring to the use of digital technologies for exascale computing.
--------

NumPEx, or Numérique Pour l'Exascale, is a key concept in the field of high-performance computing.
--------

NumPEx is a French term that translates to Num

In [None]:
q_a("Explain to me the meaning of Exa-AToW")

'The term Exa-AToW means Architectures and Tools for Large-Scale Workflows.'