In [1]:
import pickle
import sys
import os
# current_dir = os.getcwd()
# parent_dir = os.path.dirname(current_dir)
# sys.path.append(parent_dir)

from _aux_mamba import get_tokenizer, get_model

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import evaluate
from operator import itemgetter
from datasets import load_dataset
import time

In [3]:
basemodel = get_model('state-spaces/mamba-130m-hf')
print(basemodel)

MambaForCausalLM(
  (backbone): MambaModel(
    (embeddings): Embedding(50280, 768)
    (layers): ModuleList(
      (0-23): 24 x MambaBlock(
        (norm): MambaRMSNorm()
        (mixer): MambaMixer(
          (conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
          (act): SiLU()
          (in_proj): Linear(in_features=768, out_features=3072, bias=False)
          (x_proj): Linear(in_features=1536, out_features=80, bias=False)
          (dt_proj): Linear(in_features=48, out_features=1536, bias=True)
          (out_proj): Linear(in_features=1536, out_features=768, bias=False)
        )
      )
    )
    (norm_f): MambaRMSNorm()
  )
  (lm_head): Linear(in_features=768, out_features=50280, bias=False)
)


In [4]:
basetokenizer = get_tokenizer('state-spaces/mamba-130m-hf')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
dataset = load_dataset("EdinburghNLP/xsum")

In [6]:
# load r2scores
file1 = 'experiment/mamba-130m-hf/xsum_r2_scores.pickle'
with open(file1, 'rb') as file:
    r2_scores = pickle.load(file)

In [7]:
# list of the r2_score's key
connections = list(r2_scores['score_mat'].keys())

In [8]:
connections[0]

(0, 1)

In [16]:
"""
LayerSelector 
: select layers included in forwarding.
n is the number of selected layers.
0th mamba_blocks is always included, So we need to select intermediate layers(n-1).

I implemented this LayerSelector class that makes all possible selected layer list now.
-> this implementation requires too much resources.
-> todo : Change this process with an algorithm

"""
class LayerSelector:
    def __init__(self, n):
        self.n = n
        self.selected_layers = []
        self.arch = []

    def selecting(self, prev, num_selected):
        if num_selected == self.n - 1:
            self.selected_layers.insert(0, 0)
            self.arch.append(self.selected_layers[:])
            self.selected_layers.pop(0)
            return

        for i in range(prev + 1, 25):
            self.selected_layers.append(i)
            self.selecting(i, num_selected + 1)
            self.selected_layers.pop()

    def select_layer(self):
        self.selecting(0, 0)
        return self.arch

In [17]:
selector = LayerSelector(3)
layers = selector.select_layer()
print(layers)

[[0, 1, 2], [0, 1, 3], [0, 1, 4], [0, 1, 5], [0, 1, 6], [0, 1, 7], [0, 1, 8], [0, 1, 9], [0, 1, 10], [0, 1, 11], [0, 1, 12], [0, 1, 13], [0, 1, 14], [0, 1, 15], [0, 1, 16], [0, 1, 17], [0, 1, 18], [0, 1, 19], [0, 1, 20], [0, 1, 21], [0, 1, 22], [0, 1, 23], [0, 1, 24], [0, 2, 3], [0, 2, 4], [0, 2, 5], [0, 2, 6], [0, 2, 7], [0, 2, 8], [0, 2, 9], [0, 2, 10], [0, 2, 11], [0, 2, 12], [0, 2, 13], [0, 2, 14], [0, 2, 15], [0, 2, 16], [0, 2, 17], [0, 2, 18], [0, 2, 19], [0, 2, 20], [0, 2, 21], [0, 2, 22], [0, 2, 23], [0, 2, 24], [0, 3, 4], [0, 3, 5], [0, 3, 6], [0, 3, 7], [0, 3, 8], [0, 3, 9], [0, 3, 10], [0, 3, 11], [0, 3, 12], [0, 3, 13], [0, 3, 14], [0, 3, 15], [0, 3, 16], [0, 3, 17], [0, 3, 18], [0, 3, 19], [0, 3, 20], [0, 3, 21], [0, 3, 22], [0, 3, 23], [0, 3, 24], [0, 4, 5], [0, 4, 6], [0, 4, 7], [0, 4, 8], [0, 4, 9], [0, 4, 10], [0, 4, 11], [0, 4, 12], [0, 4, 13], [0, 4, 14], [0, 4, 15], [0, 4, 16], [0, 4, 17], [0, 4, 18], [0, 4, 19], [0, 4, 20], [0, 4, 21], [0, 4, 22], [0, 4, 23], [0, 4

In [18]:
'''
load_arch
: this function takes selected layer numbers and loads real shorcut weights from the shorcut model.

linear_layers are list of the real weights.
'''
def load_arch(selected_layers):
    base_path = "linreg/mamba-130m-hf/xsum"
    linear_layers = []
    for i in range(1, len(selected_layers)):
        name = f"{selected_layers[i-1]}_{selected_layers[i]}.pickle"
        full_path = os.path.join(base_path, name)
        with open(full_path, 'rb') as file:
            linreg = pickle.load(file)
            linear_layers.append(linreg)
    return linear_layers

In [19]:
print(layers[0])

[0, 1, 2]


In [20]:
linear_layers = load_arch(layers[0])
for layer in linear_layers:
    print(layer.shape)

torch.Size([768, 768])
torch.Size([768, 768])


In [21]:
def beam_search_for_base(model, tokenizer, input_ids, beam_size=2, max_length=50):

    finished_beams = []
    running_beam = [(0, input_ids)]
    input_len = len(input_ids)

    start = time.time()
    while len(finished_beams) < beam_size and running_beam:
        beam_score, input_ids = running_beam.pop(0)
        
        outputs = model(input_ids)
        logits = outputs.logits[:, -1, :]

        # Choose top 2 (beam_size) tokens
        top_k_values, top_k_indices = torch.topk(logits, beam_size, dim=-1)
 
        input_ids_list = [input_ids] * beam_size
    
        for i in range(beam_size):
            score = top_k_values[:,i]
            token = top_k_indices[:,i]
            
            # Add the new token and update attention_mask
            new_input_ids = torch.cat((input_ids_list[i], token.unsqueeze(0)), dim=-1)
           
            if token == tokenizer.eos_token_id or new_input_ids.shape[-1] > max_length + input_len:
                finished_beams.append((beam_score + score, new_input_ids))
            else:
                running_beam.append((beam_score + score, new_input_ids))
                
        # Sort the running beams by score
        running_beam.sort(key=lambda x: x[0], reverse=True)
    
    # Return the highest scoring finished beam
    result = max(finished_beams, key=lambda x: x[0])[1]
    end = time.time()
    print('inference time : ',end-start)
    return result

In [22]:
def beam_search_for_sc(model, input_ids, beam_size=2, max_length=50):

    finished_beams = []
    running_beam = [(0, input_ids)]
    input_len = len(input_ids)

    start = time.time()
    while len(finished_beams) < beam_size and running_beam:
        beam_score, input_ids = running_beam.pop(0)
        
        outputs = model(input_ids)
        logits = outputs[:, -1, :]

        # Choose top 2 (beam_size) tokens
        top_k_values, top_k_indices = torch.topk(logits, beam_size, dim=-1)
 
        input_ids_list = [input_ids] * beam_size
    
        for i in range(beam_size):
            score = top_k_values[:,i]
            token = top_k_indices[:,i]
            
            # Add the new token and update attention_mask
            new_input_ids = torch.cat((input_ids_list[i], token.unsqueeze(0)), dim=-1)
           
            if token == model.tokenizer.eos_token_id or new_input_ids.shape[-1] > max_length + input_len:
                finished_beams.append((beam_score + score, new_input_ids))
            else:
                running_beam.append((beam_score + score, new_input_ids))
                
        # Sort the running beams by score
        running_beam.sort(key=lambda x: x[0], reverse=True)
    result = max(finished_beams, key=lambda x: x[0])[1]
    end = time.time()
    print('inference time : ',end-start)
    # Return the highest scoring finished beam
    return result

In [23]:
class ShortcutModel(nn.Module):
    def __init__(self, n, selected_layers):
        super(ShortcutModel, self).__init__()
        self.n = n
        self.path = [nn.Linear(768, 768) for _ in range(self.n-1)]
        self.selected_layers = selected_layers
        self.weight_list = []
        self.tokenizer = get_tokenizer('state-spaces/mamba-130m-hf')
        self.basemodel = get_model('state-spaces/mamba-130m-hf')
        self.vocab_size = self.tokenizer.vocab_size
        self.emb_dim = self.basemodel.backbone.embeddings.embedding_dim
        ## added tokens 23
        self.embed = nn.Embedding(self.vocab_size+30, self.emb_dim)
        self.norm = self.basemodel.backbone.norm_f
        self.lm_head = self.basemodel.lm_head
        
        
    def load_weights(self):
        base_path = "linreg/mamba-130m-hf/xsum"
        for i in range(1, self.n):
            name = f"{self.selected_layers[i-1]}_{self.selected_layers[i]}.pickle"
            full_path = os.path.join(base_path, name)
            with open(full_path, 'rb') as file:
                linreg = pickle.load(file)
                self.weight_list.append(linreg)
    
    
    def initialize(self):
        for i in range(self.n-1):
            self.path[i].weight.data = self.weight_list[i]
        
    def forward(self, x):
        x = self.embed(x)
        for layer in self.path:
            x = layer(x)
        x = self.norm(x)
        x = self.lm_head(x)
 
        return x

    def original(self, x):
        x = self.basemodel(x)
        return x

In [24]:
shortcut_model = ShortcutModel(3, layers[0])
shortcut_model.load_weights()
shortcut_model.initialize()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [26]:
src_text = [
    """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
]
prefix = "summarize this: "
suffix = "Here's the summary: "
input_text = prefix + dataset['train']['document'][0]
base_tokenized_input = basetokenizer.encode(input_text, return_tensors='pt')

In [27]:
generated =beam_search_for_base(basemodel, basetokenizer, base_tokenized_input)
tokens = generated[0]

inference time :  1.8883771896362305


In [29]:
pred_original = basetokenizer.decode(tokens, skip_special_tokens=True).split(suffix)
print(pred_original)



In [30]:
sc_tokenized_input = shortcut_model.tokenizer.encode(input_text, return_tensors='pt')
if torch.any(sc_tokenized_input >= shortcut_model.tokenizer.vocab_size):
    print('hi')

In [31]:
sc_generated = beam_search_for_sc(shortcut_model, sc_tokenized_input)
sc_tokens = sc_generated[0]

inference time :  0.09805631637573242


In [32]:
sc_pred = shortcut_model.tokenizer.decode(sc_tokens, skip_special_tokens=True).split(suffix)
print(sc_pred )



In [35]:
def get_score():
    rouge = evaluate.load('rouge')
    scores = rouge.compute(predictions=pred_original, references=sc_pred)
    print(scores)

In [36]:
get_score()

{'rouge1': 0.9987819732034106, 'rouge2': 0.9987789987789988, 'rougeL': 0.9987819732034106, 'rougeLsum': 0.9987819732034106}
