In [None]:
# imports
import os
import glob
import ibis
import torch
import lightning as L
import ibis.selectors as s
import plotly.express as px

from torch import nn
from torch.utils.data import TensorDataset, DataLoader

## local imports
from ihateai.data import read_training, transform
from ihateai.grid import InputOutputPair
from ihateai.utils import show_task_pairs, random_task_num

# configuration
px.defaults.template = "plotly_dark"

ibis.options.interactive = True
ibis.options.repr.interactive.max_rows = 3
ibis.options.repr.interactive.max_length = 3
ibis.options.repr.interactive.max_depth = 3
ibis.options.repr.interactive.max_columns = None

con = ibis.get_backend()

In [None]:
t = read_training()
t

In [None]:
test = transform(t, test=True)
train = transform(t)
train

In [None]:
def agg_t(t):
    agg = (
        t.join(
            t.group_by("task_num").agg(
                all_input_colors=ibis._["input_colors"]
                .collect()
                .flatten()
                .unique()
                .sort(),
                all_output_colors=ibis._["output_colors"]
                .collect()
                .flatten()
                .unique()
                .sort(),
            ),
            "task_num",
        )
        .group_by("task_num")
        .agg(
            f_input_colors_match_output_colors=(
                ibis._["input_colors"] == ibis._["output_colors"]
            ).all(),
            f_input_height_match_output_height=(
                ibis._["input_height"] == ibis._["output_height"]
            ).all(),
            f_input_width_match_output_width=(
                ibis._["input_width"] == ibis._["output_width"]
            ).all(),
            f_input_colors_all_same=(
                ibis._["all_input_colors"] == ibis._["input_colors"]
            ).all(),
            f_input_height_all_same=(ibis._["input_height"].nunique() == 1),
            f_input_width_all_same=(ibis._["input_width"].nunique() == 1),
            f_output_colors_all_same=(
                ibis._["all_output_colors"] == ibis._["output_colors"]
            ).all(),
            f_output_height_all_same=(ibis._["output_height"].nunique() == 1),
            f_output_width_all_same=(ibis._["output_width"].nunique() == 1),
            f_input_colors_subset_of_output_colors=(
                (
                    ibis._["input_colors"].intersect(ibis._["all_output_colors"])
                    == ibis._["input_colors"]
                ).all()
            ),
            f_output_colors_subset_of_input_colors=(
                (
                    ibis._["output_colors"].intersect(ibis._["all_input_colors"])
                    == ibis._["output_colors"]
                ).all()
            ),
            input_output_colors_intersect=(
                ibis._["output_colors"]
                .intersect(ibis._["input_colors"])
                .collect()
                .flatten()
                .unique()
                .sort()
            ),
            input_all_input_colors_intersect=(
                ibis._["all_input_colors"]
                .intersect(ibis._["input_colors"])
                .collect()
                .flatten()
                .unique()
                .sort()
            ),
            output_all_output_colors_intersect=(
                ibis._["all_output_colors"]
                .intersect(ibis._["output_colors"])
                .collect()
                .flatten()
                .unique()
                .sort()
            ),
        )
        .mutate(
            f_input_grid_matches_output_grid=(
                ibis._["f_input_height_match_output_height"]
                & ibis._["f_input_width_match_output_width"]
            ),
            f_output_grid_all_same=(
                ibis._["f_output_height_all_same"] & ibis._["f_output_width_all_same"]
            ),
        )
        .mutate(s.across(s.of_type(bool), ibis._.cast("int8")))
        .order_by("task_num")
    )

    return agg


def decision_tree(t):
    agg = agg_t(t)
    return t.join(agg, "task_num")

In [None]:
test = decision_tree(test)
train = decision_tree(train)
train

In [None]:
train.count()

In [None]:
test.count()

In [None]:
MIN_WIDTH, MIN_HEIGHT, MAX_WIDTH, MAX_HEIGHT = 1, 1, 30, 30
COLORS = range(0, 10)
WIDTHS = range(MIN_WIDTH, MAX_WIDTH + 1)
HEIGHTS = range(MIN_HEIGHT, MAX_HEIGHT + 1)

In [None]:
def enc(t):
    f = t.mutate(
        **{
            f"enc_input_colors_{i}": t["input_colors"].contains(i).cast("int8")
            for i in COLORS
        },
        **{
            f"enc_output_colors_{i}": t["output_colors"].contains(i).cast("int8")
            for i in COLORS
        },
        **{
            f"enc_input_height_{i}": (t["input_height"] == i).cast("int8")
            for i in HEIGHTS
        },
        **{
            f"enc_output_height_{i}": (t["output_height"] == i).cast("int8")
            for i in HEIGHTS
        },
        **{
            f"enc_input_width_{i}": (t["input_width"] == i).cast("int8") for i in WIDTHS
        },
        **{
            f"enc_output_width_{i}": (t["output_width"] == i).cast("int8")
            for i in WIDTHS
        },
        **{
            f"enc_input_output_colors_intersect_{i}": (
                t["input_output_colors_intersect"].contains(i)
            ).cast("int8")
            for i in COLORS
        },
        **{
            f"enc_input_all_input_colors_intersect_{i}": (
                t["input_all_input_colors_intersect"].contains(i)
            ).cast("int8")
            for i in COLORS
        },
        **{
            f"enc_output_all_output_colors_intersect_{i}": (
                t["output_all_output_colors_intersect"].contains(i)
            ).cast("int8")
            for i in COLORS
        },
    )

    return f

In [None]:
Train = enc(train).select(
    "task_num",
    s.contains("enc"),
    s.startswith("f_"),
)
# TODO: this is a major fuck-up, I think
# Test = enc(test).select(
#     s.startswith("f_"),
#     s.contains("enc"),
# )
Test = (
    (
        enc(test)
        .drop((s.contains("intersect")) & (s.startswith("enc_")))
        .select("task_num", s.startswith("enc"))
        .join(
            Train.select(
                "task_num",
                (s.contains("intersect") & (s.startswith("enc_"))),
                s.startswith("f_"),
            ),
            "task_num",
        )
    )
    .distinct()
    .drop("task_num")
)
Train = Train.drop("task_num")

In [None]:
Train

In [None]:
Train.count()

In [None]:
Test

In [None]:
Test.count()

In [None]:
# X = Train.select(s.contains("height"), s.contains("width"))
X = Train.select(s.contains("color"))
len(X.columns)

In [None]:
import plotly.express as px
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

n_components = 3
n_clusters = 10
pca = PCA(n_components=n_components).fit(X)

t_pca = ibis.memtable(pca.transform(X)).rename(
    {"pc1": "col0", "pc2": "col1", "pc3": "col2"}
)

kmeans = KMeans(n_clusters=n_clusters).fit(t_pca)
labels = ibis.memtable(kmeans.labels_).rename({"cluster": "col0"})
t_pca = (
    t_pca.mutate(row_number=ibis.row_number())
    .join(labels.mutate(row_number=ibis.row_number()), "row_number")
    .drop("row_number")
    .relocate("cluster")
)

c = px.scatter_3d(
    t_pca,
    x="pc1",
    y="pc2",
    z="pc3",
    color="cluster",
)
c.show(renderer="browser")

In [None]:
t_pca

In [None]:
def tensors(T):
    input_colors_t = torch.stack(
        [
            tensor
            for key, tensor in T.items()
            if "color" in key
            and (key.startswith("enc_") or key.startswith("f_"))
            and ~key.startswith("enc_output_colors")
        ],
        dim=0,
    ).transpose(0, 1)
    input_height_t = torch.stack(
        [
            tensor
            for key, tensor in T.items()
            if ("height" in key or "grid" in key)
            and (key.startswith("enc_") or key.startswith("f_"))
            and ~key.startswith("enc_output_height")
        ],
        dim=0,
    ).transpose(0, 1)
    input_width_t = torch.stack(
        [
            tensor
            for key, tensor in T.items()
            if ("width" in key or "grid" in key)
            and (key.startswith("enc_") or key.startswith("f_"))
            and ~key.startswith("enc_output_height")
        ],
        dim=0,
    ).transpose(0, 1)
    output_colors_t = torch.stack(
        [tensor for key, tensor in T.items() if key.startswith("enc_output_colors")],
        dim=0,
    ).transpose(0, 1)
    output_height_t = torch.stack(
        [tensor for key, tensor in T.items() if key.startswith("enc_output_height")],
        dim=0,
    ).transpose(0, 1)
    output_width_t = torch.stack(
        [tensor for key, tensor in T.items() if key.startswith("enc_output_width")],
        dim=0,
    ).transpose(0, 1)

    return (
        input_colors_t,
        input_height_t,
        input_width_t,
        output_colors_t,
        output_height_t,
        output_width_t,
    )

In [None]:
(
    train_input_colors_t,
    train_input_height_t,
    train_input_width_t,
    train_output_colors_t,
    train_output_height_t,
    train_output_width_t,
) = tensors(Train.to_torch())
(
    test_input_colors_t,
    test_input_height_t,
    test_input_width_t,
    test_output_colors_t,
    test_output_height_t,
    test_output_width_t,
) = tensors(Test.to_torch())

In [None]:
model_type = "colors"
X_train = train_input_colors_t
y_train = train_output_colors_t
X_test = test_input_colors_t
y_test = test_output_colors_t

# model_type = "height"
# X_train = train_input_height_t
# y_train = train_output_height_t
# X_test = test_input_height_t
# y_test = test_output_height_t

# model_type = "width"
# X_train = train_input_width_t
# y_train = train_output_width_t
# X_test = test_input_width_t
# y_test = test_output_width_t

In [None]:
X_train

In [None]:
X_train.shape

In [None]:
X_test.shape

In [None]:
y_train.shape

In [None]:
y_test.shape

In [None]:
class MLP(L.LightningModule):
    def __init__(self, lr=1e-3):
        super().__init__()

        f1 = 64
        f2 = 16
        f3 = 16

        encoder = nn.Sequential(
            nn.Linear(X_train.shape[1], f1), nn.ReLU(), nn.Linear(f1, f2)
        )
        hidden = nn.Sequential(nn.Linear(f2, f3), nn.ReLU(), nn.Linear(f3, f2))
        decoder = nn.Sequential(
            nn.Linear(f2, f1), nn.ReLU(), nn.Linear(f1, y_train.shape[1])
        )

        self.encoder = encoder
        self.hidden = hidden
        self.decoder = decoder

        self.lr = lr

    def forward(self, x):
        z = self.encoder(x)
        z = self.hidden(z)
        x_hat = self.decoder(z)

        return x_hat

    def training_step(self, batch, batch_idx):
        x, y = batch

        x_hat = self.forward(x)
        loss = nn.functional.mse_loss(x_hat, y)
        self.log("train_loss", loss)
        return loss

    def predict(self, x):
        x_hat = self.forward(x)
        x_hat = torch.round(x_hat)
        return x_hat

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

In [None]:
mlp = MLP()

In [None]:
train_dataset = TensorDataset(
    torch.tensor(X_train, dtype=torch.float32),
    torch.tensor(y_train, dtype=torch.float32),
)
test_dataset = TensorDataset(
    torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.float32)
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
trainer = L.Trainer(
    accelerator="mps",
    max_epochs=100,
    log_every_n_steps=10,
    logger=L.pytorch.loggers.TensorBoardLogger("lightning_logs", name=model_type),
)
trainer.fit(mlp, train_loader)

In [None]:
ckpts = glob.glob(f"lightning_logs/{model_type}/version_*/checkpoints/*.ckpt")
latest = sorted(ckpts, key=lambda x: int(x.split("version_")[-1].split("/")[0]))[-1]
print(f"using latest checkpoint: {latest}")
model = MLP.load_from_checkpoint(
    latest,
)
model.to("mps")
model

In [None]:
X_test.shape

In [None]:
with torch.no_grad():
    y_hat = model.predict(torch.tensor(X_test, dtype=torch.float32).to("mps"))

In [None]:
y_hat

In [None]:
# compute mse
mse = nn.functional.mse_loss(y_hat.cpu(), torch.tensor(y_test, dtype=torch.float32))
mse

In [None]:
# compute accuracy
acc = (y_hat.cpu() == torch.tensor(y_test, dtype=torch.float32)).float().mean()
acc