In [48]:
from datasets import load_from_disk, Dataset
import evaluate

full_ds = load_from_disk("data/chat_history")

# We only want to evaluate on the test set to ensure we didn't overfit
sample_data = full_ds["test"].to_pandas().rename(columns={"text": "ground_truth"})

sample_data = sample_data.sample(n=100, random_state=1)

print(len(sample_data))
sample_data.columns


100




Index(['prompt', 'ground_truth', 'input_ids', 'attention_mask', 'labels'], dtype='object')

In [49]:
models_to_compare = [
    ('t5-lg-base', 'google/flan-t5-large'),
    ('t5-lg-finetuned', 'models/flan-t5-large-10ep/checkpoint-42000'),
    ('t5-xl-base', 'google/flan-t5-xl'),
    ('t5-xl-finetuned', 'models/flan-t5-xl/checkpoint-38000'),
]

In [50]:
from transformers import AutoModelForSeq2SeqLM, pipeline
from transformers.pipelines.pt_utils import KeyDataset
from rich.progress import Progress, MofNCompleteColumn
import torch

prompts = list(sample_data["prompt"])



with Progress(*Progress.get_default_columns(), MofNCompleteColumn()) as progress:    
    for model_name, model_path in models_to_compare:
        print(f"Loading model {model_name} from {model_path}...")
        model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
        pipe = pipeline(
            "text2text-generation",
            model=model,
            tokenizer="google/flan-t5-base",
            device='cuda:0',
        )
        
        # pipelines appear to only return iterators when used with a KeyDataset,
        # and we have to create a Dataset to create a KeyDataset
        prompts_ds = Dataset.from_dict({"prompt": prompts})
        prompts_ds = KeyDataset(prompts_ds, "prompt")

        print(f"Running inference for {model_name}...")
        
        predictions_generator = pipe(
            prompts_ds,
            max_length=50,
            do_sample=True,
            top_p=0.9,
            repetition_penalty=2.0,
            num_return_sequences=5,
            batch_size=4,
        )
        all_predictions = []
        
        for predictions in progress.track(predictions_generator, description=model_name, total=len(prompts)):
            all_predictions.append([p["generated_text"] for p in predictions])
        
        sample_data[model_name] = all_predictions
        
        # Release GPU memory
        model.cpu()
        torch.cuda.empty_cache()

Output()

In [51]:
from dotenv import load_dotenv
from diskcache import Cache
import time
from rich.progress import track

# Load the OpenAI API key from the .env file
load_dotenv()

import openai

cache = Cache("data/openai_cache")

@cache.memoize("oai_infer_dv3_n5_mt200")
def openai_inference(prompt):
    completions = openai.Completion.create(
        engine="text-davinci-003", prompt=prompt, n=5, max_tokens=200
    )

    # Watch out for that rate limit
    time.sleep(1)
    return [completion["text"] for completion in completions["choices"]]


openai_inference("a man a plan a canal")

sample_data["davinci-003"] = [openai_inference(prompt) for prompt in track(prompts)]

# models_to_compare["openai"] = openai_batch_inference


Output()

In [52]:
import pprint
from IPython.display import display, HTML

# Pull the first item from the models_to_compare tuples
model_cols = [model[0] for model in models_to_compare]
model_cols.append("davinci-003")
sample_formatted = sample_data.copy()[["prompt", "ground_truth", *model_cols]]


for col in model_cols:
    sample_formatted[col] = sample_formatted[col].apply(
        lambda x: f"\n{'-'*20}\n".join(x)
    )

def df_to_html(df):
    table_html = df.to_html(index=False, escape=False).replace("\\n", "<br>")
    
    table_styling = """
    <style>
        table {
            /* Make all columns the same width */
            table-layout: fixed;
            background-color: #f5f5f5;
            
            /* break long urls */
            word-wrap: break-word;
            word-break: break-word;
        }
        th, td {
            max-width: 300px;
            text-align: left;
            vertical-align: top;
            font-family: sans-serif;
        }
    </style>
    """

    return table_styling + table_html

display(HTML(df_to_html(sample_formatted.head(2))))

with open("data/table.html", "w") as f:
    f.write(df_to_html(sample_formatted))


prompt,ground_truth,t5-xl-base,t5-xl-finetuned,t5-lg-base,t5-lg-finetuned,davinci-003
"Respond to chat as Kyle Corbitt: Recent messages: Jacqueline Corbitt: Vanessa's cousin Mikey wants everyone to know that his T-Rex is bigger than Vanessa 😁 David Corbitt: Lucky for him she can't walk yet David Corbitt: Merry Christmas Scott Corbitt: Shoulda been Bon Nadal my bad Scott Corbitt: Hey everybody, you keep promising to post pictures. Let's make that happen (so I can repost some of them to Instagram) Christina Corbitt: Merry Christmas 🎄 Jacqueline Corbitt: Baby's first stocking Jacqueline Corbitt: Vanessa loves her Uncle Kevin, much to the jealousy of Uncle Kevin's own child 😂 Christina Corbitt: We have a white Christmas in Washington 😌 Karen Corbitt: Oh my. Are the roads ok? Kyle Corbitt:",Wow nice!!,"The roads are mostly fine. -------------------- Karen Corbitt: yes, the roads were cleared yesterday afternoon -------------------- Jacqueline Corbitt: Most of them are. We don't have too much traffic compared to last year, but it's not good. -------------------- Yes, they are mostly good. -------------------- No, we had problems with some streets.",Washington state is a fun place! -------------------- Pretty good -------------------- Pretty good -------------------- Seattle is a pretty good example of how not to treat road salt -------------------- Pretty good,"I just found out. -------------------- Christina Corbitt: You guys are awesome. -------------------- Christina Corbitt: No. It's just frozen road ice. -------------------- Christmas -------------------- Karen Corbitt: I can't believe it's snowing, and I want the roads to be safe.","Well, that's another reason to love Washington -------------------- Yep, aren't too worried about the roads here at all. Not sure if they should be closed because of the weather though -------------------- Yeah pretty good for Washington! -------------------- yes, they're ok! -------------------- I mean, everybody said they’d be fine for the whole trip, right?","Merry Christmas everyone! And Karen, I heard the roads should be okay, but drive safely! -------------------- Hey everyone! Wow, what a great conversation! I'm sure Mikey's T-Rex is huge! And yeah, let's post some pics so the rest of the family can see what we're up to. It sounds like Vanessa already has quite the fan club! As for the snow out in Washington, Karen, I'm sure it looks beautiful, but please be careful if you're out on the roads! 🤗 -------------------- Ha! That's funny, Jacqueline. Merry Christmas, everyone! Let's see those pictures! Karen, yes, the roads have been cleared. -------------------- Merry Christmas to all! It sounds like everyone is having a great holiday. And yes Karen, the roads in Washington should be OK - mostly just snowing here in New York. -------------------- Hey everyone! Haha, I love the stories about Vanessa and Uncle Kevin. Definitely wanna see those photos! Karen, roads here in Phoenix are fine. Wish you all a very Merry Christmas!"
Respond to chat as Kyle Corbitt: Recent messages: David Corbitt: I am really excited to try this. They say 46 minutes of flight time but if it works for 30 I'll be happy Kyle Corbitt: nice Kyle Corbitt:,been thinking about buying a drone forever,"will you come up? -------------------- David Corbitt: They have a lot of great plans but the flights are so expensive. -------------------- David Corbitt: The only thing missing is me! -------------------- David Corbitt: I don't know about the price, but I am sure it will be a good trip. -------------------- David Corbitt: yup, this looks funky.","can you send me a screenshot when it starts recording? -------------------- if you think it works will let us know! -------------------- let's try it! -------------------- how was the experience? -------------------- ok 30 isn't that bad, I guess?",glad to know you are having fun -------------------- no problem. -------------------- is it the one from taht you love? -------------------- I'll take it -------------------- ok,"if they don't give you the option for an actual phone, would that work better? -------------------- what's your plan to get there? -------------------- what do they call it? -------------------- where do you want to fly? -------------------- how long will it take you to get from chicago-san bruno airport?","Sounds great, David! I'm excited to hear how it works out for you! -------------------- I'm excited to try this too! It sounds like a fun experience. Let's hope it works for at least the full 46 minutes! -------------------- That sounds great, David. You should definitely give it a try and let us know how it goes! -------------------- That sounds great, David! I'm looking forward to testing it out -- let's aim for at least 30 minutes of flight time and see how it goes. -------------------- That sounds great! I'm sure you'll be able to get at least 30 minutes out of it. Good luck with your flight!"


In [53]:
import numba