In [1]:
import sys

sys.path.append("/home/htkumar/llms/mistral-finetune")

In [2]:
from huggingface_hub import notebook_login

notebook_login()

from typing import List, NamedTuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from model.args import ModelArgs

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
from xformers.ops.fmha import memory_efficient_attention
from xformers.ops.fmha.attn_bias import AttentionBias, BlockDiagonalCausalMask

In [4]:
class LoRALinear(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        rank: int,
        scaling: float,
        dropout: float,
        bias: bool = False,
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank
        self.scaling = scaling
        self.dropout = nn.Dropout(dropout)
        assert not bias
        self.bias = bias

        self.lora_A = nn.Linear(in_features, rank, bias=self.bias)
        self.lora_B = nn.Linear(rank, out_features, bias=self.bias)
        self.frozen_w = nn.Linear(in_features, out_features, bias=self.bias)

        def ignore_missing_keys(m: nn.Module, incompatible_keys: NamedTuple):
            incompatible_keys.missing_keys[:] = []

        self.register_load_state_dict_post_hook(ignore_missing_keys)

    def forward(self, x: torch.tensor):
        lora_res = self.lora_B(self.lora_A(self.dropout(x)))
        return self.frozen_w(x) + lora_res * self.scaling

    def __repr__(self) -> str:
        return f"LorA Linear(in_features: {self.in_features}, out_features: {self.out_features}, rank: {self.rank}, scaling: {self.scaling}, dropout: {self.dropout})"

    def merge_weight(self):
        with torch.no_grad():
            down_weight = self.lora_A.weight
            up_weight = self.lora_B.weight
            lora_weight = up_weight.mm(down_weight) * self.scaling

            weight += self.frozen_w.weight
            return weight

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        key_name = prefix + "weight"
        if key_name in state_dict:
            w_ref = state_dict[key_name]

            self.frozen_w.load_state_dict({"weight": w_ref}, assign=True)

In [5]:
in_features, out_features, rank = 128, 256, 16
frozen_w = nn.Linear(in_features, out_features, bias=False)
lora_A = nn.Linear(in_features, 16, bias=False)
lora_B = nn.Linear(16, out_features, bias=False)

In [10]:
lora_A.weight.shape, lora_B.weight.shape, frozen_w.weight.shape

(torch.Size([16, 128]), torch.Size([256, 16]), torch.Size([256, 128]))

In [15]:
lora_linear = LoRALinear(
    in_features=128, bout_features=256, rank=16, scaling=1.0, dropout=0.1
)

In [16]:
lora_linear(torch.randn(16, 128)).shape

torch.Size([16, 256])

In [None]:
# import torch

# torch.cuda.empty_cache()

In [None]:
torch.cuda.memory_summary(device=None, abbreviated=False)

In [None]:
from pathlib import Path

from huggingface_hub import snapshot_download

mistral_model_path = Path.home().joinpath("mistral_models", "7B-v0.3")
mistral_model_path.mkdir(parents=True, exist_ok=True)

In [None]:
mistral_model_path

In [None]:
class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()

        self.w1 = nn.Linear(128, 128)

    def forward(self, x) -> torch.Tensor:
        return self.w1(x)

In [None]:
experts = [nn.Linear(128, 128) for _ in range(8)]
gate = nn.Linear(128, 8, bias=False)

In [None]:
input = torch.randn(4, 12, 128)
input = input.view(-1, input.shape[-1])

In [None]:
gate_logits = gate(input)
gate_logits.shape

In [None]:
weights, selected_experts = torch.topk(gate_logits, 2)

In [None]:
weights.shape, selected_experts.shape

In [None]:
weights[0, 0, :], selected_experts[0, 0, :]

In [None]:
weights_s = F.softmax(weights, dim=1)
weights_s[0, :]

In [None]:
results = torch.zeros_like(input)

In [None]:
batch_idx, nth_expert = torch.where(selected_experts == 0)

In [None]:
batch_idx, nth_expert

In [None]:
selected_experts.shape

In [None]:
input.shape, input[batch_idx].shape

In [None]:
weights[batch_idx, nth_expert]

In [None]:
weights[batch_idx, nth_expert, None].shape

In [None]:
experts[0](input[batch_idx]).shape

In [None]:
torch.equal(
    (weights[2, 0, None] * experts[0](input[2])), (weights[2, 0] * experts[0](input[2]))
)

In [None]:
import pandas as pd

In [None]:
pd.set_option("display.max_colwidth", None)
pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", None)

In [None]:
train_data = pd.read_json(
    "/home/htkumar/llms/mistral-finetune/ultrachat/train.jsonl", lines=True
)

In [None]:
train_data.columns

In [None]:
first_sample = train_data.iloc[3674]
print(first_sample["prompt"])

In [None]:
print(first_sample["prompt_id"])

In [None]:
first_sample["messages"]

In [None]:
import json
from dataclasses import dataclass, field
from typing import Optional

from simple_parsing.helpers import Serializable

In [None]:
@dataclass
class MoeArgs(Serializable):
    num_experts: int = 8
    num_experts_per_tok: int = 2

    def __post_init__(self):
        if self.num_experts > 10:
            raise ValueError("num_experts must be <= 2")

In [None]:
a = MoeArgs(num_experts=11)

In [None]:
with open("/home/htkumar/llms/mistral-finetune/moe_args.txt", "w") as f:
    f.write(json.dumps(a.to_dict()))

In [None]:
with open("/home/htkumar/llms/mistral-finetune/moe_args.txt", "r") as f:
    for line in f:
        b = MoeArgs.from_dict(json.loads(line))
        print(b)