In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.transforms import *
import numpy as np
import matplotlib.pyplot as pp
import os
from STN import Stn, Stn11, Stn2_1, Stn2_2, Stn2_3
import models

In [10]:
def conv3x3(in_channels, out_channels):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3,
                     stride=1, padding=1, bias=True)

def conv2x2(in_channels, out_channels):
    return nn.Conv2d(in_channels, out_channels, kernel_size=2,
                     stride=2, padding=0, bias=True)

def conv1x1(in_channels, out_channels):
    return nn.Conv2d(in_channels, out_channels, kernel_size=1,
                     stride=1, padding=0, bias=True)

In [11]:
class Encoder1(nn.Module):
    def __init__(self, in_channels):
        super(Encoder1, self).__init__()

        self.in_channels = in_channels

        # 1 - Ori Size
        self.conv3_1_1 = conv3x3(in_channels=self.in_channels, out_channels=16)
        self.relu1_1 = nn.LeakyReLU(0.1)
        self.conv3_1_2 = conv3x3(in_channels=16, out_channels=16)
        self.relu1_2 = nn.LeakyReLU(0.1)
        self.conv3_1_3 = conv3x3(in_channels=16, out_channels=16)
        self.relu1_3 = nn.LeakyReLU(0.1)

        # 2 - 1/2 Size
        self.conv2_2_1 = conv2x2(in_channels=16, out_channels=32)
        self.relu2_1 = nn.LeakyReLU(0.1)
        self.conv3_2_1 = conv3x3(in_channels=32, out_channels=32)
        self.relu2_2 = nn.LeakyReLU(0.1)
        self.conv3_2_2 = conv3x3(in_channels=32, out_channels=32)
        self.relu2_3 = nn.LeakyReLU(0.1)

        # 3 - 1/4 Size
        self.conv2_3_1 = conv2x2(in_channels=32, out_channels=64)
        self.relu3_1 = nn.LeakyReLU(0.1)
        self.conv3_3_1 = conv3x3(in_channels=64, out_channels=64)
        self.relu3_2 = nn.LeakyReLU(0.1)
        self.conv3_3_2 = conv3x3(in_channels=64, out_channels=64)
        self.relu3_3 = nn.LeakyReLU(0.1)

        # 4 - 1/8 Size
        self.conv2_4_1 = conv2x2(in_channels=64, out_channels=128)
        self.relu4_1 = nn.LeakyReLU(0.1)
        self.conv3_4_1 = conv3x3(in_channels=128, out_channels=128)
        self.relu4_2 = nn.LeakyReLU(0.1)
        self.conv3_4_2 = conv3x3(in_channels=128, out_channels=128)
        self.relu4_3 = nn.LeakyReLU(0.1)

    def forward(self, x):
        # 1 - Ori Size
        # 1 채널로 바꾸기
        # x = x[:, 1, :, :]
        # x = x.unsqueeze(1)
        y1_1 = self.conv3_1_1(x)
        y1_1 = self.relu1_1(y1_1)
        y1_2 = self.conv3_1_2(y1_1)
        y1_2 = self.relu1_2(y1_2)
        y1_3 = self.conv3_1_3(y1_2)
        y1_3 = self.relu1_3(y1_3)  # 32 channel

        # 2 - 1/2 Size
        y2_1 = self.conv2_2_1(y1_3)  # size는 감소, 채널 증가(16->32)
        y2_1 = self.relu2_1(y2_1)
        y2_2 = self.conv3_2_1(y2_1)
        y2_2 = self.relu2_2(y2_2)
        y2_3 = self.conv3_2_2(y2_2)
        out1 = self.relu2_3(y2_3)

        # 3 - 1/4 Size
        y3_1 = self.conv2_3_1(out1)  # size는 감소, 채널 증가(32->64)
        y3_1 = self.relu3_1(y3_1)
        y3_2 = self.conv3_3_1(y3_1)
        y3_2 = self.relu3_2(y3_2)
        y3_3 = self.conv3_3_2(y3_2)
        y3_3 = self.relu3_3(y3_3)

        # 4 - 1/8 Size
        y4_1 = self.conv2_4_1(y3_3)  # size는 감소, 채널 증가(64->128)
        y4_1 = self.relu4_1(y4_1)
        y4_2 = self.conv3_4_1(y4_1)
        y4_2 = self.relu4_2(y4_2)
        y4_3 = self.conv3_4_2(y4_2)
        out2 = self.relu4_3(y4_3)
        return out1, out2  # 1/2 size with 32 channels, 1/8 size with 128 channels

In [12]:
class Encoder2(nn.Module):
    def __init__(self, in_channels):
        super(Encoder2, self).__init__()

        self.in_channels = in_channels

        # 1 - Ori Size
        self.conv3_1_1 = conv3x3(in_channels=self.in_channels, out_channels=16)
        self.relu1_1 = nn.LeakyReLU(0.1)
        self.conv3_1_2 = conv3x3(in_channels=16, out_channels=16)
        self.relu1_2 = nn.LeakyReLU(0.1)
        self.conv3_1_3 = conv3x3(in_channels=16, out_channels=16)
        self.relu1_3 = nn.LeakyReLU(0.1)

        # 2 - 1/2 Size
        self.conv2_2_1 = conv2x2(in_channels=16, out_channels=32)
        self.relu2_1 = nn.LeakyReLU(0.1)
        self.conv3_2_1 = conv3x3(in_channels=32, out_channels=32)
        self.relu2_2 = nn.LeakyReLU(0.1)
        self.conv3_2_2 = conv3x3(in_channels=32, out_channels=32)
        self.relu2_3 = nn.LeakyReLU(0.1)

        # 3 - 1/4 Size
        self.conv2_3_1 = conv2x2(in_channels=32, out_channels=64)
        self.relu3_1 = nn.LeakyReLU(0.1)
        self.conv3_3_1 = conv3x3(in_channels=64, out_channels=64)
        self.relu3_2 = nn.LeakyReLU(0.1)
        self.conv3_3_2 = conv3x3(in_channels=64, out_channels=64)
        self.relu3_3 = nn.LeakyReLU(0.1)

        # 4 - 1/8 Size
        self.conv2_4_1 = conv2x2(in_channels=64, out_channels=128)
        self.relu4_1 = nn.LeakyReLU(0.1)
        self.conv3_4_1 = conv3x3(in_channels=128, out_channels=128)
        self.relu4_2 = nn.LeakyReLU(0.1)
        self.conv3_4_2 = conv3x3(in_channels=128, out_channels=128)
        self.relu4_3 = nn.LeakyReLU(0.1)

    def forward(self, x):
        # 1 - Ori Size
        y1_1 = self.conv3_1_1(x)
        y1_1 = self.relu1_1(y1_1)
        y1_2 = self.conv3_1_2(y1_1)
        y1_2 = self.relu1_2(y1_2)
        y1_3 = self.conv3_1_3(y1_2)
        y1_3 = self.relu1_3(y1_3)
        out1 = y1_3  # 16 channels

        # 2 - 1/2 Size
        y2_1 = self.conv2_2_1(y1_3)
        y2_1 = self.relu2_1(y2_1)
        y2_2 = self.conv3_2_1(y2_1)
        y2_2 = self.relu2_2(y2_2)
        y2_3 = self.conv3_2_2(y2_2)
        y2_3 = self.relu2_3(y2_3)
        out2 = y2_3  # 32 channels

        # 3 - 1/4 Size
        y3_1 = self.conv2_3_1(y2_3)
        y3_1 = self.relu3_1(y3_1)
        y3_2 = self.conv3_3_1(y3_1)
        y3_2 = self.relu3_2(y3_2)
        y3_3 = self.conv3_3_2(y3_2)
        y3_3 = self.relu3_3(y3_3)
        out3 = y3_3  # 64 channels

        # 4 - 1/8 Size
        y4_1 = self.conv2_4_1(y3_3)
        y4_1 = self.relu4_1(y4_1)
        y4_2 = self.conv3_4_1(y4_1)
        y4_2 = self.relu4_2(y4_2)
        y4_3 = self.conv3_4_2(y4_2)
        y4_3 = self.relu4_3(y4_3)
        out4 = y4_3  # 128 channels
        return out1, out2, out3, out4  # stage1 ~4

In [13]:
class Decoder1(nn.Module):
    def __init__(self, out_channels):
        super(Decoder1, self).__init__()

        self.out_channels = out_channels
        self.PS = nn.PixelShuffle(2)  # (B, C x 2^2, H, W) -> (B, C, H x 2, W x 2)

        # 4 - 1/8 Size
        self.conv3_4_3 = conv3x3(in_channels=512, out_channels=512)
        self.relu4_6 = nn.LeakyReLU(0.1)
        self.conv3_4_4 = conv3x3(in_channels=512, out_channels=1024)
        self.relu4_7 = nn.LeakyReLU(0.1)

        # 5 - 1/4 Size
        self.conv3_5_1 = conv3x3(in_channels=256, out_channels=256)
        self.relu5_1 = nn.LeakyReLU(0.1)
        self.conv3_5_2 = conv3x3(in_channels=256, out_channels=512)
        self.relu5_2 = nn.LeakyReLU(0.1)

        # 6 - 1/2 Size
        self.conv3_6_1 = conv3x3(in_channels=256, out_channels=256)  # 128 + 128
        self.relu6_1 = nn.LeakyReLU(0.1)
        self.conv3_6_2 = conv3x3(in_channels=256, out_channels=256)
        self.relu6_2 = nn.LeakyReLU(0.1)

        # 7 - Ori Size
        self.conv3_7_1 = conv3x3(in_channels=64, out_channels=64)
        self.relu7_1 = nn.LeakyReLU(0.1)
        self.conv3_7_2 = conv3x3(in_channels=64, out_channels=64)
        self.relu7_2 = nn.LeakyReLU(0.1)
        self.conv1_7_1 = conv1x1(in_channels=64, out_channels=self.out_channels)

    def forward(self, x1, x2):  # 128, 512 channels
        # 4 - 1/8 Size
        y4_6 = self.conv3_4_3(x2)
        y4_6 = self.relu4_6(y4_6)
        y4_7 = self.conv3_4_4(y4_6)
        y4_7 = self.relu4_7(y4_7)  # 512 channels

        # 5 - 1/4 Size
        y5_1 = self.PS(y4_7)  # 128 channels
        y5_2 = self.conv3_5_1(y5_1)
        y5_2 = self.relu5_1(y5_2)
        y5_3 = self.conv3_5_2(y5_2)
        y5_3 = self.relu5_2(y5_3)

        # 6 - 1/2 Size
        y6_1 = self.PS(y5_3)
        y6_2 = torch.cat((x1, y6_1), 1)
        y6_3 = self.conv3_6_1(y6_2)
        y6_3 = self.relu6_1(y6_3)
        y6_4 = self.conv3_6_2(y6_3)
        y6_4 = self.relu6_2(y6_4)

        # 7 - Ori Size
        y7_1 = self.PS(y6_4)
        y7_2 = self.conv3_7_1(y7_1)
        y7_2 = self.relu7_1(y7_2)
        y7_3 = self.conv3_7_2(y7_2)
        y7_3 = self.relu7_2(y7_3)
        out = self.conv1_7_1(y7_3)
        return out

In [14]:
class Decoder2(nn.Module):
    def __init__(self, out_channels):
        super(Decoder2, self).__init__()

        self.out_channels = out_channels
        self.PS = nn.PixelShuffle(2)

        # 4 - 1/8 Size
        self.conv3_4_3 = conv3x3(in_channels=128, out_channels=128)
        self.relu4_6 = nn.LeakyReLU(0.1)
        self.conv3_4_4 = conv3x3(in_channels=128, out_channels=256)
        self.relu4_7 = nn.LeakyReLU(0.1)

        # 5 - 1/4 Size
        self.conv3_5_1 = conv3x3(in_channels=128, out_channels=128)  # 64+64 (Concat)
        self.relu5_1 = nn.LeakyReLU(0.1)
        self.conv3_5_2 = conv3x3(in_channels=128, out_channels=128)
        self.relu5_2 = nn.LeakyReLU(0.1)

        # 6 - 1/2 Size
        self.conv3_6_1 = conv3x3(in_channels=64, out_channels=64)  # 32+32 (Concat)
        self.relu6_1 = nn.LeakyReLU(0.1)
        self.conv3_6_2 = conv3x3(in_channels=64, out_channels=64)
        self.relu6_2 = nn.LeakyReLU(0.1)

        # 7 - Ori Size
        self.conv3_7_1 = conv3x3(in_channels=32, out_channels=32)  # 16+16 (Concat)
        self.relu7_1 = nn.LeakyReLU(0.1)
        self.conv3_7_2 = conv3x3(in_channels=32, out_channels=32)
        self.relu7_2 = nn.LeakyReLU(0.1)
        self.conv1_7_1 = conv1x1(in_channels=32, out_channels=self.out_channels)

    def forward(self, x1, x2, x3, x4):  # 16, 32, 64, 128 channels
        # 4 - 1/8 Size with 128 channels
        y4_1 = self.conv3_4_3(x4)
        y4_1 = self.relu4_6(y4_1)
        y4_2 = self.conv3_4_4(y4_1)  # 256 channels
        y4_2 = self.relu4_7(y4_2)

        # 5 - 1/4 Size with 64 channels
        y5_1 = self.PS(y4_2)  # 256 -> 64 channels
        y5_2 = torch.cat((x3, y5_1), 1)  # Concat -> 128 channels
        y5_3 = self.conv3_5_1(y5_2)
        y5_3 = self.relu5_1(y5_3)
        y5_4 = self.conv3_5_2(y5_3)
        y5_4 = self.relu5_2(y5_4)

        # 6 - 1/2 Size with 32 channels
        y6_1 = self.PS(y5_4)  # 128 -> 32 channels
        y6_2 = torch.cat((x2, y6_1), 1)  # Concat -> 64 channels
        y6_3 = self.conv3_6_1(y6_2)
        y6_3 = self.relu6_1(y6_3)
        y6_4 = self.conv3_6_2(y6_3)
        y6_4 = self.relu6_2(y6_4)

        # 7 - Ori Size
        y7_1 = self.PS(y6_4)  # 64 -> 16 channels
        y7_2 = torch.cat((x1, y7_1), 1)  # Concat -> 32 channels
        y7_3 = self.conv3_7_1(y7_2)
        y7_3 = self.relu7_1(y7_3)
        y7_4 = self.conv3_7_2(y7_3)
        y7_4 = self.relu7_2(y7_4)
        out = self.conv1_7_1(y7_4)  # 32 -> 3 channels

        return out

In [15]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
#from base_networks import *
from torchvision.transforms import *
import torch.nn.functional as F
from ConvLSTM import ConvLSTMCell
import numpy as np
# from vit_pytorch import ViT

class Stn(nn.Module):
    def __init__(self):
        super(Stn, self).__init__()

        # self.v = ViT(
        #     image_size=128,
        #     patch_size=16,
        #     num_classes=1000,
        #     dim=1024,
        #     depth=6,
        #     heads=16,
        #     mlp_dim = 2048,
        #     dropout=0.1,
        #     emb_dropout=0.1)

        self.st = nn.Sequential(
            nn.MaxPool2d(2, stride=2, ceil_mode=True),
            nn.Conv2d(1, 250, kernel_size=5, stride=1, padding=2),
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2, ceil_mode=True),
            nn.Conv2d(250, 250, kernel_size=5, stride=1, padding=2),
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2, ceil_mode=True)
        )

        self.FC_ = nn.Sequential(
            nn.Linear(64000, 250),
            nn.ReLU(True),
            nn.Linear(250, 6)
        )
        self.FC_[2].weight.data.zero_()
        self.FC_[2].bias.data.copy_(torch.tensor([0.9, 0, 0, 0, 0.9, 0], dtype=torch.float))

    def forward(self, x):

        h = self.st(x)
        h = h.view(-1, 64000)
        h = self.FC_(h)
        theta = h.view(-1, 2, 3)
        grid = F.affine_grid(theta,x.size(),align_corners=True)
        out = F.grid_sample(x, grid,align_corners=True)
        return out

class Stn11(nn.Module):
    def __init__(self):
        super(Stn11, self).__init__()

        # self.v = ViT(
        #     image_size=128,
        #     patch_size=16,
        #     num_classes=1000,
        #     dim=1024,
        #     depth=6,
        #     heads=16,
        #     mlp_dim = 2048,
        #     dropout=0.1,
        #     emb_dropout=0.1)

        self.st = nn.Sequential(
            nn.MaxPool2d(2, stride=2, ceil_mode=True),
            nn.Conv2d(3, 250, kernel_size=5, stride=1, padding=2),
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2, ceil_mode=True),
            nn.Conv2d(250, 250, kernel_size=5, stride=1, padding=2),
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2, ceil_mode=True)
        )

        self.FC_ = nn.Sequential(
            nn.Linear(64000, 250),
            nn.ReLU(True),
            nn.Linear(250, 6)
        )
        self.FC_[2].weight.data.zero_()
        self.FC_[2].bias.data.copy_(torch.tensor([0.9, 0, 0, 0, 0.9, 0], dtype=torch.float))

    def forward(self, x):

        h = self.st(x)
        h = h.view(-1, 64000)
        h = self.FC_(h)
        theta = h.view(-1, 2, 3)
        grid = F.affine_grid(theta,x.size(),align_corners=True)
        out = F.grid_sample(x, grid,align_corners=True)
        return out

class Stn2(nn.Module):
    def __init__(self):
        super(Stn2, self).__init__()

        self.st = nn.Sequential(
            nn.MaxPool2d(2, stride=2, ceil_mode=True),
            nn.Conv2d(3, 250, kernel_size=5, stride=1, padding=2),
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2, ceil_mode=True),
            nn.Conv2d(250, 250, kernel_size=5, stride=1, padding=2),
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2, ceil_mode=True)
        )
        self.FC_ = nn.Sequential(
            nn.Linear(784000, 250),
            nn.ReLU(True),
            nn.Linear(250, 6)
        )
        self.FC_[2].weight.data.zero_()
        self.FC_[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    def forward(self, x):
        h = self.st(x)
        h = h.view(-1, 784000)
        h = self.FC_(h)
        theta = h.view(-1, 2, 3)
        grid = F.affine_grid(theta, x.size())
        out = F.grid_sample(x, grid)
        return out, theta


class Stn2_1(nn.Module):
    def __init__(self):
        super(Stn2_1, self).__init__()

        self.st = nn.Sequential(
            nn.MaxPool2d(2, stride=2, ceil_mode=True),
            nn.Conv2d(1, 250, kernel_size=5, stride=1, padding=2),
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2, ceil_mode=True),
            nn.Conv2d(250, 250, kernel_size=5, stride=1, padding=2),
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2, ceil_mode=True)
        )
        self.FC_ = nn.Sequential(
            nn.Linear(16000, 250),
            nn.ReLU(True),
            nn.Linear(250, 6)
        )
        self.FC_[2].weight.data.zero_()
        self.FC_[2].bias.data.copy_(torch.tensor([0.9, 0, 0, 0, 0.9, 0], dtype=torch.float))

    def forward(self, x):
        h = self.st(x)
        h = h.view(-1, 16000)
        h = self.FC_(h)
        theta = h.view(-1, 2, 3)
        grid = F.affine_grid(theta, x.size())
        out = F.grid_sample(x, grid)
        return out

class Stn2_2(nn.Module):
    def __init__(self):
        super(Stn2_2, self).__init__()

        self.st = nn.Sequential(
            nn.MaxPool2d(2, stride=2, ceil_mode=True),
            nn.Conv2d(1, 250, kernel_size=5, stride=1, padding=2),
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2, ceil_mode=True),
            nn.Conv2d(250, 250, kernel_size=5, stride=1, padding=2),
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2, ceil_mode=True)
        )
        self.FC_ = nn.Sequential(
            nn.Linear(32000, 250),
            nn.ReLU(True),
            nn.Linear(250, 6)
        )
        self.FC_[2].weight.data.zero_()
        self.FC_[2].bias.data.copy_(torch.tensor([0.9, 0, 0, 0, 0.9, 0], dtype=torch.float))

    def forward(self, x):
        h = self.st(x)
        h = h.view(-1, 32000)
        h = self.FC_(h)
        theta = h.view(-1, 2, 3)
        grid = F.affine_grid(theta, x.size())
        out = F.grid_sample(x, grid)
        return out

class Stn2_3(nn.Module):
    def __init__(self):
        super(Stn2_3, self).__init__()

        self.st = nn.Sequential(
            nn.MaxPool2d(2, stride=2, ceil_mode=True),
            nn.Conv2d(1, 250, kernel_size=5, stride=1, padding=2),
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2, ceil_mode=True),
            nn.Conv2d(250, 250, kernel_size=5, stride=1, padding=2),
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2, ceil_mode=True)
        )
        self.FC_ = nn.Sequential(
            nn.Linear(1000, 250),
            nn.ReLU(True),
            nn.Linear(250, 6)
        )
        self.FC_[2].weight.data.zero_()
        self.FC_[2].bias.data.copy_(torch.tensor([0.9, 0, 0, 0, 0.9, 0], dtype=torch.float))

    def forward(self, x):
        h = self.st(x)
        h = h.view(-1, 1000)
        h = self.FC_(h)
        theta = h.view(-1, 2, 3)
        grid = F.affine_grid(theta, x.size())
        out = F.grid_sample(x, grid)
        return out, theta



class FC(nn.Module):
    def __init__(self):
        super(FC, self).__init__()

        self.FC = nn.Sequential(
            nn.Linear(3, 6),
            nn.ReLU(True),
            nn.Linear(6, 3),
            nn.ReLU(True),
            nn.Linear(3, 1)
        )
        self.FC[0].weight.data.zero_()
        self.FC[0].bias.data.copy_(torch.tensor([1], dtype=torch.float))

    def forward(self, x):
        h = self.FC(x)
        return h

In [21]:
def partition4(x):  # X : (B, C, H, W) 가정
    H, W = x.shape[2], x.shape[3]
    pivot_H, pivot_W = H // 2, W // 2

    left_upper = x[:, :, 0:pivot_H, 0:pivot_W]  # 좌상단
    right_upper = x[:, :, 0:pivot_H, pivot_W:]  # 우상단
    left_lower = x[:, :, pivot_H:, 0:pivot_W]  # 좌하단
    right_lower = x[:, :, pivot_H:, pivot_W:]  # 우하단 

    return left_upper, right_upper, left_lower, right_lower
    
# 좌우 2분할 : (B, C, H, W) -> (B, C, H, 2/W) for each patch
def partition2_vertical(x):  
    H, W = x.shape[2], x.shape[3]
    pivot = W // 2

    left = x[:, :, :, 0:pivot]
    right = x[:, :, :, pivot:]

    return left, right
    
# 상하 2분할 : (B, C, H, W) -> (B, C, 2/H, W) for each patch
def partition2_horizontal(x):
    H, W = x.shape[2], x.shape[3]
    pivot = W // 2

    upper = x[:, :, 0:pivot, :]
    lower = x[:, :, pivot:, :]

    return upper, lower

## Forward

In [33]:
encoder_lv1 = models.Encoder()  # Original
encoder_lv2 = models.Encoder()  # 상하 2분할
encoder_lv3 = models.Encoder()  # 4분할

decoder_lv1 = models.Decoder()
decoder_lv2 = models.Decoder()
decoder_lv3 = models.Decoder()

In [34]:
stn4_1 = Stn2_1()  # theta : 16000 -> 250 -> 6
stn4_2 = Stn2_1()
stn4_3 = Stn2_1()
stn4_4 = Stn2_1()

stn2_1 = Stn2_2()  # theta : 32000 -> 250 -> 6
stn2_2 = Stn2_2()

stn1_1 = Stn()
stn_rgb = Stn11()

conv5_1 = conv3x3(in_channels=16, out_channels=16)
conv5_2 = conv3x3(in_channels=32, out_channels=32)
conv5_3 = conv3x3(in_channels=64, out_channels=64)
conv5_4 = conv3x3(in_channels=128, out_channels=128)

E5 = Encoder2(3)  # Input : RGB 3 channels
E6 = Encoder1(1)  # Input : Grayscale 1 channel
D2 = Decoder2(3)  # Output : RGB 3 channels

### Inputs

In [37]:
x4 = torch.randn(1, 1, 128, 128)  # grayscale

shape1 = x4  # 원본
upper, lower = partition2_horizontal(shape1)  # 상하 2 분할
lu, ru, ll, rl = partition4(shape1)  # 4 분할

print(shape1.shape)
print(upper.shape, lower.shape)
print(lu.shape, ru.shape, ll.shape, rl.shape)

torch.Size([1, 1, 128, 128])
torch.Size([1, 1, 64, 128]) torch.Size([1, 1, 64, 128])
torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 64, 64]) torch.Size([1, 1, 64, 64])


### STN Outputs

In [38]:
shape1 = stn1_1(shape1)
print('shape1 :', shape1.shape)
print()

shape21 = stn2_1(upper) 
shape22 = stn2_2(lower)
print('shape21 :', shape21.shape)
print('shape22 :', shape22.shape)
print()

shape41 = stn4_1(lu)
shape42 = stn4_2(ru)
shape43 = stn4_3(ll)
shape44 = stn4_4(rl)
print('shape41 :', shape41.shape)
print('shape42 :', shape42.shape)
print('shape43 :', shape43.shape)
print('shape44 :', shape44.shape)

shape1 : torch.Size([1, 1, 128, 128])

shape21 : torch.Size([1, 1, 64, 128])
shape22 : torch.Size([1, 1, 64, 128])

shape41 : torch.Size([1, 1, 64, 64])
shape42 : torch.Size([1, 1, 64, 64])
shape43 : torch.Size([1, 1, 64, 64])
shape44 : torch.Size([1, 1, 64, 64])


  "Default grid_sample and affine_grid behavior has changed "
  "Default grid_sample and affine_grid behavior has changed "


### Encoding & Decoding 1  

4 분할 입력의 STN outputs 처리

In [41]:
feature_lv41 = encoder_lv3(shape41)
feature_lv42 = encoder_lv3(shape42)
feature_lv43 = encoder_lv3(shape43)
feature_lv44 = encoder_lv3(shape44)

print('feature_lv41 :', feature_lv41.shape)
print('feature_lv42 :', feature_lv42.shape)
print('feature_lv43 :', feature_lv43.shape)
print('feature_lv44 :', feature_lv44.shape)

feature_lv41 : torch.Size([1, 128, 16, 16])
feature_lv42 : torch.Size([1, 128, 16, 16])
feature_lv43 : torch.Size([1, 128, 16, 16])
feature_lv44 : torch.Size([1, 128, 16, 16])


In [42]:
out_lv41 = decoder_lv3(feature_lv41)
out_lv42 = decoder_lv3(feature_lv42)
out_lv43 = decoder_lv3(feature_lv43)
out_lv44 = decoder_lv3(feature_lv44)

print('out_lv41 :', out_lv41.shape)
print('out_lv42 :', out_lv42.shape)
print('out_lv43 :', out_lv43.shape)
print('out_lv44 :', out_lv44.shape)

out_lv41 : torch.Size([1, 1, 64, 64])
out_lv42 : torch.Size([1, 1, 64, 64])
out_lv43 : torch.Size([1, 1, 64, 64])
out_lv44 : torch.Size([1, 1, 64, 64])


### 다음 계층에 전달할 사항 정리

2 분할 STN Outputs의 처리 계층의 encoder, decoder에 전달할 사항들

In [52]:
# 다음 계층 encoder output에 전달
feature_lv412 = torch.cat([feature_lv41, feature_lv42], dim=-1)  # 좌상단, 우상단 concat
feature_lv434 = torch.cat([feature_lv43, feature_lv44], dim=-1)  # 좌하단, 우하단 concat
print('feature_lv412 :', feature_lv412.shape)  
print('feature_lv434 :', feature_lv434.shape)  

feature_lv412 : torch.Size([1, 128, 16, 32])
feature_lv434 : torch.Size([1, 128, 16, 32])


In [53]:
# 다음 계층 encoder input에 전달
out_lv412 = torch.cat([out_lv41, out_lv42], dim=-1)  # 좌상단, 우상단 concat
out_lv434 = torch.cat([out_lv43, out_lv44], dim=-1)  # 좌하단, 우하단 concat
print('out_lv412 :', out_lv412.shape)  
print('out_lv434 :', out_lv434.shape)  

out_lv412 : torch.Size([1, 1, 64, 128])
out_lv434 : torch.Size([1, 1, 64, 128])


### Encoding & Decoding - 2

2 분할 입력의 STN Outputs 처리

 이전 계층의 입력을 전달 받음

In [57]:
# 이전 계층의 decoder 출력 + 현재 계층의 입력(STN Outputs)
shape21_ = shape21 + out_lv412
print('shape21 :', shape21.shape)
print('out_lv412 :', out_lv412.shape)
print('shape21_ :', shape21_.shape)
print()

shape22_ = shape22 + out_lv434
print('shape22 :', shape22.shape)
print('out_lv434 :', out_lv434.shape)
print('shape22_ :', shape22_.shape)

shape21 : torch.Size([1, 1, 64, 128])
out_lv412 : torch.Size([1, 1, 64, 128])
shape21_ : torch.Size([1, 1, 64, 128])

shape22 : torch.Size([1, 1, 64, 128])
out_lv434 : torch.Size([1, 1, 64, 128])
shape22_ : torch.Size([1, 1, 64, 128])


In [59]:
# Encoder 통과
feature_lv21 = encoder_lv2(shape21_)
feature_lv22 = encoder_lv2(shape22_)

# 이전 계층의 Encoder output을 전달 받음 -> Decoder 입력
feature_lv21_ = feature_lv21 + feature_lv412
feature_lv22_ = feature_lv22 + feature_lv434

print('feature_lv21 :', feature_lv21.shape)
print('feature_lv22 :', feature_lv22.shape)
print('feature_lv21_ :', feature_lv21_.shape)
print('feature_lv22_ :', feature_lv22_.shape)

feature_lv21 : torch.Size([1, 128, 16, 32])
feature_lv22 : torch.Size([1, 128, 16, 32])
feature_lv21_ : torch.Size([1, 128, 16, 32])
feature_lv22_ : torch.Size([1, 128, 16, 32])


In [60]:
out_lv21 = decoder_lv2(feature_lv21_)
out_lv22 = decoder_lv2(feature_lv22_)

print('out_lv21 :', out_lv21.shape)
print('out_lv22 :', out_lv22.shape)

out_lv21 : torch.Size([1, 1, 64, 128])
out_lv22 : torch.Size([1, 1, 64, 128])


### 다음 계층에 전달할 사항 정리

In [64]:
# 다음 계층 encoder output에 전달
feature_lv212 = torch.cat([feature_lv21_, feature_lv22_], dim=-2)
print('feature_lv21 :', feature_lv21.shape)
print('feature_lv22 :', feature_lv22.shape)
print('feature_lv212 :', feature_lv212.shape)

feature_lv21 : torch.Size([1, 128, 16, 32])
feature_lv22 : torch.Size([1, 128, 16, 32])
feature_lv212 : torch.Size([1, 128, 32, 32])


In [66]:
# 다음 계층 encoder input에 전달
out_lv212 = torch.cat([out_lv21, out_lv22], dim=-2)
print('out_lv21 :', out_lv21.shape)
print('out_lv22 :', out_lv22.shape)
print('out_lv212 :', out_lv212.shape)

out_lv21 : torch.Size([1, 1, 64, 128])
out_lv22 : torch.Size([1, 1, 64, 128])
out_lv212 : torch.Size([1, 1, 128, 128])


### Encoding & Decoding - 3

원본 입력의 STN Outputs 처리

이전 계층의 입력을 전달 받음

In [68]:
shape1_ = shape1 + out_lv212
print('shape1 :', shape1.shape)
print('out_lv212 :', out_lv212.shape)
print('shape1_ :', shape1_.shape)

shape1 : torch.Size([1, 1, 128, 128])
out_lv212 : torch.Size([1, 1, 128, 128])
shape1_ : torch.Size([1, 1, 128, 128])


In [70]:
# Encoder 통과
feature_lv1 = encoder_lv1(shape1_)

# 이전 계층의 Encoder ouput을 전달 받음
feature_lv1_ = feature_lv1 + feature_lv212
print('feature_lv1 :', feature_lv1.shape)
print('feature_lv212 :', feature_lv212.shape)
print('feature_lv1_ :', feature_lv1_.shape)

feature_lv1 : torch.Size([1, 128, 32, 32])
feature_lv212 : torch.Size([1, 128, 32, 32])
feature_lv1_ : torch.Size([1, 128, 32, 32])


In [72]:
out_lv1 = decoder_lv1(feature_lv1_)  # S_t+1 hat (predicted gray-scale shape image - 1 channel)
print(out_lv1.shape)

torch.Size([1, 1, 128, 128])


### Encoding & Decoding 4

out_lv1을 입력으로 함

 RGB 이미지에 대한 stage1~4 encoder output을 전달 받음

In [74]:
# Encoding S_t+1 hat
temp, y5_4 = E6(out_lv1)
print('temp :', temp.shape)  # 1/2 size
print('y5_4 :', y5_4.shape)  # 1/8 size

temp : torch.Size([1, 32, 64, 64])
y5_4 : torch.Size([1, 128, 16, 16])


In [78]:
# RGB 입력
x5 = torch.randn(1, 3, 128, 128)  

# STN
x5_5 = stn_rgb(x5)  # STN - 3 channels

# 4 stage convolutional encoding
y5_1, y5_2, y5_3, y5_4 = E5(x5_5)  # 16, 32, 64, 128 channels (Outputs of 4 stages)

print('x5 :', x5.shape)
print('x5_5 :', x5_5.shape)
print()
print('y5_1 :', y5_1.shape)  # Original size with 16 channels 
print('y5_2 :', y5_2.shape)  # 1/2 size with 32 channels
print('y5_3 :', y5_3.shape)  # 1/4 size with 64 channels
print('y5_4 :', y5_4.shape)  # 1/8 size with 128 channels

x5 : torch.Size([1, 3, 128, 128])
x5_5 : torch.Size([1, 3, 128, 128])

y5_1 : torch.Size([1, 16, 128, 128])
y5_2 : torch.Size([1, 32, 64, 64])
y5_3 : torch.Size([1, 64, 32, 32])
y5_4 : torch.Size([1, 128, 16, 16])


In [81]:
# Decoding 4 stage encoded outputs
D2_in1 = conv5_1(y5_1)  # 16 channels
D2_in2 = conv5_2(y5_2)  # 32 channels
D2_in3 = conv5_3(y5_3)  # 64 channels
D2_in4 = conv5_4(y5_4)  # 128 channels

print('D2_in1 :', D2_in1.shape)
print('D2_in2 :', D2_in2.shape)
print('D2_in3 :', D2_in3.shape)
print('D2_in4 :', D2_in4.shape)

D2_in1 : torch.Size([1, 16, 128, 128])
D2_in2 : torch.Size([1, 32, 64, 64])
D2_in3 : torch.Size([1, 64, 32, 32])
D2_in4 : torch.Size([1, 128, 16, 16])


In [82]:
out2 = D2(D2_in1, D2_in2, D2_in3, D2_in4)  
print('out2 :', out2.shape)

out2 : torch.Size([1, 3, 128, 128])
