In [31]:
from pathlib import Path
from collections import defaultdict

import altair as alt
import pandas
import matplotlib.pyplot as plt
from IPython.display import display, HTML

max_score = 0
num_nodes = 0
seq_len = 0

for fname in Path("output").glob("*.tsv"):
    parts = fname.stem.split('.')
    num_nodes = max(num_nodes, int(parts[0].replace("g", "")))
    seq_len = max(seq_len, int(parts[1].replace("s", "")))
    score = int(parts[2].replace("score", ""))

    if score > max_score:
        max_score = score

def create_dp_chart(df):
    heatmap = alt.Chart(df).mark_rect().encode(
        x=alt.X('offset:O', scale=alt.Scale(domain=list(range(0, seq_len+1)))),
        y=alt.Y('rank:O', scale=alt.Scale(domain=list(range(0, num_nodes+1)))),
        color=alt.Color('score:Q', scale=alt.Scale(domain=(0, max_score))),
        tooltip=['k:O', 'offset:O', 'rank:O', 'prev:N']
    )
    
    labels = alt.Chart(df).mark_text(baseline='middle').encode(
        alt.Text('score:Q'),
        x=alt.X('offset:O', scale=alt.Scale(domain=list(range(0, seq_len+1)))),
        y=alt.Y('rank:O', scale=alt.Scale(domain=list(range(0, num_nodes+1)))),
        color=alt.condition(
            alt.datum.score < (max_score / 2),
            alt.value('black'),
            alt.value('white')
        ),
        tooltip=['k:O', 'offset:O', 'rank:O', 'prev:N']
    )

    return (heatmap + labels).properties(
        width=300,
        height=300
    ).facet(column='state:N')

def create_fr_chart(df):
    heatmap = alt.Chart(df).mark_rect().encode(
        x=alt.X('k:O'),
        y=alt.Y('score:O'),
        color=alt.Color('offset:Q', scale=alt.Scale(domain=(0, seq_len+1))),
        tooltip=['k:O', 'rank:O', 'prev:N']
    )
    labels = alt.Chart(df).mark_text(baseline='middle').encode(
        alt.Text('offset:Q'),
        x=alt.X('k:O'),
        y=alt.Y('score:O'),
        color=alt.condition(
            alt.datum.score < (max_score / 2),
            alt.value('black'),
            alt.value('white')
        ),
        tooltip=['k:O', 'offset:O', 'rank:O', 'prev:N']
    )

    return (heatmap + labels).properties(
        width=300,
        height=300
    ).facet(column='state:N')

for i in range(max_score+1):
    if i % 2 > 0:
        continue

    if i > 70:
        break
        
    fname_before = Path("output") / f"g{num_nodes}.s{seq_len}.score{i}.before_extend.tsv"
    if fname_before.is_file():
        before_extend = pandas.read_csv(fname_before, sep='\t')
        display(HTML(f"<h3>Score {i}, before extend</h3>"))
        display(create_dp_chart(before_extend))
        display(create_fr_chart(before_extend))
        
        
    fname_after = Path("output") / f"g{num_nodes}.s{seq_len}.score{i}.after_extend.tsv"
    if fname_after.is_file():
        after_extend = pandas.read_csv(fname_after, sep='\t')
        display(HTML(f"<h3>Score {i}, after extend</h3>"))
        display(create_dp_chart(after_extend))
        display(create_fr_chart(after_extend))
        display(HTML("<hr>"))

    