In [None]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
torch.manual_seed(0)
tokenizer = AutoTokenizer.from_pretrained("model",
                                          token='token')
model = AutoModelForCausalLM.from_pretrained("model",
                                             device_map='auto',
                                             token='token')

In [2]:
def calculate_perplexity(text:str,
                         model=model,
                         tokenizer=tokenizer):
    # Tokenize input text
    inputs = tokenizer(text, return_tensors='pt')
    
    # Move tensors to the appropriate device
    inputs = {key: value.to(model.device) for key, value in inputs.items()}
    
    # Get model outputs
    with torch.no_grad():
        outputs = model(**inputs, labels=inputs['input_ids'])
    
    # Calculate log-likelihood
    log_likelihood = outputs.loss.item()
    
    # Calculate perplexity
    perplexity = torch.exp(torch.tensor(log_likelihood))
    
    return perplexity.item()

In [None]:
import os
import numpy as np
import json
from tqdm import tqdm

ppl_dic = dict()
for file in ['logic', 'comprehensive', 'math', 'algorithm']:
    ppl_dic[file.replace('.json', '')] = dict()
    data = json.load(open(f'../redial/redial_gold/{file}.json'))
    for dia in ['aave', 'original']:
        ppl_dic[file.replace('.json', '')][dia] = list()
        prompts = [d['prompt'] for d in data['vanilla'][dia]]
        for prompt in tqdm(prompts):
            ppl_dic[file.replace('.json', '')][dia].append(calculate_perplexity(prompt))
            with open(f'../redial/redial_gold/model_ppl.json', 'w') as f:
                json.dump(ppl_dic, f)
        print(f'The average perplexity of {dia} prompts in {file} is {round(np.mean(ppl_dic[file.replace(".json", "")][dia]), 1)}')

In [None]:
# calculate averaged perplexity over all instances per dialect
aave_ppl = list()
original_ppl = list()
for file in ppl_dic:
    aave_ppl.extend(ppl_dic[file]['aave'])
    original_ppl.extend(ppl_dic[file]['original'])
print(f'The average perplexity of aave prompts is {round(np.mean(aave_ppl), 1)}')
print(f'The average perplexity of original prompts is {round(np.mean(original_ppl), 1)}')

In [None]:
# reprint all perplexities with round 3
for file in ppl_dic:
    print(f'{file}: {round(np.mean(ppl_dic[file]["aave"]), 1)} {round(np.mean(ppl_dic[file]["original"]), 1)}')