In [1]:
import random

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from torch.utils.data import Dataset
from utils import zero_padding_multiplicatn, generate_training_set, generate_validation_set

  from .autonotebook import tqdm as notebook_tqdm


In [24]:
val_set = generate_validation_set(5, 75)
val_set = [s for s in val_set if len(str(s[0])) == 1 and len(str(s[1])) == 1]


(75, 75)

In [17]:
device = "cuda:0"
model_id = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
  model_id,
  trust_remote_code=True,
  torch_dtype=torch.float16,
)
model.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [27]:
zero_padding_multiplicatn(12, 5, 4)

'0012 * 0005 = 0060'

In [25]:
max_num = 9999
max_num_len = len(str(max_num**2))
eval_examples_padded = []
#for i in range(1, max_num + 1):
#    for j in range(1, max_num + 1):
#        eval_examples_padded.append(zero_padding_multiplicatn(i, j, max_num_len))

eval_examples_base = []
for i in range(1, max_num + 1):
    for j in range(1, max_num + 1):
        eval_examples_base.append(zero_padding_multiplicatn(i, j, 0))

few_shot_num = 4
random_examples = random.sample(eval_examples_base, few_shot_num)
prompt = "\n".join(random_examples)
print(prompt)

KeyboardInterrupt: 

In [5]:
from transformers.utils import logging

logging.set_verbosity_error()

No padding 2 digits, 4 shots

In [14]:
from tqdm import tqdm
answers = []
batch_size = 32
for i in tqdm(range(0, len(eval_examples_padded), batch_size)):
    queries = eval_examples_base[i:i+batch_size]
    test_prompts = [prompt + "\n" + query.split("=")[0] + "=" for query in queries]
    truths = [query[-max_num_len:].strip() for query in queries]

    inputs = tokenizer(test_prompts, return_tensors="pt").input_ids.to(device)
    outputs = model.generate(inputs, max_length=64, do_sample=False)
    prediction = tokenizer.batch_decode(outputs, skip_special_tokens=True)#[len(test_prompt)+1:len(test_prompt)+1+max_num_len]
    predictions = [p.split("\n")[few_shot_num].split("=")[1].strip()  for p in prediction]
    #predictions = [p.split().strip() for p in prediction]
    for pred, truth in zip(predictions, truths):
        answers.append(int(pred == truth))
    #print(f"Test Prompt: {example}")
    #print(f"Prediction: {prediction}")
    #print(f"Truth: {truth}")
    #answers.append(int(prediction == truth))
print(f"Accuracy: {sum(answers)/len(answers)*2.:f}")

100%|██████████| 3/3 [00:02<00:00,  1.05it/s]

Accuracy: 0.271605





padding 2 digits

In [15]:
prediction

['9 * 6 = 54\n3 * 1 = 3\n6 * 6 = 36\n1 * 4 = 4\n8 * 2 = 2\n6 * 2 = 6\n1 * 2 = 6\n1 * 2 = 6\n1 * 2 = 6\n1 * 2 = 6\n1 * 2 =',
 '9 * 6 = 54\n3 * 1 = 3\n6 * 6 = 36\n1 * 4 = 4\n8 * 3 = 8\n1 * 6 = 6\n1 * 6 = 6\n1 * 6 = 6\n1 * 6 = 6\n1 * 6 = 6\n1 * 6 =',
 '9 * 6 = 54\n3 * 1 = 3\n6 * 6 = 36\n1 * 4 = 4\n8 * 4 = 8\n1 * 8 = 8\n1 * 16 = 16\n1 * 32 = 32\n1 * 64 = 64\n1 * 128 = 128\n1 * 256 =',
 '9 * 6 = 54\n3 * 1 = 3\n6 * 6 = 36\n1 * 4 = 4\n8 * 5 = 8\n1 * 6 = 6\n1 * 8 = 8\n1 * 16 = 16\n1 * 32 = 32\n1 * 64 = 64\n1 * 128 =',
 '9 * 6 = 54\n3 * 1 = 3\n6 * 6 = 36\n1 * 4 = 4\n8 * 6 = 24\n1 * 4 = 4\n8 * 6 = 24\n1 * 4 = 4\n8 * 6 = 24\n1 * 4 = 4\n8 * 6 =',
 '9 * 6 = 54\n3 * 1 = 3\n6 * 6 = 36\n1 * 4 = 4\n8 * 7 = 16\n1 * 8 = 8\n1 * 9 = 8\n1 * 10 = 8\n1 * 11 = 8\n1 * 12 = 8\n1 * 13 =',
 '9 * 6 = 54\n3 * 1 = 3\n6 * 6 = 36\n1 * 4 = 4\n8 * 8 = 16\n1 * 8 = 16\n1 * 8 = 16\n1 * 8 = 16\n1 * 8 = 16\n1 * 8 = 16\n1 * 8 =',
 '9 * 6 = 54\n3 * 1 = 3\n6 * 6 = 36\n1 * 4 = 4\n8 * 9 = 16\n1 * 4 = 4\n8 * 9 = 16\n1 * 4 = 4\n8 * 

In [13]:
from tqdm import tqdm
answers = []
batch_size = 64
for i in tqdm(range(0, len(eval_examples_padded), batch_size)):
    queries = eval_examples_base[i:i+batch_size]
    test_prompts = [prompt + "\n" + query.split("=")[0] + "=" for query in queries]
    truths = [query[-max_num_len:].strip() for query in queries]

    inputs = tokenizer(test_prompts, return_tensors="pt").input_ids.to(device)
    outputs = model.generate(inputs, max_length=64, do_sample=False)
    prediction = tokenizer.batch_decode(outputs, skip_special_tokens=True, padding=True)#[len(test_prompt)+1:len(test_prompt)+1+max_num_len]
    predictions = [p.split("\n")[few_shot_num].split("=")[1].strip()  for p in prediction]
    #predictions = [p.split().strip() for p in prediction]
    for pred, truth in zip(predictions, truths):
        answers.append(int(pred == truth))
    #print(f"Test Prompt: {example}")
    #print(f"Prediction: {prediction}")
    #print(f"Truth: {truth}")
    #answers.append(int(prediction == truth))
print(f"Accuracy: {sum(answers)/len(answers)}")

 40%|███▉      | 61/154 [03:27<05:16,  3.41s/it]


ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`input_ids` in this case) have excessive nesting (inputs type `list` where type `int` is expected).

In [18]:
inputs = tokenizer(test_prompts, return_tensors="pt").input_ids.to(device)


ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`input_ids` in this case) have excessive nesting (inputs type `list` where type `int` is expected).