In [None]:
import sys
sys.path.append("..")

from pathlib import Path
import math

from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
from torchvision.utils import make_grid

from src.datasets.ca import *
from src.models.rbm import RBM
from src.models.cnn import ConvAutoEncoder

In [None]:
ds = TotalCADataset(
    (32, 32), 
    seed=23,
    #num_repetitions=8,
    #num_iterations=[1, 20],
    #init_prob=[0, 1],
    #rules=["3-23", "124-45"],
    dtype=torch.float,
)
len(ds)

In [None]:
def plot_images(iterable, num: int = 8*8, nrow: int = 8):
    images = []
    for image in iterable:
        if isinstance(image, (list, tuple)):
            image = image[0]
        if image.ndim == 2:
            image = image.unsqueeze(0)
        images.append(image)
        if len(images) >= num:
            break
    return VF.to_pil_image(make_grid(images, nrow=nrow))

dl = DataLoader(ds)
plot_images(dl)

In [None]:
MODEL_FILE = "../checkpoints/ae-ca-32x32x32-fft/snapshot.pt"
FEATURES_FILE = "../checkpoints/ae-ca-32x32x32-fft/ca-features.pt"
IMAGES_FILE = "../datasets/ca-32x32.pt"
RULES_FILE = "../datasets/ca-32x32-rules.pt"

In [None]:
#model = RBM(math.prod(ds.shape), 32)
model = ConvAutoEncoder((1, 32, 32), channels=[32, 64], code_size=32)
model.load_state_dict(torch.load(MODEL_FILE)["state_dict"])

In [None]:
# store ca -> features

if Path(FEATURES_FILE).exists() and Path(IMAGES_FILE).exists() and Path(RULES_FILE).exists():
    features = torch.load(FEATURES_FILE)
    images = torch.load(IMAGES_FILE)
    
else:
    torch.multiprocessing.set_sharing_strategy('file_system')
    dl = DataLoader(ds, batch_size=100, num_workers=3)

    features = []
    image_array = []
    rules_array = []
    
    with torch.no_grad():
        for batch in tqdm(dl):
            images, rules = batch
            images = images.reshape(-1, 1, 32, 32)
            features.append(model.encode(images))
            image_array.append(images)
            rules_array.append(rules)
            #if len(features) > 5:
            #    break
        features = torch.cat(features)
        torch.save(features, FEATURES_FILE)
        images = torch.cat(image_array)
        torch.save(images, IMAGES_FILE)
        rules = torch.cat(rules_array)
        torch.save(rules, RULES_FILE)

features.shape

In [None]:
from sklearn.decomposition import PCA
from sklearn import cluster
from IPython.display import display, HTML
import plotly.express as px

In [None]:
#clusterer = cluster.BisectingKMeans(32, verbose=0, n_init="auto")
clusterer = cluster.BisectingKMeans(32, verbose=0, init="k-means++")
#clusterer.fit(images.reshape(images.shape[0], -1))
clusterer.fit(features)

#labels = clusterer.predict(images.reshape(images.shape[0], -1))
labels = clusterer.predict(features)

cluster_sizes = [
    (labels == l).sum()
    for l in range(clusterer.n_clusters)
]
px.bar(y=cluster_sizes, title="cluster sizes")

In [None]:

for label in range(clusterer.n_clusters):
    cluster_images = images[labels == label]
    display(HTML(f"<h3>label #{label}: {cluster_images.shape[0]}</h3>"))
    display(plot_images(cluster_images, nrow=21, num=21*3))

    cluster_rules = rules[labels == label].mean(axis=0).reshape(1, 2, 9)
    cluster_rules = VF.resize(cluster_rules, (40, 180), interpolation=VF.InterpolationMode.NEAREST)
    display(VF.to_pil_image(cluster_rules))
    #print(cluster_images.shape)