# Renet
https://github.com/NisTa24/ReNet-Implementation.git

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchdata.datapipes.map import SequenceWrapper
from torchvision.datasets import MNIST
from torchvision.utils import save_image
from torchvision.transforms import Compose, ToTensor, Normalize, RandomCrop, 

import numpy as np
import matplotlib.pyplot as plt

# Data

In [38]:
train = MNIST(root='./data', train=True, transform=Compose([
    ToTensor(),
    RandomCrop(14, ),
]), download=True)
test = MNIST(root='./data', train=False, transform=Compose([
    ToTensor(),
]))

In [98]:
from copy import deepcopy
import random

def do_kfold(dataset, k=2, batch_size=128):
    '''
    Input
        dataset    : Pytorch Dataset
        k          : The number of validation set
        batch_size : The size of mini-batch
        
        => valid_size = k * batch_size
        => train_size = batch_size * (total_length - k)
    Return
        train_loader, # DataLoader Type 
        valid_loader, # DataLoader Type
    '''
    data, targets = dataset.data, dataset.targets
    dataset = list(zip(data, targets))
    random.shuffle(dataset)
    
    sliced_idx = batch_size * k
    valid_set, train_set = dataset[:sliced_idx], dataset[sliced_idx:]
    
    train_data, train_targets = zip(*train_set)
    valid_data, valid_targets = zip(*valid_set)

    train_dp = SequenceWrapper(train_data).zip(SequenceWrapper(train_targets))
    valid_dp = SequenceWrapper(valid_data).zip(SequenceWrapper(valid_targets))
    
    train_loader = DataLoader(train_dp, batch_size=batch_size, shuffle=True, drop_last=True)
    valid_loader = DataLoader(valid_dp, batch_size=batch_size, shuffle=True, drop_last=True)
    
    return train_loader, valid_loader

# Model

In [None]:
class ReNet(nn.Module):
    def __init__(self):
        '''
        X       : [batch_size, 1, 28(h), 28(w)]
        Layer 1 :
            Input_shape  : [b, 1, 28, 28]
            Patch_size   : [1, 2, 2]
            Vertically   : hidden_dim = 128, bidirectional = True(LSTM)
            Patch_size   : [256, 1, 1]
            Horizontally : hidden_dim = 128, bidirectional = True(LSTM)
            Output_shape : [b, 256, 14, 14] # The hidden_dim * 2 is the number of output-channels
        
        Layer 2 :
            Input_shape  : [b, 256, 14, 14]
            Patch_size   : [256, 2, 2]
            Vertically   : hidden_dim = 128, bidirectional = True(LSTM)
            Patch_size   : [256, 1, 1]
            Horizontally : hidden_dim = 128, bidirectional = True(LSTM)
            Output_shape : [b, 256, 7, 7]
        
        Layer 3 :
            [FC_Layer]
            Input_shape  : [b, 256 * 7 * 7]
            Output_shape : [b, 256 * 7 * 7]
        
        Layer 4 :
            [FC_Layer]
            Input_shape  : [b, 256 * 7 * 7]
            Output_shape : [b, 10]
        '''
        super().__init__()
        self.hidden_dim = 128
        
        self.layer1_v = nn.LSTM(1*2*2, self.hidden_dim, 1, True, True, 0, True)
        self.layer1_h = nn.LSTM(256*1*1, self.hidden_dim, 1, True, True, 0, True)
        
        self.layer2_v = nn.LSTM(256*2*2, self.hidden_dim, 1, True, True, 0, True)
        self.layer2_h = nn.LSTM(256*1*1, self.hidden_dim, 1, True, True, 0, True)
        
        self.fc_1 = nn.Linear(256*7*7, 256*7*7)
        self.fc_2 = nn.Linear(256*7*7, 10)
    
    def _flatten_input(self, x, patch_size, direction):
        '''
        direction : 
            v : vertically
            h : horizontally
        Output :
            shape : [[batch_size, seq_len, input_size], ...]
        '''
        x_b, _, x_h, x_w = x.shape
        _, p_h, p_w = patch_size
        assert (x_h % p_h) != 0 and (x_w % p_w) != 0
        
        match direction:
            case 'v':
                x = torch.split(x, x_h // p_h, dim=2)
                x = [tmp.reshape(x_b, )]
            case 'h':
                x = torch.split(x, x_w // p_w, dim=3)
        
        
        return inputs
    
    def forward(self, x):
        '''
        Input :
            x : [batch_size(b), channel(c), height, width]
        Flow. :
            1 : Divide x into pathes, [b, c, Hp, Wp], Pij
            2 : Flatten pathes, Pij : [b, Hp * c * Wp]
            3 : Sweep pathes Vertically and Bidirectionally
              : Patch size = [2, 2] fixed
            5 : Sweep pathes Horizontally and Bidirectionally
              : Patch size = [1, 1] fixed
            6 : Repeat 3-4
            7 : Flatten and FC Layer
        
        Additional
            - By 3, implement pooling
            - Each patch isn't overlapped at this model
            - The number of FC-layer is 2
        '''
        inputs = self._flatten_input(x, (2, 2))
        
        
        

In [178]:
x = torch.tensor([1,2,3,4,5,6,7,8,9,0]).reshape(2, 5)
x

tensor([[1, 2, 3, 4, 5],
        [6, 7, 8, 9, 0]])