In [1]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import matplotlib.pyplot as plt
import pickle
pl.seed_everything(42, workers=True)
from set_models import (
    SetCNNEmbedder,
    SetSequenceModel,
    MaxPoolModule,
    FirstPoolModule,
    MeanPoolModule,
    SetClassifierLayer
)

from simple_abstractor import SimpleAbstractorEncoder
from set_data_lit import SetTriplesDataModule, SetCardDataModule
from set_data import SetCardBaseDataset, SetTriplesDataset



Seed set to 42


In [None]:
from itertools import combinations
import logging

logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)

In [None]:
logger = False
trainer_kwargs = dict(
    max_epochs=1,
    precision="16",
    logger=logger,
    callbacks=[
        EarlyStopping(monitor="val_loss", mode="min", patience=15),
    ],
    val_check_interval=25,
    deterministic=True,
)
seq_model = torch.load("./snellius_checkpoints/e2e_long_seq.pth")


In [None]:
seq_model.eval()
for feature_states_used in combinations([0,1,2], 2):
    ds = SetCardBaseDataset(feature_states_used=feature_states_used, features_used=[0,1,2,3])
    dm = SetTriplesDataModule(
        ds,
        batch_size=64,
        label_choice="is_set",
        balanced_sampling=True,  # Enable balanced sampling
        val_split=0.01,
        test_split=0.9,
        # balanced_subset= False,
    )
    dm.setup()
    print(feature_states_used)

    trainer = pl.Trainer(**trainer_kwargs)
    # trainer.fit(seq_model, dm)
    test_res = trainer.test(seq_model, dm)
    print(test_res)

In [None]:
# seq_model = torch.load("./checkpoints/full_seq_model.pt")


In [None]:
import logging

logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)

class ReshapeModule(nn.Module):
  def __init__(self, embed_dim):
    super(ReshapeModule, self).__init__()
    self.embed_dim = embed_dim

  def forward(self, x):
    return x.view(-1, self.embed_dim)

In [None]:
def precompute_Rel_attention_hidden_states(dm, module, batch_size=128):
    module.eval()
    ds_len = dm.triples_dataset.__len__()
    with torch.no_grad():
        A = dm.triples_dataset.triples_hidden_states
        X = dm.triples_dataset.get_from_setcard_dataset(torch.arange(ds_len))

        attn_output = module.self_attn(
            query=X, key=X, value=A,
        )

        A = A + attn_output

        A = module.norm1(A)

        dm.triples_dataset.triples_hidden_states = A

def precompute_ff_hidden_states(dm, module, batch_size=128):
    module.eval()
    with torch.no_grad():
        A = dm.triples_dataset.triples_hidden_states

        ff_output = module.ff(A)

        A = A + ff_output

        A = module.norm2(A)
        
        dm.triples_dataset.triples_hidden_states = A

In [None]:
def probe_contextual_hidden_states(dm):
    n_features = 4
    accuracies = {}

    for label_choice in dm.triples_dataset.label_functions.keys():

        dm.set_labels_dm(label_choice)
        try:
            embed_dim = dm.triples_dataset.triples_hidden_states.size(-1)
        except:
            embed_dim = dm.triples_dataset.setcard_dataset.card_embeds.size(-1)

        if label_choice == "features_pointwise":
            aggregate_seq = ReshapeModule(embed_dim)
            seq_len = 1
        else:
            aggregate_seq = nn.Flatten()
            seq_len = 3

        classifier = SetClassifierLayer(label_choice=label_choice, embed_dim=embed_dim, seq_len=seq_len, n_features=n_features)
        probe = SetSequenceModel(classifier=classifier, aggregate_seq=aggregate_seq)

        trainer_kwargs = dict(
        max_epochs=20,
        logger=False,   
            callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=5)],
        val_check_interval=10,
        enable_progress_bar=False,
        enable_model_summary = False,
        enable_checkpointing=False,
        )

        trainer = pl.Trainer(**trainer_kwargs)
        trainer.fit(probe, dm)
        accuracies[label_choice] = trainer.test(probe, dm, verbose=False)[0]
    return accuracies

In [None]:
cnn = seq_model.base_embedder
ds = SetCardBaseDataset(cnn, feature_states_used=[0,1])
dm = SetTriplesDataModule(
    ds,
    batch_size=64,
    label_choice=label_choice,
    val_split=0.1,
    test_split=0.1,
    balanced_sampling=True,  # Enable balanced sampling
    balanced_subset= False,
)
dm.setup()

In [None]:

plot_width = 2
plot_heigth = 3.5

In [None]:
2* plot_width* plot_heigth * 10

In [None]:
import torch

In [None]:
40*70

In [None]:
math.sqrt(4)

In [105]:
plot_width = 2
plot_heigth = 1.75 * plot_width

dpi = (40*70)/(plot_width*plot_heigth
                *20
                )
dpi

20.0

In [116]:
(40*70)/(plot_width*plot_heigth * 7.7) 

51.94805194805195

In [18]:
5 /7

0.7142857142857143

In [35]:
def generate_custom_card(
    number_index, color_index, pattern_index, shape_index, data_dir=None
):

    numbers = [1, 2, 3, 4]
    colors = ["red", "green", "purple", "yellow"]
    patterns = ["empty", "striped", "solid", "plus"]
    shapes = ["diamond", "oval", "bar", "tie"]

    number = numbers[number_index]
    color = colors[color_index]
    pattern = patterns[pattern_index]
    shape = shapes[shape_index]

    plot_width = 5
    plot_heigth = 1.4 * plot_width

    # don't ask about this
    magic_correction_factor = 7.7
    dpi = (50 * 70) / (plot_width * plot_heigth * magic_correction_factor)

    fig, ax = plt.subplots(
        figsize=(plot_width, plot_heigth), dpi=dpi
    )  # Adjusted figsize for 40x70 pixels
    ax.set_xlim([0, plot_width])
    ax.set_ylim([0, plot_heigth])
    ax.axis("off")

    colors_plt_codes = {
        "red": "r",
        "green": "g",
        "purple": "purple",
        "yellow": "yellow",
    }

    colors_plt_code = colors_plt_codes[color]

    y_spacing = plot_heigth / (number + 1) + 0.05 * plot_heigth

    biases = [
        plot_heigth * 0.05,
        plot_heigth * 0.075,
        plot_heigth * 0.1,
        plot_heigth * 0.125,
    ]
    for i in range(number):
        x = plot_width * 0.5

        y = (i + 1) * y_spacing - biases[number - 1]

        if shape == "diamond":
            shape_object = Polygon(
                [
                    [x - 0.5 * x, y],
                    [x, y + plot_heigth * 0.1],
                    [x + 0.5 * x, y],
                    [x, y - plot_heigth * 0.1],
                ]
            )
        elif shape == "oval":
            shape_object = Ellipse(
                (x, y), width=plot_width * 0.75, height=plot_heigth * 0.2
            )
        elif shape == "bar":
            shape_object = Rectangle(
                (x - 0.75 * x, y - plot_heigth * 0.1),
                width=plot_width * 0.75,
                height=plot_heigth * 0.2,
            )
        elif shape == "tie":
            shape_object = Polygon(
                [
                    [x - 0.5 * x, y - plot_heigth * 0.1],
                    [x - 0.5 * x, y + plot_heigth * 0.1],
                    [x + 0.5 * x, y - plot_heigth * 0.1],
                    [x + 0.5 * x, y + plot_heigth * 0.1],
                ]
            )

        # Set shading
        if pattern == "solid":
            shape_object.set_facecolor(colors_plt_code)
        elif pattern == "striped":
            shape_object.set_facecolor("none")
            shape_object.set_edgecolor(colors_plt_code)
            shape_object.set_hatch("/")
        elif pattern == "empty":
            shape_object.set_facecolor("none")
            shape_object.set_edgecolor(colors_plt_code)
        elif pattern == "plus":
            shape_object.set_facecolor("none")
            shape_object.set_edgecolor(colors_plt_code)
            shape_object.set_hatch("+")

        ax.add_patch(shape_object)

    # Draw card border
    rect = plt.Rectangle(
        (0, 0.05),
        plot_width - 0.05,
        plot_heigth - 0.05,
        linewidth=2,
        edgecolor="black",
        facecolor="none",
    )
    ax.add_patch(rect)

    if data_dir:
        # Save plt fig as png
        file_path = os.path.join(
            data_dir,
            f"setcard_{number_index}{color_index}{pattern_index}{shape_index}.png",
        )
        plt.savefig(
            file_path,
            dpi=dpi,
            bbox_inches="tight",
            pad_inches=0.0,
        )
    else:
        plt.show()
    plt.close(fig)

def generate_all_custom_cards(data_dir="data/custom_cards"):
    all_cards = product(range(4), repeat=4)
    for card in all_cards:
        number_index, color_index, pattern_index, shape_index = card
        generate_custom_card(
            number_index, color_index, pattern_index, shape_index, data_dir
        )

generate_all_custom_cards()

In [32]:
import os
import torch
from torch.utils.data import Dataset
from matplotlib import image as mpimg
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse, Rectangle, Polygon
from itertools import product



In [131]:
data_dir = "data/custom_cards"

In [129]:
os.path.join(data_dir, "setcard")

'data/setcard'

In [132]:
card = [0,0,0,0]
number_index, color_index, pattern_index, shape_index = card
file_path = file_path = os.path.join(
            data_dir,
            f"setcard_{number_index}{color_index}{pattern_index}{shape_index}.png",
        )

im = mpimg.imread(file_path)
cards = torch.from_numpy(im).permute(2, 1, 0)

In [None]:
pickle_path = "probe_accuracies_train01_probe01.pkl"

try:
    # Load the dictionary from the file
    with open(pickle_path, "rb") as f:
        probe_accuracies = pickle.load(f)
except:
    probe_accuracies = {}
    probe_accuracies["cnn"] = probe_contextual_hidden_states(dm)

    abstractor = seq_model.contextual_embedder
    seq_model.eval()
    with torch.no_grad():
        S = abstractor.initial_symbol_sequence[:, :seq_len]
        dm.triples_dataset.triples_hidden_states = S
    dm.triples_dataset.set_get_cards(True)

    for i, layer in enumerate(abstractor.layers):
        precompute_Rel_attention_hidden_states(dm, layer)
        probe_accuracies[f"layer {i}: attention"] = probe_contextual_hidden_states(dm)
        
        precompute_ff_hidden_states(dm, layer)
        probe_accuracies[f"layer {i}: fnn"] = probe_contextual_hidden_states(dm)


    # Save the dictionary using pickle
    with open("probe_accuracies.pkl", "wb") as f:
        pickle.dump(probe_accuracies, f)

else:
    # Load the dictionary from the file
    with open(pickle_path, "rb") as f:
        probe_accuracies = pickle.load(f)

# Print the loaded dictionary
print(probe_accuracies)

In [None]:
# Extract tasks and layers
tasks = list(probe_accuracies['cnn'].keys())
layers = list(probe_accuracies.keys())

# Create subplots for each task
fig, axes = plt.subplots(len(tasks), 1, figsize=(5, 2 * len(tasks)), sharex=True)

for i, task in enumerate(tasks):
    accuracies = [probe_accuracies[layer][task]['test_acc'] for layer in layers]
    axes[i].bar(layers, accuracies)
    axes[i].set_ylabel('Accuracy')
    axes[i].set_title(f'Task: {task}')
    axes[i].set_ylim([0, 1.1])  # Set y-axis limits for better visualization

plt.xlabel('Layer')
plt.tight_layout()
plt.show()