In [7]:
import pickle
import os
from _aux_mamba import get_tokenizer, get_model

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import evaluate
from operator import itemgetter
from datasets import load_dataset

In [3]:
basemodel = get_model('state-spaces/mamba-130m-hf')
print(basemodel)

MambaForCausalLM(
  (backbone): MambaModel(
    (embeddings): Embedding(50280, 768)
    (layers): ModuleList(
      (0-23): 24 x MambaBlock(
        (norm): MambaRMSNorm()
        (mixer): MambaMixer(
          (conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
          (act): SiLU()
          (in_proj): Linear(in_features=768, out_features=3072, bias=False)
          (x_proj): Linear(in_features=1536, out_features=80, bias=False)
          (dt_proj): Linear(in_features=48, out_features=1536, bias=True)
          (out_proj): Linear(in_features=1536, out_features=768, bias=False)
        )
      )
    )
    (norm_f): MambaRMSNorm()
  )
  (lm_head): Linear(in_features=768, out_features=50280, bias=False)
)


In [4]:
tokenizer = get_tokenizer('state-spaces/mamba-130m-hf')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
dataset = load_dataset("EdinburghNLP/xsum")

In [8]:
# load r2scores
file1 = 'experiment/mamba-130m-hf/xsum_r2_scores.pickle'
with open(file1, 'rb') as file:
    r2_scores = pickle.load(file)

In [9]:
# list of the r2_score's key
connections = list(r2_scores['score_mat'].keys())

In [10]:
connections[0]

(0, 1)

In [11]:
"""
LayerSelector 
: select layers included in forwarding.
n is the number of selected layers.
0th and 24th layers are always included, So we need to select intermediate layers(n-2).

I implemented this LayerSelector class that makes all possible selected layer list now.
-> this implementation requires too much resources.
-> todo : Change this process with an algorithm

"""
class LayerSelector:
    def __init__(self, n):
        self.n = n
        self.selected_layers = []
        self.arch = []

    def selecting(self, prev, num_selected):
        if num_selected == self.n - 2:
            self.selected_layers.append(24)
            self.selected_layers.insert(0, 0)
            self.arch.append(self.selected_layers[:])
            self.selected_layers.pop()
            self.selected_layers.pop(0)
            return

        for i in range(prev + 1, 24):
            self.selected_layers.append(i)
            self.selecting(i, num_selected + 1)
            self.selected_layers.pop()

    def select_layer(self):
        self.selecting(0, 0)
        return self.arch

In [12]:
selector = LayerSelector(3)
layers = selector.select_layer()
print(layers)

[[0, 1, 24], [0, 2, 24], [0, 3, 24], [0, 4, 24], [0, 5, 24], [0, 6, 24], [0, 7, 24], [0, 8, 24], [0, 9, 24], [0, 10, 24], [0, 11, 24], [0, 12, 24], [0, 13, 24], [0, 14, 24], [0, 15, 24], [0, 16, 24], [0, 17, 24], [0, 18, 24], [0, 19, 24], [0, 20, 24], [0, 21, 24], [0, 22, 24], [0, 23, 24]]


In [13]:
'''
load_arch
: this function takes selected layer numbers and loads real shorcut weights from the shorcut model.

linear_layers are list of the real weights.
'''
def load_arch(selected_layers):
    base_path = "linreg/mamba-130m-hf/xsum"
    linear_layers = []
    for i in range(1, len(selected_layers)):
        name = f"{selected_layers[i-1]}_{selected_layers[i]}.pickle"
        full_path = os.path.join(base_path, name)
        with open(full_path, 'rb') as file:
            linreg = pickle.load(file)
            linear_layers.append(linreg)
    return linear_layers

In [14]:
print(layers[0])

[0, 1, 24]


In [15]:
linear_layers = load_arch(layers[0])
for layer in linear_layers:
    print(layer.shape)

torch.Size([768, 768])
torch.Size([768, 768])


In [40]:
"""
tried to write beam search code
I was struggling because huggingface's mamba code doesn't get last state in forward code.
If I modify and use customized mamba code then I think we can't use finetuned weights.
What should I do?
"""

LENGTH_PENALTY = 1.2
MIN_LENGTH = 5

class SingleBeamSearchSpace():

    def __init__(self, decoder_input, beam_size, max_length = 255):
        self.beam_size = beam_size
        self.max_length = max_length

        super(SingleBeamSearchSpace, self).__init__()

        self.device = decoder_input.device
        self.word_indice = [torch.LongTensor(self.beam_size).zero_().to(self.device)]
        self.prev_beam_indice = [torch.LongTensor(self.beam_size).zero_().to(self.device) - 1]
        self.cumulative_probs = [torch.FloatTensor([.0] + [-float('inf')] * (beam_size - 1)).to(self.device)]
        self.masks = [torch.ByteTensor(beam_size).zero_().to(self.device)] # 1 if it is done else 0
        self.decoder_input = decoder_input

        self.decoder_input_ids = self.decoder_input.repeat(self.beam_size, 1)
        #self.decoder_attention_masks = self.decoder_input['attention_mask'].repeat(self.beam_size, 1)

        self.current_time_step = 0
        self.done_cnt = 0

    def get_length_penalty(self, length, alpha = LENGTH_PENALTY, min_length = MIN_LENGTH):
        p = (1 + length) ** alpha / (1 + min_length) ** alpha

        return p

    def is_done(self):
        if self.done_cnt >= self.beam_size:
            return 1
        return 0

    def get_batch(self):
        y_prev = self.word_indice[-1].unsqueeze(-1)
        
        return y_prev

    def collect_result(self, y_hat):
        output_size = y_hat.size(-1)

        self.current_time_step += 1

        cumulative_prob = y_hat + self.cumulative_probs[-1].masked_fill_(self.masks[-1], -float('inf')).view(-1, 1, 1).expand(self.beam_size, 1, output_size)
        top_log_prob, top_indice = torch.topk(cumulative_prob.view(-1), self.beam_size, dim = -1)
        # |top_log_prob| = (beam_size)
        # |top_indice| = (beam_size)
        self.word_indice += [top_indice.fmod(output_size)]
        self.prev_beam_indice += [top_indice.div(output_size).long()]

        self.cumulative_probs += [top_log_prob]
        self.masks += [torch.eq(self.word_indice[-1], tokenizer.eos_token_id)]
        self.done_cnt += self.masks[-1].float().sum()

        #self.prev_state = torch.index_select(prev_state, dim = 1, index = self.prev_beam_indice[-1]).contiguous()

    def get_n_best(self, n = 1):
        sentences = []
        probs = []
        founds = []

        for t in range(len(self.word_indice)):
            for b in range(self.beam_size):
                if self.masks[t][b] == 1:
                    probs += [self.cumulative_probs[t][b] / self.get_length_penalty(t)]
                    founds += [(t, b)]

        for b in range(self.beam_size):
            if self.cumulative_probs[-1][b] != -float('inf'):
                if not (len(self.cumulative_probs) - 1, b) in founds:
                    probs += [self.cumulative_probs[-1][b]]
                    founds += [(t, b)]

        sorted_founds_with_probs = sorted(zip(founds, probs), 
                                            key = itemgetter(1), 
                                            reverse = True
                                            )[:n]
        probs = []

        for (end_index, b), prob in sorted_founds_with_probs:
            sentence = []

            for t in range(end_index, 0, -1):
                sentence = [self.word_indice[t][b]] + sentence
                b = self.prev_beam_indice[t][b]

            sentences += [sentence]
            probs += [prob]

        return sentences, probs

In [42]:
def batch_beam_search(model, tokenized_input, beam_size, max_length = 255, n_best = 1):
    model.eval()
    x = tokenized_input
    batch_size = x.size(0)

    # initialize beam-search.
    spaces = [SingleBeamSearchSpace(tokenized_input, 
                                        beam_size=beam_size
                                        ) for i in range(batch_size)]
    done_cnt = [space.is_done() for space in spaces]

    length = 0
    while sum(done_cnt) < batch_size and length <= max_length:
        # current_batch_size = sum(done_cnt) * beam_size

        # initialize fabricated variables.
        fab_input = []
        fab_h_src, fab_mask = [], []

        # batchify.
        for i, space in enumerate(spaces):
            if space.is_done() == 0:
                y_prev_ = space.get_batch()

                fab_input += [y_prev_]

                fab_h_src += [x[i, :]] * beam_size
                #fab_mask += [input_mask[i, :]] * beam_size

        fab_input = torch.cat(fab_input, dim = 0)
        fab_h_src = torch.stack(fab_h_src)

        fab_output_logits = model(fab_input).logits[:, -1, :]
        
        y_hat = torch.log_softmax(fab_output_logits, dim=-1)
        output_size = y_hat.shape[-1]

        cnt = 0
        for space in spaces:
            if space.is_done() == 0:
                from_index = cnt * output_size
                to_index = (cnt + 1) * output_size

                # pick k-best results for each sample.
                space.collect_result(y_hat[from_index:to_index])
                cnt += 1

        done_cnt = [space.is_done() for space in spaces]
        length += 1

    batch_sentences = []
    batch_probs = []

    for i, space in enumerate(spaces):
        sentences, probs = space.get_n_best(n_best)

        batch_sentences += [sentences]
        batch_probs += [probs]

    return batch_sentences

In [32]:
class ShortcutModel(nn.Module):
    def __init__(self, n, selected_layers):
        super(ShortcutModel, self).__init__()
        self.n = n
        self.path = nn.ModuleList([nn.Linear(768, 768) for _ in range(self.n-1)])
        self.selected_layers = selected_layers
        self.weight_list = []
        self.tokenizer = get_tokenizer('state-spaces/mamba-130m-hf')
        self.basemodel = get_model('state-spaces/mamba-130m-hf')
        self.embed = basemodel.backbone.embeddings
        self.norm = basemodel.backbone.norm_f
        self.lm_head = basemodel.lm_head
        
        
    def load_weights(self):
        base_path = "linreg/mamba-130m-hf/xsum"
        for i in range(1, self.n):
            name = f"{self.selected_layers[i-1]}_{self.selected_layers[i]}.pickle"
            full_path = os.path.join(base_path, name)
            with open(full_path, 'rb') as file:
                linreg = pickle.load(file)
                self.weight_list.append(linreg)
    
    
    def initialize(self):
        for i in range(self.n-1):
            self.path[i].weight.data = self.weight_list[i]
        
    def forward(self, x):
        x = self.embed(x)
        for layer in self.path:
            x = layer(x)
        x = self.norm(x)
        x = self.lm_head(x)
 
        return x

    def original(self, x):
        x = self.basemodel(x)
        return x
    
#     def decode(self, x):
#         result_list = self.result.tolist()
#         self.decoded_result = self.tokenizer.decode(result_list, skip_special_tokens=True).split(suffix)[1]
#         self.decoded_ref = self.tokenizer.decode(x, skip_special_tokens=True).split(suffix)[1]
#         print(self.decoded_result)
#         print(self.decoded_ref)
        
    def get_score(self, x):
        rouge = evaluate.load('rouge')
        scores = rouge.compute(predictions=self.result, references=x)
        print(scores['rouge2'])
        return scores['rouge2']

In [35]:
model = basemodel

In [34]:
model = ShortcutModel(3, layers[0])
# model.load_weights()
# model.initialize()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [36]:
src_text = [
    """ PG&E stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow."""
]
prefix = "summarize this: "
suffix = "Here's the summary: "
input_text = prefix + dataset['train']['document'][0]+suffix
max_length = 512
tokenized_input = tokenizer.encode(input_text, return_tensors='pt')

In [None]:
generated = batch_beam_search(model, tokenized_input, 3)
tokens = [int(t.item()) for t in generated[0][0]]

  cumulative_prob = y_hat + self.cumulative_probs[-1].masked_fill_(self.masks[-1], -float('inf')).view(-1, 1, 1).expand(self.beam_size, 1, output_size)


In [None]:
pred = tokenizer.decode(tokens, skip_special_tokens=True).split(suffix)
print(pred)