In [None]:
# import wandb
# run = wandb.init()
# artifact = run.use_artifact('karanravindra/mnist-autoencoder/model-o31mbhqq:v0', type='model')
# artifact_dir = artifact.download()
# run.finish()

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torchinfo import summary
from nn_zoo.models.components import DepthwiseSeparableConv2d, VectorQuantizer
from torchmetrics.functional.image import structural_similarity_index_measure as ssim_func
import lpips

import warnings

warnings.filterwarnings("ignore")


class Block(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, num_layers: int):
        super(Block, self).__init__()
        self.layers = nn.ModuleList(
            [
                self._block(in_channels, out_channels)
                if i == 0
                else self._block(out_channels, out_channels)
                for i in range(num_layers)
            ]
        )

    def _block(self, in_channels: int, out_channels: int):
        return nn.Sequential(
            nn.GroupNorm(1, in_channels),
            nn.GELU(),
            DepthwiseSeparableConv2d(in_channels, out_channels, 3),
        )

    def forward(self, x):
        x = self.layers[0](x)
        for i, layer in enumerate(self.layers[1:]):
            x = layer(x) + x
        return x


class DownBlock(nn.Sequential):
    def __init__(self, in_channels: int, out_channels: int, depth: int):
        super(DownBlock, self).__init__(
            Block(in_channels * 4, out_channels, depth),
            # nn.MaxPool2d(2)
            nn.PixelUnshuffle(2),
        )


class UpBlock(nn.Sequential):
    def __init__(self, in_channels: int, out_channels: int, depth: int):
        super(UpBlock, self).__init__()
        self.block = nn.Sequential(
            nn.PixelShuffle(2),
            # nn.Upsample(scale_factor=2, mode="nearest"),
            Block(in_channels, out_channels * 4, depth),
        )

    def forward(self, x):
        return self.block(x)


class AutoEncoder(nn.Module):
    def __init__(self, width: int, depth: int):
        super(AutoEncoder, self).__init__()
        self.encoder = nn.Sequential(
            Block(1, width * 4, depth),
            DownBlock(width, width, depth),
            DownBlock(width, width, depth),
            DownBlock(width, width, depth),
            DepthwiseSeparableConv2d(width * 4, width * 2, 1, padding=0),
        )
        self.proj_in = nn.Identity()
        self.vq = nn.Identity()
        # VectorQuantizer(width, 8, use_ema=True, decay=0.99, epsilon=1e-5)
        self.proj_out = nn.Identity()  # nn.Conv2d(width, width, 1)
        self.decoder = nn.Sequential(
            DepthwiseSeparableConv2d(width * 2, width * 4, 1, padding=0),
            UpBlock(width, width, depth),
            UpBlock(width, width, depth),
            UpBlock(width, width, depth),
            Block(width * 4, 1, depth),
            nn.Tanh(),
        )

        self.register_module(
            "lpips", lpips.LPIPS(net="squeeze", verbose=False, lpips=False)
        )

    def encode(self, x):
        x = self.encoder(x)
        x = self.proj_in(x)
        return self.vq(x)  # quant_x, dict_loss, commit_loss, indices = self.vq(x)

    def decode(self, x):
        x = self.proj_out(x)
        x = self.decoder(x)
        return x

    def forward(self, x):
        x = self.encode(x)
        x = self.decode(x)
        return x

    # @classmethod
    def loss(self, x, y):
        mse = F.mse_loss(x, y)
        bce = F.binary_cross_entropy((x + 1) / 2, (y + 1) / 2)
        psnr = 10 * (1 / mse).log10()
        ssim = ssim_func(x, y)
        lpips = self.lpips(x.repeat(1, 3, 1, 1), y.repeat(1, 3, 1, 1)).mean()

        return {
            "loss": bce + lpips,
            "bce": bce,
            "mse": mse,
            "ssim": ssim,
            "psnr": psnr,
            "lpips": lpips,
        }


In [None]:
model = AutoEncoder(4, 2)
state_dict = torch.load("artifacts/model-o31mbhqq:v0/model.ckpt", map_location=torch.device('mps'))['state_dict']
state_dict = {k.replace('model.', ''): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)

In [None]:
import torchvision
from nn_zoo.datamodules import MNISTDataModule

dm = MNISTDataModule(
        data_dir="../../../data",
        dataset_params={
            "download": True,
            "transform": torchvision.transforms.Compose(
                [
                    torchvision.transforms.Resize((32, 32)),
                    torchvision.transforms.ToTensor(),
                    torchvision.transforms.Lambda(lambda x: x * 2 - 1),
                ]
            )
        },
        loader_params={
            "batch_size": 128,
        },
    )

dm.setup()

train_loader = dm.train_dataloader()
val_loader = dm.val_dataloader()

In [None]:
class RandomMask:
    def __init__(self, patch_size, mask_prob=0.5, mask_value=-1):
        self.patch_size = (patch_size, patch_size)
        self.mask_prob = mask_prob
        self.mask_value = mask_value
        

    def __call__(self, x):
        start_pos = torch.randint(0+2, x.shape[-1]-2, (2,), device=x.device)
        num_to_mask = int(self.mask_prob * self.patch_size[0] * self.patch_size[1])
        mask = torch.ones_like(x)
        mask[
            :,
            :,
            start_pos[0] : start_pos[0] + self.patch_size[0],
            start_pos[1] : start_pos[1] + self.patch_size[1],
        ] = self.mask_value
        

        return x * mask

        # x[
        #     :,
        #     :,
            
        #     start_pos[0] : start_pos[0] + self.patch_size[0],
        #     start_pos[1] : start_pos[1] + self.patch_size[1],
        # ] = self.mask_value

        # return x
    

In [None]:
import matplotlib.pyplot as plt

x, y = next(iter(val_loader))

plt.subplot(1, 2, 1)
plt.imshow(torchvision.utils.make_grid((x + 1) / 2, nrow=8).permute(1, 2, 0).cpu().numpy())
plt.axis('off')

masker = RandomMask(0)
x_masked = masker(x)
plt.subplot(1, 2, 2)
plt.imshow(torchvision.utils.make_grid((x_masked + 1) / 2, nrow=8).permute(1, 2, 0).cpu().numpy())
plt.axis('off')

plt.show()

In [None]:
model = model.to("mps")

In [None]:
from tqdm import tqdm

@torch.no_grad()
def evaluate(model, loader, masker: int = 0):
    model.eval()
    metrics = []
    masker = RandomMask(patch_size=masker)
    for x, y in tqdm(loader, desc="Evaluating"):
        x = x.to("mps")
        x_masked = masker(x.clone())
        y_hat = model(x_masked)
        metrics.append(model.loss(y_hat, x))

    return {k: sum(m[k] for m in metrics) / len(metrics) for k in metrics[0]}

evaluate(model, val_loader)

In [None]:
evaluate(model, val_loader)

In [None]:
from ema_pytorch import EMA

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
ema = EMA(model, beta=0.9999, update_after_step=100, update_every=10)

In [None]:
# train
for epoch in range(10):
    masker = RandomMask(0)
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    for x, y in pbar:
        x = x.to("mps")
        x_masked = masker(x.clone())
        optimizer.zero_grad()
        y_hat = model(x_masked)
        loss = model.loss(y_hat, x)
        loss['loss'].backward()
        optimizer.step()

        ema.update()
        pbar.set_postfix_str("".join([f"{k}: {v.item():.2f} " for k, v in loss.items()]))


    metrics = evaluate(model, val_loader)
    print("".join([f"{k}: {v.item():.4f} " for k, v in metrics.items()]))