In [None]:
import os
import re
import subprocess

import matplotlib.pyplot as plt
import pandas as pd
from IPython.display import Video

In [None]:
checkpoint_dir = "_checkpoints/experiment_20231117_171147/"
epoch = 70

In [None]:
if not os.path.exists(os.path.join(checkpoint_dir, "tables")):
    command = f"run_evaluation.sh --checkpoint_dir {checkpoint_dir} --epoch {epoch}"
    subprocess.run(command, shell=True)

In [None]:
# load the data from csv files
df1 = pd.read_csv(
    os.path.join(checkpoint_dir, "tables", f"epoch_{epoch}", "continuous_masking.csv")
)
df2 = pd.read_csv(
    os.path.join(checkpoint_dir, "tables", f"epoch_{epoch}", "random_masking.csv")
)

# create subplots with reduced height
fig, axs = plt.subplots(1, 2, figsize=(10, 4))

# plot the first dataframe
axs[0].plot(
    df1["mask_lengths"], df1["filtering_incorrect_pixels"], ".-", label="Filtering"
)
axs[0].plot(
    df1["mask_lengths"], df1["smoothing_incorrect_pixels"], ".-", label="Smoothing"
)
axs[0].set_xlabel("Mask Length")
axs[0].set_ylabel("Fraction of Incorrect Pixels")
axs[0].set_title("Continuous Masking")
axs[0].legend()
axs[0].set_xlim(0, 50)  # set x-axis limit
axs[0].grid(True)  # add grid

# plot the second dataframe
axs[1].plot(
    df2["dropout_probabilities"],
    df2["filtering_incorrect_pixels"],
    ".-",
    label="Filtering",
)
axs[1].plot(
    df2["dropout_probabilities"],
    df2["smoothing_incorrect_pixels"],
    ".-",
    label="Smoothing",
)
axs[1].set_xlabel("Dropout Probability")
axs[1].set_ylabel("Fraction of Incorrect Pixels")
axs[1].set_title("Random Masking")
axs[1].legend()
axs[1].set_xlim(0, 1)  # set x-axis limit
axs[1].grid(True)  # add grid

# display the plot
plt.show()

In [None]:
def get_maximum_index(checkpoint_dir, epoch):
    directory = os.path.join(checkpoint_dir, "videos", f"epoch_{epoch}")
    file_names = os.listdir(directory)
    pattern = r"idx_(\d+)_mask_length_\d+\.mp4"
    indices = []
    for file_name in file_names:
        match = re.search(pattern, file_name)
        if match:
            index = int(match.group(1))
            indices.append(index)

    max_index = max(indices)
    return max_index

In [None]:
video_path = os.path.join(
    checkpoint_dir, "videos", f"epoch_{epoch}", "idx_2_mask_length_40.mp4"
)
for idx in range(get_maximum_index(checkpoint_dir, epoch)):
    for mask_length in (10, 20, 30, 40):
        print(f"Data index: {idx}")
        print(f"Mask length: {mask_length}")
        video_path = os.path.join(
            checkpoint_dir,
            "videos",
            f"epoch_{epoch}",
            f"idx_{idx}_mask_length_{mask_length}.mp4",
        )
        display(Video(video_path))