In [None]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join
from tqdm import tqdm
import pandas as pd
import sys
from IPython.display import display, HTML
from typing import List
from mprompt.modules.emb_diff_module import EmbDiffModule
import numpy as np
import matplotlib
import imodelsx.util
import re
import scipy.special
from spacy.tokenizer import Tokenizer
from spacy.lang.en import English

# Get prompts

In [None]:
expls = [
    'baseball',
    'animals',
    'water',
]

In [None]:
prompt_init = 'Write the beginning paragraph of a story about {expl}. Make sure it contains several references to {expl}.'
prompt_continue = 'Write the next paragraph of the story, but now make it about {expl}.'

In [None]:
prompts = [prompt_init.format(expl=expls[0])] + [prompt_continue.format(expl=expl) for expl in expls[1:]]

In [None]:
for prompt in prompts:
    print(prompt)

# Get data

In [None]:
# Using chatbot for now
paragraphs = [
    "The crack of the bat echoed through the stadium as the pitcher windmilled his arm and fired a fastball down the middle of the plate. The batter's eyes lit up as he swung with all his might, sending the ball sailing high into the sky. The center fielder raced back, tracking the ball's flight with a keen eye, ready to make the catch. The crowd held its breath as the ball descended, coming closer and closer to the fielder's outstretched glove. With a satisfying thud, the ball landed squarely in the pocket, and the center fielder triumphantly jogged off the field. It was just another day at the ballpark, where the crack of the bat and the roar of the crowd were the soundtrack to America's favorite pastime: baseball.",
    "As the center fielder jogged off the field, a family of ducks waddled onto the grass. The mother duck led her ducklings towards a small pond near the outfield, quacking softly to each other. The players watched in amusement as the ducks made themselves at home, seemingly oblivious to the fact that they were in the middle of a baseball game. Suddenly, a squirrel darted across the field, causing the ducks to scatter in all directions. The players laughed as they watched the animals go about their business, momentarily forgetting about the game they were playing. It was a reminder that, despite all the excitement and drama of the sport, the natural world continued to carry on around them.",
    "In the distance, the sound of crashing waves could be heard, a reminder that the stadium was located just a stone's throw away from the ocean. The salty sea air mingled with the smell of freshly cut grass, creating a unique aroma that was both refreshing and invigorating. As the game went on, the temperature began to rise, and fans could be seen fanning themselves with whatever they could find. Suddenly, a gust of wind picked up, and a fine mist sprayed over the crowd, providing some much-needed relief. The players on the field looked up as they felt the cool droplets on their skin, grateful for the natural air conditioning that the ocean breeze provided.",
]

# Visualize data heatmap

### Get embedding dists

In [None]:
mod = EmbDiffModule()
nlp = English()
nlp.tokenizer = Tokenizer(nlp.vocab, token_match=re.compile(r'\S+').match) # only split on whitespace

In [122]:
def colorize(words: List[str], color_array: np.ndarray[float],
             char_width_max=60, title: str=None, subtitle: str=None):
    '''
    Colorize a list of words based on a color array.
    color_array
        an array of numbers between 0 and 1 of length equal to words
    '''
    cmap = matplotlib.cm.get_cmap('viridis')
    template = '<span class="barcode"; style="color: black; background-color: {}">{}</span>'
    colored_string = ''
    char_width = 0
    for word, color in zip(words, color_array):
        char_width += len(word)
        color = matplotlib.colors.rgb2hex(cmap(color)[:3])
        colored_string += template.format(color, '&nbsp' + word + '&nbsp')
        if char_width >= char_width_max:
            colored_string += '</br>'
            char_width = 0

    if subtitle:
        colored_string = f'<h5>{subtitle}</h5>\n' + colored_string
    if title:
        colored_string = f'<h3>{title}</h3>\n' + colored_string
    return colored_string

# calculate moving average in a window
def moving_average(a, n=3):
    ret = np.cumsum(a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return np.array(a[: n - 1].tolist() + (ret[n - 1:] / n).tolist())

In [123]:
story_running = ''
for i in range(len(expls)):
    expl = expls[i]
    text = paragraphs[i].lower()
    words = text.split()
    prompt = prompts[i]

    mod._init_task(expl)
    print(mod.target_str)
    ngrams = imodelsx.util.generate_ngrams_list(text, ngrams=3, tokenizer_ngrams=nlp.tokenizer)
    ngrams = [words[0], words[0] + ' ' + words[1]] + ngrams
    neg_dists = mod(ngrams)
    assert len(ngrams) == len(words) == len(neg_dists)

    # plt.plot(neg_dists)
    # plt.plot(moving_average(neg_dists, n=5))
    # neg_dists = moving_average(neg_dists, n=3)

    # neg_dists = scipy.special.softmax(neg_dists)
    neg_dists = (neg_dists - neg_dists.min()) / (neg_dists.max() - neg_dists.min())
    neg_dists = neg_dists / 2 + 0.5 # shift to 0.5-1 range

    s = colorize(words, neg_dists, title=expl, subtitle=prompt)
    display(HTML(s))

    story_running += ' ' + s

with open('../results/story_running.html', 'w') as f:
    f.write(story_running)

baseball


100%|██████████| 5/5 [00:00<00:00,  5.00it/s]


animals


100%|██████████| 4/4 [00:00<00:00,  4.16it/s]


water


100%|██████████| 4/4 [00:00<00:00,  4.80it/s]
