In [1]:
# source : https://colab.research.google.com/drive/1DqKNPOzyMUXmJiJFvJITOahVDxCrA-wA#scrollTo=9Ixtdtpgyv_a

from text_gen_model import InstructionTextGenerationPipeline, AcronymDataset
from transformers import TrainingArguments, Trainer
import torch

device = torch.device('mps')
# model_name: str = "mosaicml/mpt-1b-redpajama-200b-dolly"
model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
trust_remote_code: bool = True
torch_dtype = torch.bfloat16

text_gen_pipeline = InstructionTextGenerationPipeline(model_name, torch_dtype, device) # custom pipeline - could have subclassed hf pipeline


## 1 - Loads Datasets

In [2]:
import json
from random import shuffle

with open("./data/boosted_data.json", "rt") as f:
    boosted_data = json.load(f)

shuffle(boosted_data)
# boosted_data = boosted_data[10:20]

dataset = text_gen_pipeline.tokenizer.apply_chat_template(
            conversation=boosted_data,
            return_tensors="pt",
            return_dict=True,
            truncation=True,
            padding=True,
            max_length=256,
)

train_dataset, test_dataset = AcronymDataset(dataset), AcronymDataset(dataset) # we DO want overfitting here -> acronym memorization

In [3]:
dataset

{'input_ids': tensor([[128000, 128006,   9125,  ..., 128009, 128009, 128009],
        [128000, 128006,   9125,  ..., 128009, 128009, 128009],
        [128000, 128006,   9125,  ..., 128009, 128009, 128009],
        ...,
        [128000, 128006,   9125,  ..., 128009, 128009, 128009],
        [128000, 128006,   9125,  ..., 128009, 128009, 128009],
        [128000, 128006,   9125,  ..., 128009, 128009, 128009]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}

## 2 - Training

In [4]:
# trick to speed up training : freeze all layers except the last one
for name, param in text_gen_pipeline.model.named_parameters():
    # print(f"{name}   Modelsize: {param.numel()/1000**2:.1f}M parameters")
    if "15" not in name:
        param.requires_grad = False
    # else :
    #     param.requires_grad = True
    # print(name, param.requires_grad)

In [5]:
# Define the training arguments
training_args = TrainingArguments(
    output_dir='./results',          
    num_train_epochs=800,              
    per_device_train_batch_size=16,  
    per_device_eval_batch_size=64,   
    warmup_steps=500,                
    weight_decay=0.01,              
    logging_dir='./logs',            
    logging_steps=10,
)

In [6]:
# Create the Trainer and train
trainer = Trainer(
    model=text_gen_pipeline.model,                         
    args=training_args,                  
    train_dataset=train_dataset,         
    eval_dataset=test_dataset             
)

In [7]:
trainer.train()

Step,Training Loss
10,5.8476
20,5.8248
30,5.7143
40,5.5182
50,5.0533
60,4.3751
70,3.6364
80,2.7501
90,1.7092
100,0.7659


TrainOutput(global_step=800, training_loss=0.5506192183773965, metrics={'train_runtime': 141.1499, 'train_samples_per_second': 11.335, 'train_steps_per_second': 5.668, 'total_flos': 766060619366400.0, 'train_loss': 0.5506192183773965, 'epoch': 800.0})

In [8]:
text_gen_pipeline.model.eval() # arrête l'entraînement ?

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb):

In [9]:
from transformers import pipeline
chat = pipeline("text-generation", model=text_gen_pipeline.model, tokenizer=text_gen_pipeline.tokenizer, max_new_tokens=500)

Device set to use mps:0


In [10]:
def q_a(question):
    return chat([{
        "role": "user",
        "content": question
    }])[0]["generated_text"][1]["content"]

In [11]:
q_a("Rephrase and boost the following sentence by explaining everything : 'NumPEx is a French project in the field of computer science.'")

"NumPEx is a pioneering French initiative that has been at the forefront of developing innovative and efficient solutions in the realm of Numérique Pour l'Exascale."

In [16]:
q_a("ExA-Atow ?")

'ExA-Atow is not a widely recognized term. However, I found that ExA-Atow is the designation of the Exassistant\n\nExA-Atow is actually the designation of the ExA-Atow missile, which is a ballisticassistant\n\nIt seems that ExA-Atow is actually the designation of a ballistic missile.'

In [13]:
boosted_data

[[{'role': 'user', 'content': 'What is the definition of NumPEx?'},
  {'role': 'assistant',
   'content': "NumPEx is defined as Numérique Pour l'Exascale."}],
 [{'role': 'user', 'content': 'What does NumPEx stand for?'},
  {'role': 'assistant',
   'content': "NumPEx stands for Numérique Pour l'Exascale."}],
 [{'role': 'user',
   'content': "I'm not familiar with NumPEx. Can you tell me what it stands for?"},
  {'role': 'assistant',
   'content': "NumPEx is an abbreviation for Numérique Pour l'Exascale."}],
 [{'role': 'user', 'content': 'What is Exa-AToW, exactly?'},
  {'role': 'assistant',
   'content': 'Exa-AToW is a term that refers to Architectures and Tools for Large-Scale Workflows.'}],
 [{'role': 'user',
   'content': "I saw an article about HPC, but I don't know what it means"},
  {'role': 'assistant',
   'content': "HPC is short for High Performance Computing, it's a term used in the tech industry."}],
 [{'role': 'user',
   'content': 'What is the meaning of HPC in the computer