## Heatmap for differences between initialization with k=1 or k=3 anchor points

In [9]:
import numpy as np
import pandas as pd
import altair as alt

In [10]:
np.set_printoptions(precision=2, suppress=True)

In [11]:
initT_k1 = np.load("../cifar10_n02_k1_initT.npy")
print(initT_k1)

[[0.15 0.   0.   0.   0.   0.   0.   0.   0.   0.85]
 [0.   0.87 0.12 0.   0.   0.   0.   0.   0.   0.  ]
 [0.   0.   0.79 0.2  0.   0.   0.   0.   0.   0.  ]
 [0.   0.   0.   0.83 0.17 0.   0.   0.   0.   0.  ]
 [0.   0.   0.   0.   0.79 0.21 0.   0.   0.   0.  ]
 [0.   0.   0.   0.   0.   0.86 0.13 0.   0.   0.  ]
 [0.   0.   0.   0.   0.   0.   0.84 0.16 0.   0.  ]
 [0.   0.   0.   0.   0.   0.   0.   0.81 0.19 0.  ]
 [0.   0.   0.   0.   0.   0.   0.   0.   0.85 0.15]
 [0.16 0.   0.   0.   0.   0.   0.   0.   0.   0.84]]


In [12]:
x_coords, y_coords = np.meshgrid(range(initT_k1.shape[0]), range(initT_k1.shape[1]))

source = pd.DataFrame({
    'x': x_coords.ravel(),
    'y': y_coords.ravel(),
    'z': initT_k1.ravel()
})

domain = [0, 0.10, 1]
range_ = ["#000000",'#125ca4', "#7dcdfc"]
# 3. Create the heatmap using Altair
heatmap = alt.Chart(source).mark_rect().encode(
    # Map the 'x' column to the x-axis
       x=alt.X('x:O', title='Probability of Noise Label',
            axis=alt.Axis(
                titleFontSize=12,
                labelFontSize=12,
                grid=False  # Disable grid lines for clarity
            )
    ),
    
    # Map the 'y' column to the y-axis
    y=alt.Y('y:O', title='Probability of Clean Label',
            axis=alt.Axis(
                titleFontSize=12,
                labelFontSize=12,
                grid=False  # Disable grid lines for clarity
            )
    ),
    
    # Map the 'z' column to the color of the rectangles
    color=alt.Color('z:Q', title='Value', scale=alt.Scale(domain=domain, range=range_)) # ':Q' treats data as quantitative
).properties(
    title='Initial Transition Matrix k=1',
    width=600,
    height=600
)

# 4. Prepare data for the text labels (top 2 values per row)
# We group by 'y' (the row index) and find the two largest 'z' values.
top_indices = source.groupby('y')['z'].nlargest(2).index.get_level_values(1)
top_values_data = source.loc[top_indices]

# 5. Create the text layer
text_labels = alt.Chart(top_values_data).mark_text(
    align='center',
    baseline='middle',
    fontSize=12
).encode(
    x='x:O',
    y='y:O',
    # Display the 'z' value, formatted to two decimal places
    text=alt.Text('z:Q', format='.2f'),
    # Make text white on dark cells and black on light cells for readability
    color=alt.condition(
        alt.datum.z > 0.8,
        alt.value('white'),
        alt.value('white')
    )
)

# 6. Layer the heatmap and the text labels
# The '+' operator in Altair is used to layer charts.
final_chart_k1 = (heatmap + text_labels).properties(
    title='CIFAR10 Pair-20%: Initial Transition Matrix k=1',
    width=600,
    height=600
)
# ).configure_title(
#     font='monospace',
#     fontSize=16,
#     anchor='middle',
#     color='black'
# )

final_chart_k1.show()

In [13]:
initT_k3 = np.load("../cifar10_n02_k3_initT.npy")
print(initT_k3)

[[0.86 0.14 0.   0.   0.   0.   0.   0.   0.   0.  ]
 [0.27 0.61 0.12 0.   0.   0.   0.   0.   0.   0.  ]
 [0.   0.   0.75 0.25 0.   0.   0.   0.   0.   0.  ]
 [0.29 0.05 0.   0.58 0.08 0.   0.   0.   0.   0.  ]
 [0.   0.   0.   0.   0.84 0.16 0.   0.   0.   0.  ]
 [0.   0.   0.   0.   0.59 0.35 0.06 0.   0.   0.  ]
 [0.   0.   0.   0.   0.   0.6  0.36 0.04 0.   0.  ]
 [0.   0.   0.   0.   0.   0.   0.   0.82 0.18 0.  ]
 [0.   0.   0.   0.   0.   0.   0.   0.27 0.59 0.14]
 [0.19 0.   0.   0.   0.   0.   0.   0.   0.   0.81]]


In [14]:
x_coords, y_coords = np.meshgrid(range(initT_k3.shape[0]), range(initT_k3.shape[1]))

source = pd.DataFrame({
    'x': x_coords.ravel(),
    'y': y_coords.ravel(),
    'z': initT_k3.ravel()
})

domain = [0, 0.10, 1]
range_ = ["#000000",'#125ca4', "#7dcdfc"]
# 3. Create the heatmap using Altair
heatmap = alt.Chart(source).mark_rect().encode(
    # Map the 'x' column to the x-axis
       x=alt.X('x:O', title='Probability of Noise Label',
            axis=alt.Axis(
                titleFontSize=12,
                labelFontSize=12,
                grid=False  # Disable grid lines for clarity
            )
    ),
    
    # Map the 'y' column to the y-axis
    y=alt.Y('y:O', title='Probability of Clean Label',
            axis=alt.Axis(
                titleFontSize=12,
                labelFontSize=12,
                grid=False  # Disable grid lines for clarity
            )
    ),
    
    # Map the 'z' column to the color of the rectangles
    color=alt.Color('z:Q', title='Value', scale=alt.Scale(domain=domain, range=range_)) # ':Q' treats data as quantitative
).properties(
    title='Initial Transition Matrix k=1',
    width=600,
    height=600
)

# 4. Prepare data for the text labels (top 2 values per row)
# We group by 'y' (the row index) and find the two largest 'z' values.
top_indices = source.groupby('y')['z'].nlargest(2).index.get_level_values(1)
top_values_data = source.loc[top_indices]

# 5. Create the text layer
text_labels = alt.Chart(top_values_data).mark_text(
    align='center',
    baseline='middle',
    fontSize=12
).encode(
    x='x:O',
    y='y:O',
    # Display the 'z' value, formatted to two decimal places
    text=alt.Text('z:Q', format='.2f'),
    # Make text white on dark cells and black on light cells for readability
    color=alt.condition(
        alt.datum.z > 0.8,
        alt.value('white'),
        alt.value('white')
    )
)

# 6. Layer the heatmap and the text labels
# The '+' operator in Altair is used to layer charts.
final_chart_k3 = (heatmap + text_labels).properties(
    title='CIFAR10 Pair-20%: Initial Transition Matrix k=3',
    width=600,
    height=600
)
# ).configure_title(
#     font='monospace',
#     fontSize=16,
#     anchor='middle',
#     color='black'
# )

final_chart_k3.show()

In [15]:
final_plot = final_chart_k1 | final_chart_k3
final_plot.configure_title(
    font='monospace',)
final_plot.show()

In [16]:
final_plot.save("init_matrix_visualization.pdf")