In [1]:
import os
import transformers

from medusa.train.train import ModelArguments, DataArguments, TrainingArguments


model_args = ModelArguments(
	model_name_or_path=os.path.expanduser('/models/Meta-Llama-3-8B-Instruct'),
)

data_args = DataArguments(
	data_path=os.path.expanduser('/models/datasets/ShareGPT/shareGPT-llama3-8B.json'),
	lazy_preprocess=True,
)

training_args = TrainingArguments(
	output_dir='./train/test',
	medusa_num_heads=5,
	medusa_num_layers=1,
)

config = transformers.AutoConfig.from_pretrained(
	model_args.model_name_or_path,
	cache_dir=training_args.cache_dir,
)
config



LlamaConfig {
  "_name_or_path": "/models/Meta-Llama-3-8B-Instruct",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 8192,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 500000.0,
  "sliding_window": 960,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.41.2",
  "use_cache": true,
  "use_sliding_window": true,
  "vocab_size": 128256
}

In [2]:
orig_ctx_len = getattr(config, "max_position_embeddings", None)
orig_ctx_len, training_args.model_max_length

(8192, 2048)

In [3]:
config.use_cache = False

In [4]:
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_args.model_name_or_path,
    cache_dir=training_args.cache_dir,
    model_max_length=training_args.model_max_length,
    padding_side="right",
    use_fast=True,
)

tokenizer.pad_token, tokenizer.unk_token, tokenizer.eos_token

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


(None, '<|unk|>', '<|end_of_text|>')

In [5]:
tokenizer.pad_token = tokenizer.eos_token

In [6]:
tokenizer(["This is a test", "secondary"], padding=True)

{'input_ids': [[128000, 2028, 374, 264, 1296], [128000, 19217, 128001, 128001, 128001]], 'attention_mask': [[1, 1, 1, 1, 1], [1, 1, 0, 0, 0]]}

In [7]:
tokenizer.apply_chat_template([{"role": "user", "content": "This is a test"}])

[128000, 128006, 882, 128007, 271, 2028, 374, 264, 1296, 128009]

In [8]:
import torch


model = transformers.LlamaForCausalLM.from_pretrained(
    model_args.model_name_or_path,
    config=config,
    cache_dir=training_args.cache_dir,
    torch_dtype=torch.bfloat16,
)

for param in model.base_model.parameters():
    param.requires_grad = False

model

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head)

In [9]:
from medusa.model.medusa_model import MedusaModelLlama, MedusaConfig


medusa_config = MedusaConfig(
    medusa_num_heads=training_args.medusa_num_heads,
    medusa_num_layers=training_args.medusa_num_layers,
    base_model_name_or_path=model_args.model_name_or_path,
	**config.to_dict(),
)
medusa_config

MedusaConfig {
  "_name_or_path": "/models/Meta-Llama-3-8B-Instruct",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "base_model_name_or_path": "/models/Meta-Llama-3-8B-Instruct",
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 8192,
  "medusa_num_heads": 5,
  "medusa_num_layers": 1,
  "mlp_bias": false,
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 500000.0,
  "sliding_window": 960,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.41.2",
  "use_cache": false,
  "use_sliding_window": true,
  "vocab_size": 128256
}

In [10]:
medusa_lm_head = MedusaModelLlama(medusa_config)
medusa_lm_head

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


MedusaModelLlama(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Li

In [11]:
medusa_lm_head.load_state_dict(model.state_dict(), strict=False)

_IncompatibleKeys(missing_keys=['medusa_head.0.0.linear.weight', 'medusa_head.0.0.linear.bias', 'medusa_head.0.1.weight', 'medusa_head.1.0.linear.weight', 'medusa_head.1.0.linear.bias', 'medusa_head.1.1.weight', 'medusa_head.2.0.linear.weight', 'medusa_head.2.0.linear.bias', 'medusa_head.2.1.weight', 'medusa_head.3.0.linear.weight', 'medusa_head.3.0.linear.bias', 'medusa_head.3.1.weight', 'medusa_head.4.0.linear.weight', 'medusa_head.4.0.linear.bias', 'medusa_head.4.1.weight'], unexpected_keys=[])

In [12]:
for param in medusa_lm_head.parameters():
    param.requires_grad = False

for param in medusa_lm_head.medusa_head.parameters():
    param.requires_grad = True

In [13]:
from medusa.train.train import make_supervised_data_module


data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
data_module

{'train_dataset': <medusa.train.train.LazySupervisedDataset at 0x7fa0a44d1ae0>,
 'eval_dataset': None}

In [14]:
from torch.utils.data import DataLoader, RandomSampler
from transformers.trainer_utils import seed_worker
from transformers.data.data_collator import DataCollatorWithPadding


default_collator = DataCollatorWithPadding(tokenizer)

dataloader_params = {
	"batch_size": 1,
	"collate_fn": default_collator,
	"num_workers": 1,
	"pin_memory": training_args.dataloader_pin_memory,
	"persistent_workers": training_args.dataloader_persistent_workers,
	"sampler": RandomSampler(data_module['train_dataset']),
	"drop_last": training_args.dataloader_drop_last,
	"worker_init_fn": seed_worker,
	"prefetch_factor": training_args.dataloader_prefetch_factor,
}

loader = DataLoader(data_module['train_dataset'], **dataloader_params)
loader

<torch.utils.data.dataloader.DataLoader at 0x7f9e5bf10ca0>

In [15]:
it = iter(loader)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [16]:
batch = next(it)
batch

{'input_ids': tensor([[128000, 128000, 128006,  ..., 128001, 128001, 128001]]), 'labels': tensor([[-100, -100, -100,  ..., -100, -100, -100]]), 'attention_mask': tensor([[ True,  True,  True,  ..., False, False, False]])}

In [17]:
batch['input_ids'].shape, batch['labels'].shape, batch['attention_mask'].shape

(torch.Size([1, 2048]), torch.Size([1, 2048]), torch.Size([1, 2048]))

In [18]:
print(tokenizer.decode(batch['input_ids'][0], skip_special_tokens=False, spaces_between_special_tokens=False, clean_up_tokenization_spaces=True))

<|begin_of_text|><|begin_of_text|><|start_header_id|>user<|end_header_id|>

Modify the text above (your last message), following to guidelines below:

- The tram cannot avoid obstacles because it is running on rails,
- Behind the wheel? Try to refrain from using this type of wording because Klee is driving a tram, not a car.
- Dead man's switch in Moderus Beta Tram requires the use of pressure on the joystick (you have to "push" the handle), not a firm grip.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

I apologize, but there is no text above (my last message) for me to modify. This conversation has just started, and I haven't made any previous messages. Please provide the text you would like me to modify, and I'll be happy to assist you following the guidelines you've provided.<|eot_id|><|start_header_id|>user<|end_header_id|>

Guidelines:
- The tram cannot avoid obstacles because it is running on rails,
- Behind the wheel? Try to refrain from using this type of wording beca

In [19]:
batch['labels'].tolist()

[[-100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  40,
  37979,
  11,
  719,
  1070,
  374,
  912,
  1495,
  3485,
  320,
  2465,
  1566,
  1984,
  8,
  369,
  757,
  311,
  5719,
  13,
  1115,
  10652,
  706,
  1120,
  3940,
  11,
  323,
  358,
  9167,
  956,
  19

In [20]:
medusa_lm_head.medusa

5

In [21]:
#medusa_lm_head.to('cuda')

In [36]:
logits = medusa_lm_head(input_ids=batch["input_ids"].cpu(), attention_mask=batch["attention_mask"].cpu(), medusa_forward=True)
logits

tensor([[[[-2.1204,  0.0394,  0.5454,  ..., -0.0472, -0.1307, -1.7578],
          [-2.1204,  0.0394,  0.5453,  ..., -0.0472, -0.1307, -1.7578],
          [-2.7197, -0.9162, -0.6635,  ...,  1.2583,  0.3645,  0.5341],
          ...,
          [-1.7105, -0.8079, -0.6877,  ...,  1.0344, -0.5797, -0.4303],
          [-1.7677, -0.8082, -0.6855,  ...,  1.0795, -0.5661, -0.4886],
          [-1.7735, -0.7621, -0.6533,  ...,  1.0569, -0.5843, -0.5578]]],


        [[[ 0.1929, -0.6450, -0.1753,  ..., -0.3662,  2.9861,  2.2217],
          [ 0.1929, -0.6450, -0.1753,  ..., -0.3662,  2.9861,  2.2217],
          [ 3.5173, -0.3727, -0.6956,  ..., -2.1757,  1.4991,  0.0981],
          ...,
          [ 3.7704,  1.0806, -0.7429,  ..., -1.9475,  0.7647,  1.6122],
          [ 3.7350,  1.0163, -0.7496,  ..., -2.0691,  0.7846,  1.6484],
          [ 3.6551,  0.9750, -0.8173,  ..., -2.1715,  0.7359,  1.6903]]],


        [[[ 1.8577,  0.5890,  0.8425,  ..., -0.0783, -1.4155, -1.1008],
          [ 1.8577,  0.589

In [37]:
logits.shape

torch.Size([5, 1, 2048, 128256])

In [38]:
medusa_logits = logits[0, :, :-2].contiguous()
medusa_logits.shape, medusa_logits

(torch.Size([1, 2046, 128256]),
 tensor([[[-2.1204,  0.0394,  0.5454,  ..., -0.0472, -0.1307, -1.7578],
          [-2.1204,  0.0394,  0.5453,  ..., -0.0472, -0.1307, -1.7578],
          [-2.7197, -0.9162, -0.6635,  ...,  1.2583,  0.3645,  0.5341],
          ...,
          [-1.7288, -0.7451, -0.6858,  ...,  0.9639, -0.5497, -0.5661],
          [-1.7611, -0.7304, -0.6741,  ...,  0.9562, -0.5233, -0.4513],
          [-1.7105, -0.8079, -0.6877,  ...,  1.0344, -0.5797, -0.4303]]],
        grad_fn=<SliceBackward0>))

In [39]:
medusa_labels = batch['labels'][..., 2:].contiguous()
medusa_labels.shape

torch.Size([1, 2046])

In [40]:
medusa_logits = medusa_logits.view(-1, logits.shape[-1])
medusa_logits.shape, medusa_logits

(torch.Size([2046, 128256]),
 tensor([[-2.1204,  0.0394,  0.5454,  ..., -0.0472, -0.1307, -1.7578],
         [-2.1204,  0.0394,  0.5453,  ..., -0.0472, -0.1307, -1.7578],
         [-2.7197, -0.9162, -0.6635,  ...,  1.2583,  0.3645,  0.5341],
         ...,
         [-1.7288, -0.7451, -0.6858,  ...,  0.9639, -0.5497, -0.5661],
         [-1.7611, -0.7304, -0.6741,  ...,  0.9562, -0.5233, -0.4513],
         [-1.7105, -0.8079, -0.6877,  ...,  1.0344, -0.5797, -0.4303]],
        grad_fn=<ViewBackward0>))

In [41]:
medusa_labels = medusa_labels.view(-1)
medusa_labels = medusa_labels.to(medusa_logits.device)
medusa_labels.shape, medusa_labels

(torch.Size([2046]), tensor([-100, -100, -100,  ..., -100, -100, -100]))

In [42]:
from torch.nn import CrossEntropyLoss


loss_fct = CrossEntropyLoss()
loss_i = loss_fct(medusa_logits, medusa_labels)
loss_i

tensor(12.8094, grad_fn=<NllLossBackward0>)

In [43]:
from transformers.trainer_pt_utils import LabelSmoother

LabelSmoother.ignore_index

-100