In [1]:
import os
import transformers

from medusa.train.train_legacy import ModelArguments, DataArguments, TrainingArguments


model_args = ModelArguments(
	model_name_or_path=os.path.expanduser('~/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B/snapshots/62bd457b6fe961a42a631306577e622c83876cb6'),
)

data_args = DataArguments(
	data_path=os.path.expanduser('~/data/ShareGPT_Vicuna_unfiltered/shareGPT-llama3-8B.json'),
	lazy_preprocess=True,
)

training_args = TrainingArguments(
	output_dir='./train/test',
	medusa_num_heads=5,
	medusa_num_layers=1,
)

config = transformers.AutoConfig.from_pretrained(
	model_args.model_name_or_path,
	cache_dir=training_args.cache_dir,
)
config

  from .autonotebook import tqdm as notebook_tqdm


LlamaConfig {
  "_name_or_path": "/home/camus/.cache/huggingface/hub/models--meta-llama--Meta-Llama-3-8B/snapshots/62bd457b6fe961a42a631306577e622c83876cb6",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 8192,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 500000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.41.2",
  "use_cache": true,
  "vocab_size": 128256
}

In [2]:
orig_ctx_len = getattr(config, "max_position_embeddings", None)
orig_ctx_len, training_args.model_max_length

(8192, 2048)

In [3]:
config.use_cache = False

In [4]:
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_args.model_name_or_path,
    cache_dir=training_args.cache_dir,
    model_max_length=training_args.model_max_length,
    padding_side="right",
    use_fast=True,
)

tokenizer.pad_token, tokenizer.unk_token, tokenizer.eos_token

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


(None, None, '<|end_of_text|>')

In [5]:
tokenizer.pad_token = tokenizer.eos_token

In [6]:
tokenizer(["This is a test", "secondary"], padding=True)

{'input_ids': [[128000, 2028, 374, 264, 1296], [128000, 19217, 128001, 128001, 128001]], 'attention_mask': [[1, 1, 1, 1, 1], [1, 1, 0, 0, 0]]}

In [7]:
tokenizer.apply_chat_template([{"role": "user", "content": "This is a test"}])

[128000,
 128006,
 882,
 128007,
 271,
 2028,
 374,
 264,
 1296,
 128009,
 128006,
 78191,
 128007,
 271]

In [8]:
import torch


model = transformers.LlamaForCausalLM.from_pretrained(
    model_args.model_name_or_path,
    config=config,
    cache_dir=training_args.cache_dir,
    torch_dtype=torch.bfloat16,
)

for param in model.base_model.parameters():
    param.requires_grad = False

model

Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.98it/s]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head)

In [9]:
from medusa.model.medusa_model_legacy import MedusaModel


medusa_lm_head = MedusaModel(
    model,
    medusa_num_heads=training_args.medusa_num_heads,
    medusa_num_layers=training_args.medusa_num_layers,
    base_model_name_or_path=model_args.model_name_or_path,
)
medusa_lm_head

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


MedusaModel(
  (base_model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 4096)
      (layers): ModuleList(
        (0-31): 32 x LlamaDecoderLayer(
          (self_attn): LlamaSdpaAttention(
            (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
            (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
            (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm()
          (post_attention_layernorm): Llam

In [10]:
from medusa.train.train_legacy import make_supervised_data_module


data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
data_module

{'train_dataset': <medusa.train.train_legacy.LazySupervisedDataset at 0x78f5f885bfa0>,
 'eval_dataset': None}

In [17]:
from torch.utils.data import DataLoader, RandomSampler
from transformers.trainer_utils import seed_worker
from transformers.data.data_collator import DataCollatorWithPadding


default_collator = DataCollatorWithPadding(tokenizer)

dataloader_params = {
	"batch_size": 1,
	"collate_fn": default_collator,
	"num_workers": 1,
	"pin_memory": training_args.dataloader_pin_memory,
	"persistent_workers": training_args.dataloader_persistent_workers,
	"sampler": RandomSampler(data_module['train_dataset']),
	"drop_last": training_args.dataloader_drop_last,
	"worker_init_fn": seed_worker,
	"prefetch_factor": training_args.dataloader_prefetch_factor,
}

loader = DataLoader(data_module['train_dataset'], **dataloader_params)
loader

<torch.utils.data.dataloader.DataLoader at 0x78f1a7450c10>

In [18]:
it = iter(loader)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [20]:
batch = next(it)
batch

{'input_ids': tensor([[128000, 128000, 128006,  ...,   2763,    315,   8149]]), 'labels': tensor([[  -100, 128000, 128006,  ...,   2763,    315,   8149]]), 'attention_mask': tensor([[True, True, True,  ..., True, True, True]])}

In [21]:
batch['input_ids'].shape, batch['labels'].shape, batch['attention_mask'].shape

(torch.Size([1, 2048]), torch.Size([1, 2048]), torch.Size([1, 2048]))

In [27]:
print(tokenizer.decode(batch['input_ids'][0], skip_special_tokens=False, spaces_between_special_tokens=False, clean_up_tokenization_spaces=True))

<|begin_of_text|><|begin_of_text|><|start_header_id|>user<|end_header_id|>

I am going to write a choose your own adventure role playing book using some mechanics from D&D 5th edition with several homebrew changes, class, and settings. The adventure will take place on a Aztec themed temple called the Yoselotl Temple and I'm initially naming the adventure book, "Defeat the Dungeon: Temple of Lost Souls." First thing is first, I want to go over the lore and history of the world before we get into the hooks/goals for the character going into the dungeon.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

What an exciting project! I'd be happy to help you develop the lore and history of the world for "Defeat the Dungeon: Temple of Lost Souls."

Let's start by creating a rich and immersive world. Here's a possible backstory for the Yoselotl Temple and the world it's set in:

**World Setting:**
The world is called Tenochtitlan, a land of ancient civilizations, mystical energies, and hid

In [34]:
batch['labels'].tolist()

[[-100,
  128000,
  128006,
  -100,
  128007,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  2162,
  22363,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  15757,
  55300,
  -100,
  -100,
  -100,
  -100,
  816,
  437,
  301,
  354,
  75,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  2685,
  33166,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  26606,
  1147,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  -100,
  128009,
  128006,
  -100,
  128007,
  -100,
  3923,
  459,
  13548,
  2447,
  0,
  358,
  4265,
 