In [9]:
import os
import cv2
import glob
import torch
from tqdm import tqdm
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
from PIL import Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
import glob
import os
import random
import torch
import torchvision
from PIL import Image
from tqdm import tqdm
from torch.utils.data.dataset import Dataset

class MnistDataset(Dataset):
    r"""
    Nothing special here. Just a simple dataset class for mnist images.
    Created a dataset class rather using torchvision to allow
    replacement with any other image dataset
    """
    
    def __init__(self, split, im_path, im_ext='png'):
        r"""
        Init method for initializing the dataset properties
        :param split: train/test to locate the image files
        :param im_path: root folder of images
        :param im_ext: image extension. assumes all
        images would be this type.
        """
        self.split = split
        self.im_ext = im_ext
        self.images, self.labels = self.load_images(im_path)
    
    def load_images(self, im_path):
        r"""
        Gets all images from the path specified
        and stacks them all up
        :param im_path:
        :return:
        """
        assert os.path.exists(im_path), "images path {} does not exist".format(im_path)
        ims = []
        labels = []
        for d_name in tqdm(os.listdir(im_path)):
            for fname in glob.glob(os.path.join(im_path, d_name, '*.{}'.format(self.im_ext))):
                ims.append(fname)
                labels.append(int(d_name))
        print('Found {} images for split {}'.format(len(ims), self.split))
        return ims, labels
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        im = Image.open(self.images[index])
        im_tensor = torchvision.transforms.ToTensor()(im)
        im.close()
        
        # Uncomment below 4 lines for colored mnist images
        # a = (im_tensor[0]*random.uniform(0.2, 1.0)).unsqueeze(0)
        # b = (im_tensor[0]*random.uniform(0.2, 1.0)).unsqueeze(0)
        # c = (im_tensor[0]*random.uniform(0.2, 1.0)).unsqueeze(0)
        # im_tensor = torch.cat([a, b, c], dim=0)
        
        # Convert input to -1 to 1 range.
        im_tensor = (2 * im_tensor) - 1
        return im_tensor


In [None]:
mnist_train = MnistDataset('train', '../data/train/images')
mnist_loader_train = DataLoader(mnist_train, batch_size=16, shuffle=True, num_workers=0)

mnist_test = MnistDataset('test', '../data/test/images')
mnist_loader_test = DataLoader(mnist_test, batch_size=16, shuffle=True, num_workers=0)

#### Model

In [29]:
import torch
import numpy as np
import torch.nn as nn


'''
Source code from 
https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html 
and 
https://arxiv.org/pdf/1511.06434.pdf

'''
latent_dim = 100
in_channels = [512, 256, 128, 64]
kernel_size = [4,4,4,4,4]
stride = [1,2,2,2,2]
padding = [0,1,1,1,1]
out_channels = 3


class Generator(nn.Module):
    def __init__(self,latent_dim = latent_dim, in_channels = in_channels, kernel_size = kernel_size, stride = stride, padding = padding, out_channels = out_channels):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.in_channels = in_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.out_channels = out_channels

        self.main = nn.Sequential(
            ###
            nn.ConvTranspose2d(self.latent_dim,self.in_channels[0],self.kernel_size[0], self.stride[0], self.padding[0], bias=False),
            nn.BatchNorm2d(self.in_channels[0]),
            nn.ReLU(True),
            ###
            nn.ConvTranspose2d(self.in_channels[0],self.in_channels[1],self.kernel_size[1], self.stride[1], self.padding[1], bias=False),
            nn.BatchNorm2d(self.in_channels[1]),
            nn.ReLU(True),
            ###
            nn.ConvTranspose2d(self.in_channels[1],self.in_channels[2],self.kernel_size[2], self.stride[2], self.padding[2], bias=False),
            nn.BatchNorm2d(self.in_channels[2]),
            nn.ReLU(True),
            ###
            nn.ConvTranspose2d(self.in_channels[2],self.in_channels[3],self.kernel_size[3], self.stride[3], self.padding[3], bias=False),
            nn.BatchNorm2d(self.in_channels[3]),
            nn.ReLU(True),
            ###
            nn.ConvTranspose2d(self.in_channels[3],self.out_channels,self.kernel_size[4], self.stride[4], self.padding[4], bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)
    
class Discriminator(nn.Module):
    def __init__(self, in_channels = in_channels, kernel_size = kernel_size, stride = stride, padding = padding, out_channels = out_channels):
        super(Discriminator, self).__init__()
        self.in_channels = np.flip(in_channels)
        self.kernel_size = np.flip(kernel_size)
        self.stride = np.flip(stride)
        self.padding = np.flip(padding)
        self.out_channels = out_channels

        self.main = nn.Sequential(
            ###
            nn.Conv2d(self.out_channels,self.in_channels[0],self.kernel_size[0], self.stride[0], self.padding[0], bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            ###
            nn.Conv2d(self.in_channels[0],self.in_channels[1],self.kernel_size[1], self.stride[1], self.padding[1], bias=False),
            nn.BatchNorm2d(self.in_channels[1]),
            nn.LeakyReLU(0.2, inplace=True),
            ###
            nn.Conv2d(self.in_channels[1],self.in_channels[2],self.kernel_size[2], self.stride[2], self.padding[2], bias=False),
            nn.BatchNorm2d(self.in_channels[2]),
            nn.LeakyReLU(0.2, inplace=True),
            ###
            nn.Conv2d(self.in_channels[2],self.in_channels[3],self.kernel_size[3], self.stride[3], self.padding[3], bias=False),
            nn.BatchNorm2d(self.in_channels[3]),
            nn.LeakyReLU(0.2, inplace=True),
            ###
            nn.Conv2d(self.in_channels[3],1,self.kernel_size[4], self.stride[4], self.padding[4], bias=False),
            nn.Sigmoid()
        )
    def forward(self, input):
        return self.main(input)