In [15]:
import numpy as np
import matplotlib.pyplot as plt

def plot_segmentation_comparison(gt, asot, omp, comp, asot_mapping=None, name=''):
    """
    Visualize segmentation comparison across Ground Truth, ASOT, OMPN, and CompILE.

    Parameters:
        gt (list or np.array): Ground truth labels.
        asot (list or np.array): ASOT predicted labels.
        omp (list or np.array): OMPN predicted labels.
        comp (list or np.array): CompILE predicted labels.
        asot_mapping (dict): Mapping from ASOT predicted labels to GT labels.
        name (str): Title for the plot.
    """
    gt = np.array(gt)
    asot = np.array(asot)
    omp = np.array(omp)
    comp = np.array(comp)
    
    n_frames = len(gt)
    assert all(len(arr) == n_frames for arr in [asot, omp, comp]), "All input arrays must have the same length."

    # Get all unique labels across all predictions and GT
    all_labels = np.unique(np.concatenate([gt, list(asot_mapping.keys()) if asot_mapping else [], omp, comp]))
    if -1 not in all_labels:
        all_labels = np.append(all_labels, -1)

    n_class = len(all_labels)
    if n_class <= 20:
        cmap = plt.get_cmap('tab20')
    else:
        cmap1 = plt.get_cmap('tab20')
        cmap2 = plt.get_cmap('tab20b')
        cmap = lambda x: cmap1(round(x * n_class / 20., 2)) if x <= 19. / n_class else cmap2(round((x - 20 / n_class) * n_class / 20, 2))

    # Assign consistent colors for labels
    colors = {}
    for i, label in enumerate(all_labels):
        if label == -1:
            colors[label] = (0, 0, 0)
        else:
            colors[label] = cmap(i / n_class)

    def plot_segments(ax, sequence, label, color_mapping=None):
        ax.set_ylabel(label, fontsize=20, rotation=0, labelpad=40, verticalalignment='center')
        ax.set_yticklabels([])
        ax.set_xticklabels([])
        segment_boundaries = np.where(sequence[1:] - sequence[:-1])[0] + 1
        segment_boundaries = np.concatenate(([0], segment_boundaries, [len(sequence)]))
        for start, end in zip(segment_boundaries[:-1], segment_boundaries[1:]):
            original_label = sequence[start]
            label_to_use = color_mapping.get(original_label, original_label) if color_mapping else original_label
            ax.axvspan(start / n_frames, end / n_frames, facecolor=colors.get(label_to_use, (0, 0, 0)), alpha=1.0)
            ax.axvline(start / n_frames, color='black', linewidth=2)
            ax.axvline(end / n_frames, color='black', linewidth=2)

    fig = plt.figure(figsize=(16, 6))
    plt.axis('off')
    plt.title(name, fontsize=25, pad=20)

    axs = [fig.add_subplot(4, 1, i+1) for i in range(4)]
    plot_segments(axs[0], gt, 'Truth')
    plot_segments(axs[1], asot, 'ASOT', color_mapping=asot_mapping or {})
    plot_segments(axs[2], omp, 'OMPN')
    plot_segments(axs[3], comp, 'CompILE')

    fig.tight_layout()
    return fig


def get_asot_predicted_list(task_name, number):

    path = f"../paper_runs/{task_name}/{task_name}_pixels_big_{task_name}_pixels_big/version_0/predicted_skills/craftax_{number}_skills.txt"

    with open(path, "r") as file:
        lines = file.readlines()
        asot_predicted_list = [int(line.strip()) for line in lines]
    
    return asot_predicted_list

def get_gt_list(task_name, number):
    txt_path = f"../Traces/{task_name}/{task_name}_pixels_big/groundTruth/craftax_{number}"
    mapping_path = f"../Traces/{task_name}/{task_name}_pixels_big/mapping/mapping.txt"

    with open(mapping_path, "r") as f:
        mapping_lines = f.readlines()
        mapping = {}
        for line in mapping_lines:
            parts = line.strip().split()
            if len(parts) == 2:
                mapping[int(parts[0])] = parts[1]

    # Reverse the mapping: label name -> int
    reverse_mapping = {v: k for k, v in mapping.items()}

    with open(txt_path, "r") as f:
        lines = f.readlines()
        gt_list = [int(reverse_mapping[line.strip()]) for line in lines]
    
    return gt_list


In [None]:
# WSWS STATIC CRAFTAX 406

GT =        get_gt_list("wsws_static", 406)
ASOT =      get_asot_predicted_list("wsws_static", 406)

OMPN =      [1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0] #TODO
CompILE =   [1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0] #TODO

asot_map =  {0: 0, 1: 1}

fig = plot_segmentation_comparison(GT, ASOT, OMPN, CompILE, asot_map, name="Wood, Stone, Wood, Stone : Static")
plt.savefig('wsws_static/bad_wsws_static_craftax_406.pdf', bbox_inches='tight', dpi=300)

In [None]:
# WSWS RANDOM CRAFTAX 202

GT =        get_gt_list("wsws_random", 202)
ASOT =      get_asot_predicted_list("wsws_random", 202)

OMPN =      [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1] #TODO
CompILE =   [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1]  #TODO

asot_map =  {0: 1, 1: 0}

fig = plot_segmentation_comparison(GT, ASOT, OMPN, CompILE, asot_map, name="Wood, Stone, Wood, Stone : Random")
plt.savefig('wsws_random/bad_wsws_random_craftax_202.pdf', bbox_inches='tight', dpi=300)

In [None]:
# Stone Pick Static Craftax 493

GT =        get_gt_list("stone_pick_static", 493)
ASOT =      get_asot_predicted_list("stone_pick_static", 493)

OMPN =      [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 2, 2, 4, 4, 0, 0, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3] #TODO
CompILE =   [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3] #TODO

asot_map =  {0: 4, 1:1, 2:3, 3:2, 4:0}

fig = plot_segmentation_comparison(GT, ASOT, OMPN, CompILE, asot_map, name="Stone Pickaxe : Static")
plt.savefig('stone_pick_static/bad_stone_pick_static_craftax_493.pdf', bbox_inches='tight', dpi=300)


In [None]:
# Stone Pick Random Craftax 403

GT =        get_gt_list("stone_pick_random", 403)
ASOT =      get_asot_predicted_list("stone_pick_random", 403)

OMPN =     [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 4, 4, 3, 3, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
CompILE =   [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

asot_map =  {0: 3, 1:2, 2:0, 3:4, 4:1}

fig = plot_segmentation_comparison(GT, ASOT, OMPN, CompILE, asot_map, name="Stone Pickaxe : Random")
plt.savefig('stone_pick_random/bad_stone_pick_random_craftax_403.pdf', bbox_inches='tight', dpi=300)


In [None]:
# Mixed Static Craftax 275

GT =        get_gt_list("mixed_static", 275)
ASOT =      get_asot_predicted_list("mixed_static", 275)

OMPN =      [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2]
CompILE =  [1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]

asot_map =  {0: 1, 1:4, 2:3, 3:0, 4:2}

fig = plot_segmentation_comparison(GT, ASOT, OMPN, CompILE, asot_map, name="Stone Pickaxe : Random")
plt.savefig('mixed_static/bad_mixed_static_craftax_275.pdf', bbox_inches='tight', dpi=300)
