In [1]:
from IPython.display import display, clear_output
from ipywidgets import Output, VBox

In [472]:
import matplotlib.pyplot as plt
import numpy as np
import time

from tqdm.auto import tqdm

import torchvision
import torch.nn as nn
import torch
import torch.nn.functional as F

import kornia

In [11]:
def printput(*args):
    out = Output()
    with out:
        print(*args)
    return out

def dispput(*args):
    out = Output()
    with out:
        display(*args)
    return out

class BaseModule(nn.Module):
    @property
    def device(self):
        return next(self.parameters()).device

In [571]:
class IAE(BaseModule):
    def __init__(self, msg_size, out_shape):
        super().__init__()
        out_size = np.prod(out_shape)
        self.msg_size = msg_size
        self.out_shape = out_shape
        self.generator = nn.Sequential(
            nn.Linear(msg_size, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, out_size),
            nn.Sigmoid(),
        )
        
        self.decoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(out_size, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, msg_size),
        )
        
    def noise(self, img):
        noise = nn.Sequential(
            kornia.augmentation.RandomAffine(
                degrees=20,
                translate=[0.1, 0.1],
                scale=[0.9, 1.1],
                shear=[-5, 5],
                p=1,
            ),
#             kornia.augmentation.RandomPerspective(0.2, p=0.9),
        )
        normal_noise = torch.randn_like(img) / 3

        img = noise(img)
        img = img + normal_noise
        return img
    
    def sample_msg(self, bs):
        return torch.rand(bs, self.msg_size).to(self.device)
    
    def generate_img(self, msg):
        img = self.generator(msg)
        img = img.reshape(*(-1, *self.out_shape))
        return img
        
    def forward(self, bs):
        msg = self.sample_msg(bs)
        img = self.generator(msg)
        img = img.reshape(*(-1, *self.out_shape))
        noise_img = self.noise(img)
        pred_msg = self.decoder(noise_img)
        return msg, img, pred_msg
        
    def optim_step(self, bs, lr):
        msg, img, pred_msg = self(bs)
        loss = F.mse_loss(pred_msg, msg)
        
        optim = torch.optim.Adam(self.parameters(), lr=lr)
        optim.zero_grad()
        loss.backward()
        optim.step()
        
        return {'loss': loss.item()}

In [572]:
DEVICE = 'cuda'

In [581]:
model = IAE(msg_size=64, out_shape=(1, 32, 32))
model = model.to(DEVICE)

In [582]:
msg = model.sample_msg(bs=10)

In [586]:
epochs = 10000
for i in range(1, epochs + 1):
    f, ax = plt.subplots(figsize=(10, 10))
    
    info = model.optim_step(bs=32, lr=0.005)
    loss = info['loss']

    if i % 10 == 0:
        img = model.generate_img(msg)
        grid = torchvision.utils.make_grid(img, nrow=5, padding=2)
        grid = grid.permute(1, 2, 0).cpu().numpy()
        ax.imshow(grid)
        plt.close()

        clear_output(wait=True)
        display(VBox([
            printput(f'#{i} | Loss: {loss:0.5f}'),
            dispput(f)
        ]), display_id='stats')
        
    plt.close()

VBox(children=(Output(), Output()))