In [18]:
import torch
import re


def load_ensemble(filepaths):
    def load_from_checkpoint(idx, filepath):
        parameters = torch.load(filepath)["state_dict"]["bayesian_layer"].params
        parameters = {
            re.sub(r"model\.layers\.\d+\.", "", k): v
            for k, v in parameters.items()
            if v.numel() > 0
        }
        return parameters

    return [
        load_from_checkpoint(idx, filepath) for idx, filepath in enumerate(filepaths)
    ]


parameters = load_ensemble(
    [
        "logs/bayes_with_seq/checkpoints/epoch=0-step=300.ckpt",
        "logs/bayes_with_seq/checkpoints/epoch=1-step=600.ckpt",
        "logs/bayes_with_seq/checkpoints/epoch=2-step=900.ckpt",
        "logs/bayes_with_seq/checkpoints/epoch=3-step=1200.ckpt",
        "logs/bayes_with_seq/checkpoints/epoch=4-step=1500.ckpt",
        "logs/bayes_with_seq/checkpoints/epoch=5-step=1800.ckpt",
    ]
)

In [19]:
%load_ext autoreload
%autoreload 2

from llama3.modules.bayesllama import BayesLlamaForCausalLM
bayes_config = {'n_ensemble': 6}

bayes = BayesLlamaForCausalLM.from_pretrained("Meta-Llama-3-8B", bayes_config=bayes_config).to("cuda:0")
bayes.load_bayesian_layers(parameters)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.04it/s]
Some weights of BayesLlamaForCausalLM were not initialized from the model checkpoint at Meta-Llama-3-8B and are newly initialized: ['model.bayesian_layers.0.input_layernorm.weight', 'model.bayesian_layers.0.mlp.down_proj.weight', 'model.bayesian_layers.0.mlp.gate_proj.weight', 'model.bayesian_layers.0.mlp.up_proj.weight', 'model.bayesian_layers.0.post_attention_layernorm.weight', 'model.bayesian_layers.0.self_attn.k_proj.weight', 'model.bayesian_layers.0.self_attn.o_proj.weight', 'model.bayesian_layers.0.self_attn.q_proj.weight', 'model.bayesian_layers.0.self_attn.v_proj.weight', 'model.bayesian_layers.1.input_layernorm.weight', 'model.bayesian_layers.1.mlp.down_proj.weight', 'model.bayesian_layers.1.mlp.gate_proj.weight', 'model.bayesian_layers.1.mlp.up_proj.weight', 'model.bayesian_layers.1.post_attention_layernorm.weight', 'model.bayesian_layers.1.self_attn.k_proj.weight', 'model.bayesian_layers.1.self_attn.o_p

In [20]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Meta-Llama-3-8B")

inputs = "This is a test"
inputs = tokenizer(inputs, return_tensors="pt")

inputs.input_ids.shape

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


torch.Size([1, 5])

In [22]:
out = bayes(
    **inputs.to("cuda"), return_dict=True, output_hidden_states=True, use_cache=True
)
out

CausalLMOutputWithPast(loss=None, logits=(tensor([[[ 1.9585,  3.7956,  3.3220,  ..., -1.9572, -1.9571, -1.9572],
         [-0.2953, -1.7536, -1.7713,  ..., -8.5935, -8.5937, -8.5938],
         [-0.9260, -1.0948, -2.7706,  ..., -8.0689, -8.0688, -8.0689],
         [-2.8621, -1.4704, -2.8481,  ..., -8.1362, -8.1362, -8.1363],
         [ 6.8398,  0.2942,  1.5580,  ..., -7.5922, -7.5921, -7.5921]]],
       device='cuda:0', grad_fn=<MeanBackward1>), tensor([[[[ 6.8929,  8.9111, 12.2886,  ..., -4.5894, -4.5893, -4.5895],
          [ 1.6496, -0.2037, -0.6217,  ..., -9.4461, -9.4463, -9.4464],
          [ 0.4088, -0.3973, -2.1203,  ..., -8.7409, -8.7408, -8.7409],
          [-1.2420, -0.1184, -1.7484,  ..., -9.1504, -9.1505, -9.1506],
          [ 7.9730,  1.3265,  2.4272,  ..., -8.6550, -8.6549, -8.6548]]],


        [[[ 4.8328,  6.8529,  8.6868,  ..., -3.5003, -3.5003, -3.5004],
          [ 1.4659, -0.2835, -0.6483,  ..., -9.5624, -9.5625, -9.5627],
          [ 0.1830, -0.4444, -2.1270,  ...,

In [38]:
@torch.no_grad()
def generate(model, inputs, max_length=10, use_cache=True):
    seq_out = []
    for idx in range(max_length):
        outputs = model(**inputs, return_dict=False, use_cache=use_cache)

        if "attention_mask" in inputs:
            del inputs["attention_mask"]

        next_token = outputs[0][0][:, -1].argmax(-1).unsqueeze(-1)
        if use_cache:
            inputs["past_key_values"] = outputs[1]
            inputs["ensemble_past_key_values"] = outputs[2]
            inputs["input_ids"] = next_token
        else:
            inputs["input_ids"] = torch.cat([inputs["input_ids"], next_token], dim=1)
        seq_out.append(next_token)

    seq_out = torch.cat(seq_out, -1)
    return tokenizer.batch_decode(seq_out, skip_special_tokens=True)


inputs = "This is a test"
inputs = tokenizer(inputs, return_tensors="pt")
generate(bayes, inputs.to("cuda"), use_cache=True)

[' of the emergency broadcast system. This is only a']

In [38]:
print("logits out", out.logits[0].shape)  # (batch_size, sequence_length, vocab_size)
print(
    "ensemble logits out", out.logits[1].shape
)  # (n_ensemble, batch_size, sequence_length, vocab_size)

print("attentions size", out.attentions[0].shape)  # 32 items
print(
    "last attentions length", len(out.attentions[-1])
)  # Last item has attentions from n_ensemble layers

print("hidden states length", out.hidden_states[0].shape)  # 32 items
print(
    "last hidden states shape", out.hidden_states[-1].shape
)  # (n_ensemble, batch_size, sequence_length, hidden_size)

print("length of past key values", len(out.past_key_values))  # 32 items
print("past key values shape", out.past_key_values[-1][0].shape)  # Need to fix this

logits out torch.Size([1, 5, 128256])
ensemble logits out torch.Size([3, 1, 5, 128256])
attentions size torch.Size([1, 32, 5, 5])
last attentions length 3
hidden states length torch.Size([1, 5, 4096])
last hidden states shape torch.Size([3, 1, 5, 4096])
length of past key values 32
past key values shape torch.Size([1, 8, 15, 128])


In [9]:
# This is wrong
print(out.attentions[-1][0].shape)
print(out.attentions[-1][-1].shape)

torch.Size([1, 32, 5, 5])
torch.Size([1, 32, 5, 15])
