# hellaswag

In [3]:
from datasets import load_dataset
import re
import datasets

def preprocess(text):
    text = text.strip()
    # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
    text = text.replace(" [title]", ". ")
    text = re.sub("\\[.*?\\]", "", text)
    text = text.replace("  ", " ")
    return text


def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
    def _process_doc(doc):
        ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
        out_doc = {
            "query": preprocess(doc["activity_label"] + ": " + ctx),
            "choices": [preprocess(ending) for ending in doc["endings"]],
            "gold": int(doc["label"]),
        }
        return out_doc

    return dataset.map(_process_doc)


In [25]:
ds = load_dataset("Rowan/hellaswag")
test_ds = ds['validation']
cleaned_ds = process_docs(test_ds)

Map: 100%|██████████| 10042/10042 [00:01<00:00, 8984.22 examples/s]


In [28]:
cleaned_ds[0]

{'ind': 24,
 'activity_label': 'Roof shingle removal',
 'ctx_a': 'A man is sitting on a roof.',
 'ctx_b': 'he',
 'ctx': 'A man is sitting on a roof. he',
 'endings': ['is using wrap to wrap a pair of skis.',
  'is ripping level tiles off.',
  "is holding a rubik's cube.",
  'starts pulling up roofing on a roof.'],
 'source_id': 'activitynet~v_-JhWjGDPHMY',
 'split': 'val',
 'split_type': 'indomain',
 'label': '3',
 'query': 'Roof shingle removal: A man is sitting on a roof. He',
 'choices': ['is using wrap to wrap a pair of skis.',
  'is ripping level tiles off.',
  "is holding a rubik's cube.",
  'starts pulling up roofing on a roof.'],
 'gold': 3}

In [32]:
for i in range(4):
    print(cleaned_ds[0]['query'] + ' ' + cleaned_ds[0]['choices'][i])

Roof shingle removal: A man is sitting on a roof. He is using wrap to wrap a pair of skis.
Roof shingle removal: A man is sitting on a roof. He is ripping level tiles off.
Roof shingle removal: A man is sitting on a roof. He is holding a rubik's cube.
Roof shingle removal: A man is sitting on a roof. He starts pulling up roofing on a roof.


In [None]:

out_doc = {
    "query": preprocess(doc["activity_label"] + ": " + ctx),
    "choices": [preprocess(ending) for ending in doc["endings"]],
    "gold": int(doc["label"]),
}

In [None]:
out_doc

# arc

In [2]:
from datasets import load_dataset

ds = load_dataset("allenai/ai2_arc", "ARC-Easy")

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
ds['test'][92]

{'id': 'NYSEDREGENTS_2015_8_24',
 'question': 'The length of one day on Earth is determined by how long it takes',
 'choices': {'text': ['the Moon to revolve once',
   'the Moon to rotate once',
   'Earth to rotate once',
   'Earth to revolve once'],
  'label': ['1', '2', '3', '4']},
 'answerKey': '3'}

In [7]:
cleaned_ds = ds['test']

In [12]:
item['choices']

{'text': ['Sunlight is the source of energy for nearly all ecosystems.',
  'Most ecosystems are found on land instead of in water.',
  'Carbon dioxide is more available than other gases.',
  'The producers in all ecosystems are plants.'],
 'label': ['A', 'B', 'C', 'D']}

In [40]:
text = f"Question: {ds['test'][0]['question']}\nAnswer: {ds['test'][0]['choices']['text'][0]}"
text

'Question: Which statement best explains why photosynthesis is the foundation of most food webs?\nAnswer: Sunlight is the source of energy for nearly all ecosystems.'

In [None]:
test_split: test
doc_to_text: "Question: {{question}}\nAnswer:"
doc_to_target: "{{choices.label.index(answerKey)}}"
doc_to_choice: "{{choices.text}}"

# piqa

doc_to_text: "Question: {{goal}}\nAnswer:"
doc_to_target: label
doc_to_choice: "{{[sol1, sol2]}}"

Instance(request_type='loglikelihood', doc={'goal': 'Where can I buy a tennis ball', 'sol1': 'You can purchase a tennis ball at any sports store', 'sol2': 'You can purchase a tennis racket at any sports store', 'label': 0}, arguments=('Question: Where can I buy a tennis ball\nAnswer:', ' You can purchase a tennis ball at any sports store'), idx=0, metadata=('piqa', 1837, 1), resps=[], filtered_resps={}, task_name='piqa', doc_id=1837, repeats=1), 

Instance(request_type='loglikelihood', doc={'goal': 'Where can I buy a tennis ball', 'sol1': 'You can purchase a tennis ball at any sports store', 'sol2': 'You can purchase a tennis racket at any sports store', 'label': 0}, arguments=('Question: Where can I buy a tennis ball\nAnswer:', ' You can purchase a tennis racket at any sports store'), idx=1, metadata=('piqa', 1837, 1), resps=[], filtered_resps={}, task_name='piqa', doc_id=1837, repeats=1)]

In [10]:
import json

with open("piqa-valid-labels.lst", "r") as file:
    data = file.readlines()

responses = [int(line.strip()) for line in data]


data_list = []
file_path = 'piqa-valid.jsonl'

with open(file_path, 'r') as file:
    for i, line in enumerate(file):
        data = json.loads(line.strip())
        query = f"Question: {data['goal']}\nAnswer:"
        sentence1 = f"{query} {data['sol1']}"
        sentence2 = f"{query} {data['sol2']}"

        data_list.append({
            'sentence1': sentence1,
            'sentence2': sentence2,
            'correct_index': responses[i]
        })
        


In [11]:
data_list[-1]

{'sentence1': 'Question: Where can I buy a tennis ball\nAnswer: You can purchase a tennis ball at any sports store',
 'sentence2': 'Question: Where can I buy a tennis ball\nAnswer: You can purchase a tennis racket at any sports store',
 'correct_index': 0}

# siqa

In [19]:
import json

with open("siqa-dev-labels.lst", "r") as file:
    data = file.readlines()

responses = [int(line.strip()) for line in data]


data_list = []
file_path = 'siqa-dev.jsonl'

with open(file_path, 'r') as file:
    for i, line in enumerate(file):
        data = json.loads(line.strip())
        query = f"Q: {data['context']} {data['question']}\nA:"
        sentence1 = f"{query} {data['answerA']}"
        sentence2 = f"{query} {data['answerB']}"
        sentence3 = f"{query} {data['answerC']}"

        data_list.append({
            'sentence1': sentence1,
            'sentence2': sentence2,
            'sentence3': sentence3,
            'correct_index': responses[i]
        })
        


In [21]:
data_list[-1]

{'sentence1': 'Q: Aubrey the officer pulled a driver over for speeding on the road. Why did Aubrey do this?\nA: find a safe place to pull the person over',
 'sentence2': "Q: Aubrey the officer pulled a driver over for speeding on the road. Why did Aubrey do this?\nA: so people don't drive to fast",
 'sentence3': "Q: Aubrey the officer pulled a driver over for speeding on the road. Why did Aubrey do this?\nA: look up the person's license plate number",
 'correct_index': 2}

In [None]:
validation_split: validation
doc_to_text: "Q: {{context}} {{question}}\nA:"
target_delimiter: " "
doc_to_choice: "{{[answerA, answerB, answerC]}}"
doc_to_target: "{{ (label|int) - 1 }}"
metric_list:

Instance(request_type='loglikelihood', doc={'context': 'Aubrey the officer pulled a driver over for speeding on the road.', 'question': 'Why did Aubrey do this?', 'answerA': 'find a safe place to pull the person over', 'answerB': "so people don't drive to fast", 'answerC': "look up the person's license plate number", 'label': '2'}, arguments=('Q: Aubrey the officer pulled a driver over for speeding on the road. Why did Aubrey do this?\nA:', ' find a safe place to pull the person over'), idx=0, metadata=('social_iqa', 1953, 1), resps=[], filtered_resps={}, task_name='social_iqa', doc_id=1953, repeats=1), 

Instance(request_type='loglikelihood', doc={'context': 'Aubrey the officer pulled a driver over for speeding on the road.', 'question': 'Why did Aubrey do this?', 'answerA': 'find a safe place to pull the person over', 'answerB': "so people don't drive to fast", 'answerC': "look up the person's license plate number", 'label': '2'}, arguments=('Q: Aubrey the officer pulled a driver over for speeding on the road. Why did Aubrey do this?\nA:', " so people don't drive to fast"), idx=1, metadata=('social_iqa', 1953, 1), resps=[], filtered_resps={}, task_name='social_iqa', doc_id=1953, repeats=1), 

Instance(request_type='loglikelihood', doc={'context': 'Aubrey the officer pulled a driver over for speeding on the road.', 'question': 'Why did Aubrey do this?', 'answerA': 'find a safe place to pull the person over', 'answerB': "so people don't drive to fast", 'answerC': "look up the person's license plate number", 'label': '2'}, arguments=('Q: Aubrey the officer pulled a driver over for speeding on the road. Why did Aubrey do this?\nA:', " look up the person's license plate number"), idx=2, metadata=('social_iqa', 1953, 1), resps=[], filtered_resps={}, task_name='social_iqa', doc_id=1953, repeats=1)]

# obqa

In [23]:
from datasets import load_dataset

ds = load_dataset("allenai/openbookqa", "main")
dataset = ds['test']

In [25]:
dataset[-1]

{'id': '7-7',
 'question_stem': 'Some animals use a liquid coming from their skin to adjust to',
 'choices': {'text': ['cold', 'water', 'heat', 'humidity'],
  'label': ['A', 'B', 'C', 'D']},
 'answerKey': 'C'}

In [None]:
doc_to_text: question_stem
doc_to_target: "{{choices.label.index(answerKey.lstrip())}}"
doc_to_choice: "{{choices.text}}"

In [28]:
import numpy as np 
np.std([len(data['choices']['label']) for data in dataset])

0.0

In [26]:
f"{dataset[-1]['question_stem']} {dataset[-1]['choices']['text'][0]}"

'Some animals use a liquid coming from their skin to adjust to cold'

Instance(request_type='loglikelihood', doc={'id': '7-7', 'question_stem': 'Some animals use a liquid coming from their skin to adjust to', 'choices': {'text': ['cold', 'water', 'heat', 'humidity'], 'label': ['A', 'B', 'C', 'D']}, 'answerKey': 'C'}, arguments=('Some animals use a liquid coming from their skin to adjust to', ' cold'), idx=0, metadata=('openbookqa', 499, 1), resps=[], filtered_resps={}, task_name='openbookqa', doc_id=499, repeats=1), 

Instance(request_type='loglikelihood', doc={'id': '7-7', 'question_stem': 'Some animals use a liquid coming from their skin to adjust to', 'choices': {'text': ['cold', 'water', 'heat', 'humidity'], 'label': ['A', 'B', 'C', 'D']}, 'answerKey': 'C'}, arguments=('Some animals use a liquid coming from their skin to adjust to', ' water'), idx=1, metadata=('openbookqa', 499, 1), resps=[], filtered_resps={}, task_name='openbookqa', doc_id=499, repeats=1), 

Instance(request_type='loglikelihood', doc={'id': '7-7', 'question_stem': 'Some animals use a liquid coming from their skin to adjust to', 'choices': {'text': ['cold', 'water', 'heat', 'humidity'], 'label': ['A', 'B', 'C', 'D']}, 'answerKey': 'C'}, arguments=('Some animals use a liquid coming from their skin to adjust to', ' heat'), idx=2, metadata=('openbookqa', 499, 1), resps=[], filtered_resps={}, task_name='openbookqa', doc_id=499, repeats=1), 

Instance(request_type='loglikelihood', doc={'id': '7-7', 'question_stem': 'Some animals use a liquid coming from their skin to adjust to', 'choices': {'text': ['cold', 'water', 'heat', 'humidity'], 'label': ['A', 'B', 'C', 'D']}, 'answerKey': 'C'}, arguments=('Some animals use a liquid coming from their skin to adjust to', ' humidity'), idx=3, metadata=('openbookqa', 499, 1), resps=[], filtered_resps={}, task_name='openbookqa', doc_id=499, repeats=1)]

# wsc273

Instance(request_type='loglikelihood', doc={'text': 'Carol believed that Rebecca regretted that she had stolen the watch.', 'pronoun': 'she', 'pronoun_loc': 43, 'quote': 'she had stolen the watch.', 'quote_loc': 43, 'options': ['Carol', 'Rebecca'], 'label': 1, 'source': 'Leora Morgenstern'}, arguments=('Carol believed that Rebecca regretted that Carol', '  had stolen the watch.'), idx=0, metadata=('wsc273', 272, 1), resps=[], filtered_resps={}, task_name='wsc273', doc_id=272, repeats=1), 

Instance(request_type='loglikelihood', doc={'text': 'Carol believed that Rebecca regretted that she had stolen the watch.', 'pronoun': 'she', 'pronoun_loc': 43, 'quote': 'she had stolen the watch.', 'quote_loc': 43, 'options': ['Carol', 'Rebecca'], 'label': 1, 'source': 'Leora Morgenstern'}, arguments=('Carol believed that Rebecca regretted that Rebecca', '  had stolen the watch.'), idx=1, metadata=('wsc273', 272, 1), resps=[], filtered_resps={}, task_name='wsc273', doc_id=272, repeats=1)]

In [None]:
upper_pronouns = [
    "A",
    "An",
    "The",
    "She",
    "He",
    "It",
    "They",
    "My",
    "His",
    "Her",
    "Their",
]


def process_doc(dataset):
    def process_fn(doc):
        # The HF implementation of `wsc273` is not `partial evaluation` friendly.
        doc["text"] = doc["text"].replace("  ", " ")
        doc["options"][0] = __normalize_option(doc, doc["options"][0])
        doc["options"][1] = __normalize_option(doc, doc["options"][1])
        return doc

    return dataset.map(process_fn)


def __normalize_option(doc, option):
    # Append `'s` to possessive determiner based options.
    if doc["pronoun"].lower() in ["my", "his", "her", "our", "their"]:
        option += "'s"
    # Appropriately lowercase the pronoun in the option.
    pronoun = option.split()[0]
    start_of_sentence = doc["text"][doc["pronoun_loc"] - 2] == "."
    if not start_of_sentence and pronoun in upper_pronouns:
        return option.replace(pronoun, pronoun.lower())
    return option


In [None]:
process_docs: !function utils.process_doc
doc_to_target: "{% set index = pronoun_loc + pronoun | length %}{{text[index:]}}"
doc_to_choice: "{% set template = text[:pronoun_loc] %}{{[template+options[0], template+options[1]]}}"
should_decontaminate: true

In [33]:
# Prepare the wsc
data_list = []

with open('wsc273.txt', 'r', encoding='utf-8') as f:
    lines = f.readlines()

blocks = []
current_block = []

for line in lines:
    line = line.strip()
    if line == '':
        if current_block:
            blocks.append(current_block)
            current_block = []
    else:
        current_block.append(line)

# Append the last block if any
if current_block:
    blocks.append(current_block)

for block in blocks:
    if len(block) != 4:
        print(f"Unexpected block format: {block}")
        continue
    sentence_with_mask = block[0]
    mask_line = block[1]  # should be '[MASK]'
    options_line = block[2]
    correct_answer = block[3]
    
    options = options_line.split(',')
    options = [opt.strip() for opt in options]
    
    if correct_answer not in options:
        print(f"Correct answer not in options: {correct_answer}")
        continue
    
    correct_index = options.index(correct_answer)
    
    # Replace [MASK] with each option
    sentences = []
    for opt in options:
        sentence = sentence_with_mask.replace(' [MASK] ', opt)
        sentences.append(sentence)
    
    data_list.append({
        'sentence1': sentences[0],
        'sentence2': sentences[1],
        'correct_index': correct_index
    })
############

In [36]:
sentence_with_mask.replace(' [MASK] ', opt)

'Carol believed that Rebecca regretted that Carol had stolen the watch.'

# lambada

In [37]:
from datasets import load_dataset

ds = load_dataset("EleutherAI/lambada_openai")
dataset = ds['test']

Downloading builder script: 100%|██████████| 4.82k/4.82k [00:00<00:00, 11.5MB/s]
Downloading readme: 100%|██████████| 4.99k/4.99k [00:00<00:00, 20.9MB/s]
Downloading data: 100%|██████████| 1.82M/1.82M [00:00<00:00, 11.8MB/s]
Generating test split: 100%|██████████| 5153/5153 [00:00<00:00, 78588.20 examples/s]


In [42]:
text = dataset[-1]['text']
text

'“Yes, Grandmother, and don’t worry, I won’t forget my water when I go wander.”\nHer brothers laughed a little only because the comment was so cute. \n“Wonderful my little Wanderer, because I would hate to think of you forgetting an important rule like that.”\n“I would never forget the rules, Grandmother'

'“Yes, Grandmother, and don’t worry, I won’t forget my water when I go wander.”\nHer brothers laughed a little only because the comment was so cute. \n“Wonderful my little Wanderer, because I would hate to think of you forgetting an important rule like that.”\n“I would never forget the rules, Grandmother'

In [None]:
doc_to_text: "{{text.split(' ')[:-1]|join(' ')}}"
doc_to_target: "{{' '+text.split(' ')[-1]}}"

In [46]:

doc_to_text = ' '.join(text.split(' ')[:-1])
doc_to_text



'“Yes, Grandmother, and don’t worry, I won’t forget my water when I go wander.”\nHer brothers laughed a little only because the comment was so cute. \n“Wonderful my little Wanderer, because I would hate to think of you forgetting an important rule like that.”\n“I would never forget the rules,'

In [47]:
doc_to_target = ' ' + text.split(' ')[-1]
doc_to_target

' Grandmother'

Instance(request_type='loglikelihood', doc={'text': '“Yes, Grandmother, and don’t worry, I won’t forget my water when I go wander.”\nHer brothers laughed a little only because the comment was so cute. \n“Wonderful my little Wanderer, because I would hate to think of you forgetting an important rule like that.”\n“I would never forget the rules, Grandmother'}, 

arguments=('“Yes, Grandmother, and don’t worry, I won’t forget my water when I go wander.”\nHer brothers laughed a little only because the comment was so cute. \n“Wonderful my little Wanderer, because I would hate to think of you forgetting an important rule like that.”\n“I would never forget the rules,', ' Grandmother'), idx=0, metadata=('lambada_openai', 5152, 1), resps=[], filtered_resps={}, task_name='lambada_openai', doc_id=5152, repeats=1)]

In [2]:
from transformers import GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoXTokenizerFast

tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b")
special_tokens = {'bos_token': '<|beginoftext|>'}
tokenizer.add_special_tokens(special_tokens)
bos_token_id = tokenizer.bos_token_id

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


# MLPT


In [2]:
"""
Sample from the trained model with PyTorch
"""
import os
import sys
import pickle
from contextlib import nullcontext
import torch
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname("/home/deqian/random_effect_LLM/sampling/zero-shot-reasoning-likelihood/data.ipynb"))))
import time 
from optimizer import PosteriorOptimizer
import numpy as np
import argparse
import yaml
from tokenizer import Tokenizer
from transformers import GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoXTokenizerFast
import logging

import os
import json
import logging
from datetime import datetime
import random
from zero_shot_utils import *

task_to_test = ['wsc', 'obqa', 'arc_easy'] # ['wsc', 'winogrande', 'siqa', 'piqa', 'obqa', 'hellaswag', 'arc_easy', 'arc_challenge']
process_functions = []
for task in task_to_test:
    if task in task_functions:
        process_functions.append(task_functions[task])
    else:
        raise ValueError(f"Task '{task}' not supported.")

checkpoints_to_check = ['output/owt_liger/owt_liger_mlpt_2024_11_13_08_17_57/ckpt_58000.pt',
                        'output/owt_liger/owt_liger_mlpt_2024_11_19_08_12_13/ckpt_58000.pt']

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
checkpoint = checkpoints_to_check[0]

In [6]:
checkpoint = f'../../{checkpoint}'
ckpt_name = f"logs/{checkpoint.split('/')[-2]}_{checkpoint.split('/')[-1].split('.')[0]}"

fast_lr = 0.3
posterior_steps = 15
max_z_len = None # None if want to use cfg['max_z_len']

logging.basicConfig(filename=f"{ckpt_name}_z{max_z_len}.log", level=logging.INFO, format="%(message)s")

if 'dclm' in checkpoint : 
    from model_old import ModelArgs, LatentPromptTransformerVIPostTraining, LatentPromptTransformerVI

    tokenizer = GPTNeoXTokenizerFast.from_pretrained("EleutherAI/gpt-neox-20b")
    special_tokens = {'bos_token': '<|beginoftext|>'}
    tokenizer.add_special_tokens(special_tokens)
    bos_token_id = tokenizer.bos_token_id
    use_liger = True
    use_z_pos_emb = True
elif 'mlpt' in checkpoint:
    from model import ModelArgs, MultiLayerLatentPromptTransformer 
    tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
    use_liger = True
    use_z_pos_emb = True
    tokenizer.add_special_tokens({'bos_token': '<|beginoftext|>'})
    bos_token_id = tokenizer.bos_token_id
else:
    from model import ModelArgs, LatentPromptTransformerVIPostTraining, LatentPromptTransformerVI

    tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
    use_liger = True
    use_z_pos_emb = True

# -----------------------------------------------------------------------------
device_id = "cuda"
device = torch.device(device_id)
device = device_id
dtype = "float32"
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
np.random.seed(seed)
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# init from a model saved in a specific directory
checkpoint_dict = torch.load('../../output/owt_liger/owt_liger_mlpt_2024_11_13_08_17_57/ckpt_58000.pt', map_location=device)
gptconf = ModelArgs(**checkpoint_dict['model_args'])
cfg = checkpoint_dict["config"]
gptconf.use_liger = use_liger
gptconf.use_z_pos_emb = use_z_pos_emb

# model = LatentPromptTransformerVIPostTraining(gpt.conf)
model = MultiLayerLatentPromptTransformer(gptconf)
state_dict = checkpoint_dict['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict, strict=False)

model.eval()
model.to(device)
bos_token = tokenizer.bos_token

if max_z_len == None:
    max_z_len=cfg['max_z_len']
posterior_optimizer = PosteriorOptimizer(model = model, 
                                        inference_method='adamVIPPL', 
                                        num_steps=posterior_steps, 
                                        max_z_len=max_z_len, 
                                        z_dim=cfg['z_dim'],
                                        lr = fast_lr)


  checkpoint_dict = torch.load('../../output/owt_liger/owt_liger_mlpt_2024_11_13_08_17_57/ckpt_58000.pt', map_location=device)


Optimizer kwargs {'num_steps': 15, 'max_z_len': 96, 'z_dim': 512, 'lr': 0.3}


In [7]:
cfg['fast_lr']

0.35

In [None]:

##################################################################
message_ckpt = f"Using checkpoint {checkpoint}, max_z_len = {max_z_len}, fast_lr = {fast_lr}, posterior_steps = {posterior_steps}, use_liger={use_liger}, use_z_pos_emb={use_z_pos_emb}"
logging.info("="*30)
logging.info(message_ckpt)
logging.info("using adamw")

all_messages = []
for task_id, process_func in enumerate(process_functions):
    correct = 0
    data_list = process_func()

    for index, item in enumerate(data_list):              
        log_info = True if index % 20 == 0 else False
            
        loss_output = []
        option_time_used = 0
        sentences = item['sentences']
        question_specific_seed = np.random.randint(100000)
        for i in range(len(sentences)):
            # print(f"option {i}: {sentences[i]}")
            input_text = f"{bos_token}{sentences[i]}".strip()
            start_ids = tokenizer.encode(input_text, add_special_tokens=False)
            start_ids = start_ids[:gptconf.max_seq_len]
            x_input = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
            
            torch.manual_seed(question_specific_seed)
            torch.cuda.manual_seed(question_specific_seed)
            z1 = torch.randn(1, max_z_len,  cfg['z_dim']).to(device)
            z = z1 * 0.01
            
            start_time = time.time()
            with ctx:
                z, ppl, kl_loss, nlkhd = posterior_optimizer.step(data=[x_input[:, :-1], x_input[:, 1:], z], ctx=ctx, seed=question_specific_seed)

            loss = nlkhd + kl_loss
            loss_output.append(loss.item())
            finish_time = time.time()
            time_passed = finish_time - start_time
            option_time_used += time_passed
            if log_info:
                option_info = f"index {index} option {i}: nlkhd {nlkhd.item():.2f} kl_loss: {kl_loss.item():.2f} loss: {loss.item():.2f}, time {time_passed} sec"
                logging.info(option_info)

        generated_answer = item['label'][np.argmin(loss_output)]
        is_correct = generated_answer == item['correct_index']
        if is_correct:
            correct += 1

        step_info = f"index {index}: correct: {is_correct}, losses: {np.round(loss_output, 2)}, generatedID: {generated_answer}, correctID: {item['correct_index']} ({correct / len(data_list):4f}), current rate {correct/(index+1):.2f}, option time: {option_time_used:.2f}sec"
        print(step_info)
        if log_info:
            logging.info(step_info)

        # break

    # print("-"*20)
    message = f"Evaluation for {task_to_test[task_id]}, correct rate {correct / len(data_list):4f}, {correct}/{len(data_list)}"
    logging.info(message)
    all_messages.append(message)
    print(message)

In [2]:
"""
Sample from the trained model with PyTorch
"""
import os
import sys
import pickle
from contextlib import nullcontext
import torch
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname("/home/deqian/random_effect_LLM/sampling/zero-shot-reasoning-likelihood/data.ipynb"))))
import time 
from optimizer import PosteriorOptimizer
import numpy as np
import argparse
import yaml
from tokenizer import Tokenizer
from transformers import GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoXTokenizerFast
import logging

import os
import json
import logging
from datetime import datetime
import random
from zero_shot_utils import *

task_to_test = ['wsc', 'obqa', 'arc_easy'] # ['wsc', 'winogrande', 'siqa', 'piqa', 'obqa', 'hellaswag', 'arc_easy', 'arc_challenge']
process_functions = []
for task in task_to_test:
    if task in task_functions:
        process_functions.append(task_functions[task])
    else:
        raise ValueError(f"Task '{task}' not supported.")

checkpoints_to_check = ['output/owt_liger/owt_liger_mlpt_2024_11_13_08_17_57/ckpt_58000.pt',
                        'output/owt_liger/owt_liger_mlpt_2024_11_19_08_12_13/ckpt_58000.pt']

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
np.sum([len(item['sentences']) != 4 for item in task_functions['arc_easy']()])

11

In [20]:
def process_arc_easy():
    ds = load_dataset("allenai/ai2_arc", "ARC-Easy")
    cleaned_ds = ds['test']
    data_list_4_options = []
    data_list_3_options = []

    for index, item in enumerate(cleaned_ds):  
        sentences = []
        for i in range(len(item['choices']['label'])):
            st = f"Question: {item['question']}\nAnswer: {item['choices']['text'][i]}"
            sentences.append(st)
        
        correct_index = item['answerKey']
        label = item['choices']['label']
        
        if len(item['choices']['label']) == 4:
            data_list_4_options.append({
                'sentences': sentences,
                'correct_index': correct_index,
                'label': label
            })
        elif len(item['choices']['label']) == 3:
            data_list_3_options.append({
                'sentences': sentences,
                'correct_index': correct_index,
                'label': label
            })
        else:
            print(f"Unexpected number of options: {len(item['choices']['label'])}")

    return data_list_4_options, data_list_3_options



In [21]:
result = process_arc_easy()
# result

Unexpected number of options: 5
Unexpected number of options: 5
Unexpected number of options: 5
Unexpected number of options: 5


In [13]:
ds = load_dataset("allenai/ai2_arc", "ARC-Challenge")
cleaned_ds = ds['test']
data_list = []

for index, item in enumerate(cleaned_ds):  
    sentences = []
    for i in range(len(item['choices']['label'])):
        st = f"Question: {item['question']}\nAnswer: {item['choices']['text'][i]}"
        sentences.append(st)
    
    correct_index = item['answerKey']
    label = item['choices']['label']

    data_list.append({
        'sentences': sentences,
        'correct_index': correct_index,
        'label': label
    })
    if len(item['choices']['label']) != 4:
        print(f"Unexpected number of options: {index}, {len(item['choices']['label'])}")

Unexpected number of options: 121, 3
Unexpected number of options: 385, 3
Unexpected number of options: 400, 3
Unexpected number of options: 836, 5
Unexpected number of options: 868, 5
Unexpected number of options: 1037, 5
Unexpected number of options: 1042, 3


In [14]:
cleaned_ds[1037]

{'id': 'TIMSS_1995_8_N4',
 'question': 'Years ago farmers found that corn plants grew better if decaying fish were buried near by. What did the decaying fish probably supply to the plants to improve their growth?',
 'choices': {'text': ['energy', 'minerals', 'protein', 'oxygen', 'water'],
  'label': ['A', 'B', 'C', 'D', 'E']},
 'answerKey': 'B'}

# datasets

In [2]:
"""
Sample from the trained model with PyTorch
"""
import os
import sys
import pickle
from contextlib import nullcontext
import torch
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname("/home/deqian/mount/random_effect_LLM/sampling/zero-shot-reasoning-likelihood/data.ipynb"))))
import time 
from optimizer import PosteriorOptimizer
import numpy as np
import argparse
import yaml
from tokenizer import Tokenizer
from transformers import GPT2TokenizerFast, GPT2LMHeadModel, GPTNeoXTokenizerFast
import logging

import os
import json
import logging
from datetime import datetime
import random
from zero_shot_utils import *

task_to_test = ['arc_easy'] # ['wsc', 'winogrande', 'siqa', 'piqa', 'obqa', 'hellaswag', 'arc_easy', 'arc_challenge']
process_functions = []
for task in task_to_test:
    if task in task_functions:
        process_functions.append(task_functions[task])
    else:
        raise ValueError(f"Task '{task}' not supported.")
    
for task_id, process_func in enumerate(process_functions):
    print("="*30)
    if task_to_test[task_id] != 'arc_easy' and task_to_test[task_id] != 'arc_challenge':
        # Note that arc_easy and arc_challenge have different number of options (3, 4, 5 in each question)
        dataset_lists = [process_func()]
    else: 
        dataset_lists = process_func()


  from .autonotebook import tqdm as notebook_tqdm


Processing arc_easy, total number of examples: 2376


In [5]:
dataset_lists[0][0]

{'sentences': ['Question: When a rock is placed in a graduated cylinder containing water, the height of the water will\nAnswer: decrease',
  'Question: When a rock is placed in a graduated cylinder containing water, the height of the water will\nAnswer: increase',
  'Question: When a rock is placed in a graduated cylinder containing water, the height of the water will\nAnswer: remain the same'],
 'correct_index': 'B',
 'label': ['A', 'B', 'C']}

In [6]:
task_to_test = ['arc_easy'] # ['wsc', 'winogrande', 'siqa', 'piqa', 'obqa', 'hellaswag', 'arc_easy', 'arc_challenge']
process_functions = []
for task in task_to_test:
    if task in task_functions:
        process_functions.append(task_functions[task])
    else:
        raise ValueError(f"Task '{task}' not supported.")
    
for task_id, process_func in enumerate(process_functions):
    print("="*30)
    if task_to_test[task_id] != 'arc_easy' and task_to_test[task_id] != 'arc_challenge':
        # Note that arc_easy and arc_challenge have different number of options (3, 4, 5 in each question)
        dataset_lists = [process_func()]
    else: 
        dataset_lists = process_func()
    
    all_correct = 0
    all_tested = 0
    
    for subdataset_id, all_questions_list in enumerate(dataset_lists):
        correct = 0
        for item in all_questions_list:
            print(item)
            break 

        break 


Processing arc_easy, total number of examples: 2376
{'sentences': ['Question: When a rock is placed in a graduated cylinder containing water, the height of the water will\nAnswer: decrease', 'Question: When a rock is placed in a graduated cylinder containing water, the height of the water will\nAnswer: increase', 'Question: When a rock is placed in a graduated cylinder containing water, the height of the water will\nAnswer: remain the same'], 'correct_index': 'B', 'label': ['A', 'B', 'C']}


In [7]:
item['sentences']

['Question: When a rock is placed in a graduated cylinder containing water, the height of the water will\nAnswer: decrease',
 'Question: When a rock is placed in a graduated cylinder containing water, the height of the water will\nAnswer: increase',
 'Question: When a rock is placed in a graduated cylinder containing water, the height of the water will\nAnswer: remain the same']

In [8]:
item['correct_index']

'B'