In [None]:
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

In [None]:
transform = transforms.Compose([transforms.Resize((64, 64), transforms.InterpolationMode.BICUBIC),
                                transforms.ToTensor()])

In [None]:
# use dataset with enough size
train_data = torchvision.datasets.STL10("~/.pytorch/STL10_data", split="train", download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=16, shuffle=True)

In [None]:
def show_images(images, n):
    images = images.detach().numpy()
    for i in range(n):
        # bcwh->whc
        imgs = np.transpose(images[i], [1,2,0])
        plt.imshow(imgs)
        plt.show()

In [None]:
# show sample images
show_images(next(iter(train_loader))[0], 3)

In [None]:
# make trivial conv2d but downsample images to half size
conv1 = nn.Conv2d(3, 3, 4, stride=2, padding=1)
kernel = torch.stack([
    torch.stack([torch.ones(4, 4)/16, torch.zeros(4, 4), torch.zeros(4, 4)]),
    torch.stack([torch.zeros(4, 4), torch.ones(4, 4)/16, torch.zeros(4, 4)]),
    torch.stack([torch.zeros(4, 4), torch.zeros(4, 4), torch.ones(4, 4)/16])
])
assert kernel.shape == conv1.weight.shape
bias = torch.zeros_like(conv1.bias)
conv1.weight = Parameter(kernel)
conv1.bias = Parameter(bias)

conv2 = nn.Conv2d(3, 3, 4, stride=2, padding=1)
conv2.weight = Parameter(kernel)
conv2.bias = Parameter(bias)

In [None]:
# show output features
imgs, labels = next(iter(train_loader))
output1 = conv1(imgs)
output2 = conv2(output1)
print(imgs.shape, output1.shape, output2.shape)
show_images(imgs, 3)
show_images(output1, 3)
show_images(output2, 3)

In [None]:
# https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/cyclegan/models.py
class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        channels, height, width = input_shape

        # Calculate output shape of image discriminator (PatchGAN)
        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)

        def discriminator_block(in_filters, out_filters, normalize=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, img):
        return self.model(img)

In [None]:
disc = Discriminator((3, 64, 64))

In [None]:
imgs, labels = next(iter(train_loader))
output = disc(imgs)
print(output.shape)