In [1]:
import torch
from torch.nn import MultiheadAttention
from transformers import LlamaModel, LlamaConfig,LlamaForCausalLM
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
import torch.nn as nn 
#if the directory is not the root of the project, change it
os.chdir('/home/mila/e/emiliano.penaloza/RLPHF')
import LoRA.loralib  as lora
os.environ['TRANSFORMERS_CACHE'] = '/home/mila/e/emiliano.penaloza/scratch/models'
os.environ['HF_HOME'] = '/home/mila/e/emiliano.penaloza/scratch/models'
os.environ['HF_DATASETS_CACHE'] = '/home/mila/e/emiliano.penaloza/scratch/models'
os.environ['TORCH_HOME'] = '/home/mila/e/emiliano.penaloza/scratch/models'
cache_dir = '/home/mila/e/emiliano.penaloza/scratch/models'



  from .autonotebook import tqdm as notebook_tqdm


In [2]:


# Initializing a LLaMA llama-7b style configuration
configuration = LlamaConfig()
print(f"{configuration=}")

configuration.num_hidden_layers = 2
configuration.num_attention_heads = 16
configuration.num_key_value_heads = 2
model = LlamaForCausalLM(configuration).to(0)





configuration=LlamaConfig {
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_position_embeddings": 2048,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 32,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "transformers_version": "4.39.2",
  "use_cache": true,
  "vocab_size": 32000
}



In [3]:

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

num_trainable_params = count_parameters(model)
print(f'Number of trainable parameters: {num_trainable_params}')
def human_format(num):
    magnitude = 0
    while abs(num) >= 1000:
        magnitude += 1
        num /= 1000.0
    return '%.2f%s' % (num, ['', 'K', 'M', 'B', 'T'][magnitude])

num_trainable_params = count_parameters(model)
print(f'Number of trainable parameters: {human_format(num_trainable_params)}')

Number of trainable parameters: 6738415616
Number of trainable parameters: 6.74B


In [4]:
import copy 
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer


class HyperNetController():
    def __init__(self,hyper_net, r ,  A, a , target_modules = ['q_proj','k_proj','v_proj']):
        self.hyper_net = hyper_net
        self.target_modules = target_modules
        self.hypernet_layers = []
        self.d_emb = hyper_net.d_emb
        self.r = r 
        self.A = A
        self.a = a 
        # self.num_layers = num_layers
        
    def setHyperNetLayer(self,w0):
        self.hypernet_layers.append( HyperNetLinear(self,w0,self.d_emb,self.d_emb,r = self.r,A = self.A, a = self.a))
        return self.hypernet_layers[-1]
    
    def replace_module(self, parent_module, child_name, new_module, old_module):
        setattr(parent_module, child_name, new_module)
        new_module.weight = old_module.weight
        if old_module.bias is not None:
            new_module.bias = old_module.bias
        if getattr(old_module, "state", None) is not None:
            new_module.state = old_module.state
            new_module.to(old_module.weight.device)

        # dispatch to correct device
        for name, module in new_module.named_modules():


            if "lora_" in name:
                
                module.to(old_module.weight.device)    
    def get_submodules(self, key):
        parent = self.model.get_submodule(".".join(key.split(".")[:-1]))
        target_name = key.split(".")[-1]
        target = self.model.get_submodule(key)
        return parent, target, target_name
    
    def augmentLLM(self, model):
        
        self.model = model
        key_list = [key for key, _ in model.named_modules()]
        for key in key_list:
            if isinstance(self.target_modules,str):
                target_module_found = re.fullmatch(self.target_modules, key)
            else: 
                target_module_found = any(key.endswith(target_key) for target_key in self.target_modules)


            if target_module_found:
                parent,target,targe_name = self.get_submodules(key)
                new_module = self.setHyperNetLayer( target.weight)
                self.replace_module(parent, targe_name, new_module, target)
                print('replaced',key)

    def setLayerWeights(self,hypernet_outputs, layer):
        assert hypernet_outputs.shape[1] == self.num_layers 
        for l,hyper_out in zip(self.hypernet_layers,hypernet_outputs):
            for k,_ in enumerate(self.target_modules):
                
                l = self.hypernet_layers[l]
    def updateLayers(self,new_layer_tensor):

        #new_layer_tensor is of shape b X (l * kvq) X r X a*2 
        for i,l in enumerate(self.hypernet_layers,):
            #l_new is of shape b X r X a*2 
            l_new = new_layer_tensor[:,i]

            a = l_new[:,:,:self.a ]
            b = l_new[:,:,self.a: ]
            l.set_adapter(a ,b )
            
            
                
        

class HyperNet(nn.Module):
    def __init__(self, r, alpha, dropout, output_emb,d_emb, d_model,n_transformer_layers,hypernet_heads = 2,target_modules = ['q_proj','k_proj','v_proj'],a=16):
        super(HyperNet, self).__init__()
        self.r = r
        self.alpha = alpha
        self.dropout = dropout
        self.n_transformer_layers = n_transformer_layers
        self.d_emb = d_emb
        self.a = a 

        
        #each attention head is 3*(d_emb * d_emb) as each k,v,h matrix is d_emb * d_emb
        #We want to use an r-ranked represtation of the d_emb * d_emb matrix so we decompose the total output parameters by r 
        #We replace d_emb by a singular r rankned vector

        
        # Initialize 
        # encoder layers with user_emb_dim
        self.target_modules = target_modules
        encoder_layer = TransformerEncoderLayer(d_model=d_emb, nhead=hypernet_heads, dim_feedforward=d_emb * 4, dropout=dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=n_transformer_layers)
        
        #make an mlp to transform the output of the encoder into d_model 
        self.mlp_encoder = nn.Sequential(
            nn.Linear(d_emb, d_emb * 2),
            nn.ReLU(),
            nn.Linear(d_emb * 2, d_model)
        )
        # Initialize transformer decoder layers with d_emb

        decoder_layer = TransformerDecoderLayer(d_model=d_model, nhead=hypernet_heads, dim_feedforward=d_emb * 4, dropout=dropout)
        self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=n_transformer_layers)
        
        # MLP to map the decoder output to self.net_dim
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model * 2),
            nn.ReLU(),
            nn.Linear(d_model* 2, self.a  * 2 *3  )
        )
        
    def forward(self, task_embedding, identity_matrix):
        
        # Encode the task embedding
        encoded_task = self.transformer_encoder(task_embedding)
        encoded_task = self.mlp_encoder(encoded_task)
        
        # Decode to generate the network parameters
        decoded_output = self.transformer_decoder(identity_matrix, encoded_task)
        
        # Map the decoder output to the network parameters using MLP
        params = self.mlp(decoded_output)
        # params = self.organize_outputs(params)
        return params
    
    def organize_outputs(self, outputs):
        param_l = []
        for out in outputs:
            sub_list = []
            s = 0 
            for target in self.target_modules:
                
                sub_list.append(out[s:2 *(s+self.r) ])
                s = 2* (s + self.r)
            param_l.append(sub_list)
        return param_l
            

            
class HyperNetLinear(nn.Linear):
    def __init__(self,hypernet,w0,in_features,out_features,r = 1 , scaling = 1, A=512, a = 16):
        nn.Linear.__init__(self,in_features,out_features)
        self.scaling = scaling
        self.hypernet = hypernet
        self.w0 = copy.deepcopy(w0)

        self.adapterA = nn.init.xavier_normal_(torch.empty(A, r)).to(w0.device)
        self.adapterB = nn.init.xavier_normal_(torch.empty(A, r)).to(w0.device)

        self.orthA = self.make_orth(A,a).to(w0.device)
        self.orthB = self.make_orth(A,a).to(w0.device)
    def make_orth(self,A,a):
        gaus = torch.randn (A,A )
        svd = torch.linalg.svd (gaus)        
        return (svd[0] @ svd[2])[:a]

    def set_adapter(self,adapterA,adapterB,transposeA = True):
        self.adapterA = adapterA if not transposeA else torch.transpose(adapterA,1,2)
        self.adapterB = adapterB
        
    def forward(self,x):
        bsize = x.shape[0]
        orthA = self.orthA.T.repeat(bsize,1,1)
        orthB = self.orthB.repeat(bsize,1,1)
        _A = torch.bmm(orthA , self.adapterA)
        _B = torch.bmm( self.adapterB,orthB)
        out = x @ (self.scaling * (self.w0 + _A @ _B  ))
        raise Exception
        return out 
    def train(self,mode = True):
        #Make sure the adapters are in training mode but the orth matrices are not and w0 are not 
        self.adapterA.train(mode)
        self.adapterB.train(mode)
        self.orthA.train(False)
        self.orthB.train(False)
        self.w0.train(False)
        return self
        
        

        


In [5]:
from transformers import BertTokenizer, BertModel

r = 1
alpha = 0.1
dropout = 0.1
n_transformer_layers = 1
n_transformer_heads = 2
d_emb = 64
d_model = 16
a_b = 2 
kvq =3
#this is to produce a layer at a time
output_dim = (r *  d_model ) * a_b  * kvq   
a = 16 

#load bert as the model 
model = BertModel.from_pretrained('bert-base-uncased',cache_dir  = cache_dir).to(0)
A = model.config.hidden_size

hypernet = HyperNet(r, alpha, dropout, output_dim,d_emb, d_model, n_transformer_layers, n_transformer_heads,a = a ).to(0)
controller = HyperNetController(hypernet,target_modules = ['query','key','value'], 
                                r = r,
                                A = A,
                                a = a)




In [6]:
controller.augmentLLM(model)



replaced encoder.layer.0.attention.self.query
replaced encoder.layer.0.attention.self.key
replaced encoder.layer.0.attention.self.value
replaced encoder.layer.1.attention.self.query
replaced encoder.layer.1.attention.self.key
replaced encoder.layer.1.attention.self.value
replaced encoder.layer.2.attention.self.query
replaced encoder.layer.2.attention.self.key
replaced encoder.layer.2.attention.self.value
replaced encoder.layer.3.attention.self.query
replaced encoder.layer.3.attention.self.key
replaced encoder.layer.3.attention.self.value
replaced encoder.layer.4.attention.self.query
replaced encoder.layer.4.attention.self.key
replaced encoder.layer.4.attention.self.value
replaced encoder.layer.5.attention.self.query
replaced encoder.layer.5.attention.self.key
replaced encoder.layer.5.attention.self.value
replaced encoder.layer.6.attention.self.query
replaced encoder.layer.6.attention.self.key
replaced encoder.layer.6.attention.self.value
replaced encoder.layer.7.attention.self.query
re