In [1]:
# import torch

# w1 = torch.load('./llama_7b/pytorch_model-00001-of-00002.bin')
# w2 = torch.load('./llama_7b/pytorch_model-00002-of-00002.bin')

In [2]:
import math
from dataclasses import dataclass
from typing import Any, Optional, Tuple

import fairscale.nn.model_parallel.initialize as fs_init
from fairscale.nn.model_parallel.initialize import (
    get_model_parallel_rank,
    initialize_model_parallel,
    model_parallel_is_initialized,
)

import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.layers import (
    ColumnParallelLinear,
    ParallelEmbedding,
    RowParallelLinear,
)
from torch import nn
import numpy as np
import random


seed = 2021

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)



@dataclass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 2
    n_heads: int = 2
    n_kv_heads: Optional[int] = None
    vocab_size: int = 2  # defined later by tokenizer
    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5

    max_batch_size: int = 32
    max_seq_len: int = 2048


class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)  # type: ignore
    freqs = torch.outer(t, freqs).float()  # type: ignore
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )


class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        model_parallel_size = 1 # fs_init.get_model_parallel_world_size()
        self.n_local_heads = args.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads

        self.wq = nn.Linear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False
        )
        self.wk = nn.Linear(
            args.dim,
            self.n_kv_heads * self.head_dim,
            bias=False,
        )
        self.wv = nn.Linear(
            args.dim,
            self.n_kv_heads * self.head_dim,
            bias=False
        )
        self.wo = nn.Linear(
            args.n_heads * self.head_dim,
            args.dim,
            bias=False
        )

        self.cache_k = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        )
        self.cache_v = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        )

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
#         print('output of xq values:', xq)
#         print('output of xk values:', xk)
#         print('output of xv values:', xv)
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)

        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        # repeat k/v heads if n_kv_heads < n_heads
        keys = repeat_kv(keys, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
        values = repeat_kv(values, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)

        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        output = self.wo(output)
        print('output of attention layer:', output)
        return output


class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        hidden_dim: int,
        multiple_of: int,
        ffn_dim_multiplier: Optional[float],
    ):
        super().__init__()
        hidden_dim = int(2 * hidden_dim / 3)
        # custom dim factor multiplier
        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
        
        self.w1 = nn.Linear(
            dim, hidden_dim, bias=False
        )
        self.w2 = nn.Linear(
            hidden_dim, dim, bias=False
        )
        self.w3 = nn.Linear(
            dim, hidden_dim, bias=False
        )

    def forward(self, x):
        print('ffn_w1 input:', x)
        our_w1 = self.w1(x)
        our_w2 = self.w3(x)
        our_w3 = self.w2(F.silu(our_w1) * our_w2)
        print('our ffn_w1:', our_w1.shape, our_w1)
        print('our ffn_w2:', our_w2.shape, our_w2)
        print('our ffn_w3:', our_w3.shape, our_w3)
        return our_w3
        # return self.w2(F.silu(self.w1(x)) * self.w3(x))


class TransformerBlock(nn.Module):
    def __init__(self, layer_id: int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim,
            hidden_dim=4 * args.dim,
            multiple_of=args.multiple_of,
            ffn_dim_multiplier=args.ffn_dim_multiplier,
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        print('output of attention_norm', self.attention_norm(x))
        h = x + self.attention.forward(
            self.attention_norm(x), start_pos, freqs_cis, mask
        )
        
        print('output of ffn_norm:', self.ffn_norm(h))
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out


class Transformer(nn.Module):
    def __init__(self, params: ModelArgs):
        super().__init__()
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers

        self.tok_embeddings = nn.Embedding(
            params.vocab_size, params.dim
        )

        self.layers = torch.nn.ModuleList()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlock(layer_id, params))

        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
        self.output = nn.Linear(
            params.dim, params.vocab_size, bias=False
        )

        self.freqs_cis = precompute_freqs_cis(
            self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
        )

    @torch.inference_mode()
    def forward(self, tokens: torch.Tensor, start_pos: int):
        _bsz, seqlen = tokens.shape
        h = self.tok_embeddings(tokens)
        # print('output of embedding:', h)
        self.freqs_cis = self.freqs_cis.to(h.device)
        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

        mask = None
        if seqlen > 1:
            mask = torch.full(
                (1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
            )
            mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)

        for layer in self.layers:
            h = layer(h, start_pos, freqs_cis, mask)
        h = self.norm(h)
        print('output of out_norm: ', h)
        output = self.output(h).float()
        return output

In [3]:
params = ModelArgs()
model = Transformer(params)

In [4]:
X = torch.ones((1, 1), dtype=torch.long)
print('input of embedding:', X)
model(X, 0)

input of embedding: tensor([[1]])
output of attention_norm tensor([[[ 0.0691,  1.2589, -1.7623,  ...,  0.1090,  0.0561, -0.2484]]])
output of attention layer: tensor([[[ 0.0868,  0.4848,  0.1852,  ..., -0.0855, -0.4508,  0.2967]]])
output of ffn_norm: tensor([[[ 0.1484,  1.6586, -1.4997,  ...,  0.0223, -0.3757,  0.0461]]])
ffn_w1 input: tensor([[[ 0.1484,  1.6586, -1.4997,  ...,  0.0223, -0.3757,  0.0461]]])
our ffn_w1: torch.Size([1, 1, 11008]) tensor([[[ 0.7358, -0.1107,  0.4632,  ..., -0.4656,  0.3076,  0.7836]]])
our ffn_w2: torch.Size([1, 1, 11008]) tensor([[[ 0.0739,  0.2551, -0.5209,  ...,  0.3584,  0.7982, -0.2969]]])
our ffn_w3: torch.Size([1, 1, 4096]) tensor([[[-0.1175, -0.0929,  0.0018,  ...,  0.1145, -0.0024, -0.1818]]])
output of attention_norm tensor([[[ 0.0362,  1.5597, -1.4880,  ...,  0.1304, -0.3754, -0.1261]]])
output of attention layer: tensor([[[-0.4230,  0.1057,  0.2668,  ..., -0.4855, -0.1360,  0.1278]]])
output of ffn_norm: tensor([[[-0.3463,  1.5800, -1.1765,  

tensor([[[-1.2173,  0.1302]]])

In [5]:
model.state_dict()['norm.weight'].shape

torch.Size([4096])

In [6]:
# model.state_dict()['tok_embeddings.weight']

In [7]:
model.state_dict().keys()

odict_keys(['tok_embeddings.weight', 'layers.0.attention.wq.weight', 'layers.0.attention.wk.weight', 'layers.0.attention.wv.weight', 'layers.0.attention.wo.weight', 'layers.0.feed_forward.w1.weight', 'layers.0.feed_forward.w2.weight', 'layers.0.feed_forward.w3.weight', 'layers.0.attention_norm.weight', 'layers.0.ffn_norm.weight', 'layers.1.attention.wq.weight', 'layers.1.attention.wk.weight', 'layers.1.attention.wv.weight', 'layers.1.attention.wo.weight', 'layers.1.feed_forward.w1.weight', 'layers.1.feed_forward.w2.weight', 'layers.1.feed_forward.w3.weight', 'layers.1.attention_norm.weight', 'layers.1.ffn_norm.weight', 'norm.weight', 'output.weight'])

In [8]:
print(model.state_dict()['layers.0.feed_forward.w1.weight'].transpose(1, 0).shape)
print(model.state_dict()['layers.0.feed_forward.w3.weight'].transpose(1, 0).shape)

torch.Size([4096, 11008])
torch.Size([4096, 11008])


In [9]:
import numpy as np

file = open("/home/seungbaek/hdd/projects/nntrainer/build/Applications/LLaMA/jni/llama_v2.bin", "wb")

def save_llama_to_bin(params, n_layer = 32, n_head = 32):
    def save_weight(weight):
        np.array(weight).tofile(file)

    def save_embedding(weight):
        save_weight(weight)

    def save_attention(weights, layer_name, n_head = 32):        
        save_weight(params[layer_name + 'attention_norm' + '.weight'])
        split_size = (4096 // n_head)
        for head_idx in range(1, n_head+1):            
            st_idx = (4096 - split_size * head_idx)
            end_idx = st_idx + split_size
            save_weight(params[layer_name + 'attention.wv' + '.weight'][st_idx:end_idx, :].permute(1, 0))
            
        for head_idx in range(1, n_head+1):
            st_idx = (4096 - split_size * head_idx)
            end_idx = st_idx + split_size
            save_weight(params[layer_name + 'attention.wk' + '.weight'][st_idx:end_idx, :].permute(1, 0))

        for head_idx in range(1, n_head+1):
            st_idx = (4096 - split_size * head_idx)
            end_idx = st_idx + split_size
            save_weight(params[layer_name + 'attention.wq' + '.weight'][st_idx:end_idx, :].permute(1, 0)) # It includes multiple heads
        
        save_weight(params[layer_name + 'attention.wo' + '.weight'].permute(1, 0))

    def save_feed_forward(weights, layer_name):
        save_weight(params[layer_name + 'ffn_norm' + '.weight'])        
        
        save_weight(params[layer_name + 'feed_forward.w3' + '.weight'].permute(1, 0))
        save_weight(params[layer_name + 'feed_forward.w1' + '.weight'].permute(1, 0))        
        save_weight(params[layer_name + 'feed_forward.w2' + '.weight'].permute(1, 0))

    # save weights of embedding layer
    save_embedding(params['tok_embeddings.weight'])
    
    # save weights of attention layers
    for layer_idx in range(n_layer):
        save_attention(params, 'layers.{}.'.format(layer_idx), n_head)
        save_feed_forward(params, 'layers.{}.'.format(layer_idx))
        
    save_weight(params['norm.weight'])
    
    save_weight(params['output.weight'].permute(1, 0))
    
save_llama_to_bin(model.state_dict(), n_layer = params.n_layers, n_head = params.n_heads)
