In [None]:
import numpy as np
import torch
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorWithPadding,
)
from datasets import load_dataset

In [8]:
def load_data(id_dataset, id_config, id_split, ood_dataset, ood_config, seed):
    id_data = load_dataset(id_dataset, id_config, split=id_split, trust_remote_code=True).shuffle(seed=seed)
    ood_data = load_dataset(ood_dataset, ood_config, trust_remote_code=True).shuffle(seed=seed)

    if len(ood_data.keys()) == 1:
        ood_data = ood_data["train"].train_test_split(test_size=0.2)
    
    return id_data, ood_data

In [9]:
def initialize_model_and_tokenizer(model_name):
    model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16).eval()
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    return model, tokenizer

In [None]:
def generate_prompts(i, id_data, ood_data, id_exemplars, ood_exemplars, n_id, n_ood, id_prompt, ood_prompt):
    id_idxs = id_exemplars[i, :n_id]
    ood_idxs = ood_exemplars[i, :n_ood]

    id_exs = []
    ood_exs = []

    for id_idx in id_idxs:
        id_idx = int(id_idx)
        data = id_data[id_idx]
        label = "Positive" if data['label'] else "Negative"
        id_exs.append(id_prompt.format(data['text'], label))

    for ood_idx in ood_idxs:
        ood_idx = int(ood_idx)
        ood_exs.append(ood_prompt.format(ood_data[ood_split][ood_idx]['sentence']))

    full_prompt = ''.join(id_exs) + ''.join(ood_exs)
    return full_prompt

In [10]:
def evaluate_model(model, tokenizer, ood_data, full_prompt, test_prompt, tok_choices, bsize, n_ood_test):
    bid = 0
    ncorrect = 0

    while bid < n_ood_test:
        batch = []
        labels = []

        while len(batch) < bsize and bid < n_ood_test:
            batch.append(full_prompt + test_prompt.format(ood_data['test'][bid]['sentence']))
            labels.append(ood_data['test'][bid]['label'])
            bid += 1

        batch = tokenizer(batch)
        ex_lens = [len(bid) - 1 for bid in batch['input_ids']]
        batch = DataCollatorWithPadding(tokenizer)(batch)
        batch = {k: v.cuda() for k, v in batch.items()}
        out = model(**batch)

        preds = out.logits[torch.arange(out.logits.shape[0]), ex_lens][:, tok_choices].argmax(-1).cpu().numpy()
        ncorrect += (preds == labels).sum()
        del out

    return ncorrect / n_ood_test

In [None]:
# Model parameters
model_name = "meta-llama/Llama-3.1-8B"
id_dataset = "stanfordnlp/imdb"
id_config = None
id_split = "train"
ood_dataset = "stanfordnlp/sst2"
ood_config = None
ood_split = "train"
max_id_icl = 16
max_ood_icl = 16
n_ood_test = 600
nsamples = 10
seed = 1
bsize = 4

In [12]:
# Set prompts and label tokens
id_prompt = "Review: {}\nSentiment: {}\n\n"
ood_prompt = "Review: {}\n\nSentiment:\n\n"
test_prompt = "Review: {}\nSentiment:"
tok_choices = [51957, 45003]

In [None]:
# Load data
id_data, ood_data = load_data(id_dataset, id_config, id_split, ood_dataset, ood_config, seed)

# Initialize model and tokenizer
model, tokenizer = initialize_model_and_tokenizer(model_name)

# Sample exemplars
np.random.seed(seed)
id_exemplars = np.random.choice(len(id_data), size=(nsamples, max_id_icl), replace=False)
ood_exemplars = np.random.choice(len(ood_data[ood_split]), size=(nsamples, max_ood_icl), replace=False)

In [None]:
max_log_id = int(np.ceil(np.log2(max_id_icl)))
max_log_ood = int(np.ceil(np.log2(max_ood_icl)))

res = np.zeros((max_log_id + 2, max_log_ood + 2, nsamples))
out_name = id_dataset.split('/')[-1] + '_' + ood_dataset.split('/')[-1]

for n_log_id in list(range(max_log_id + 2))[::-1]: # need at least one ID ICL
    for n_log_ood in range(max_log_ood + 2):
        if n_log_id == 0:
            n_id = 0
        else:
            n_id = 2**(n_log_id - 1)

        if n_log_ood == 0:
            n_ood = 0
        else:
            n_ood = 2**(n_log_ood - 1)

        for i in tqdm(range(nsamples), desc=f"n_id: {n_id}, n_ood: {n_ood}"):
            # generate prompts
            full_prompt = generate_prompts(i, id_data, ood_data, id_exemplars, ood_exemplars,
                                           n_id, n_ood, id_prompt, ood_prompt)

            # evaluate model
            res[n_log_id, n_log_ood, i] = evaluate_model(model, tokenizer, ood_data, full_prompt, 
                                                         test_prompt, tok_choices, bsize, n_ood_test)

        print(res.mean(-1))
        print(res.var(-1))
        np.save(out_name + '.npy', res, allow_pickle=True)