In [1]:
! cd ~/loutrebleu/menta/stroker

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
import os
sys.path.append(os.path.expanduser("~/loutrebleu/menta/stroker"))

In [4]:
import time
import numpy as np
import xtools as xt
import torch
import torch.nn as nn
import torch.optim as toptim
import torch.utils as tutils
import torchvision.datasets as tdatasets
import torchvision.transforms as ttransforms
import matplotlib.pyplot as plt

from typing import Tuple, List
from collections import namedtuple
from IPython.display import clear_output

In [5]:
from stroker.tools.dataloader import new_dataloader

In [6]:
config_file = "config.yaml"
cf = xt.Config(config_file)
cf._cf

{'dataset': {'path': '../images/marker/dev1',
  'transforms': {'crop': [128, 128], 'resize': [64, 64]},
  'batch_size': 16,
  'shuffle': True}}

In [7]:
loader = new_dataloader(cf.dataset)

In [8]:
class DEVEncoder(nn.Module):

    def __init__(
            self,
            image_shape: Tuple[int] | List[int] | np.ndarray,
            num_layer: int,
            base_channel: int,
            latent_size: int = 50,
    ) -> None:
        super().__init__()
        self.image_shape = np.asarray(image_shape)
        self.num_layer = num_layer
        ch = np.power(2, np.arange(num_layer)) * base_channel
        ch = np.insert(ch, 0, image_shape[0])
        self.ch = ch

        self.shape_after_conv = shape_after_conv = (self.image_shape[1:] / (2 ** num_layer) / 2)
        self.size_after_conv = size_after_conv = int(shape_after_conv[0] * shape_after_conv[1] * ch[-1])

        self.convs = nn.Sequential(*[
            self._make_layer(ich, och)
            for ich, och in zip(ch[:-1], ch[1:])
        ])

        self.layer_flt = nn.Sequential(
            nn.Flatten(),
            nn.Linear(size_after_conv, 100),
            nn.ReLU(),
        )
        self.layer_ave = nn.Linear(100, latent_size)
        self.layer_dev = nn.Linear(100, latent_size)
    
    def _make_layer(self, in_ch, out_ch=None):
        out_ch = 2 * in_ch if out_ch is None else out_ch
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
        )
    
    def forward(self, x):
        x = self.convs(x)
        x = self.layer_flt(x)

        mean = self.layer_ave(x)
        log_dev = self.layer_dev(x)

        eps = torch.rand_like(mean)
        feat = mean + torch.exp(log_dev / 2) * eps
        return feat, mean, log_dev


In [9]:
class DEVDecoder(nn.Module):

    def __init__(self, encoder: DEVEncoder) -> None:
        super().__init__()
        ch = encoder.ch[::-1]
        ch[-1] = ch[-2]
        self.ch = ch

        self.shape_resize_to = np.asarray([ch[0], *encoder.shape_after_conv]).astype(int)
        print(ch)
        print("shape_resize_to:", self.shape_resize_to)

        self.layer_latent = nn.Sequential(
            nn.Linear(encoder.latent_size, 100),
            nn.ReLU(),
            nn.Linear(100, encoder.size_after_conv),
            nn.ReLU()
        )
        self.convs = nn.Sequential(*[
            self._make_layer(ich, och)
            for ich, och in zip(ch[:-1], ch[1:])
        ])
        self.layer_recons = nn.Sequential(
            nn.Conv2d(ch[-1], encoder.image_shape[0], kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def _make_layer(self, in_ch, out_ch=None):
        if out_ch is None:
            out_ch = in_ch // 2
        return nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(),
        )
    
    def forward(self, x):
        feat = self.layer_latent(x)
        feat = feat.view(x.shape[0], *self.shape_resize_to)
        
        xcv = self.convs(feat)
        recons = self.layer_recons(xcv)

        return recons

In [72]:
# from stroker.model.encoder import Encoder
import stroker.model.encoder as libenc

In [78]:
encoder = libenc.Encoder([1, 64, 64], 2, 8, latent_size=10)

In [79]:
for images, labels in loader:
    latent, mean, log_dev = encoder(images)

In [80]:
# from stroker.model.decoder import Decoder
import stroker.model.decoder as libdec

In [81]:
decoder = libdec.Decoder(encoder)

In [82]:
for images, labels in loader:
    latent, mean, log_dev = encoder(images)
    recons = decoder(latent)
    print(images.shape, recons.shape)

torch.Size([16, 1, 64, 64]) torch.Size([16, 1, 64, 64])
torch.Size([16, 1, 64, 64]) torch.Size([16, 1, 64, 64])
torch.Size([16, 1, 64, 64]) torch.Size([16, 1, 64, 64])
torch.Size([4, 1, 64, 64]) torch.Size([4, 1, 64, 64])


In [83]:
class VAEData:

    def __init__(self, recons, latent, mean, log_dev):
        self.recons = recons
        self.latent = latent
        self.mean = mean
        self.log_dev = log_dev

        self._loss = None
        self._rc_loss = None
        self._kl_loss = None
    
    def loss(self, target):
        rc_loss = nn.functional.binary_cross_entropy(self.recons, target, reduction="sum")
        kl_loss = -1/2 * torch.sum(1 + self.log_dev - self.mean ** 2 - self.log_dev.exp())
        self._loss = loss = rc_loss + kl_loss
        self._kl_loss = kl_loss
        self._rc_loss = rc_loss
        return loss

    @property
    def kl_loss(self):
        return self._kl_loss.item()
    
    @property
    def rc_loss(self):
        return self._rc_loss.item()

In [84]:
class VAE(nn.Module):

    def __init__(
            self,
            image_shape: Tuple[int] | List[int] | np.ndarray,
            num_layer: int,
            base_channel: int,
            latent_size: int = 50,
    ):
        super().__init__()
        self.encoder = libenc.Encoder(image_shape, num_layer, base_channel, latent_size)
        self.decoder = libdec.Decoder(self.encoder)
    
    def forward(self, x):
        latent, mean, log_dev = self.encoder(x)
        recons = self.decoder(latent)
        data = VAEData(recons, latent, mean, log_dev)
        return data
    
    def encode(self, images):
        if not torch.is_tensor(images):
            images = torch.from_numpy(images).float()
        if images.ndim == 2:
            images = torch.unsqueeze(images, 0)
        if images.ndim == 3:
            images = torch.unsqueeze(images, 0)
        latent, _, _ = self.encoder(images)
        return latent.detach().numpy()
    
    def decode(self, latents):
        if not torch.is_tensor(latents):
            latents = torch.from_numpy(latents)
        if latents.ndim == 1:
            latents = latents.unsqueeze(latents, 0)
        recons = self.decoder(latents)
        return recons.detach().numpy()

In [85]:
vae = VAE([1, 64, 64], 2, 4, 10)
len(list(vae.parameters()))

28

In [86]:
for images, labels in loader:
    image = images[0, ...]
    latent = vae.encode(image)
    print("latent:", latent.shape, latent.dtype)
    recons = vae.decode(latent)
    print("recons:", recons.shape, recons.dtype)
    break

latent: (1, 10) float32
recons: (1, 1, 64, 64) float32
