In [2]:
import os

import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image

from ViT import *

# Loss imports
from piq import ssim, SSIMLoss

# Images of training epochs
if not os.path.exists('./dc_img'):
    os.mkdir('./dc_img')

# Models Save
if not os.path.exists('./models'):
    os.mkdir('./models')


def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 3, 256, 256)
    return x


# batch_size = 128

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# dataset = MNIST('./data', transform=img_transform, download=True)
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

import load_data as ld
dataLoader = ld.ReadData()
import torch

# Spatial size of training images. All images will be resized to this
image_size = (256, 256)

# Batch size during training
batch_size = 32
num_epochs = 100

# Root directory for dataset
dataroot = "../data/train/sunset"

dataloader = dataLoader.create_dataLoader(dataroot, image_size, batch_size)


In [3]:
img_clor = next(iter(dataloader))[0]
img = img_clor[:,:1,:,:]
img.shape

torch.Size([32, 1, 256, 256])

In [4]:
nn.Conv2d(1, 16, 3, stride=2, padding=1)(img).shape

torch.Size([32, 16, 128, 128])

In [5]:
nn.MaxPool2d(2, stride=2)(img).shape

torch.Size([32, 1, 128, 128])

In [6]:
# class ConvDown2d(nn.Module):
#     def __init__(self, in_channel, out_channel, kernel_size=3, stride=2, padding=1) -> None:
#         super(ConvDown2d, self).__init__()
#         self.in_channel = in_channel
#         self.out_channel = out_channel
#         self.kernel_size = kernel_size
#         self.stride = stride
#         self.padding = padding

#     def forward(self, x) -> torch.Tensor:
#         conv = nn.Sequential(
#             nn.Conv2d(self.in_channel, self.out_channel, self.kernel_size, self.stride, self.padding),
#             nn.ReLU(True),
#             nn.MaxPool2d(2, stride=1)
#         )
#         return conv(x)

In [7]:
# conv = ConvDown2d(1,16,3,2,1)
# conv(img).shape

In [8]:
def ConvDown2d(in_channel, out_channel, kernel_size=3, stride=2, padding=1):

    conv = nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding)
    return conv

def ConvUp2d(in_channel, out_channel, kernel_size=3, stride=2, padding=1):

    conv = nn.ConvTranspose2d(in_channel, out_channel, kernel_size, stride, padding)
    return conv

In [9]:
class ColorNetwork(nn.Module):

    def __init__(self, in_channel, out_channel, stride, padding) -> None:
        super(ColorNetwork, self).__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.kernel_size = 3
        self.stride = stride
        self.padding = padding

        #Encoder Network
        self.dw_conv1 = ConvDown2d(1, self.out_channel, self.kernel_size,self.stride,self.padding)
        self.max_pol1 = nn.MaxPool2d(2, stride=1)

        self.dw_conv2 = ConvDown2d(self.out_channel, self.out_channel*2, self.kernel_size,self.stride,self.padding)
        self.max_pol2 = nn.MaxPool2d(2, stride=1)

        self.dw_conv3 = ConvDown2d(self.out_channel*2, self.out_channel*4, self.kernel_size,self.stride,self.padding)
        self.max_pol3 = nn.MaxPool2d(2, stride=1)

        #Decoder
        self.up_conv1 = ConvUp2d(self.out_channel*8, self.out_channel*2, 2, self.stride, 0)

        self.up_conv2 = ConvUp2d(self.out_channel*4, self.out_channel, 2, self.stride,0)

        self.up_conv3 = ConvUp2d(self.out_channel*2, 3, 2, 2, 0)

        #Activation
        self.activation = nn.Tanh()

    def forward(self, x, color_sample) -> torch.Tensor:

        #Encoder
        e1 = self.dw_conv1(x)
        e1 = nn.ReLU(True)(e1)
        e1 = self.max_pol1(e1)
        print(f"e1 shape: {e1.shape}")

        e2 = self.dw_conv2(e1)
        e2 = nn.ReLU(True)(e2)
        e2 = self.max_pol2(e2)
        print(f"e2 shape: {e2.shape}")

        e3 = self.dw_conv3(e2)
        e3 = nn.ReLU(True)(e3)
        e3 = self.max_pol3(e3)
        print(f"e3 shape: {e3.shape}")

        #BottlerNeck
        # neck = vit.forward(color_sample)
        # print(f"neck shape: {neck.shape}")
        # neck = torch.reshape(neck, (e3.shape[0], e3.shape[1], e3.shape[2], e3.shape[3]))
        neck = e3
        
        #Decoder
        e3 = torch.cat((neck, e3), 1)
        d1 = self.up_conv1(e3)
        d1 = nn.ReLU(True)(d1)
        print(f"d1 shape: {d1.shape}")

        d2 = torch.cat((e2, d1), 1)
        d2 = self.up_conv2(d2)
        d2 = nn.ReLU(True)(d2)
        print(f"d2 shape: {d2.shape}")

        d3 = torch.cat((e1, d2), 1)
        d3 = self.up_conv3(d3)
        d3 = nn.ReLU(True)(d3)
        print(f"d3 shape: {d3.shape}")

        #Activation
        out = self.activation(d3)
        return out

In [10]:
#Visual Transformer
from ViT import *
#Visual Transformer
vit = Vit_neck(batch_size, image_size[0], 64*32*32)

a = torch.rand(batch_size, 256, 32, 32)

In [11]:
img_clor.shape

torch.Size([32, 3, 256, 256])

In [12]:
img.shape

torch.Size([32, 1, 256, 256])

In [13]:
net = ColorNetwork(1,32,2,2)
net(img, img_clor.to("cuda")).shape

e1 shape: torch.Size([32, 32, 128, 128])
e2 shape: torch.Size([32, 64, 64, 64])
e3 shape: torch.Size([32, 128, 32, 32])
d1 shape: torch.Size([32, 64, 64, 64])
d2 shape: torch.Size([32, 32, 128, 128])
d3 shape: torch.Size([32, 3, 256, 256])


torch.Size([32, 3, 256, 256])

In [14]:
net = ColorNetwork(1,32,2,2)
net(img, img_clor.to("cuda")).shape

e1 shape: torch.Size([32, 32, 128, 128])
e2 shape: torch.Size([32, 64, 64, 64])
e3 shape: torch.Size([32, 128, 32, 32])
d1 shape: torch.Size([32, 64, 64, 64])
d2 shape: torch.Size([32, 32, 128, 128])
d3 shape: torch.Size([32, 3, 256, 256])


torch.Size([32, 3, 256, 256])

In [20]:
c1 = torch.ones(256,256)
c2 = torch.add(c1, 1)
c3 = torch.add(c2, 1)
# c2 = torch.add()

In [37]:
sample = torch.concat((c1.unsqueeze(0),c2.unsqueeze(0),c3.unsqueeze(0)), 0)

In [45]:
sample[:1,:,:]

tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]])

In [None]:
# How make skip connection in pytorch
# https://github.com/pytorch/vision/blob/a9a8220e0bcb4ce66a733f8c03a1c2f6c68d22cb/torchvision/models/resnet.py#L56-L72