In [10]:
from datasets import load_dataset
import transformers
import torch
from tqdm import tqdm


In [11]:
#Download dataset from HF
dataset = load_dataset('sciq', split='test')

In [12]:
#Get model, tokenizer, and max_length
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2')
model = transformers.AutoModelForCausalLM.from_pretrained('gpt2')
max_length = tokenizer.model_max_length


In [13]:
sample_data = dataset[0]
#print(sample_data)
question =  sample_data['question']
print(f"{question}\n")
distractors = [sample_data['distractor3'], sample_data['distractor2'], sample_data['distractor1']]
print(f"{distractors}\n")
answer = sample_data['correct_answer']
print(f"{answer}\n")
support = sample_data['support']
print(f"{support}\n")



Compounds that are capable of accepting electrons, such as o 2 or f2, are called what?

['residues', 'Oxygen', 'antioxidants']

oxidants

Oxidants and Reductants Compounds that are capable of accepting electrons, such as O 2 or F2, are calledoxidants (or oxidizing agents) because they can oxidize other compounds. In the process of accepting electrons, an oxidant is reduced. Compounds that are capable of donating electrons, such as sodium metal or cyclohexane (C6H12), are calledreductants (or reducing agents) because they can cause the reduction of another compound. In the process of donating electrons, a reductant is oxidized. These relationships are summarized in Equation 3.30: Equation 3.30 Saylor URL: http://www. saylor. org/books.



In [14]:
def generate_prompts(data):
  support = data['support']
  question = data['question']
  distractors = [data['distractor3'], data['distractor2'], data['distractor1']]
  answer = data['correct_answer']
  prompts = []

  #Create prompt for answer
  promptInstance = {
      'context': support + '\nquestion: ' + question + ' \nanswer:',
      'continuation': answer
  }
  prompts.append(promptInstance)
  # Create prompts for each distractor
  for distractor in distractors:
    promptInstance = {
        'context': support + '\nquestion: ' + question + ' \nanswer:',
        'continuation': distractor
    }
    prompts.append(promptInstance)

  return prompts

prompts = generate_prompts(sample_data)
print(prompts)

[{'context': 'Oxidants and Reductants Compounds that are capable of accepting electrons, such as O 2 or F2, are calledoxidants (or oxidizing agents) because they can oxidize other compounds. In the process of accepting electrons, an oxidant is reduced. Compounds that are capable of donating electrons, such as sodium metal or cyclohexane (C6H12), are calledreductants (or reducing agents) because they can cause the reduction of another compound. In the process of donating electrons, a reductant is oxidized. These relationships are summarized in Equation 3.30: Equation 3.30 Saylor URL: http://www. saylor. org/books.\nquestion: Compounds that are capable of accepting electrons, such as o 2 or f2, are called what? \nanswer:', 'continuation': 'oxidants'}, {'context': 'Oxidants and Reductants Compounds that are capable of accepting electrons, such as O 2 or F2, are calledoxidants (or oxidizing agents) because they can oxidize other compounds. In the process of accepting electrons, an oxidant 

In [15]:
def encode_prompts(prompts, tokenizer):
  encoded_prompts = []
  #Loop through and encode prompts based on tokenizer
  for prompt in prompts:
    context = prompt['context']
    continuation = prompt['continuation']
    encoded_prompt = tokenizer.encode(context + continuation)
    encoded_context = tokenizer.encode(context)
    encoded_context_len = len(encoded_context)
    encoded_continuation = encoded_prompt[encoded_context_len:]
    encoded_prompts.append({'encoded_context': encoded_context, 'encoded_continuation': encoded_continuation})

  return encoded_prompts

encoded = encode_prompts(prompts, tokenizer)
print(encoded)


[{'encoded_context': [38208, 312, 1187, 290, 2297, 4782, 1187, 3082, 3733, 326, 389, 6007, 286, 12598, 28722, 11, 884, 355, 440, 362, 393, 376, 17, 11, 389, 1444, 1140, 312, 1187, 357, 273, 18762, 2890, 6554, 8, 780, 484, 460, 18762, 1096, 584, 16439, 13, 554, 262, 1429, 286, 12598, 28722, 11, 281, 18762, 415, 318, 5322, 13, 3082, 3733, 326, 389, 6007, 286, 29798, 28722, 11, 884, 355, 21072, 6147, 393, 11700, 78, 33095, 1531, 357, 34, 21, 39, 1065, 828, 389, 1444, 445, 4782, 1187, 357, 273, 8868, 6554, 8, 780, 484, 460, 2728, 262, 7741, 286, 1194, 13061, 13, 554, 262, 1429, 286, 29798, 28722, 11, 257, 2027, 310, 415, 318, 18762, 1143, 13, 2312, 6958, 389, 31880, 287, 7889, 341, 513, 13, 1270, 25, 7889, 341, 513, 13, 1270, 311, 7167, 10289, 25, 2638, 1378, 2503, 13, 910, 4685, 13, 8745, 14, 12106, 13, 198, 25652, 25, 3082, 3733, 326, 389, 6007, 286, 12598, 28722, 11, 884, 355, 267, 362, 393, 277, 17, 11, 389, 1444, 644, 30, 220, 198, 41484, 25], 'encoded_continuation': [1140, 312, 1187]

In [16]:
device = torch.device('cpu')
model.to(device)

def get_log_liklihood(encoded_prompts, model):
  results = []
  for encoded_prompt in encoded_prompts:
    encoded_context = encoded_prompt['encoded_context']
    encoded_continuation = encoded_prompt['encoded_continuation']
    #Convert the input to a tensor to be passed to the model
    inp = torch.tensor(
        (encoded_context + encoded_continuation)[-(max_length + 1) :][:-1],
        dtype=torch.long,
        device=device,
    )
    # Add a batch dimension to the input tensor
    inp = inp.unsqueeze(0)

    # Pass inputs and get logits from model response
    #tock.no_grad() disables gradient calculation which is not needed for eval
    with torch.no_grad():
      logits = model(inp).logits

    # Normalize logits vocab dimension using log_softmax
    log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

    #Convert continuation tokens to a tensor
    cont_toks = torch.tensor(encoded_continuation, dtype=torch.long, device=device).unsqueeze(0)
    # Slice logits to get the logits for the continuation tokens
    # logits has the shape [batch_size, sequence_length, vocab]
    # Slice the sequence from the end of the sequence starting at the length of tokens in the continuation sequence
    log_probs_for_cont = log_probs[:, -cont_toks.shape[1] :, :]
    #Get most likely tokens from logits vocabulary for sequence (vocab is the last dimension)
    greedy_tokens = log_probs_for_cont.argmax(dim=-1)

    # Get the log probabilities of the actual continuation tokens
    compare_log_probs = torch.gather(log_probs_for_cont, 2, cont_toks.unsqueeze(-1)).squeeze(-1)


    # Sum the log likelihood for the continuation tokens
    # Sum instead of multiply because log(a*b) = log(a) + log(b)
    log_likelihood = float(compare_log_probs.sum())
    results.append(log_likelihood)

  return results

log_likelihoods = get_log_liklihood(encoded, model)
print(log_likelihoods)

[-8.289290428161621, -15.653117179870605, -12.925821304321289, -13.8721342086792]


In [17]:
# Set a limit for the number of questions to process
LIMIT = 10

#Evaluate
limited_dataset = dataset.select(range(LIMIT))

results = []
for sample in tqdm(limited_dataset):
    prompts = generate_prompts(sample)
    encoded = encode_prompts(prompts, tokenizer)
    log_likelihoods = get_log_liklihood(encoded, model)
    # Based on generating prompts we set the correct answer to be the first one
    correct_answer_log_likelihood = log_likelihoods[0]
    distractor_log_likelihoods = [log_liklihood for log_liklihood in log_likelihoods[1:]]

    # Determine if the correct answer has the highest log likelihood
    is_correct = all(correct_answer_log_likelihood > distractor_ll for distractor_ll in distractor_log_likelihoods)

    results.append({
        'question': sample['question'],
        'correct_answer': sample['correct_answer'],
        'log_likelihood_correct': correct_answer_log_likelihood,
        'log_likelihood_distractors': distractor_log_likelihoods,
        'is_correct': is_correct
    })

# Print results for each question
for i, result in enumerate(results):
    print(f"Question: {result['question']}")
    print(f"Correct Answer: {result['correct_answer']} (Log Likelihood: {result['log_likelihood_correct']})")
    print(f"Distractor Log Likelihoods: {result['log_likelihood_distractors']}")
    print(f"Model Predicted Correct: {result['is_correct']}\n")
    print("-" * 20)

# Calculate overall accuracy
accuracy = sum(r['is_correct'] for r in results) / len(results)
print(f"Overall Accuracy on the first {LIMIT} examples: {accuracy:.2f}")

100%|██████████| 10/10 [00:41<00:00,  4.19s/it]

Question: Compounds that are capable of accepting electrons, such as o 2 or f2, are called what?
Correct Answer: oxidants (Log Likelihood: -8.289290428161621)
Distractor Log Likelihoods: [-15.653117179870605, -12.925821304321289, -13.8721342086792]
Model Predicted Correct: True

--------------------
Question: What term in biotechnology means a genetically exact copy of an organism?
Correct Answer: clone (Log Likelihood: -14.674302101135254)
Distractor Log Likelihoods: [-15.728692054748535, -15.405716896057129, -17.5056095123291]
Model Predicted Correct: True

--------------------
Question: Vertebrata are characterized by the presence of what?
Correct Answer: backbone (Log Likelihood: -13.691313743591309)
Distractor Log Likelihoods: [-19.18093490600586, -15.642292976379395, -15.81669807434082]
Model Predicted Correct: True

--------------------
Question: What is the height above or below sea level called?
Correct Answer: elevation (Log Likelihood: -11.339323997497559)
Distractor Log Lik


