In [14]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [15]:
import cv2
from IPython.display import HTML
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
from pathlib import Path
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms, datasets
from torchvision.datasets import ImageFolder, DatasetFolder
import torchvision.utils as vutils

from htools import hdir
from config import *
from models import BaseModel, conv_block
import models
from utils import render_samples, show_img

In [17]:
# TEMPORARY, IMPORTED FROM CONFIG - JUST FOR EASY REFERENCE

bs = 64                # Batch size (paper uses 128).
img_size = 64          # Size of input (here it's 64 x 64).
workers = 2            # Number of workers for data loader.
input_c = 100          # Depth of input noise (1 x 1 x noise_dim). AKA nz.
ngf = 64           # Filters in last layer before reduced to 3 (aka ngf.)
ndf = 64             # Filters in first conv layer in D (aka ndf.)
lr = 2e-4              # Recommended learning rate of .0002.
beta1 = .5             # Recommended parameter for Adam.
nc = 3                 # Number of channels of input image.
ngpu = 1               # Number of GPUs to use.
sample_dir = 'samples' # Directory to store sample images from G. 
weight_dir = 'weights' # Directory to store model weights.
device = torch.device('cuda:0' if torch.cuda.is_available() and ngpu > 0 
                      else 'cpu')

In [54]:
class ResBlock(nn.Module):
    
    def __init__(self, c_in):
        super().__init__()
        self.conv = conv_block(False, c_in, c_in, f=3, stride=1, pad=1, 
                               bias=False, bn=True)
    
    def forward(self, x):
        return x + self.conv(x)

In [68]:
class CycleGenerator(BaseModel):
    
    def __init__(self, img_c, ngf):
        super().__init__()
        
        # ENCODER
        # 3 x 64 x 64 -> 64 x 32 x 32
        deconv1 = conv_block(False, img_c, ngf, f=4, stride=2, pad=1)
        # 64 x 32 x 32 -> 128 x 16 x 16
        deconv2 = conv_block(False, ngf, ngf*2, 4, 2, 1)
        self.encoder = nn.Sequential(deconv1, deconv2)

        # TRANSFORMER
        # 128 x 16 x 16 -> 128 x 16 x 16
        res1 = ResBlock(ngf*2)
        # 128 x 16 x 16 -> 128 x 16 x 16
        res2 = ResBlock(ngf*2)
        self.transformer = nn.Sequential(res1, res2)
        
        # DECODER
        deconv1 = conv_block(True, ngf*2, ngf, )
        
        self.groups = nn.ModuleList([self.encoder, self.transformer])
        
    def forward(self, x):
        for group in self.groups:
            x = F.relu(group(x), inplace=True)
#         x = F.relu(self.deconv1(x), inplace=True)
#         x = F.relu(self.deconv2(x), inplace=True)
#         x = F.relu(self.res1(x), inplace=True)
#         x = F.relu(self.res1(x), inplace=True)
        return x

In [69]:
class CycleDiscriminator(BaseModel):
    
    def __init__(self, img_c=3, ndf=64):
        super().__init__()
        self.conv1 = conv_block(False, img_c, ndf, f=4, stride=2, pad=1)
        
    def forward(self):
        pass

In [70]:
G = CycleGenerator(img_c, ngf)

In [71]:
G(x).shape

torch.Size([2, 128, 16, 16])

In [65]:
G

CycleGenerator(
  (encoder): Sequential(
    (0): Sequential(
      (0): ConvTranspose2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): Sequential(
      (0): ConvTranspose2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (transformer): Sequential(
    (0): ResBlock(
      (conv): Sequential(
        (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): ResBlock(
      (conv): Sequential(
        (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
 

In [66]:
G.dims()

[torch.Size([3, 64, 4, 4]),
 torch.Size([64]),
 torch.Size([64]),
 torch.Size([64, 128, 4, 4]),
 torch.Size([128]),
 torch.Size([128]),
 torch.Size([128, 128, 3, 3]),
 torch.Size([128]),
 torch.Size([128]),
 torch.Size([128, 128, 3, 3]),
 torch.Size([128]),
 torch.Size([128])]

In [55]:
res = ResBlock(3)

In [56]:
x = torch.randn(2, 3, 64, 64)
res(x).shape

torch.Size([2, 3, 64, 64])