In [None]:
import torch
import torchvision
import torchvision.transforms as transforms

from pathlib import Path

from data.dataset import Dataset

In [None]:
from datetime import datetime

# datetime object containing current date and time
now = datetime.now()
 
print("now =", now)

# dd/mm/YY H:M:S
dt_string = now.strftime("%d-%m-%Y_%H.%M.%S")
print("date and time =", dt_string)

In [None]:
print(torch.__version__)

In [None]:
# print(Path.home())
base_path = f'{Path.home()}/SageMaker'
print(base_path)

In [None]:
# train_dataset = Dataset(data_path=config['train_data_path'],
#                                 with_subfolder=False,
#                                 image_shape=config['image_shape'],
#                                 random_crop=config['random_crop'])

dataset = Dataset(data_path=f'{base_path}/dataset/flickr8k/Images',
                                with_subfolder=False,
                                image_shape=[256,256,3],
                                random_crop=True)

loader = torch.utils.data.DataLoader(dataset, 
                                     batch_size=4,
                                     shuffle=True,
                                     num_workers=1)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# functions to show an image


def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(loader)
images = dataiter.next()

# show images
imshow(torchvision.utils.make_grid(images))

## Instantiate the model, and test gating.

In [None]:
from trainer import Trainer
from data.dataset import Dataset
from utils.tools import get_config, random_bbox, mask_image
from model.networks import Generator, LocalDis, GlobalDis
from utils.tools import get_model_list, local_patch, spatial_discounting_mask

# from utils.logger import get_logger

from argparse import ArgumentParser

%load_ext autoreload
%autoreload 2

In [None]:
parser = ArgumentParser()
parser.add_argument('--config', type=str, default='configs/config.yaml',
                    help="training configuration")
parser.add_argument('--seed', type=int, help='manual seed')

args = parser.parse_args('')
config = get_config(args.config)


## Grab some data and prepare masks.

In [None]:
# try:
#     ground_truth = next(loader)
# except StopIteration:
#     iterable_train_loader = iter(train_loader)
#     ground_truth = next(iterable_train_loader)

ground_truth = images

# Prepare the inputs
bboxes = random_bbox(config, batch_size=ground_truth.size(0))
x, mask = mask_image(ground_truth, bboxes, config)
if config['cuda']:
    x = x.cuda()
    mask = mask.cuda()
    ground_truth = ground_truth.cuda()

## Generator - non-gated

In [None]:
from model.networks import Generator
# del netG

config = get_config(f'{base_path}/generative-inpainting-pytorch/configs/config.yaml')
netG = Generator(config['netG'], config['cuda'], config['gpu_ids']).cuda()

In [None]:
x1, x2, offset_flow = netG(x, mask)
print("x1: ", x1.shape)
print("x2: ", x2.shape)


## Inspect individual outputs from layers.

In [None]:
x1 = netG.coarse_generator(x, mask)

In [None]:
ones = torch.ones(x.size(0), 1, x.size(2), x.size(3)).cuda()
print("ones: ", ones.shape)

_in = torch.cat([x, ones, mask], dim=1)
print("_in: ", _in.shape)
_x1 = netG.coarse_generator.conv1(_in)
print("_x1: ", _x1.shape)


## Gated Generator

In [None]:
from model.networks import Generator
del netG_gated

config_gated = get_config(f'{base_path}/generative-inpainting-pytorch/configs/config-gated.yaml')
netG_gated = Generator(config_gated['netG'], config_gated['cuda'], config_gated['gpu_ids']).cuda()

In [None]:
# x1, x2, offset_flow = netG_gated(x, mask)
print("x:, ", x.shape, " mask: ", mask.shape)
x1 = netG_gated.coarse_generator(x, mask)
print("x1: ", x1.shape)
# print("x2: ", x2.shape)

In [None]:
ones = torch.ones(x.size(0), 1, x.size(2), x.size(3)).cuda()
print("ones: ", ones.shape)

_in = torch.cat([x, ones, mask], dim=1)
print("_in: ", _in.shape)
_x1 = netG_gated.coarse_generator.conv1(_in)
print("_x1: ", _x1.shape)

_x2 = netG_gated.coarse_generator.conv2_downsample(_x1)
print("_x2: ", _x2.shape)

In [None]:
import torch.nn as nn

class Conv2dBlockGated(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, stride, padding=0,
                 conv_padding=0, dilation=1, weight_norm='none', norm='none',
                 activation='relu', pad_type='zero', transpose=False):
        
        super(Conv2dBlockGated, self).__init__()
        self.output_dim = output_dim
        self.use_bias = True
        # initialize padding
        if pad_type == 'reflect':
            self.pad = nn.ReflectionPad2d(padding)
        elif pad_type == 'replicate':
            self.pad = nn.ReplicationPad2d(padding)
        elif pad_type == 'zero':
            self.pad = nn.ZeroPad2d(padding)
        elif pad_type == 'none':
            self.pad = None
        else:
            assert 0, "Unsupported padding type: {}".format(pad_type)

        # initialize normalization
        norm_dim = output_dim
        if norm == 'bn':
            self.norm = nn.BatchNorm2d(norm_dim)
        elif norm == 'in':
            self.norm = nn.InstanceNorm2d(norm_dim)
        elif norm == 'none':
            self.norm = None
        else:
            assert 0, "Unsupported normalization: {}".format(norm)

        if weight_norm == 'sn':
            self.weight_norm = spectral_norm_fn
        elif weight_norm == 'wn':
            self.weight_norm = weight_norm_fn
        elif weight_norm == 'none':
            self.weight_norm = None
        else:
            assert 0, "Unsupported normalization: {}".format(weight_norm)

        # initialize activation
        if activation == 'relu':
            self.activation = nn.ReLU(inplace=True)
        elif activation == 'elu':
            self.activation = nn.ELU(inplace=True)
        elif activation == 'lrelu':
            self.activation = nn.LeakyReLU(0.2, inplace=True)
        elif activation == 'prelu':
            self.activation = nn.PReLU()
        elif activation == 'selu':
            self.activation = nn.SELU(inplace=True)
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'none':
            self.activation = None
        else:
            assert 0, "Unsupported activation: {}".format(activation)

        # initialize convolution
        if transpose:
            self.conv = nn.ConvTranspose2d(input_dim, output_dim,
                                           kernel_size, stride,
                                           padding=conv_padding,
                                           output_padding=conv_padding,
                                           dilation=dilation,
                                           bias=self.use_bias)

        else:
            self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride,
                                  padding=conv_padding, dilation=dilation,
                                  bias=self.use_bias)

            

        if self.weight_norm:
            self.conv = self.weight_norm(self.conv)
            

    def forward(self, x):
        if self.pad:
            x = self.conv(self.pad(x))
        else:
            x = self.conv(x)
            
            
        if self.norm:
            x = self.norm(x)
            
            
        # If there are more than 3 channels, then we treat the remainder as the mask and optional input
        # and "gate" that.
        #
        print("x: ", x.shape)
        feat, gate = torch.chunk(x, 2, 1)
        print("feat: ", feat.shape)
        print("gate: ", gate.shape)
        
        
        # Output is image or no activation.
        #
        if self.activation is None or self.output_dim == 3:
            return x
        
        
        # Otherwise we compute activation of features and gate.
        #
        feat = self.activation(feat)
        gate = torch.sigmoid(gate) # Gate
    
        return feat * gate
            


class Conv2dBlock(nn.Module):
    def __init__(self, input_dim, output_dim, kernel_size, stride, padding=0,
                 conv_padding=0, dilation=1, weight_norm='none', norm='none',
                 activation='relu', pad_type='zero', transpose=False):
        
        super(Conv2dBlock, self).__init__()
        self.use_bias = True
        # initialize padding
        if pad_type == 'reflect':
            self.pad = nn.ReflectionPad2d(padding)
        elif pad_type == 'replicate':
            self.pad = nn.ReplicationPad2d(padding)
        elif pad_type == 'zero':
            self.pad = nn.ZeroPad2d(padding)
        elif pad_type == 'none':
            self.pad = None
        else:
            assert 0, "Unsupported padding type: {}".format(pad_type)

        # initialize normalization
        norm_dim = output_dim
        if norm == 'bn':
            self.norm = nn.BatchNorm2d(norm_dim)
        elif norm == 'in':
            self.norm = nn.InstanceNorm2d(norm_dim)
        elif norm == 'none':
            self.norm = None
        else:
            assert 0, "Unsupported normalization: {}".format(norm)

        if weight_norm == 'sn':
            self.weight_norm = spectral_norm_fn
        elif weight_norm == 'wn':
            self.weight_norm = weight_norm_fn
        elif weight_norm == 'none':
            self.weight_norm = None
        else:
            assert 0, "Unsupported normalization: {}".format(weight_norm)

        # initialize activation
        if activation == 'relu':
            self.activation = nn.ReLU(inplace=True)
        elif activation == 'elu':
            self.activation = nn.ELU(inplace=True)
        elif activation == 'lrelu':
            self.activation = nn.LeakyReLU(0.2, inplace=True)
        elif activation == 'prelu':
            self.activation = nn.PReLU()
        elif activation == 'selu':
            self.activation = nn.SELU(inplace=True)
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'none':
            self.activation = None
        else:
            assert 0, "Unsupported activation: {}".format(activation)

        # initialize convolution
        if transpose:
            self.conv = nn.ConvTranspose2d(input_dim, output_dim,
                                           kernel_size, stride,
                                           padding=conv_padding,
                                           output_padding=conv_padding,
                                           dilation=dilation,
                                           bias=self.use_bias)
        else:
            self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride,
                                  padding=conv_padding, dilation=dilation,
                                  bias=self.use_bias)

        if self.weight_norm:
            self.conv = self.weight_norm(self.conv)

    def forward(self, x):
        if self.pad:
            x = self.conv(self.pad(x))
        else:
            x = self.conv(x)
        if self.norm:
            x = self.norm(x)
        if self.activation:
            x = self.activation(x)
        return x
    
    
    
def gen_conv(input_dim, output_dim, kernel_size=3, stride=1, padding=0, rate=1,
             activation='elu', gated=False):
    if gated:
        conv2 = Conv2dBlockGated(input_dim, output_dim, kernel_size, stride,
                           conv_padding=padding, dilation=rate,
                           activation=activation)
    else:
        conv2 = Conv2dBlock(input_dim, output_dim, kernel_size, stride,
                       conv_padding=padding, dilation=rate,
                       activation=activation)
        
    return conv2




In [None]:
input_dim = config_gated['netG']['input_dim']
gated = config_gated['netG']['gated']
cnum = config_gated['netG']['ngf']

conv1 = gen_conv(input_dim + 2, cnum, 5, 1, 2, gated=True).cuda()

In [None]:
_x1 = conv1(_in)
print(_x1.shape)

In [None]:
test = torch.ones([4, 32, 256, 256])
a, b = torch.chunk(test, 2, 1)

In [None]:
c = a * b
print("c: ", c.shape)

In [None]:
b.shape