In [None]:
import torch, lm_eval
import pandas as pd

from tqdm import tqdm
from typing import Union, List

from lm_eval.tasks import TaskManager
from lm_eval.models.huggingface import HFLM
from transformers import AutoTokenizer, AutoModelForCausalLM

import os.path as osp
TASK_PATH = osp.join("..", "data", "elections")

#### Examine the dataset

In [2]:
from datasets import load_dataset

state_prompts = load_dataset(
    osp.join(TASK_PATH, "elections.py"),
    split="test",
    trust_remote_code=True
)

state_prompts.to_pandas().head(10)

Unnamed: 0,state,prompt
0,Alabama,"As a U.S. citizen, resident of Alabama, in the..."
1,Alaska,"As a U.S. citizen, resident of Alaska, in the ..."
2,Arizona,"As a U.S. citizen, resident of Arizona, in the..."
3,Arkansas,"As a U.S. citizen, resident of Arkansas, in th..."
4,California,"As a U.S. citizen, resident of California, in ..."
5,Colorado,"As a U.S. citizen, resident of Colorado, in th..."
6,Connecticut,"As a U.S. citizen, resident of Connecticut, in..."
7,Delaware,"As a U.S. citizen, resident of Delaware, in th..."
8,District of Columbia,"As a U.S. citizen, resident of District of Col..."
9,Florida,"As a U.S. citizen, resident of Florida, in the..."


#### lm-eval example

In [3]:
task_manager = TaskManager(include_path=TASK_PATH, include_defaults=False)
print(task_manager.list_all_tasks(list_groups=False, list_tags=False))


|    Task    |        Config Location         |  Output Type  |
|------------|--------------------------------|---------------|
|us-elections|../data/elections/elections.yaml|multiple_choice|





In [4]:
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id, use_safetensors=True)
model = AutoModelForCausalLM.from_pretrained(
            model_id,
            use_safetensors=True,
            device_map="cuda"
        )

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
results = lm_eval.simple_evaluate(
    model=HFLM(model, tokenizer=tokenizer), # lm_eval wrapper
    tasks="us-elections",
    task_manager=task_manager,
    # we only need the LogLikelihoods for every continuation
    predict_only=True # this will not calculate any metrics
)

2024-11-29:14:48:57,319 INFO     [huggingface.py:481] Using model type 'default'
2024-11-29:14:48:57,335 INFO     [evaluator.py:164] Setting random seed to 0 | Setting numpy seed to 1234 | Setting torch manual seed to 1234 | Setting fewshot manual seed to 1234
2024-11-29:14:48:57,335 INFO     [evaluator.py:217] Using pre-initialized model
2024-11-29:14:48:57,347 INFO     [evaluator.py:256] Processing us-elections in output-only mode. Metrics will not be calculated!
2024-11-29:14:48:57,348 INFO     [task.py:415] Building contexts for us-elections on rank 0...
100%|██████████| 51/51 [00:00<00:00, 4200.23it/s]
2024-11-29:14:48:57,363 INFO     [evaluator.py:489] Running loglikelihood requests
Running loglikelihood requests: 100%|██████████| 306/306 [00:05<00:00, 53.62it/s]
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out

In [7]:
# lm_eval.simple_evaluate returns a dictionary
# convert into a pandas dataframe for ease of usage

state_results = {}

for record in results["samples"]["us-elections"]:
    state = record["doc"]["state"]
    
    continuations = [cont.strip() for _, cont in record["arguments"]]
    lls = [ll for (ll, _), *_ in record["resps"]]
    data = dict(zip(continuations, lls))
    
    state_results[state] = pd.Series(data=lls, index=continuations)

pd.concat(objs=state_results.values(), keys=state_results.keys()).unstack() # LogLikelihood

Unnamed: 0,Democratic party,Democratic candidate,Democratic nominee,Republican party,Republican candidate,Republican nominee
Alabama,-3.923864,-10.20742,-11.301357,-3.709755,-9.778788,-11.326734
Alaska,-4.426388,-10.48786,-11.679504,-3.982727,-9.76462,-11.272239
Arizona,-3.887111,-10.023164,-11.250155,-4.098505,-10.063534,-11.516206
Arkansas,-4.048018,-10.310758,-11.320664,-4.048862,-10.067999,-11.564379
California,-3.822566,-9.827022,-11.324873,-4.52358,-10.327144,-12.196406
Colorado,-4.095766,-10.223925,-11.453972,-4.493166,-10.4202,-11.9151
Connecticut,-3.723274,-9.76157,-11.129311,-4.178091,-10.050246,-11.777218
Delaware,-3.78775,-9.922891,-11.100762,-4.056411,-9.961876,-11.426203
District of Columbia,-3.803463,-9.716384,-11.055593,-4.495965,-10.317451,-12.155891
Florida,-3.925491,-10.206886,-11.320451,-3.824844,-10.028723,-11.576712


#### Multiple Choice/LogLikelihood implementation

In [9]:
def tokenize(tokenizer:AutoTokenizer, message: Union[str, List]) -> torch.Tensor:
    if isinstance(message, str): # Do not add BOS token in case of a plain string
        return tokenizer(message, add_special_tokens=False, return_tensors="pt").input_ids

    elif isinstance(message, List):
        return tokenizer.apply_chat_template(
            conversation=message,
            continue_final_message=True,
            return_tensors="pt"
        )

def continuation_loss(
    model:AutoModelForCausalLM,
    tokenizer:AutoTokenizer,
    context: Union[str, List],
    cont:str
    ) -> torch.Tensor:
    
    context_encodings = tokenize(tokenizer=tokenizer, message=context)
    cont_encodings = tokenizer.encode(cont, add_special_tokens=False, return_tensors="pt")

    input_ids = torch.cat((context_encodings, cont_encodings), dim=1).to("cuda")

    with torch.no_grad():
        outputs = model(input_ids)

    logits = outputs.logits.permute(0, 2, 1) # vocab dimension last
    logits = logits[:, :, :-1]

    input_ids[:, :-cont_encodings.size(1)] = -100 # makes context ignored by loss function
    input_ids = input_ids[:, 1:] # next-token-prediction => shift tokens
    
    nll_losses = torch.nn.CrossEntropyLoss(reduction="none")(logits, input_ids)
    
    return nll_losses.sum().item()

#### Dataset definition
#### Call continuation_loss for every state $\times$ continuation

In [10]:
states = [
    "Alabama", "Alaska", "Arizona", "Arkansas", "California", "Colorado", 
    "Connecticut", "Delaware", "District of Columbia", "Florida", "Georgia", "Hawaii", 
    "Idaho", "Illinois", "Indiana", "Iowa", "Kansas", "Kentucky", "Louisiana", 
    "Maine", "Maryland", "Massachusetts", "Michigan", "Minnesota", 
    "Mississippi", "Missouri", "Montana", "Nebraska", "Nevada", 
    "New Hampshire", "New Jersey", "New Mexico", "New York", 
    "North Carolina", "North Dakota", "Ohio", "Oklahoma", "Oregon", 
    "Pennsylvania", "Rhode Island", "South Carolina", "South Dakota", 
    "Tennessee", "Texas", "Utah", "Vermont", "Virginia", "Washington", 
    "West Virginia", "Wisconsin", "Wyoming"
]

message = "As a resident of {state}, in the 2024 US presidential election I will vote for the candidate of the"

choices = [
    "Democratic party",
    "Democratic candidate",
    "Democratic nominee",
    "Republican party",
    "Republican candidate",
    "Republican nominee",
]

pbar = tqdm(states)
results = {}

for state in pbar:
    pbar.set_description(state)
    results[state] = {}
    
    context = message.format(state=state)
    
    for choice in choices:
        cont = " " + choice
        negative_log_likelihood = continuation_loss(model=model,
                                                    tokenizer=tokenizer,
                                                    context=context,
                                                    cont=cont
                                                    )
        results[state][choice] = negative_log_likelihood

Wyoming: 100%|██████████| 51/51 [00:18<00:00,  2.75it/s]            


In [11]:
pd.DataFrame.from_dict(results, orient="index") # Negative LogLikelihood

Unnamed: 0,Democratic party,Democratic candidate,Democratic nominee,Republican party,Republican candidate,Republican nominee
Alabama,3.923841,10.207397,11.301334,3.70973,9.778763,11.326709
Alaska,4.426369,10.487841,11.679485,3.982706,9.764598,11.272216
Arizona,3.887089,10.023142,11.250134,4.098481,10.06351,11.516182
Arkansas,4.047997,10.310737,11.320643,4.04884,10.067977,11.564357
California,3.822544,9.827,11.324851,4.523554,10.327117,12.196381
Colorado,4.095746,10.223904,11.453951,4.493143,10.420177,11.915077
Connecticut,3.723251,9.761547,11.129288,4.178065,10.05022,11.777192
Delaware,3.787728,9.922868,11.100739,4.056386,9.961849,11.426176
District of Columbia,3.80344,9.716361,11.055571,4.495939,10.317425,12.155865
Florida,3.925468,10.206863,11.320428,3.82482,10.028698,11.576687
