In [1]:
import torch
import re
import os

def load_ensemble(filepaths):
    def load_from_checkpoint(idx, filepath):
        parameters = torch.load(filepath)["state_dict"]["bayesian_layer"].params
        parameters = {
            re.sub(r"model\.layers\.\d+\.", "", k): v
            for k, v in parameters.items()
            if v.numel() > 0
        }
        return parameters

    return [
        load_from_checkpoint(idx, filepath) for idx, filepath in enumerate(filepaths)
    ]


folder = "logs/bayes_instruct_with_seq/usable_checkpoints"
parameters = load_ensemble(
    [os.path.join(folder, ckpt) for ckpt in os.listdir(folder)]
)

In [2]:
%load_ext autoreload
%autoreload 2

from llama3.modules.bayesllama import BayesLlamaForCausalLM
bayes_config = {'n_ensemble': 10}

bayes = BayesLlamaForCausalLM.from_pretrained("Meta-Llama-3-8B-Instruct", bayes_config=bayes_config).to("cuda:0")
bayes.load_bayesian_layers(parameters)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.04s/it]
Some weights of BayesLlamaForCausalLM were not initialized from the model checkpoint at Meta-Llama-3-8B-Instruct and are newly initialized: ['model.bayesian_layers.0.input_layernorm.weight', 'model.bayesian_layers.0.mlp.down_proj.weight', 'model.bayesian_layers.0.mlp.gate_proj.weight', 'model.bayesian_layers.0.mlp.up_proj.weight', 'model.bayesian_layers.0.post_attention_layernorm.weight', 'model.bayesian_layers.0.self_attn.k_proj.weight', 'model.bayesian_layers.0.self_attn.o_proj.weight', 'model.bayesian_layers.0.self_attn.q_proj.weight', 'model.bayesian_layers.0.self_attn.v_proj.weight', 'model.bayesian_layers.1.input_layernorm.weight', 'model.bayesian_layers.1.mlp.down_proj.weight', 'model.bayesian_layers.1.mlp.gate_proj.weight', 'model.bayesian_layers.1.mlp.up_proj.weight', 'model.bayesian_layers.1.post_attention_layernorm.weight', 'model.bayesian_layers.1.self_

In [5]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Meta-Llama-3-8B-Instruct")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

inputs = "This is a test"
inputs = tokenizer(inputs, return_tensors="pt")

inputs.input_ids.shape

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


torch.Size([1, 5])

In [7]:
@torch.no_grad()
def generate(model, inputs, max_length=10, use_cache=True):
    seq_out = []
    for idx in range(max_length):
        outputs = model(**inputs, return_dict=False, use_cache=use_cache)

        if "attention_mask" in inputs:
            del inputs["attention_mask"]

        next_token = outputs[0][0][:, -1].argmax(-1).unsqueeze(-1)
        if use_cache:
            inputs["past_key_values"] = outputs[1]
            inputs["ensemble_past_key_values"] = outputs[2]
            inputs["input_ids"] = next_token
        else:
            inputs["input_ids"] = torch.cat([inputs["input_ids"], next_token], dim=1)
        seq_out.append(next_token)

    seq_out = torch.cat(seq_out, -1)
    return tokenizer.batch_decode(seq_out, skip_special_tokens=True)

# inputs = "This is a test"
inputs_a = """
Answer the following multiple choice question.
Steps of the scientific method include all of the following except
a. doing background research.
b. constructing a hypothesis.
c. asking a question.
d. proving a theory.
Answer:
"""
inputs_b = """
Answer the following multiple choice question.
All birds build nests the same way.
a. true
b. false
Answer:
"""
inputs_c = """
Answer the following multiple choice question.
If the results of an experiment disprove a hypothesis, then the
a. results should not be reported.
b. hypothesis is just a theory.
c. data must contain errors.
d. none of the above
Answer:
"""

inputs = [inputs_a, inputs_b, inputs_c]
print(inputs)
inputs = tokenizer(inputs, padding=True, return_tensors="pt")
generate(bayes, inputs.to("cuda"), use_cache=True)

['\nAnswer the following multiple choice question.\nSteps of the scientific method include all of the following except\na. doing background research.\nb. constructing a hypothesis.\nc. asking a question.\nd. proving a theory.\nAnswer:\n', '\nAnswer the following multiple choice question.\nAll birds build nests the same way.\na. true\nb. false\nAnswer:\n', '\nAnswer the following multiple choice question.\nIf the results of an experiment disprove a hypothesis, then the\na. results should not be reported.\nb. hypothesis is just a theory.\nc. data must contain errors.\nd. none of the above\nAnswer:\n']


['d', 'b', 'c. data must contain errors.  (The']