In [1]:
import os 
import torch
# os.environ["CUDA_VISIBLE_DEVICES"] = '1'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'  # For synchronous execution of CPU and GPU
device = torch.device("cuda:1")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from modeling_mixformer_sequential import MixFormerSequentialForCausalLM, InferenceParams
from configuration_mixformer_sequential import MixFormerSequentialConfig

from transformers.modeling_outputs import CausalLMOutputWithPast
import torch
from typing import Any, Dict, Optional, Tuple, Union

class ClientSideMixFormerSequentialForCausalLM(MixFormerSequentialForCausalLM):
    
    def __init__(self, config):
        super().__init__(config)
        self.split_layer=2
        self.num_layers=20
        for i in range(self.split_layer, self.num_layers+1):
            self.layers[i] = None
            
    def forward(
        self,
        input_ids: torch.LongTensor,
        past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
        attention_mask: Optional[torch.BoolTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        **kwargs,
    ):
        if attention_mask is not None and self.training:
            print("`attention_mask` is not supported during training. Using it might lead to unexpected results.")

        if past_key_values is None and attention_mask is None:
            print("[Client] past_key_values & attention_mask is None!")
            lm_logits = self.layers(input_ids)
            return lm_logits
        else:
            print("[Client] forward with past_key_values or attention_mask!")
            hidden_layer = self.layers[0](input_ids)
            for module in self.layers[1: self.split_layer]:  # return intermediate tensor
                hidden_layer = module(hidden_layer, past_key_values=past_key_values, attention_mask=attention_mask)
            return input_ids, past_key_values, attention_mask, labels, hidden_layer



class ServerSideMixFormerSequentialForCausalLM(MixFormerSequentialForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.split_layer=2
        for i in range(0, self.split_layer):
            self.layers[i] = None
        
    def forward(
        self,
        input_ids: torch.LongTensor,
        past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
        attention_mask: Optional[torch.BoolTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        hidden_layer_input: torch.Tensor = None,
        **kwargs,
    ) -> CausalLMOutputWithPast:
        if attention_mask is not None and self.training:
            print("`attention_mask` is not supported during training. Using it might lead to unexpected results.")

        if past_key_values is None and attention_mask is None:
            print("[Server] past_key_values & attention_mask is None!")
            lm_logits = self.layers(input_ids)
        else:
            print("[Server] forward with past_key_values or attention_mask!")
            hidden_layer = hidden_layer_input
            for module in self.layers[self.split_layer:-1]:  # Compute the remaining block 
                hidden_layer = module(hidden_layer, past_key_values=past_key_values, attention_mask=attention_mask)
            lm_logits = self.layers[-1](hidden_layer)
            
        loss = None
        if labels is not None:
            loss = self.loss(lm_logits, labels)

        return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)



[2023-10-13 15:54:18,968] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [3]:
config = MixFormerSequentialConfig()
client_model = ClientSideMixFormerSequentialForCausalLM(config)
server_model = ServerSideMixFormerSequentialForCausalLM(config)

In [4]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
model_name = ("gpt2", "microsoft/phi-1_5")
tokenizer = AutoTokenizer.from_pretrained(
    model_name[1], 
    trust_remote_code=True,
    cache_dir="/app/.huggingface_cache/model/"
)
model = AutoModelForCausalLM.from_pretrained(
    model_name[1],
    cache_dir="/app/.huggingface_cache/model/",
    trust_remote_code=True,
)
print(model)
lora_modules=["Wqkv"] 
lora_config = LoraConfig(
    r=2,  # dimension of the updated matrices
    lora_alpha=64,  # parameter for scaling
    target_modules=lora_modules,
    lora_dropout=0.1,  # dropout probability for layers
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
client_model = get_peft_model(client_model, lora_config)
server_model = get_peft_model(server_model, lora_config)

MixFormerSequentialForCausalLM(
  (layers): Sequential(
    (0): Embedding(
      (wte): Embedding(51200, 2048)
      (drop): Dropout(p=0.0, inplace=False)
    )
    (1): ParallelBlock(
      (ln): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (resid_dropout): Dropout(p=0.0, inplace=False)
      (mixer): MHA(
        (rotary_emb): RotaryEmbedding()
        (Wqkv): Linear(in_features=2048, out_features=6144, bias=True)
        (out_proj): Linear(in_features=2048, out_features=2048, bias=True)
        (inner_attn): SelfAttention(
          (drop): Dropout(p=0.0, inplace=False)
        )
        (inner_cross_attn): CrossAttention(
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (mlp): MLP(
        (fc1): Linear(in_features=2048, out_features=8192, bias=True)
        (fc2): Linear(in_features=8192, out_features=2048, bias=True)
        (act): NewGELUActivation()
      )
    )
    (2): ParallelBlock(
      (ln): LayerNorm((2048,), eps=1e-05, elementwis