In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

import torch
from torch.utils.data import DataLoader

import os
import pandas as pd
from tqdm import tqdm
import csv

In [None]:
CACHE_DIR = "../cache_dir"

In [None]:
tokenizer = AutoTokenizer.from_pretrained("gpt2-xl", padding_side="left", cache_dir=CACHE_DIR)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained("gpt2-xl", cache_dir=CACHE_DIR)
model.to("mps")
model.eval()

ds = load_dataset("gbharti/finance-alpaca", cache_dir=CACHE_DIR)

In [10]:
save_path = "unlabeled.csv"
write_steps = 2
if not os.path.exists(save_path):
    with(open(save_path, "w")) as f:
        f.write("prompt,response_a,response_b\n")

# Read to the correct number in the dataset
n = sum(1 for _ in open(save_path)) - 1
train = ds["train"][n:]

with open(save_path, "a", newline='') as f:
    writer = csv.writer(f, quoting=csv.QUOTE_MINIMAL)
    write_batch = []
    dl = DataLoader(train["instruction"], batch_size=4, shuffle=False)
    for step, batch in tqdm(enumerate(dl)):
        tokens = tokenizer(batch, padding=True, return_tensors="pt").to(model.device)
        prompt_len = tokens.input_ids.shape[1]
        out_a = model.generate(**tokens, max_new_tokens=16, do_sample=True, pad_token_id=50256)
        out_b = model.generate(**tokens, max_new_tokens=16, do_sample=True, pad_token_id=50256)
        prompts = tokenizer.batch_decode(out_a[:,:prompt_len], skip_special_tokens=True) 
        responses_a = tokenizer.batch_decode(out_a[:,prompt_len:], skip_special_tokens=True)
        responses_b = tokenizer.batch_decode(out_b[:,prompt_len:], skip_special_tokens=True)

        # Batch out outputs so we minimize the amount of file writes.
        # Lines are wrapped in quotes so the commas can go in the csv.
        write_batch += list(zip(prompts, responses_a, responses_b))
        if (step + 1) % write_steps == 0:
            writer.writerows(write_batch)
            write_batch.clear()


2it [00:23, 11.69s/it]

[('What happens to the insider trade profits?', " Well, there's a rule in this country that if you buy a stock and", ' There are tax laws that prevent corporations from paying taxes on the profits the insider makes'), ('Is CLM a stock or an ETF?', '\n\nNo. CLM is not a stock or an ETF. CLM', ' The stock market is the place to make the most money in crypto investing today.'), ('Ways to establish credit history for international student', ' applications may include a personal interview or questionnaire that is part of the online application process', ' applications:\n\nThe most important way to establish credit history for international student applications'), ('Do I have to pay a capital gains tax if I rebuy the same stock within 30 days?', "\n\nThe stock must have been sold by the taxpayer or under the owner's", '\n\nYes. Withholding and Capital Gain Taxes will always be required for un'), ('Can a credit card company raise my rates for making a large payment?', '\n\nIf your credit card

4it [00:48, 12.24s/it]

[('Home loan transferred to Freddie Mac — What does this mean?', ' — What it means for borrowers\n\nHow can lenders keep mortgage rates low?', ' — And on a more specific note, how did it all work?\n\n'), ('Why do I get a much better price for options with a limit order than the ask price?', '\n\nA limit order is like when you go to the store and want to', " That's because there is a lot of volume associated with the option trade (although"), ('How to execute a large stock purchase, relative to the order book?', '\n\nAn example of such a large buy-sell strategy using this model is', ' You might just be the person that needs to figure that out. What if your'), ('Where should my money go next: savings, investments, retirement, or my mortgage?', '\n\nA recent survey by NACPA found that the average American adult has', '\n\nHere are some suggestions:\n\nSaving: If you have a'), ('What are the benefits of opening an IRA in an unstable/uncertain economy?', "\n\nI believe it's extremely helpfu

5it [00:59, 11.97s/it]


KeyboardInterrupt: 

In [None]:
df = pd.read_csv(save_path)
df