In [1]:
import os
import json

import torch
import torch.distributed as dist
from transformers import AutoTokenizer
from safetensors.torch import load_model

from model import Transformer, ModelArgs
from generate import sample, generate

In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

## Prepare model and inputs

In [3]:
ckpt_path = "/home/DeepSeek-V2-Lite-Chat_converted"
config = "configs/config_16B.json"
input_file = "input_file.txt"
max_new_tokens: int = 200
temperature: float = 0.2

torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(8)
torch.manual_seed(965)

with open(config) as f:
    args = ModelArgs(**json.load(f))
print(args)

with torch.device("cuda"):
    model = Transformer(args)

tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])

rank, world_size = 0, 1  # single-device
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))

ModelArgs(max_batch_size=8, max_seq_len=16384, dtype='bf16', vocab_size=102400, dim=2048, inter_dim=10944, moe_inter_dim=1408, n_layers=27, n_dense_layers=1, n_heads=16, n_routed_experts=64, n_shared_experts=2, n_activated_experts=6, n_expert_groups=1, n_limited_groups=1, score_func='softmax', route_scale=1.0, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, original_seq_len=4096, rope_theta=10000.0, rope_factor=40, beta_fast=32, beta_slow=1, mscale=0.707)


(set(), [])

In [4]:
# model

In [5]:
with open(input_file) as f:
    prompts = [line.strip() for line in f.readlines()]
assert len(prompts) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})"

prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
prompt_tokens

[[100000, 5726, 25, 37727, 0, 185, 185, 77398, 25],
 [100000, 5726, 25, 1724, 418, 340, 30, 185, 185, 77398, 25],
 [100000, 5726, 25, 7566, 2653, 13, 185, 185, 77398, 25]]

In [6]:
%time completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)

completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
for prompt, completion in zip(prompts, completions):
    print("Prompt:", prompt)
    print("Completion:", completion)

CPU times: user 4.22 s, sys: 21.9 ms, total: 4.24 s
Wall time: 4.24 s
Prompt: Hello!
Completion:  Hello! How can I help you today? If you have any questions or need assistance, feel free to ask.
Prompt: How are you?
Completion:  As an AI, I do not have feelings, but I am functioning properly and ready to assist you with any questions or tasks you have.
Prompt: Good night.
Completion:  Good night! Have a great rest and pleasant dreams.


## Checkpoints MLA input output

In [7]:
mla_layer = model.layers[0].attn

In [8]:
mla_layer.softmax_scale

0.1147213867929261

In [9]:
mla_layer.save_ckpt, mla_layer.ckpt_dir, mla_layer.ckpt_iter

(False, 'output_ckpt', 0)

In [10]:
#!mkdir output_ckpt

In [11]:
mla_layer.save_ckpt = True
mla_layer.ckpt_iter = 0

completion_tokens = generate(
    model,
    prompt_tokens,
    10, # max_new_tokens,
    tokenizer.eos_token_id,
    temperature
)

completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
for prompt, completion in zip(prompts, completions):
    print("Prompt:", prompt)
    print("Completion:", completion)

saving ckpt to:  output_ckpt/mla_ckpt_0.safetensors
saving ckpt to:  output_ckpt/mla_ckpt_1.safetensors
saving ckpt to:  output_ckpt/mla_ckpt_2.safetensors
saving ckpt to:  output_ckpt/mla_ckpt_3.safetensors
saving ckpt to:  output_ckpt/mla_ckpt_4.safetensors
saving ckpt to:  output_ckpt/mla_ckpt_5.safetensors
saving ckpt to:  output_ckpt/mla_ckpt_6.safetensors
saving ckpt to:  output_ckpt/mla_ckpt_7.safetensors
saving ckpt to:  output_ckpt/mla_ckpt_8.safetensors
saving ckpt to:  output_ckpt/mla_ckpt_9.safetensors
saving ckpt to:  output_ckpt/mla_ckpt_10.safetensors
saving ckpt to:  output_ckpt/mla_ckpt_11.safetensors
Prompt: Hello!
Completion:  Hello! How can I help you today? If
Prompt: How are you?
Completion:  As an AI, I do not have feelings,
Prompt: Good night.
Completion:  Good night! Have a great rest and pleasant dreams


In [12]:
!ls -lh output_ckpt

total 2.7M
-rw-r--r-- 1 root root 957K Mar 29 08:45 mla_ckpt_0.safetensors
-rw-r--r-- 1 root root 138K Mar 29 08:45 mla_ckpt_1.safetensors
-rw-r--r-- 1 root root 169K Mar 29 08:45 mla_ckpt_10.safetensors
-rw-r--r-- 1 root root 172K Mar 29 08:45 mla_ckpt_11.safetensors
-rw-r--r-- 1 root root 141K Mar 29 08:45 mla_ckpt_2.safetensors
-rw-r--r-- 1 root root 145K Mar 29 08:45 mla_ckpt_3.safetensors
-rw-r--r-- 1 root root 148K Mar 29 08:45 mla_ckpt_4.safetensors
-rw-r--r-- 1 root root 152K Mar 29 08:45 mla_ckpt_5.safetensors
-rw-r--r-- 1 root root 155K Mar 29 08:45 mla_ckpt_6.safetensors
-rw-r--r-- 1 root root 159K Mar 29 08:45 mla_ckpt_7.safetensors
-rw-r--r-- 1 root root 162K Mar 29 08:45 mla_ckpt_8.safetensors
-rw-r--r-- 1 root root 165K Mar 29 08:45 mla_ckpt_9.safetensors


## Checkpoint all layers

In [13]:
len(model.layers)

27

In [14]:
# !rm -rf mla_ckpts

In [15]:
import os

parent_dir = "mla_ckpts"

for i, layer in enumerate(model.layers):
    current_dir = os.path.join(parent_dir, f"layer{i:02}")
    os.makedirs(current_dir, exist_ok=True)
    
    mla_current = layer.attn
    mla_current.save_ckpt = True
    mla_current.ckpt_dir = current_dir
    mla_current.ckpt_iter = 0

completion_tokens = generate(
    model,
    prompt_tokens,
    20, # max_new_tokens,
    tokenizer.eos_token_id,
    temperature
)

saving ckpt to:  mla_ckpts/layer00/mla_ckpt_0.safetensors
saving ckpt to:  mla_ckpts/layer01/mla_ckpt_0.safetensors
saving ckpt to:  mla_ckpts/layer02/mla_ckpt_0.safetensors
saving ckpt to:  mla_ckpts/layer03/mla_ckpt_0.safetensors
saving ckpt to:  mla_ckpts/layer04/mla_ckpt_0.safetensors
saving ckpt to:  mla_ckpts/layer05/mla_ckpt_0.safetensors
saving ckpt to:  mla_ckpts/layer06/mla_ckpt_0.safetensors
saving ckpt to:  mla_ckpts/layer07/mla_ckpt_0.safetensors
saving ckpt to:  mla_ckpts/layer08/mla_ckpt_0.safetensors
saving ckpt to:  mla_ckpts/layer09/mla_ckpt_0.safetensors
saving ckpt to:  mla_ckpts/layer10/mla_ckpt_0.safetensors
saving ckpt to:  mla_ckpts/layer11/mla_ckpt_0.safetensors
saving ckpt to:  mla_ckpts/layer12/mla_ckpt_0.safetensors
saving ckpt to:  mla_ckpts/layer13/mla_ckpt_0.safetensors
saving ckpt to:  mla_ckpts/layer14/mla_ckpt_0.safetensors
saving ckpt to:  mla_ckpts/layer15/mla_ckpt_0.safetensors
saving ckpt to:  mla_ckpts/layer16/mla_ckpt_0.safetensors
saving ckpt to

In [16]:
!du -sh mla_ckpts

122M	mla_ckpts
