## Import Dependencies

In [6]:
import torch
from simple_gan import Generator,Discriminator
import torch.optim as optim
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter

## Setup hyperparams

In [2]:
# Initial GANs are sensitive to Hyperparams
# Hyperparameters
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4  # best lr for Adam given by andrew karpathy
z_dimension = 64 #128,256
image_dim = 28*28*1
batch_size = 32
num_epochs = 50

## Initialize GAN and setup Dataset

In [4]:
# Initialize Generator and Discriminator
discriminator = Discriminator(image_dimension=image_dim).to(device=device)
generator = Generator(z_dimension=z_dimension,image_dimension=image_dim).to(device=device)
fixed_noise = torch.randn((batch_size,z_dimension)).to(device=device)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(0.3081,))  # Actual mean and std for MNIST Dataset
])
dataset = datasets.MNIST(root='./data/',transform=transform,download=True)
loader = DataLoader(dataset=dataset,batch_size=batch_size,shuffle=True)
optim_disc = optim.Adam(discriminator.parameters(),lr=lr)
optim_gen = optim.Adam(generator.parameters(),lr=lr)
criterion = nn.BCELoss()


In [7]:
writer_fake = SummaryWriter(log_dir=f'./runs/GAN_MNIST/fake')  # fake images
writer_real = SummaryWriter(log_dir=f'./runs/GAN_MNIST/real')  # real images
step = 0