# Simple Autoencoder
- Autoencoder mnist
    - Encoder: tự làm nhưng dùng kỹ thuật của mobilenet
            - sep depthwise
            - skip-connection with inverted bottleneck
    - Global-pool + flatten +linear -> vector z
    - View (-1 , dims, 1, 1)
- Decoder: simple upsampling+conv & no skip connection
- Học được: trick tính mse theo batch thay vì theo pixel


In [1]:
%cd "../"

/workspace/DeepVisualization


In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets

In [3]:
def prepare_data(batch_size, num_workers):
    train_dataset = datasets.MNIST(root='dataset', train=True, download=True)
    test_dataset = datasets.MNIST(root='dataset', train=True)
    
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, 
                              num_workers=num_workers, shuffle=True)
    
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, 
                             num_workers=num_workers, shuffle=True)
    
    return train_loader, test_loader

In [28]:
class DepthWise(nn.Module):
    def __init__(self, in_channels, kernel_size, stride):
        super().__init__()
        padding = (kernel_size - 1) // 2
        self.conv = nn.Conv2d(
            in_channels=in_channels, 
            out_channels=in_channels, 
            kernel_size=kernel_size,
            stride=stride, 
            groups=in_channels, 
            padding=padding,
            bias=False,
        )
        self.bn = nn.BatchNorm2d(num_features=in_channels)
        self.relu = nn.ReLU6(inplace=True)
        
    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

class PointWise(nn.Module):
    def __init__(self, in_channels, out_channels, use_relu=False):
        super().__init__()
        self.use_relu = use_relu
        self.conv_1x1 = nn.Conv2d(
            in_channels=in_channels, 
            out_channels=out_channels, 
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False,
        )
        self.bn = nn.BatchNorm2d(num_features=out_channels)
        if self.use_relu:
            self.relu = nn.ReLU6(inplace=True)
        
    def forward(self, x):
        x = self.bn(self.conv_1x1(x))
        return self.relu(x) if self.use_relu else x

In [38]:
class InvertedResidual(nn.Module):
    def __init__(self, in_channels, out_channels, expand_ratio, stride, downsample):
        super().__init__()
        hidden_dim = int(round(in_channels * expand_ratio))
        self.downsample = downsample
        self.conv1 = PointWise(
            in_channels=in_channels,
            out_channels=hidden_dim,
            use_relu=True,
        )
        
        self.conv2 = DepthWise(
            in_channels=hidden_dim,
            kernel_size=3,
            stride=stride,
        )
        
        self.conv3 = PointWise(
            in_channels=hidden_dim,
            out_channels=out_channels,
            use_relu=False,
        )
        
    def forward(self, x):
        residual = x
        x = self.conv3(self.conv2(self.conv1(x)))
        return self.downsample(residual) + x if self.downsample is not None else x
    
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, expand_ratio):
        super().__init__()
        downsample = nn.Conv2d(in_channels=in_channels, 
                               out_channels=out_channels, 
                               kernel_size=1, stride=2, bias=False)
        
        self.unit1 = InvertedResidual(in_channels=in_channels, 
                                      out_channels=out_channels, 
                                      expand_ratio=expand_ratio, 
                                      stride=2, downsample=downsample)
        
        self.unit2 = InvertedResidual(in_channels=out_channels, 
                                      out_channels=out_channels, 
                                      expand_ratio=expand_ratio, 
                                      stride=1, downsample=None)
    
    def forward(self, x):
        x = self.unit1(x)
        x = self.unit2(x)
        return x
        

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = ResBlock(in_channels=32, out_channels=64, expand_ratio=2)
        self.block2= ResBlock(in_channels=64, out_channels=128, expand_ratio=2)
        
    def _make_blocks(self):
        pass
    
    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        return x
    
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        pass
    
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2D
        self.encoder = Encoder()
    
    def forward(self, x):
        x = self.encoder(x)
        return x

In [39]:
model = Model()
x = torch.randn(12, 32, 28, 28)
y = model(x)
y.shape

torch.Size([12, 128, 7, 7])