In [2]:
from datasets import load_from_disk, Dataset
import evaluate

full_ds = load_from_disk("data/chat_history")

# We only want to evaluate on the test set to ensure we didn't overfit
sample_data = full_ds["test"].to_pandas().rename(columns={"text": "ground_truth"})

sample_data = sample_data.sample(n=100, random_state=1)

print(len(sample_data))
sample_data.columns


100




Index(['prompt', 'ground_truth', 'input_ids', 'attention_mask', 'labels'], dtype='object')

In [3]:
from transformers import AutoModelForSeq2SeqLM, pipeline
from transformers.pipelines.pt_utils import KeyDataset
from rich.progress import track


def create_inference_function(model):
    pipe = pipeline(
        "text2text-generation",
        model=model,
        tokenizer="google/flan-t5-base",
        device=0,
    )

    def run_inference(prompts, **kwargs):
        # pipelines appear to only return iterators when used with a KeyDataset, and we have to create a Dataset to create a KeyDataset
        prompts_ds = Dataset.from_dict({"prompt": prompts})
        prompts_ds = KeyDataset(prompts_ds, "prompt")

        all_predictions = pipe(
            prompts_ds,
            max_length=200,
            num_beams=5,
            repetition_penalty=3.0,
            num_return_sequences=5,
            batch_size=6,
            **kwargs,
        )
        for predictions in all_predictions:
            yield [p["generated_text"] for p in predictions]

    return run_inference


# test_fn = create_inference_function(
#     AutoModelForSeq2SeqLM.from_pretrained(f"google/flan-t5-base")
# )
# test_output = list(track(
#     test_fn(test_ds["prompt"].tolist(), batch_size=16), total=len(test_ds["prompt"])
# ))


In [4]:
models_to_compare = {
    # "t5-lg": AutoModelForSeq2SeqLM.from_pretrained(f"google/flan-t5-large"),
    "t5-lg-29000": AutoModelForSeq2SeqLM.from_pretrained(
        f"models/flan-t5-large-telegram/checkpoint-29000"
    ),
    "t5-lg-42000": AutoModelForSeq2SeqLM.from_pretrained(
        f"models/flan-t5-large-10ep/checkpoint-42000"
    ),
}


models_to_compare = {
    k: create_inference_function(v) for k, v in models_to_compare.items()
}


In [5]:
from dotenv import load_dotenv
from diskcache import Cache
import time

# Load the OpenAI API key from the .env file
load_dotenv()

import openai

cache = Cache("data/openai_cache")


@cache.memoize("oai_infer_dv3_n5_mt200")
def openai_inference(prompt):
    completions = openai.Completion.create(
        engine="text-davinci-003", prompt=prompt, n=5, max_tokens=200
    )

    return [completion["text"] for completion in completions["choices"]]


def openai_batch_inference(prompts):
    for prompt in prompts:
        yield openai_inference(prompt)
        time.sleep(1)


openai_inference("a man a plan a canal")

models_to_compare["openai"] = openai_batch_inference


In [6]:
from rich.progress import Progress, MofNCompleteColumn

predictions = {}

prompts = list(sample_data["prompt"])

with Progress(*Progress.get_default_columns(), MofNCompleteColumn()) as progress:
    for model_name, inference_fn in models_to_compare.items():
        # if model_name != 't5-lg-29000':
        #     continue
        predictions[model_name] = list(
            progress.track(
                inference_fn(prompts), description=model_name, total=len(prompts)
            )
        )

sample_data = sample_data.assign(**predictions)


Output()

In [8]:
import pprint
from IPython.display import display, HTML

sample_formatted = sample_data.copy()[["prompt", "ground_truth", *predictions.keys()]]


for col in predictions.keys():
    sample_formatted[col] = sample_formatted[col].apply(
        lambda x: f"\n{'-'*20}\n".join(x)
    )

def df_to_html(df):
    table_html = df.to_html(index=False, escape=False).replace("\\n", "<br>")
    
    table_styling = """
    <style>
        table {
            /* Make all columns the same width */
            table-layout: fixed;
            background-color: #f5f5f5;
            
            /* break long urls */
            word-wrap: break-word;
            word-break: break-word;
        }
        th, td {
            max-width: 300px;
            text-align: left;
            vertical-align: top;
            font-family: sans-serif;
        }
    </style>
    """

    return table_styling + table_html

display(HTML(df_to_html(sample_formatted.head(2))))

with open("data/table.html", "w") as f:
    f.write(df_to_html(sample_formatted))


prompt,ground_truth,t5-lg-29000,t5-lg-42000,openai
"Respond to chat as Kyle Corbitt: Recent messages: Jacqueline Corbitt: Vanessa's cousin Mikey wants everyone to know that his T-Rex is bigger than Vanessa 😁 David Corbitt: Lucky for him she can't walk yet David Corbitt: Merry Christmas Scott Corbitt: Shoulda been Bon Nadal my bad Scott Corbitt: Hey everybody, you keep promising to post pictures. Let's make that happen (so I can repost some of them to Instagram) Christina Corbitt: Merry Christmas 🎄 Jacqueline Corbitt: Baby's first stocking Jacqueline Corbitt: Vanessa loves her Uncle Kevin, much to the jealousy of Uncle Kevin's own child 😂 Christina Corbitt: We have a white Christmas in Washington 😌 Karen Corbitt: Oh my. Are the roads ok? Kyle Corbitt:",Wow nice!!,"Yeah they're ok -------------------- yep, they're pretty good -------------------- yep, they're all good -------------------- yep -------------------- yep!",Yeah they're pretty ok here -------------------- Yeah they're pretty ok right now -------------------- Yeah they're ok -------------------- Yeah they're pretty ok -------------------- Yeah they're all good here,"Merry Christmas everyone! And Karen, I heard the roads should be okay, but drive safely! -------------------- Hey everyone! Wow, what a great conversation! I'm sure Mikey's T-Rex is huge! And yeah, let's post some pics so the rest of the family can see what we're up to. It sounds like Vanessa already has quite the fan club! As for the snow out in Washington, Karen, I'm sure it looks beautiful, but please be careful if you're out on the roads! 🤗 -------------------- Ha! That's funny, Jacqueline. Merry Christmas, everyone! Let's see those pictures! Karen, yes, the roads have been cleared. -------------------- Merry Christmas to all! It sounds like everyone is having a great holiday. And yes Karen, the roads in Washington should be OK - mostly just snowing here in New York. -------------------- Hey everyone! Haha, I love the stories about Vanessa and Uncle Kevin. Definitely wanna see those photos! Karen, roads here in Phoenix are fine. Wish you all a very Merry Christmas!"
Respond to chat as Kyle Corbitt: Recent messages: David Corbitt: I am really excited to try this. They say 46 minutes of flight time but if it works for 30 I'll be happy Kyle Corbitt: nice Kyle Corbitt:,been thinking about buying a drone forever,if it works for 30 I'll be happy -------------------- if it works for 30 I'd be happy -------------------- if it works for 30 that's great -------------------- if it works for 30 I'm happy -------------------- what's the price?,if it works for 30 that's good enough for me -------------------- if it works for 30 that's great -------------------- if it works for 30 that's good enough for me! -------------------- if it works for 30 I'm happy -------------------- if it works for 30 minutes that's good enough for me,"Sounds great, David! I'm excited to hear how it works out for you! -------------------- I'm excited to try this too! It sounds like a fun experience. Let's hope it works for at least the full 46 minutes! -------------------- That sounds great, David. You should definitely give it a try and let us know how it goes! -------------------- That sounds great, David! I'm looking forward to testing it out -- let's aim for at least 30 minutes of flight time and see how it goes. -------------------- That sounds great! I'm sure you'll be able to get at least 30 minutes out of it. Good luck with your flight!"
