In [3]:
import torch
import re


def load_ensemble(filepaths):
    def load_from_checkpoint(idx, filepath):
        parameters = torch.load(filepath)["state_dict"]["bayesian_layer"].params
        parameters = {
            re.sub(r"model\.layers\.\d+\.", "", k): v
            for k, v in parameters.items()
            if v.numel() > 0
        }
        return parameters

    return [
        load_from_checkpoint(idx, filepath) for idx, filepath in enumerate(filepaths)
    ]


parameters = load_ensemble(
    [
        "logs/bayes_with_seq/checkpoints/epoch=0-step=300.ckpt",
        "logs/bayes_with_seq/checkpoints/epoch=1-step=600.ckpt",
        "logs/bayes_with_seq/checkpoints/epoch=2-step=900.ckpt",
        "logs/bayes_with_seq/checkpoints/epoch=3-step=1200.ckpt",
        "logs/bayes_with_seq/checkpoints/epoch=4-step=1500.ckpt",
        "logs/bayes_with_seq/checkpoints/epoch=5-step=1800.ckpt",
    ]
)

In [4]:
%load_ext autoreload
%autoreload 2

from llama3.modules.bayesllama import BayesLlamaForCausalLM
bayes_config = {'n_ensemble': 6}

bayes = BayesLlamaForCausalLM.from_pretrained("Meta-Llama-3-8B-Instruct", bayes_config=bayes_config).to("cuda:0")
bayes.load_bayesian_layers(parameters)

Loading checkpoint shards: 100%|██████████| 4/4 [00:06<00:00,  1.67s/it]
Some weights of BayesLlamaForCausalLM were not initialized from the model checkpoint at Meta-Llama-3-8B-Instruct and are newly initialized: ['model.bayesian_layers.0.input_layernorm.weight', 'model.bayesian_layers.0.mlp.down_proj.weight', 'model.bayesian_layers.0.mlp.gate_proj.weight', 'model.bayesian_layers.0.mlp.up_proj.weight', 'model.bayesian_layers.0.post_attention_layernorm.weight', 'model.bayesian_layers.0.self_attn.k_proj.weight', 'model.bayesian_layers.0.self_attn.o_proj.weight', 'model.bayesian_layers.0.self_attn.q_proj.weight', 'model.bayesian_layers.0.self_attn.v_proj.weight', 'model.bayesian_layers.1.input_layernorm.weight', 'model.bayesian_layers.1.mlp.down_proj.weight', 'model.bayesian_layers.1.mlp.gate_proj.weight', 'model.bayesian_layers.1.mlp.up_proj.weight', 'model.bayesian_layers.1.post_attention_layernorm.weight', 'model.bayesian_layers.1.self_attn.k_proj.weight', 'model.bayesian_layers.1.self

In [5]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Meta-Llama-3-8B-Instruct")

inputs = "This is a test"
inputs = tokenizer(inputs, return_tensors="pt")

inputs.input_ids.shape

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


torch.Size([1, 8])

In [6]:
out = bayes(
    **inputs.to("cuda"), return_dict=True, output_hidden_states=True, use_cache=True
)
out

CausalLMOutputWithPast(loss=None, logits=(tensor([[[-3.6186, -1.7161, -5.9545,  ...,  0.6420,  0.6421,  0.6422],
         [ 0.3428, -1.6811, -0.8162,  ..., -7.3246, -7.3250, -7.3248],
         [ 4.4888,  1.4576, -0.7534,  ..., -3.7288, -3.7291, -3.7289],
         ...,
         [ 5.4567, -2.0977,  1.5987,  ..., -3.6205, -3.6206, -3.6204],
         [ 3.8118,  0.4680,  5.8900,  ..., -3.8390, -3.8391, -3.8387],
         [ 3.8207, -2.5295,  3.6482,  ..., -4.2569, -4.2572, -4.2570]]],
       device='cuda:0', grad_fn=<MeanBackward1>), tensor([[[[  5.1573,   7.6584,  10.2043,  ...,  -4.1804,  -4.1803,  -4.1802],
          [  2.1930,  -0.0945,   0.3915,  ...,  -8.3144,  -8.3149,  -8.3147],
          [  3.5777,   1.2291,  -0.5330,  ...,  -4.0701,  -4.0703,  -4.0702],
          ...,
          [  8.1945,   0.1359,   3.2816,  ...,  -5.1957,  -5.1958,  -5.1957],
          [  6.2369,   3.2923,   8.2790,  ...,  -5.3957,  -5.3958,  -5.3955],
          [  6.0016,  -0.7353,   5.4501,  ...,  -5.8723,  -5.

In [8]:
@torch.no_grad()
def generate(model, inputs, max_length=10, use_cache=True):
    seq_out = []
    for idx in range(max_length):
        outputs = model(**inputs, return_dict=False, use_cache=use_cache)

        if "attention_mask" in inputs:
            del inputs["attention_mask"]

        next_token = outputs[0][0][:, -1].argmax(-1).unsqueeze(-1)
        if use_cache:
            inputs["past_key_values"] = outputs[1]
            inputs["ensemble_past_key_values"] = outputs[2]
            inputs["input_ids"] = next_token
        else:
            inputs["input_ids"] = torch.cat([inputs["input_ids"], next_token], dim=1)
        seq_out.append(next_token)

    seq_out = torch.cat(seq_out, -1)
    return tokenizer.batch_decode(seq_out, skip_special_tokens=True)

# inputs = "This is a test"
inputs = """
Answer the following multiple choice question. You should answer the question by choosing the letter (a, b, c, or d) that corresponds with the correct answer.
Steps of the scientific method include all of the following except
a. doing background research.
b. constructing a hypothesis.
c. asking a question.
d. proving a theory.
"""
inputs = tokenizer(inputs, return_tensors="pt")
generate(bayes, inputs.to("cuda"), use_cache=True)

['Answer: d. proving a theory. The scientific']