In [1]:
"""
Sample from the trained model with PyTorch
"""
import os
import sys
import pickle
from contextlib import nullcontext
import torch
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname("/home/deqian/mount/random_effect_LLM/sampling/zero-shot-reasoning-likelihood/zero-shot-qa.ipynb"))))
import time 
from optimizer import PosteriorOptimizer
import numpy as np
from transformers import GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoXTokenizerFast
import logging


checkpoint = 'output/owt_liger/owt_liger_mlpt_2024_11_19_08_12_13/ckpt_58000.pt'                   
checkpoint = f'../../{checkpoint}'
ckpt_name = f"logs/{checkpoint.split('/')[-2]}_{checkpoint.split('/')[-1].split('.')[0]}"

fast_lr = 0.3
posterior_steps = 15
max_z_len = None # None if want to use cfg['max_z_len']

from model import ModelArgs, LatentPromptTransformerVI, MultiLayerLatentPromptTransformer
use_liger = True
use_z_pos_emb = True    
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')

if 'mlpt' in checkpoint:
    tokenizer.add_special_tokens({'bos_token': '<|beginoftext|>'})

# -----------------------------------------------------------------------------
device_id = "cuda"
device = torch.device(device_id)
device = device_id
dtype = "float32"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
np.random.seed(seed)
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# init from a model saved in a specific directory
checkpoint_dict = torch.load(checkpoint, map_location=device)
gptconf = ModelArgs(**checkpoint_dict['model_args'])
cfg = checkpoint_dict["config"]
gptconf.use_liger = use_liger
gptconf.use_z_pos_emb = use_z_pos_emb

if 'mlpt' in checkpoint:
    model = MultiLayerLatentPromptTransformer(gptconf)
else:
    model = LatentPromptTransformerVI(gptconf)


state_dict = checkpoint_dict['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict, strict=False)

model.eval()
model.to(device)
bos_token = tokenizer.bos_token

if max_z_len == None:
    max_z_len=cfg['max_z_len']
    
fast_lr = cfg['fast_lr']

posterior_optimizer = PosteriorOptimizer(model = model, 
                                        inference_method='adamVIPPL', 
                                        num_steps=posterior_steps, 
                                        max_z_len=max_z_len, 
                                        z_dim=cfg['z_dim'],
                                        lr = fast_lr)


  from .autonotebook import tqdm as notebook_tqdm
  checkpoint_dict = torch.load(checkpoint, map_location=device)


Optimizer kwargs {'num_steps': 15, 'max_z_len': 96, 'z_dim': 512, 'lr': 0.3}


In [6]:
def process_arc_easy():
    from datasets import load_dataset  # Ensure you have the datasets library imported
    
    ds = load_dataset("allenai/ai2_arc", "ARC-Easy")
    cleaned_ds = ds['test']
    print(f"Processing arc_easy, total number of examples: {len(cleaned_ds)}")

    data_list_4_options = []
    data_list_3_options = []
    data_list_5_options = []

    for index, item in enumerate(cleaned_ds):
        question = item['question']
        answers = item['choices']['text']
        correct_index = item['answerKey']
        labels = item['choices']['label']
        
        entry = {
            'question': f"Question: {item['question']}",
            'answer': [f"Answer: {answer}" for answer in answers],
            'correct_index': correct_index,
            'label': labels
        }
        
        num_options = len(labels)
        if num_options == 4:
            data_list_4_options.append(entry)
        elif num_options == 3:
            data_list_3_options.append(entry)
        elif num_options == 5:
            data_list_5_options.append(entry)
        else:
            print(f"Unexpected number of options: {num_options} at index {index}")

    return data_list_3_options, data_list_4_options, data_list_5_options

process_func = process_arc_easy
data_list = process_func()
data_list = data_list[1] # 3 options
data_list[0]

Processing arc_easy, total number of examples: 2376


{'question': 'Question: Which statement best explains why photosynthesis is the foundation of most food webs?',
 'answer': ['Answer: Sunlight is the source of energy for nearly all ecosystems.',
  'Answer: Most ecosystems are found on land instead of in water.',
  'Answer: Carbon dioxide is more available than other gases.',
  'Answer: The producers in all ecosystems are plants.'],
 'correct_index': 'A',
 'label': ['A', 'B', 'C', 'D']}

In [7]:
correct = 0
for index, item in enumerate(data_list):                      
    question = item['question']
    answers = item['answer']
    
    question_specific_seed = 0# np.random.randint(100000)

    question_text = f"{bos_token}{question}".strip()
    question_tokens = tokenizer.encode(question_text, add_special_tokens=False)
    question_tokens = question_tokens[:gptconf.max_seq_len]
    question_input = (torch.tensor(question_tokens, dtype=torch.long, device=device)[None, ...])

    torch.manual_seed(question_specific_seed)
    torch.cuda.manual_seed(question_specific_seed)
    z1 = torch.randn(1, max_z_len,  cfg['z_dim']).to(device)
    z = z1 * 0.01

    start_time = time.time()
    with ctx:
        z, ppl, kl_loss, nlkhd = posterior_optimizer.step(data=[question_input[:, :-1], question_input[:, 1:], z], ctx=ctx, seed=question_specific_seed)


    candidate_seqs = []
    for i in range(len(answers)):
        start_ids = tokenizer.encode(answers[i], add_special_tokens=False)
        start_ids = start_ids[:gptconf.max_seq_len]
        x_input = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
        candidate_seqs.append(x_input)

    lkhds = model.evaluate_conditional(question_input, z, candidate_seqs)

    generated_answer = item['label'][np.argmax(lkhds)]
    is_correct = generated_answer == item['correct_index']
    if is_correct:
        correct += 1
    print(f"likelihoods: {lkhds}, is_correct: {is_correct}, generated_answer: {generated_answer}, correct_index: {item['correct_index']}, correct: {correct}/{index+1}")


likelihoods: [-158.37961121185708, -155.646724077872, -109.18237615586168, -186.39605844020753], is_correct: False, generated_answer: C, correct_index: A, correct: 0/1
likelihoods: [-72.00241279602051, -75.9348201751709, -78.07278251647949, -93.58049011230469], is_correct: False, generated_answer: A, correct_index: B, correct: 0/2
likelihoods: [-70.2221265733242, -87.71943664550781, -89.47136926651001, -80.523027934134], is_correct: False, generated_answer: A, correct_index: D, correct: 0/3
likelihoods: [-65.22980499267578, -76.98519134521484, -69.90315628051758, -54.87093162536621], is_correct: True, generated_answer: D, correct_index: D, correct: 1/4
likelihoods: [-183.4392318725586, -182.82874202728271, -140.95112049277395, -240.4794521331787], is_correct: False, generated_answer: C, correct_index: B, correct: 1/5
likelihoods: [-98.47512817382812, -76.81863516569138, -102.34827995300293, -108.32036018371582], is_correct: False, generated_answer: B, correct_index: C, correct: 1/6
lik

In [4]:
correct = 0
for index, item in enumerate(data_list):                      
    loss_output = []
    option_time_used = 0
    question = item['question']
    answers = item['answer']
    
    question_specific_seed = 0# np.random.randint(100000)

    question_text = f"{bos_token}{question}".strip()
    question_tokens = tokenizer.encode(question_text, add_special_tokens=False)
    question_tokens = question_tokens[:gptconf.max_seq_len]
    question_input = (torch.tensor(question_tokens, dtype=torch.long, device=device)[None, ...])

    torch.manual_seed(question_specific_seed)
    torch.cuda.manual_seed(question_specific_seed)
    z1 = torch.randn(1, max_z_len,  cfg['z_dim']).to(device)
    z = z1 * 0.01

    start_time = time.time()
    with ctx:
        z, ppl, kl_loss, nlkhd = posterior_optimizer.step(data=[question_input[:, :-1], question_input[:, 1:], z], ctx=ctx, seed=question_specific_seed)


    candidate_seqs = []
    for i in range(len(answers)):
        start_ids = tokenizer.encode(answers[i], add_special_tokens=False)
        start_ids = start_ids[:gptconf.max_seq_len]
        x_input = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
        candidate_seqs.append(x_input)

    lkhds = model.evaluate_conditional(question_input, z, candidate_seqs)

generated_answer = item['label'][np.argmax(lkhds)]
is_correct = generated_answer == item['correct_index']
print(f"likelihoods: {lkhds}, is_correct: {is_correct}, generated_answer: {generated_answer}, correct_index: {item['correct_index']}")


likelihoods: [-80.52820381522179, -85.37247261404991, -99.95350775122643], is_correct: False, generated_answer: A, correct_index: B


In [5]:
generated_answer = item['label'][np.argmax(lkhds)]
is_correct = generated_answer == item['correct_index']
is_correct

True

In [13]:
condition_idx = question_input.to(device)
z = z.to(device)

batch_size = len(candidate_seqs)
condition_length = condition_idx.size(1)

# Concatenate condition with each candidate
concatenated_seqs = [torch.cat([condition_idx[0], torch.tensor(seq[0], dtype=torch.long, device=device)]) for seq in candidate_seqs]

# Determine the maximum sequence length
max_length = max(seq.size(0) for seq in concatenated_seqs)

# Pad all sequences to the maximum length
pad_token_id = 50256
padded_seqs = torch.full((batch_size, max_length), pad_token_id, dtype=torch.long, device=device)
for i, seq in enumerate(concatenated_seqs):
    padded_seqs[i, :seq.size(0)] = seq
padded_seqs.shape

  concatenated_seqs = [torch.cat([condition_idx[0], torch.tensor(seq[0], dtype=torch.long, device=device)]) for seq in candidate_seqs]


torch.Size([4, 19])

In [11]:
concatenated_seqs

[tensor([50257, 24361,    25,  9022,   636,   286,   257,  4252, 25547,  4618,
         42909,  1660,   290, 20901,    30, 33706,    25, 34341],
        device='cuda:0'),
 tensor([50257, 24361,    25,  9022,   636,   286,   257,  4252, 25547,  4618,
         42909,  1660,   290, 20901,    30, 33706,    25,   520,  5232],
        device='cuda:0'),
 tensor([50257, 24361,    25,  9022,   636,   286,   257,  4252, 25547,  4618,
         42909,  1660,   290, 20901,    30, 33706,    25, 46597],
        device='cuda:0'),
 tensor([50257, 24361,    25,  9022,   636,   286,   257,  4252, 25547,  4618,
         42909,  1660,   290, 20901,    30, 33706,    25, 30036],
        device='cuda:0')]

In [19]:
tokenizer.eos_token_id

50256

In [2]:
from zero_shot_utils_qa import *

process_function = task_functions['hellaswag'] # ['wsc', 'winogrande', 'siqa', 'piqa', 'obqa', 'hellaswag', 'arc_easy', 'arc_challenge']
data_list = process_function()


  from .autonotebook import tqdm as notebook_tqdm
Map: 100%|██████████| 10042/10042 [00:00<00:00, 11024.66 examples/s]


In [3]:
data_list[0]

{'question': 'Question: Roof shingle removal: A man is sitting on a roof. He',
 'answer': ['Answer: is using wrap to wrap a pair of skis.',
  'Answer: is ripping level tiles off.',
  "Answer: is holding a rubik's cube.",
  'Answer: starts pulling up roofing on a roof.'],
 'correct_index': 3,
 'label': [0, 1, 2, 3]}

In [4]:
from zero_shot_utils import *

process_function = task_functions['hellaswag'] # ['wsc', 'winogrande', 'siqa', 'piqa', 'obqa', 'hellaswag', 'arc_easy', 'arc_challenge']
data_list = process_function()
data_list[0]

{'sentences': ['Roof shingle removal: A man is sitting on a roof. He is using wrap to wrap a pair of skis.',
  'Roof shingle removal: A man is sitting on a roof. He is ripping level tiles off.',
  "Roof shingle removal: A man is sitting on a roof. He is holding a rubik's cube.",
  'Roof shingle removal: A man is sitting on a roof. He starts pulling up roofing on a roof.'],
 'correct_index': 3,
 'label': [0, 1, 2, 3]}