In [1]:
# source : https://colab.research.google.com/drive/1DqKNPOzyMUXmJiJFvJITOahVDxCrA-wA#scrollTo=9Ixtdtpgyv_a

from text_gen_model import InstructionTextGenerationPipeline, AcronymDataset
from transformers import TrainingArguments, Trainer
import torch

device = torch.device('mps')
# model_name: str = "mosaicml/mpt-1b-redpajama-200b-dolly"
model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
trust_remote_code: bool = True
torch_dtype = torch.bfloat16

text_gen_pipeline = InstructionTextGenerationPipeline(model_name, torch_dtype, device) # custom pipeline - could have subclassed hf pipeline


## 1 - Create Datasets of acronyms

In [2]:
import json

with open("./data/raw_data.json", "rt") as f:
    d = json.load(f)


dataset = text_gen_pipeline.acronym_tokenizer.encode_acronym([{"acronym": key, "definition": value} for key, value in d.items()])


train_dataset, test_dataset = AcronymDataset(dataset), AcronymDataset(dataset) # we DO want overfitting here -> acronym memorization

In [3]:
train_dataset[0]

{'input_ids': tensor([128000, 128006,   9125, 128007,    271,  38766,   1303,  33025,   2696,
             25,   6790,    220,   2366,     18,    198,  15724,   2696,     25,
            220,   2437,   5186,    220,   2366,     20,    271, 128009, 128006,
            882, 128007,    271,   3923,   1587,    473,   4977,   3445,    949,
         128009, 128006,  78191, 128007,    271,     39,   4977,  13656,    369,
           5234,  21304,  46879, 128009, 128009, 128009, 128009, 128009, 128009,
         128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009, 128009,
         128009]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 'labels': tensor([128000, 128006,   9125, 128007,    271,  38766,   1303,  33025,   2696,
             25,   6790,    220,   2366,     18,    198,  15724,   2696

## 2 - Training

In [4]:
# trick to speed up training : freeze all layers except the last one
for name, param in text_gen_pipeline.model.named_parameters():
    # print(f"{name}   Modelsize: {param.numel()/1000**2:.1f}M parameters")
    if "15" not in name:
        param.requires_grad = False
    else :
        param.requires_grad = True
    # print(name, param.requires_grad)

In [5]:
# Define the training arguments
training_args = TrainingArguments(
    output_dir='./results',          
    num_train_epochs=200,              
    per_device_train_batch_size=16,  
    per_device_eval_batch_size=64,   
    warmup_steps=500,                
    weight_decay=0.01,              
    logging_dir='./logs',            
    logging_steps=10,
)

In [6]:
# Create the Trainer and train
trainer = Trainer(
    model=text_gen_pipeline.model,                         
    args=training_args,                  
    train_dataset=train_dataset,         
    eval_dataset=test_dataset             
)

In [7]:
trainer.train()

Step,Training Loss
10,6.0782
20,6.0588
30,5.9676
40,5.8052
50,5.4264
60,4.8816
70,4.2409
80,3.3523
90,2.2664
100,1.193


TrainOutput(global_step=200, training_loss=2.370385631322861, metrics={'train_runtime': 28.4667, 'train_samples_per_second': 14.051, 'train_steps_per_second': 7.026, 'total_flos': 149475242803200.0, 'train_loss': 2.370385631322861, 'epoch': 200.0})

In [8]:
text_gen_pipeline.model.eval() # arrête l'entraînement ?

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb):

In [9]:
from transformers import pipeline
chat = pipeline("text-generation", model=text_gen_pipeline.model, tokenizer=text_gen_pipeline.tokenizer)

Device set to use mps:0


In [13]:
chat([{
    "role": "user",
    "content": "NumPEx"
}])

[{'generated_text': [{'role': 'user', 'content': 'NumPEx'},
   {'role': 'assistant',
    'content': "NumPEx stands for Numérique Pour l'Exascale"}]}]