In [None]:
import os
import re
import stanza
import pandas as pd
from tqdm import tqdm
from nltk import Tree

In [None]:
# NEEDS PYTORCH 2.1.2
parser = stanza.Pipeline(
    lang='en', 
    # processors='tokenize,mwt,pos,lemma,depparse'
    processors='tokenize,pos,constituency'
)

In [None]:
def natural_language_to_dyck(sentence, parser):
    doc = parser(sentence)
    tree = str(doc.sentences[0].constituency)
    tree = re.sub(r"[^()]", "", tree)
    return tree

def max_nested_depth(dyck_sentence):
    max_depth = 0
    current_depth = 0

    for char in dyck_sentence:
        if char == '(':
            current_depth += 1
            max_depth = max(max_depth, current_depth)
        elif char == ')':
            current_depth -= 1

    return max_depth

natural_language_to_dyck("i love you", parser)


In [None]:
task = "recipes"
model = "eb6_d6_c128_lr8e-4"
checkpoint = "30000"
data_path = "my_output/recipes"

In [None]:
with open(f"{data_path}/{model}/gen/dev_{checkpoint}.gen") as f:
    gens = f.readlines()

In [None]:
results = []
for gen in tqdm(gens[:10000], total=len(gens[:10000])):
    gen = gen.split("- ")
    for g in gen:
        g = g.strip()
        if len(g) > 2:
            dyck_sentence = natural_language_to_dyck(g, parser)
            results.append({
                "sentence": g,
                "dyck": dyck_sentence,
                "max_depth": max_nested_depth(natural_language_to_dyck(g, parser)),
                "length": len(dyck_sentence),
            })

In [None]:
df = pd.DataFrame(results)
df.to_csv(f"{data_path}/{model}/dyck/dev_{checkpoint}.csv", index=False)

In [None]:
df = pd.DataFrame(results)
# get only length & max_depth
df = df[["length", "max_depth"]]
# plot a heatmap of the max depth vs length
import seaborn as sns
import matplotlib.pyplot as plt
heatmap_data = df.pivot_table(index='length', columns='max_depth', aggfunc=len, fill_value=0)
mask = heatmap_data != 0

# Create the heatmap using Seaborn
plt.figure(figsize=(10, 8))
# possible cmpaps that higher is darker: 'viridis', 'rocket', 'mako', 'cividis'
# cmap = sns.diverging_palette(2000, -2000, as_cmap=True)
sns.heatmap(heatmap_data, cmap='viridis_r',  fmt='g', mask=~mask, linewidths=.5)#annot_kws={"size": 10}, annot=heatmap_data.values,
# make the x and y labels bigger
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
# xticks only every 5 with number shown
plt.xticks(range(0, 21, 5), range(0, 21, 5))
# yticks evert 10
plt.yticks(range(0, 101, 10), range(0, 101, 10))
# change xtick label
plt.xlabel('Max Depth', fontsize=16)
plt.ylabel('Length', fontsize=16)

# plt.title('Heatmap of Length vs Max Depth in Recipe Directions')
plt.show()