In [None]:
# imports
import ibis
import torch
import lightning as L
import ibis.selectors as s
import plotly.express as px

from torch import nn
from torch.utils.data import TensorDataset, DataLoader

## local imports
from ihateai.data import read_training, transform
from ihateai.grid import InputOutputPair
from ihateai.utils import show_task_pairs, random_task_num

# configuration
px.defaults.template = "plotly_dark"

ibis.options.interactive = True
ibis.options.repr.interactive.max_rows = 3
ibis.options.repr.interactive.max_length = 3
ibis.options.repr.interactive.max_depth = 3
ibis.options.repr.interactive.max_columns = None

con = ibis.get_backend()

In [None]:
t = read_training()
t

In [None]:
train = transform(t)
train

In [None]:
test = transform(t, test=True)
test

In [None]:
MIN_WIDTH, MIN_HEIGHT, MAX_WIDTH, MAX_HEIGHT = 1, 1, 30, 30
COLORS = range(0, 10)
HEIGHTS = range(MIN_HEIGHT, MAX_HEIGHT + 1)

In [None]:
def ohe(t):
    f = t.mutate(
        **{
            f"ohe_input_colors_{i}": t["input_colors"].contains(i).cast("int8")
            for i in COLORS
        },
        **{
            f"ohe_output_colors_{i}": t["output_colors"].contains(i).cast("int8")
            for i in COLORS
        },
        **{
            f"ohe_input_height_{i}": (t["input_height"] == i).cast("int8")
            for i in HEIGHTS
        },
        **{
            f"ohe_output_height_{i}": (t["output_height"] == i).cast("int8")
            for i in HEIGHTS
        },
        **{
            f"ohe_input_width_{i}": (t["input_width"] == i).cast("int8")
            for i in HEIGHTS
        },
        **{
            f"ohe_output_width_{i}": (t["output_width"] == i).cast("int8")
            for i in HEIGHTS
        },
    )
    return f

In [None]:
Train = ohe(train).select(
    s.contains("ohe"),
)
Test = ohe(test).select(
    s.contains("ohe"),
)

In [None]:
Train

In [None]:
Test

In [None]:
import plotly.express as px
from sklearn.decomposition import PCA

# kmeans for clustering for colorsf
from sklearn.cluster import KMeans

X = Train.select(s.contains("height"), s.contains("width"))

n_components = 3
n_clusters = 10
pca = PCA(n_components=n_components).fit(X)

t_pca = ibis.memtable(pca.transform(X)).rename(
    {"pc1": "col0", "pc2": "col1", "pc3": "col2"}
)

kmeans = KMeans(n_clusters=n_clusters).fit(t_pca)
labels = ibis.memtable(kmeans.labels_).rename({"cluster": "col0"})
t_pca = (
    t_pca.mutate(row_number=ibis.row_number())
    .join(labels.mutate(row_number=ibis.row_number()), "row_number")
    .drop("row_number")
    .relocate("cluster")
)

c = px.scatter_3d(
    t_pca,
    x="pc1",
    y="pc2",
    z="pc3",
    color="cluster",
)
c.show(renderer="browser")

In [None]:
def tensors(T):
    input_colors_t = torch.stack(
        [tensor for key, tensor in T.items() if key.startswith("ohe_input_colors")],
        dim=0,
    ).transpose(0, 1)
    input_height_t = torch.stack(
        [tensor for key, tensor in T.items() if key.startswith("ohe_input_height")],
        dim=0,
    ).transpose(0, 1)
    input_width_t = torch.stack(
        [tensor for key, tensor in T.items() if key.startswith("ohe_input_width")],
        dim=0,
    ).transpose(0, 1)
    output_colors_t = torch.stack(
        [tensor for key, tensor in T.items() if key.startswith("ohe_output_colors")],
        dim=0,
    ).transpose(0, 1)
    output_height_t = torch.stack(
        [tensor for key, tensor in T.items() if key.startswith("ohe_output_height")],
        dim=0,
    ).transpose(0, 1)
    output_width_t = torch.stack(
        [tensor for key, tensor in T.items() if key.startswith("ohe_output_width")],
        dim=0,
    ).transpose(0, 1)

    return (
        input_colors_t,
        input_height_t,
        input_width_t,
        output_colors_t,
        output_height_t,
        output_width_t,
    )

In [None]:
(
    train_input_colors_t,
    train_input_height_t,
    train_input_width_t,
    train_output_colors_t,
    train_output_height_t,
    train_output_width_t,
) = tensors(Train.to_torch())
(
    test_input_colors_t,
    test_input_height_t,
    test_input_width_t,
    test_output_colors_t,
    test_output_height_t,
    test_output_width_t,
) = tensors(Test.to_torch())

In [None]:
X_train = train_input_colors_t
y_train = train_output_colors_t
X_test = test_input_colors_t
y_test = test_output_colors_t

In [None]:
class MLP(L.LightningModule):
    def __init__(self, lr=1e-3):
        super().__init__()

        f1 = 256
        f2 = 64
        f3 = 16

        encoder = nn.Sequential(
            nn.Linear(X_train.shape[1], f1), nn.ReLU(), nn.Linear(f1, f2)
        )
        hidden = nn.Sequential(nn.Linear(f2, f3), nn.ReLU(), nn.Linear(f3, f2))
        decoder = nn.Sequential(
            nn.Linear(f2, f1), nn.ReLU(), nn.Linear(f1, y_train.shape[1])
        )

        self.encoder = encoder
        self.hidden = hidden
        self.decoder = decoder

        self.lr = lr

    def forward(self, x):
        z = self.encoder(x)
        z = self.hidden(z)
        x_hat = self.decoder(z)

        return x_hat

    def training_step(self, batch, batch_idx):
        x, y = batch

        x_hat = self.forward(x)
        loss = nn.functional.mse_loss(x_hat, y)
        self.log("train_loss", loss)
        return loss

    def predict(self, x):
        x_hat = self.forward(x)
        x_hat = torch.round(x_hat)
        return x_hat

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

In [None]:
mlp = MLP()

In [None]:
train_dataset = TensorDataset(
    torch.tensor(X_train, dtype=torch.float32),
    torch.tensor(y_train, dtype=torch.float32),
)
test_dataset = TensorDataset(
    torch.tensor(X_test, dtype=torch.float32), torch.tensor(y_test, dtype=torch.float32)
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
trainer = L.Trainer(accelerator="mps", max_epochs=100, log_every_n_steps=10)
trainer.fit(mlp, train_loader)

In [None]:
model = MLP.load_from_checkpoint(
    "lightning_logs/version_8/checkpoints/epoch=99-step=4100.ckpt"
)
model.to("mps")
model

In [None]:
with torch.no_grad():
    y_hat = model.predict(torch.tensor(X_test, dtype=torch.float32).to("mps"))

In [None]:
y_hat

In [None]:
# compute mse
mse = nn.functional.mse_loss(y_hat.cpu(), torch.tensor(y_test, dtype=torch.float32))
mse