In [2]:
import click, torch, logging
from transformers import AutoTokenizer, AutoModelForCausalLM

def init_logger():
    logging.basicConfig(
        format="%(asctime)s %(levelname)-8s | %(message)s",
        level=logging.INFO,
        datefmt="%Y-%m-%d %H:%M:%S",
    )
init_logger()

In [3]:

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'

MODEL_NAME = "bigscience/bloom-1b1"
# MODEL_NAME = "bigscience/bloom-7b1"
# MODEL_NAME = "lmsys/vicuna-7b-v1.3"

if device == 'cpu':
    logging.warning("Running on CPU!! Operations may be slow")

logging.info("Loading tokenizer & model..")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)

logging.info(f"Moving model to {device}..")
model = model.to(device)

2023-12-23 09:03:57 INFO     | Loading tokenizer & model..
2023-12-23 09:04:13 INFO     | Moving model to cuda:0..


In [4]:
import os
import pandas as pd

for filename in os.listdir("mgsm"):
    path = f"mgsm/{filename}"
    newpath = path.split(".tsv")[0] + ".csv"
    df = pd.read_csv(path, sep="\t", index_col=False, names=['query', 'label'])
    df.to_csv(newpath)

In [5]:
def run_prompt(prompt: str, max_len: int = 10):
    with torch.no_grad():
        inputs = tokenizer(prompt, return_tensors="pt")
        generate_ids = model.generate(inputs.input_ids.to(device), max_length=max_len)
        decoding = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)

        print(decoding[0])

In [6]:
run_prompt("flowers are", 100)

flowers are the most common cause of death in patients with acute coronary syndromes. The incidence of coronary artery disease is increasing in the United States, and the number of patients with acute coronary syndromes is expected to increase in the future. The incidence of coronary artery disease is increasing in the United States, and the number of patients with acute coronary syndromes is expected to increase in the future. The incidence of coronary artery disease is increasing in the United States, and the number of patients with acute coronary


In [16]:
run_prompt("Usually, When people are sick,", 100)

Usually, When people are sick, they are not able to work. They are not able to do their jobs. They are not able to do their work. They are not able to do their work. They are not able to do their work. They are not able to do their work. They are not able to do their work. They are not able to do their work. They are not able to do their work. They are not able to do their work. They are not able


In [14]:
run_prompt("Bees", 50)

Bees, and the
bees are the only ones who can do it."

"Then you are the only one who can do it," said the old man, with a
smile.

"Then I am the only one


In [12]:
run_prompt("what is 1+1?", 20)

what is 1+1?");
        }
        else if (isNumber(value)) {


In [11]:
run_prompt("Les canes de Janet pondent 16 œufs par jour. Chaque matin, elle en mange trois au petit déjeuner et en utilise quatre autres pour préparer des muffins pour ses amis. Ce qui reste, elle le vend quotidiennement au marché fermier, au prix de 2 $ l'œuf de cane frais. Combien (en dollars) gagne-t-elle chaque jour au marché fermier ?", 200)

Les canes de Janet pondent 16 œufs par jour. Chaque matin, elle en mange trois au petit déjeuner et en utilise quatre autres pour préparer des muffins pour ses amis. Ce qui reste, elle le vend quotidiennement au marché fermier, au prix de 2 $ l'œuf de cane frais. Combien (en dollars) gagne-t-elle chaque jour au marché fermier ?
