In [15]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

device = "cuda:0"
tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-3b-4e1t")
model = AutoModelForCausalLM.from_pretrained(
  "stabilityai/stablelm-3b-4e1t",
  trust_remote_code=True,
  torch_dtype=torch.float16,
)
model.to(device)

StableLMEpochForCausalLM(
  (model): StableLMEpochModel(
    (embed_tokens): Embedding(50304, 2560)
    (layers): ModuleList(
      (0-31): 32 x DecoderLayer(
        (self_attn): Attention(
          (q_proj): Linear(in_features=2560, out_features=2560, bias=False)
          (k_proj): Linear(in_features=2560, out_features=2560, bias=False)
          (v_proj): Linear(in_features=2560, out_features=2560, bias=False)
          (o_proj): Linear(in_features=2560, out_features=2560, bias=False)
          (rotary_emb): RotaryEmbedding()
        )
        (mlp): MLP(
          (gate_proj): Linear(in_features=2560, out_features=6912, bias=False)
          (up_proj): Linear(in_features=2560, out_features=6912, bias=False)
          (down_proj): Linear(in_features=6912, out_features=2560, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_a

In [5]:
def zero_padding_multiplicatn(num_1: int, num_2: int, padding_size: int = 0, reverse: bool = False) -> str:
    num_1 = str(num_1)
    num_2 = str(num_2)
    answer = int(num_1) * int(num_2)
    num_1 = "0" * (padding_size - len(num_1)) + num_1 if padding_size else num_1
    num_2 = "0" * (padding_size - len(num_2)) + num_2 if padding_size else num_2
    answer = str(answer).reverse() if reverse else str(answer)
    answer = "0" * (padding_size - len(answer)) + answer if padding_size else answer
    return f"{num_1} * {num_2} = {answer}"

In [8]:
zero_padding_multiplicatn(123, 456, 0)

'123 * 456 = 56088'

In [2]:
99**2, 100**2

(9801, 10000)

In [68]:
import random

max_num = 9
max_num_len = len(str(max_num**2))
eval_examples_padded = []
for i in range(1, max_num + 1):
    for j in range(1, max_num + 1):
        eval_examples_padded.append(zero_padding_multiplicatn(i, j, max_num_len))

eval_examples_base = []
for i in range(1, max_num + 1):
    for j in range(1, max_num + 1):
        eval_examples_base.append(f"{i} * {j} = {i * j}")

few_shot_num = 4
random_examples = random.sample(eval_examples_base, few_shot_num)
prompt = "\n".join(random_examples)
print(prompt)

8 * 6 = 48
3 * 9 = 27
3 * 6 = 18
7 * 4 = 28


In [50]:
eval_examples_base[:5]

['1 * 1 = 1', '1 * 2 = 2', '1 * 3 = 3', '1 * 4 = 4', '1 * 5 = 5']

In [28]:
eval_examples[0][-max_num_len:]

'12'

In [16]:
from transformers.utils import logging

logging.set_verbosity_error()

In [21]:
len(eval_examples)

243

In [80]:
from tqdm import tqdm
answers = []
batch_size = 32
for i in tqdm(range(0, len(eval_examples_padded), batch_size)):
    queries = eval_examples_base[i:i+batch_size]
    test_prompts = [prompt + "\n" + query.split("=")[0] + "=" for query in queries]
    truths = [query[-max_num_len:].strip() for query in queries]

    inputs = tokenizer(test_prompts, return_tensors="pt").input_ids.to(device)
    outputs = model.generate(inputs, max_length=64, do_sample=False)
    prediction = tokenizer.batch_decode(outputs, skip_special_tokens=True)#[len(test_prompt)+1:len(test_prompt)+1+max_num_len]
    predictions = [p.split("\n")[few_shot_num].split("=")[1].strip()  for p in prediction]
    #predictions = [p.split().strip() for p in prediction]
    for pred, truth in zip(predictions, truths):
        answers.append(int(pred == truth))
    #print(f"Test Prompt: {example}")
    #print(f"Prediction: {prediction}")
    #print(f"Truth: {truth}")
    #answers.append(int(prediction == truth))
print(f"Accuracy: {sum(answers)/len(answers)}")

100%|██████████| 3/3 [00:05<00:00,  1.69s/it]

Accuracy: 1.0





In [78]:
predictions = [p.split("\n")[few_shot_num].split("=")[1].strip() for p in prediction]
predictions

['1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 '2',
 '4',
 '6',
 '8',
 '10',
 '12',
 '14',
 '16',
 '18',
 '3',
 '6',
 '9',
 '12',
 '15',
 '18',
 '21',
 '24',
 '27',
 '4',
 '8',
 '12',
 '16',
 '20']

In [63]:
queries[0].split("=")

['1 * 1 ', ' 1']

In [61]:
test_prompts

['1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n1 * 1 ',
 '1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n1 * 2 ',
 '1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n1 * 3 ',
 '1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n1 * 4 ',
 '1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n1 * 5 ',
 '1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n1 * 6 ',
 '1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n1 * 7 ',
 '1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n1 * 8 ',
 '1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n1 * 9 ',
 '1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n2 * 1 ',
 '1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n2 * 2 ',
 '1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n2 * 3 ',
 '1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n2 * 4 ',
 '1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n2 * 5 =',
 '1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n2 * 6 =',
 '1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n2 * 7 =',
 '1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n2 * 8 =',
 '1 * 6 = 

In [59]:
prediction

['1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n1 * 1 \n1 * 2 \n1 * 3 \n1 * 4 \n1 * 5 \n1 * 6 \n1 * 7 \n1 * 8 \n',
 '1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n1 * 2 \n2 * 3 = 6\n3 * 4 = 12\n4 * 5 = 20\n5 * 6 = 30\n6 * 7 = 42\n7 * 8 = 56',
 '1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n1 * 3 \n1 * 4 = 4\n1 * 5 = 5\n1 * 6 = 6\n1 * 7 = 7\n1 * 8 = 8\n1 * 9 = 9',
 '1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n1 * 4 \n1 * 3 = 3\n1 * 2 = 2\n1 * 1 = 1\n\nA: You can use the following code:\npublic static int[][] multiply',
 '1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n1 * 5 \n1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n1 * 5 \n1 * 6 = 6\n',
 '1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n1 * 6 \n1 * 7 \n1 * 8 \n1 * 9 \n2 * 6 \n2 * 7 \n2 * 8 \n2 * 9 \n',
 '1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n1 * 7 \n1 * 8 = 8\n1 * 9 = 9\n1 * 10 = 10\n1 * 11 = 11\n1 * 12 = 12\n1 * 13 = 13',
 '1 * 6 = 6\n6 * 7 = 42\n8 * 9 = 72\n5 * 7 = 35\n1 * 8 \n1 * 9 = 9\n1 * 10 = 10\n1 * 11 = 11\n1 