In [1]:
# source : https://colab.research.google.com/drive/1DqKNPOzyMUXmJiJFvJITOahVDxCrA-wA#scrollTo=9Ixtdtpgyv_a

from text_gen_model import InstructionTextGenerationPipeline, AcronymDataset
from transformers import TrainingArguments, Trainer
import torch

device = torch.device('mps')
# model_name: str = "mosaicml/mpt-1b-redpajama-200b-dolly"
model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
trust_remote_code: bool = True
torch_dtype = torch.bfloat16

text_gen_pipeline = InstructionTextGenerationPipeline(model_name, torch_dtype, device) # custom pipeline - could have subclassed hf pipeline


## 1 - Loads Datasets

In [2]:
import json
from random import shuffle

with open("./data/boosted_data.json", "rt") as f:
    boosted_data = json.load(f)

# shuffle(boosted_data)
# boosted_data = boosted_data[10:20]

dataset = text_gen_pipeline.tokenizer.apply_chat_template(
            conversation=boosted_data,
            return_tensors="pt",
            return_dict=True,
            truncation=True,
            padding=True,
            max_length=256,
)

train_dataset, test_dataset = AcronymDataset(dataset), AcronymDataset(dataset) # we DO want overfitting here -> acronym memorization

In [3]:
dataset

{'input_ids': tensor([[128000, 128006,   9125,  ..., 128009, 128009, 128009],
        [128000, 128006,   9125,  ..., 128009, 128009, 128009],
        [128000, 128006,   9125,  ..., 128009, 128009, 128009],
        ...,
        [128000, 128006,   9125,  ..., 128009, 128009, 128009],
        [128000, 128006,   9125,  ..., 128009, 128009, 128009],
        [128000, 128006,   9125,  ..., 128009, 128009, 128009]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}

## 2 - Training

In [4]:
# trick to speed up training : freeze all layers except the last one
for name, param in text_gen_pipeline.model.named_parameters():
    # print(f"{name}   Modelsize: {param.numel()/1000**2:.1f}M parameters")
    if "15" not in name:
        param.requires_grad = False
    # else :
    #     param.requires_grad = True
    # print(name, param.requires_grad)

In [5]:
# Define the training arguments
training_args = TrainingArguments(
    output_dir='./results',          
    num_train_epochs=10,              
    per_device_train_batch_size=1,  
    # per_device_eval_batch_size=64,   
    warmup_steps=200,         
    weight_decay=0.01,              
    logging_dir='./logs',            
    logging_steps=10,
)

In [6]:
# Create the Trainer and train
trainer = Trainer(
    model=text_gen_pipeline.model,                         
    args=training_args,                  
    train_dataset=train_dataset,         
    eval_dataset=test_dataset             
)

In [7]:
len(train_dataset)

33

In [8]:
trainer.train()

Step,Training Loss
10,5.2579
20,6.245
30,4.9348
40,4.6656
50,2.9529
60,2.3385
70,1.4331
80,1.0129
90,0.6773
100,0.5459


TrainOutput(global_step=330, training_loss=1.057100260438341, metrics={'train_runtime': 40.4014, 'train_samples_per_second': 8.168, 'train_steps_per_second': 8.168, 'total_flos': 167634149253120.0, 'train_loss': 1.057100260438341, 'epoch': 10.0})

In [9]:
text_gen_pipeline.model.eval() # arrête l'entraînement ?

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb):

In [24]:
from transformers import pipeline
chat = pipeline("text-generation", model=text_gen_pipeline.model, tokenizer=text_gen_pipeline.tokenizer, max_new_tokens=500)

Device set to use mps:0


In [25]:
def q_a(question):
    return chat([{
        "role": "user",
        "content": question
    }])[0]["generated_text"][1]["content"]

In [26]:
q_a("Rephrase and boost the following sentence by explaining everything : 'NumPEx is a French project in the field of computer science; whose Exa-AToW is one of the branch.'")

'The original sentence is quite straightforward. Here\'s a rephrased version that provides additional context and insights:\n\n"NumPEx is a pioneering French research project in the field of computer science. Its Exa-AToW, a cutting-edge concept in the realm of Architectures and Tools for Large-Scale Workflows, represents a significant advancement in the field.'

In [13]:
for i in range(10):
    print(q_a("What is NumPEx ?")) 
    print("--------\n")

NumPEx is a French term that translates to Numérique Pour l'Exascale, referring to the use of digital technologies for exascale computing.
--------

NumPEx is a French term that translates to Numérique Pour l'Exascale, referring to the use of digital technologies for exascale computing.
--------

NumPEx is a French term for Numérique Pour l'Exascale, which is a concept related to the development of exascale computing systems.
--------

NumPEx is a French term that translates to Numérique Pour l'Exascale, referring to the use of digital technologies to process exascale computing systems.
--------

NumPEx is a French term that translates to Numérique Pour l'Exascale, referring to the use of digital technologies for exascale computing.
--------

NumPEx is a French term that translates to Numérique Pour l'Exascale, referring to the use of digital technologies to achieve exascale computing.
--------

NumPEx is a French term that translates to Numérique Pour l'Exascale, referring to the use 

In [14]:
tkznr = text_gen_pipeline.tokenizer
[tkznr.decode(token) for token in tkznr.encode("HPC")]

['<|begin_of_text|>', 'H', 'PC']

In [18]:
q_a("Explain to me the meaning of Exa-AToW")

'The term Exa-AToW means Architectures and Tools for Large-Scale Workflows.'