# Testing Batch Inference with SOLAR 10.7B (Instruct)


In [19]:
import warnings
import pandas as pd
from tqdm import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
warnings.filterwarnings("ignore")

In [3]:
tokenizer = AutoTokenizer.from_pretrained("Upstage/SOLAR-10.7B-Instruct-v1.0")
model = AutoModelForCausalLM.from_pretrained(
    "Upstage/SOLAR-10.7B-Instruct-v1.0",
    device_map="auto",
    torch_dtype=torch.float16,
)

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

In [15]:
class NameDataset(Dataset):
    def __init__(self, file_path: str):
        self.df = pd.read_csv(file_path)

    def __len__(self):
        return len(self.df["Name"])

    def __getitem__(self, idx):
        return self.df["Name"][idx]


batch_size = 32
dataset = NameDataset("../data/preprocessed_names.csv")
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0)

In [20]:
results = []

In [21]:
prompt = """Question: Based on the examples below, choose 3 from the following descriptions that match the given name: Playful, Gentle, Majestic, Curious, Loyal, Witty, Mysterious, Adventurous, Elegant, Charming. DO NOT provide any other information.
Name: """

for batch in tqdm(dataloader):
    inputs = [prompt + name + "\nAnswer: " for name in batch]
    inputs = tokenizer(inputs, return_tensors="pt", padding=True)

    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=64)

    generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    # Remove the prompt and the Answer: \n
    generated_texts = [
        text[len(prompt) + len(name) + len("\nAnswer: \n") :]
        for text, name in zip(generated_texts, batch)
    ]

    results.append(generated_texts)

  0%|          | 0/1207 [00:00<?, ?it/s]

100%|██████████| 1207/1207 [2:54:45<00:00,  8.69s/it] 


In [22]:
# Merge the results with the original dataframe
dataset.df["Description"] = [text for batch in results for text in batch]

In [23]:
# Save the results
dataset.df.to_csv("../data/solar_descriptions.csv", index=False)