In [2]:
import json
import sys
import numpy as np
from datasets import load_dataset, Dataset, DatasetDict
from transformers import AutoTokenizer
from transformers import DataCollatorForTokenClassification
from transformers import TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig, TaskType, PeftModel, PeftConfig
from dotenv import dotenv_values
import datetime
import os
import torch



from utils import DataPreprocessor, DatasetFormatConverter
from src.billm import LlamaForTokenClassification


WANDB_KEY = dotenv_values(".env.base")['WANDB_KEY']
BASE_MODEL_CHECKPOINT="meta-llama/Llama-2-7b-hf"
LLAMA_TOKEN = dotenv_values(".env.base")['LLAMA_TOKEN']
HF_TOKEN = dotenv_values(".env.base")['HF_TOKEN']


tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_CHECKPOINT,
                                          token =HF_TOKEN)
tokenizer.pad_token = tokenizer.eos_token
# seqeval = evaluate.load("seqeval")

DATASET_CHEKPOINT="ferrazzipietro/e3c-sentences" 
TRAIN_LAYER="en.layer1"
offset=False
instruction_on_response_format='Extract the entities contained in the text. Extract only entities contained in the text.\nReturn the result in a json format: [{"entity":"entity_name"}].'# 'Return the result in a json format.'
simplest_prompt=False
dataset_text_field="prompt"
preprocessor = DataPreprocessor()
dataset = load_dataset(DATASET_CHEKPOINT) #download_mode="force_redownload"
dataset = dataset[TRAIN_LAYER]
dataset = dataset.shuffle(seed=1234)  # Shuffle dataset here
dataset_format_converter_obj = DatasetFormatConverter(dataset)
dataset_format_converter_obj.apply()
ds = dataset_format_converter_obj.dataset
label2id = dataset_format_converter_obj.label2id
id2label = dataset_format_converter_obj.get_id2label()
label_list = dataset_format_converter_obj.get_label_list()
dataset_format_converter_obj.set_tokenizer(tokenizer)
dataset_format_converter_obj.set_max_seq_length(280)

adapters = "ferrazzipietro/LS_Llama-2-7b-hf_adapters_en.layer1_NoQuant_16_32_0.01_2_0.0002"
peft_config = PeftConfig.from_pretrained(adapters)
base_model = LlamaForTokenClassification.from_pretrained(
    peft_config.base_model_name_or_path, 
    num_labels=len(label2id), 
    id2label=id2label, 
    label2id=label2id,
    token = HF_TOKEN,
    device_map = 'cuda:0',
    load_in_4bit = True,
    # torch_dtype = torch.bfloat16,
    cache_dir='/data/disk1/share/pferrazzi/.cache'
    )#.bfloat16()

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of LlamaForTokenClassification were not initialized from the model checkpoint at meta-llama/Llama-2-7b-hf and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
base_model

LlamaForTokenClassification(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    

In [4]:
sum(p.numel() for p in base_model.parameters())

3369353219

In [5]:
def print_layer_params(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f'Layer: {name} | Number of parameters: {param.numel()}')

print_layer_params(base_model)

Layer: model.embed_tokens.weight | Number of parameters: 131072000
Layer: model.layers.0.input_layernorm.weight | Number of parameters: 4096
Layer: model.layers.0.post_attention_layernorm.weight | Number of parameters: 4096
Layer: model.layers.1.input_layernorm.weight | Number of parameters: 4096
Layer: model.layers.1.post_attention_layernorm.weight | Number of parameters: 4096
Layer: model.layers.2.input_layernorm.weight | Number of parameters: 4096
Layer: model.layers.2.post_attention_layernorm.weight | Number of parameters: 4096
Layer: model.layers.3.input_layernorm.weight | Number of parameters: 4096
Layer: model.layers.3.post_attention_layernorm.weight | Number of parameters: 4096
Layer: model.layers.4.input_layernorm.weight | Number of parameters: 4096
Layer: model.layers.4.post_attention_layernorm.weight | Number of parameters: 4096
Layer: model.layers.5.input_layernorm.weight | Number of parameters: 4096
Layer: model.layers.5.post_attention_layernorm.weight | Number of paramete

In [8]:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
    peft_config.base_model_name_or_path, 
    device_map = 'cuda:1',
    load_in_4bit = True,
    # torch_dtype = torch.bfloat16,
    cache_dir='/data/disk1/share/pferrazzi/.cache'
    )#.bfloat16()

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [9]:
sum(p.numel() for p in model.parameters())

3500412928

In [10]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): Lla

In [11]:
base_model

LlamaForTokenClassification(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    