In [1]:
import warnings
from collections import Counter

import torch
from torch import nn
from torch.nn import functional as F
import torchvision
from torchinfo import summary
from torchmetrics.functional.image.ssim import structural_similarity_index_measure as ssim_func
from nn_zoo.datamodules import MNISTDataModule
from nn_zoo.models.components import DepthwiseSeparableConv2d, SelfAttention

import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
warnings.filterwarnings("ignore")


class Block(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, num_layers: int):
        super(Block, self).__init__()
        self.layers = nn.ModuleList(
            [
                self._block(in_channels, out_channels)
                if i == 0
                else self._block(out_channels, out_channels)
                for i in range(num_layers)
            ]
        )

    def _block(self, in_channels: int, out_channels: int):
        return nn.Sequential(
            nn.GroupNorm(in_channels // 4 if in_channels >= 4 else 1, in_channels),
            nn.GELU(),
            DepthwiseSeparableConv2d(in_channels, out_channels, 3),
        )

    def forward(self, x):
        x = self.layers[0](x)
        for i, layer in enumerate(self.layers[1:]):
            x = layer(x) + x
        return x


class DownBlock(nn.Sequential):
    def __init__(self, in_channels: int, out_channels: int, depth: int):
        super(DownBlock, self).__init__(
            Block(in_channels * 4, out_channels, depth),
            # nn.MaxPool2d(2)
            nn.PixelUnshuffle(2),
        )


class UpBlock(nn.Sequential):
    def __init__(self, in_channels: int, out_channels: int, depth: int):
        super(UpBlock, self).__init__()
        self.block = nn.Sequential(
            nn.PixelShuffle(2),
            # nn.Upsample(scale_factor=2, mode="nearest"),
            Block(in_channels, out_channels * 4, depth),
        )

    def forward(self, x):
        return self.block(x)


class AutoEncoder(nn.Module):
    def __init__(self, width: int, depth: int):
        super(AutoEncoder, self).__init__()
        self.encoder = nn.Sequential(
            Block(1, width * 4, depth),
            DownBlock(width, width, depth),
            DownBlock(width, width, depth),
            DownBlock(width, width, depth),
        )
        self.proj_in = nn.Identity()
        self.vq = nn.Identity()
        self.proj_out = nn.Identity()
        self.decoder = nn.Sequential(
            UpBlock(width, width, depth),
            UpBlock(width, width, depth),
            UpBlock(width, width, depth),
            Block(width * 4, 1, depth),
            nn.Sigmoid(),
        )
        self.apply(self._init_weights)

        # self.register_module("lpips", LPIPS(net_type="squeeze", normalize=True))

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.normal_(m.weight, 0, 0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.GroupNorm):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.encoder(x)
        x = self.proj_in(x)
        if isinstance(self.vq, nn.Identity):
            x = self.vq(x)
        else:
            x, *_ = self.vq(x)
        x = self.proj_out(x)
        x = self.decoder(x)

        return x

    # @staticmethod
    def loss(self, x, y):
        mse = F.mse_loss(x, y)
        bce = F.binary_cross_entropy(x, y)
        psnr = 10 * (1 / mse).log10()
        ssim = ssim_func(x, y)
        # lpips = self.lpips(x.repeat(1, 3, 1, 1), y.repeat(1, 3, 1, 1))

        return {
            "loss": bce, # + lpips,
            "bce": bce,
            "mse": mse,
            "ssim": ssim,
            "psnr": psnr,
            # "lpips": lpips,
        }

In [5]:
state_dict

OrderedDict([('model.encoder.0.layers.0.0.weight', tensor([0.8473])),
             ('model.encoder.0.layers.0.0.bias', tensor([-0.0390])),
             ('model.encoder.0.layers.0.1.0.weight',
              tensor([[[[ 0.0616, -0.0680,  0.0486],
                        [ 0.0842,  0.1275, -0.1216],
                        [ 0.2997, -0.0231,  0.2254]]]])),
             ('model.encoder.0.layers.0.1.0.bias', tensor([-0.2283])),
             ('model.encoder.0.layers.0.1.1.weight',
              tensor([[[[-0.6881]]],
              
              
                      [[[-0.5381]]],
              
              
                      [[[ 0.2543]]],
              
              
                      [[[ 0.3509]]]])),
             ('model.encoder.0.layers.0.1.1.bias',
              tensor([ 0.1899, -0.6852,  0.5866, -0.6415])),
             ('model.encoder.0.layers.1.0.weight',
              tensor([0.8802, 1.0830, 0.8384, 1.0540])),
             ('model.encoder.0.layers.1.0.bias',
          

In [4]:
ae = AutoEncoder(width=4, depth=2)
state_dict=torch.load('artifacts/model-s8lnnp8k:v0/model.ckpt', map_location=torch.device('cpu'))['state_dict']
err = ae.load_state_dict(dict(map(lambda x: (x[0].replace("model.", ""), x[1]), state_dict.items())), strict=False)
# ae.lpips = None
ae = ae.to("mps")
print(err)

ae.eval()
for param in ae.parameters():
    param.requires_grad = False

RuntimeError: Error(s) in loading state_dict for AutoEncoder:
	size mismatch for encoder.0.layers.1.0.weight: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([16]).
	size mismatch for encoder.0.layers.1.0.bias: copying a param with shape torch.Size([4]) from checkpoint, the shape in current model is torch.Size([16]).
	size mismatch for decoder.1.block.1.layers.0.0.weight: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([4]).
	size mismatch for decoder.1.block.1.layers.0.0.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([4]).
	size mismatch for decoder.2.block.1.layers.0.0.weight: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([4]).
	size mismatch for decoder.2.block.1.layers.0.0.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([4]).
	size mismatch for decoder.2.block.1.layers.1.0.weight: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([16]).
	size mismatch for decoder.2.block.1.layers.1.0.bias: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([16]).