In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import os
from skimage import io
%load_ext tensorboard

In [8]:
LEARNING_RATE = 2e-4
BATCH_SIZE = 64
IMAGE_SIZE = 24
CHANNELS_IMG = 4
CHANNELS_NOISE = 100
NUM_EPOCHS = 30

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [9]:
class PunkDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        self.transform = transform
        self.image_folder = image_folder
        _, _, files = next(os.walk('imgs'))
        self.n = len(files)

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_folder, str(idx) + '.png')
        image = io.imread(img_path)
        if self.transform:
            image = self.transform(image)
        return image

In [None]:
# 24x24x4
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # input: N x channels_img | N = W x H | 24 x 24 x 1
            nn.Conv2d(
                channels_img, features_d * 2, kernel_size=4, stride=2, padding=1
            ), # 14x14
            nn.LeakyReLU(0.2),
            self._block(features_d * 2, features_d * 4, 3, 1, 1), # 14x14
            self._block(features_d * 4, features_d * 8, 3, 1, 1), # 14x14
            self._block(features_d * 8, features_d * 16, 4, 2, 1), # 7x7
            # After all _block img output is 7x7 (Conv2d below makes into 1x1)
            nn.Conv2d(features_d * 16, 1, kernel_size=7, stride=2, padding=0),
            nn.Sigmoid(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)

class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # Input: N x channels_noise | 1 x 1001
            self._block(100, features_g * 32, 7, 1, 0),  # img: 7x7x896
            self._block(features_g * 32, features_g * 16, 4, 2, 1),  # img: 14x14x448
            self._block(features_g * 16, features_g * 8, 3, 1, 1),  # img: 14x14x224
            self._block(features_g * 8, features_g * 4, 3, 1, 1),  # img: 14x14x112
            nn.ConvTranspose2d(
                features_g * 4, 1, kernel_size=4, stride=2, padding=1
            ),
            # Output: N x channels_img | 28x28x1
            nn.Tanh(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.gen(x)

def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [10]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5, 0.5), (0.5, 0.5, 0.5, 0.5))
])
dataset = PunkDataset('imgs', transform=transform)
trainloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

gen = Generator(CHANNELS_NOISE, CHANNELS_IMG, IMAGE_SIZE).to(device)
disc = Discriminator(CHANNELS_IMG, IMAGE_SIZE).to(device)
initialize_weights(gen)
initialize_weights(disc)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()

fixed_noise = torch.randn(32, CHANNELS_NOISE, 1, 1).to(device) # 32 Zs for Tensorboard
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

In [11]:
def train():
    for batchIdx, image in enumerate(trainloader):
        image = image.to(device)
        print(image.max(), image.min())
        exit()