In [None]:
import wandb

In [None]:
run = wandb.init()
artifact = run.use_artifact('karanravindra/mnist-autoencoder/model-004p91an:v0', type='model')
artifact_dir = artifact.download()

In [None]:
from einops.layers.torch import Rearrange
from torch import nn
from torch.nn import functional as F
from torchinfo import summary
from nn_zoo.models.components import DepthwiseSeparableConv2d, VectorQuantizer


class Block(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, num_layers: int):
        super(Block, self).__init__()
        self.layers = nn.ModuleList(
            [
                self._block(in_channels, out_channels)
                if i == 0
                else self._block(out_channels, out_channels)
                for i in range(num_layers)
            ]
        )

    def _block(self, in_channels: int, out_channels: int):
        return nn.Sequential(
            nn.GroupNorm(in_channels // 4 if in_channels >= 4 else 1, in_channels),
            DepthwiseSeparableConv2d(in_channels, out_channels, 3),
            nn.GELU(),
        )

    def forward(self, x):
        x = self.layers[0](x)
        for layer in self.layers[1:]:
            x = x + layer(x)
        return x


class AutoEncoder(nn.Module):
    def __init__(self, width: int, depth: int, num_embeddings: int):
        super(AutoEncoder, self).__init__()
        self.encoder = nn.Sequential(
            Block(1, width, depth),
            nn.MaxPool2d(2),
            Block(width, width * 2, depth),
            nn.MaxPool2d(2),
            Block(width * 2, width * 4, depth),
            nn.MaxPool2d(2),
            Block(width * 4, width * 4, depth),
            DepthwiseSeparableConv2d(width * 4, width * 4, 3)
        )
        self.vq = nn.Identity()
        self.decoder = nn.Sequential(
            DepthwiseSeparableConv2d(width * 4, width * 4, 3),
            nn.GELU(),
            Block(width * 4, width * 4, depth),
            nn.Upsample(scale_factor=2),
            Block(width * 4, width * 2, depth),
            nn.Upsample(scale_factor=2),
            Block(width * 2, width, depth),
            nn.Upsample(scale_factor=2),
            DepthwiseSeparableConv2d(width, width, 3),
            nn.GELU(),
            DepthwiseSeparableConv2d(width, 1, 3),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.encoder(x)
        x, dict_loss, commit_loss, indices = self.vq(x)
        x = self.decoder(x)
        return x, dict_loss, commit_loss, indices

    @classmethod
    def loss(cls, x, y):
        return F.binary_cross_entropy(x, y)

In [None]:
import torch

model = AutoEncoder(width=4, depth=1, num_embeddings=512)
state_dict=torch.load('artifacts/model-004p91an:v0/model.ckpt')['state_dict']

# Load the state_dict into the model and remove the prefix
model.load_state_dict({k.replace('model.', ''): v for k, v in state_dict.items()})
model.eval()
model = model.to("mps")

In [None]:
model.vq = VectorQuantizer(4 * 4, 512, use_ema=True, decay=0.99, epsilon=1e-5).to("mps")

In [None]:
from nn_zoo.datamodules import MNISTDataModule
import torchvision

dm = MNISTDataModule(
        data_dir="../../data",
        dataset_params={
            "download": True,
            "transform": torchvision.transforms.Compose(
                [
                    torchvision.transforms.Resize((32, 32)),
                    torchvision.transforms.ToTensor(),
                ]
            ),
        },
        loader_params={
            "batch_size": 128,
            "num_workers": 4,
            "persistent_workers": True,
            "pin_memory": True,
            "prefetch_factor": 2,
        },
    )

dm.setup()

train_loader = dm.train_dataloader()
val_loader = dm.val_dataloader()

In [None]:
import matplotlib.pyplot as plt
import plotly.express as px
from sklearn.decomposition import PCA

latents = []

with torch.no_grad():
    model.eval()
    for batch in val_loader:
        x, y = batch
        x = x.to("mps")

        x = model.encoder(x)
        latents.append(x)

    latents = torch.cat(latents, dim=0).cpu()

In [None]:
pca = PCA(n_components=3)
pca.fit(latents.permute(0, 2, 3, 1).reshape(-1, latents.shape[1]))

latents_pca = pca.transform(latents.permute(0, 2, 3, 1).reshape(-1, latents.shape[1]))

# clusters
from sklearn.cluster import KMeans
clusters = KMeans(n_clusters=512)
clusters.fit(latents.permute(0, 2, 3, 1).reshape(-1, latents.shape[1]))

clustered = clusters.predict(latents.permute(0, 2, 3, 1).reshape(-1, latents.shape[1]))

print(latents_pca.shape, clusters.cluster_centers_.shape)

fig = px.scatter_3d(
        x=latents_pca[:, 0],
        y=latents_pca[:, 1],
        z=latents_pca[:, 2],
        title="PCA of latents",
        height=800,
        width=800,
        opacity=0.1,
        # color=clustered,
        color=val_loader.dataset.targets.unsqueeze(1).repeat(1, 16).flatten(),
        labels={"color": "Digit"},
    )
fig.show()

In [None]:
model.vq.e_i_ts = torch.as_tensor(clusters.cluster_centers_.astype("float32"), device="mps").permute(1, 0)

In [None]:
from collections import Counter

with torch.no_grad():
    counter = Counter()
    model.eval()
    for batch in val_loader:
        x, y = batch
        x = x.to("mps")

        x = model.encoder(x)
        x, dict_loss, commit_loss, indices = model.vq(x)

        counter.update(indices.flatten().tolist())

print(counter)

In [None]:
optimizer = torch.optim.Adam([
        {"params": model.encoder.parameters()},
        {"params": model.decoder.parameters()},
        {"params": model.vq.parameters(), 'lr': 0},
    ], lr=4e-4)


In [None]:
# Training loop
from tqdm import tqdm

for epoch in range(10):
    model.train()
    pbar = tqdm(train_loader)
    for batch in pbar:
        x, y = batch
        x = x.to("mps")

        optimizer.zero_grad()

        # x_hat, _, _, indices = model(x)
        x_hat = model.encoder(x)
        x_hat = model.decoder(x_hat)
        
        loss = AutoEncoder.loss(x_hat, x)

        loss.backward()
        optimizer.step()

        pbar.set_postfix_str(f"loss: {loss.item():.4f}")

    model.eval()
    val_loss = 0
    for batch in val_loader:
        x, y = batch
        x = x.to("mps")

        # x_hat, _, _, indices = model(x)
        x_hat = model.encoder(x)
        x_hat = model.decoder(x_hat)
        val_loss += AutoEncoder.loss(x_hat, x).item()

    val_loss /= len(val_loader)
    print(f"Epoch: {epoch}, Val Loss: {val_loss}")

In [None]:
# plot reconstruction
x, y = next(iter(val_loader))
x = x.to("mps")
x_hat = model.encoder(x)
x_hat = model.decoder(x_hat)

real_images = torchvision.utils.make_grid(x.cpu(), nrow=8)
reconstructed_images = torchvision.utils.make_grid(x_hat.cpu(), nrow=8)
negative_reconstructed_images = torchvision.utils.make_grid(x - x_hat, nrow=8, normalize=True).cpu()

fig, ax = plt.subplots(1, 3, figsize=(20, 10))
ax[0].imshow(real_images.permute(1, 2, 0))
ax[0].set_title("Real Images")
ax[0].axis("off")

ax[1].imshow(reconstructed_images.permute(1, 2, 0))
ax[1].set_title("Reconstructed Images")
ax[1].axis("off")

ax[2].imshow(negative_reconstructed_images.permute(1, 2, 0))
ax[2].set_title("Negative Reconstructed Images")
ax[2].axis("off")

plt.show()