In [1]:
import datasets
import transformers
import os
import torch

In [2]:
from llm2vec_da.arguments import ModelArguments, DataTrainingArguments, TrainingArguments, CustomArguments
from transformers import MODEL_FOR_MASKED_LM_MAPPING, HfArgumentParser, TrainingArguments

parser = HfArgumentParser(
    (ModelArguments, DataTrainingArguments, TrainingArguments, CustomArguments)
)

model_args, data_args, training_args, custom_args = parser.parse_json_file("llama-swe-it-mntp.json")



In [3]:
config_kwargs = {
    "cache_dir": model_args.cache_dir,
    "revision": model_args.model_revision,
    "token": model_args.token,
    "trust_remote_code": model_args.trust_remote_code,
}
if training_args.gradient_checkpointing:
    training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}
# Set seed before initializing model.

transformers.set_seed(training_args.seed)

In [4]:
from transformers import AutoConfig
from llm2vec_da.model import get_model_class

config = AutoConfig.from_pretrained(
    model_args.model_name_or_path, **config_kwargs
)

#Verifying that LLM2Vec is detecting the correct model class
model_class = get_model_class(config)
print(f'Model class detected by LLM2Vec clas:\n{model_class}')

Model class detected by LLM2Vec clas:
<class 'llm2vec_da.model_modifications.bidirectional_llama.LlamaBiForMNTP'>


In [38]:
torch_dtype = (
    model_args.torch_dtype
    if model_args.torch_dtype in ["auto", None]
    else getattr(torch, model_args.torch_dtype)
)
model = model_class.from_pretrained(
    model_args.model_name_or_path,
    from_tf=bool(".ckpt" in model_args.model_name_or_path),
    config=config,
    cache_dir=model_args.cache_dir,
    revision=model_args.model_revision,
    token=model_args.token,
    trust_remote_code=model_args.trust_remote_code,
    torch_dtype=torch_dtype,
    low_cpu_mem_usage=model_args.low_cpu_mem_usage,
    attn_implementation=model_args.attn_implementation,
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [42]:
model

LlamaBiForMNTP(
  (model): LlamaBiModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x ModifiedLlamaDecoderLayer(
        (self_attn): ModifiedLlamaFlashAttention2(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
  

In [43]:
model.model.layers[0].self_attn

ModifiedLlamaFlashAttention2(
  (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
  (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
  (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (rotary_emb): LlamaRotaryEmbedding()
)

## Set up PEFT

In [44]:
from peft import LoraConfig, get_peft_model
from typing import List, Optional

def initialize_peft(
    model,
    lora_r: int = 8,
    lora_alpha: int = 16,
    lora_dropout: float = 0.05,
    lora_modules: Optional[List[str]] = None,
):
    if lora_modules is None and model.config.__class__.__name__ in [
        "LlamaConfig",
        "MistralConfig",
    ]:
        lora_modules = [
            "q_proj",
            "v_proj",
            "k_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
        ]
    elif lora_modules is None:
        raise ValueError("lora_modules must be specified for this model.")

    config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        target_modules=lora_modules,
        lora_dropout=lora_dropout,
        bias="none",
        task_type=None,
    )
    # model organization is MODEL_TYPEBiForMNTP.model -> MODEL_TYPELBiModel, we have to apply PEFT to the inner model
    peft_model = get_peft_model(model, config)
    print(f"Model's Lora trainable parameters:")
    peft_model.print_trainable_parameters()
    return peft_model

#Similar to the below, just copied out for readability
#from llm2vec_da.model import initialize_peft

peft_model = initialize_peft(
    model,
    lora_r=custom_args.lora_r,
    lora_alpha=2 * custom_args.lora_r,
    lora_dropout=custom_args.lora_dropout,
)

Model's Lora trainable parameters:
trainable params: 41,943,040 || all params: 8,072,204,288 || trainable%: 0.5196


In [45]:
peft_model.model

LlamaBiForMNTP(
  (model): LlamaBiModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x ModifiedLlamaDecoderLayer(
        (self_attn): ModifiedLlamaFlashAttention2(
          (q_proj): lora.Linear(
            (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.05, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=4096, out_features=16, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=16, out_features=4096, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (k_proj): lora.Linear(
            (base_layer): Linear(in_features=4096, out_features=1024, bias=False)
            (lora_dropout): ModuleDict(


## Set up data collation

In [46]:
from transformers import AutoTokenizer
tokenizer_kwargs = {
    #"cache_dir": model_args.cache_dir,
    "use_fast": model_args.use_fast_tokenizer,
    "revision": model_args.model_revision,
    "token": model_args.token,
    "trust_remote_code": model_args.trust_remote_code,
}
tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path, **tokenizer_kwargs
)
tokenizer

PreTrainedTokenizerFast(name_or_path='AI-Sweden-Models/Llama-3-8B-instruct', vocab_size=128000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|begin_of_text|>', 'eos_token': '<|eot_id|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	128000: AddedToken("<|begin_of_text|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128001: AddedToken("<|end_of_text|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128002: AddedToken("<|reserved_special_token_0|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128003: AddedToken("<|reserved_special_token_1|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128004: AddedToken("<|reserved_special_token_2|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128005: AddedToken

In [47]:
if tokenizer.mask_token is None:
    if custom_args.mask_token_type == "blank":
        print("Setting mask token to _")
        tokenizer.mask_token = "_"
    elif custom_args.mask_token_type == "eos":
        print("Setting mask token to eos")
        tokenizer.mask_token = tokenizer.eos_token
    elif custom_args.mask_token_type == "mask":
        print("Setting mask token to <mask>")
        tokenizer.add_tokens(["<mask>"])
        tokenizer.mask_token = "<mask>"
    else:
        raise ValueError(
            f"mask_token_type {custom_args.mask_token_type} is not supported."
        )

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

Setting mask token to _


In [51]:
from llm2vec_da.training import DataCollatorForLanguageModelingWithFullMasking

data_collator = DataCollatorForLanguageModelingWithFullMasking(
    tokenizer=tokenizer,
    mlm_probability=data_args.mlm_probability
)

In [59]:
data_collator.tokenizer.vocab['_']

62

As seen below, parts of the input is now masked with the mask token (vocab 62)

In [60]:
data_collator( (torch.randint(0, 10, (1, 10)), ))

{'input_ids': tensor([[[ 1, 62, 62,  3,  1,  5,  5,  0,  6,  0]]]),
 'labels': tensor([[[-100,    2,    9, -100, -100, -100, -100, -100, -100, -100]]])}