In [1]:
from torch_snippets import *
device = "cuda" if torch.cuda.is_available() else "cpu"
from torchvision.utils import make_grid
from torchvision.datasets import MNIST
from torchvision import transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
    ])

data_loader = torch.utils.data.DataLoader(MNIST('~/data', train=True, download=True, transform=transform), batch_size=128, shuffle=True, drop_last=True)

  from .autonotebook import tqdm as notebook_tqdm


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to C:\Users\olive/data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████████████████████████████████████████████████████████████| 9912422/9912422 [00:00<00:00, 62534066.00it/s]


Extracting C:\Users\olive/data\MNIST\raw\train-images-idx3-ubyte.gz to C:\Users\olive/data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to C:\Users\olive/data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|███████████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 1925019.37it/s]


Extracting C:\Users\olive/data\MNIST\raw\train-labels-idx1-ubyte.gz to C:\Users\olive/data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to C:\Users\olive/data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████████████████████████████████████████████████████████████| 1648877/1648877 [00:00<00:00, 24489613.69it/s]


Extracting C:\Users\olive/data\MNIST\raw\t10k-images-idx3-ubyte.gz to C:\Users\olive/data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to C:\Users\olive/data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████████████████████████████████████████████████████████████████████████████████| 4542/4542 [00:00<?, ?it/s]

Extracting C:\Users\olive/data\MNIST\raw\t10k-labels-idx1-ubyte.gz to C:\Users\olive/data\MNIST\raw






In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 1024), 
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3), 
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512,256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256,1),
            nn.Sigmoid()
        )
    def forward(self, x): return self.model(x)

In [13]:
from torchsummary import summary
discriminator = Discriminator().to(device)


In [16]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(100,256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512,1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
    def forward(self, x): return self.model(x)

In [17]:
generator = Generator().to(device)


In [27]:
def noise(size):
    n = torch.randn(size,100)
    return n.to(device)

In [49]:
def discriminator_train_step(real_data, fake_data):
    d_optimizer.zero_grad()
    prediction_real = discriminator(real_data)
    error_real = loss(prediction_real, torch.ones(len(real_data), 1).to(device))
    error_real.backward()
    prediction_fake = discriminator(fake_data)
    error_fake = loss(prediction_fake, torch.zeros(len(fake_data), 1).to(device))
    error_fake.backward()
    d_optimizer.step()
    return error_real + error_fake

In [50]:

def generator_train_step(fake_data):
    g_optimizer.zero_grad()
    prediction = discriminator(fake_data)
    error = loss(prediction, torch.ones(len(real_data), 1).to(device))
    error.backward()
    g_optimizer.step()
    return error
     

In [51]:

discriminator = Discriminator().to(device)
generator = Generator().to(device).to(device)
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
loss = nn.BCELoss()
num_epochs = 200
log = Report(num_epochs)

In [None]:

for epoch in range(num_epochs):
    N = len(data_loader)
    for i, (images, _) in enumerate(data_loader):
        real_data = images.view(len(images), -1).to(device)
        fake_data = generator(noise(len(real_data))).to(device)
        fake_data = fake_data.detach()
        d_loss = discriminator_train_step(real_data, fake_data)
        fake_data = generator(noise(len(real_data))).to(device)
        g_loss = generator_train_step(fake_data)
        log.record(epoch+(1+i)/N, d_loss=d_loss.item(), g_loss=g_loss.item(), end='\r')
    log.report_avgs(epoch+1)
log.plot_epochs(['d_loss', 'g_loss'])


EPOCH: 1.000  d_loss: 0.668  g_loss: 4.285  (12.77s - 2541.54s remaining))
EPOCH: 2.000  d_loss: 0.558  g_loss: 5.788  (25.31s - 2505.85s remaining))
EPOCH: 3.000  d_loss: 0.857  g_loss: 2.331  (37.69s - 2475.01s remaining)
EPOCH: 4.000  d_loss: 0.634  g_loss: 3.144  (50.05s - 2452.34s remaining)
EPOCH: 5.000  d_loss: 0.359  g_loss: 3.767  (62.24s - 2427.54s remaining)
EPOCH: 6.000  d_loss: 0.326  g_loss: 4.171  (74.57s - 2411.19s remaining)
EPOCH: 7.000  d_loss: 0.459  g_loss: 3.265  (87.11s - 2401.78s remaining)
EPOCH: 8.000  d_loss: 0.404  g_loss: 3.507  (99.69s - 2392.52s remaining)
EPOCH: 9.000  d_loss: 0.429  g_loss: 3.394  (112.26s - 2382.33s remaining)
EPOCH: 10.000  d_loss: 0.505  g_loss: 2.935  (124.54s - 2366.31s remaining)
EPOCH: 11.000  d_loss: 0.530  g_loss: 2.867  (136.97s - 2353.42s remaining)
EPOCH: 12.000  d_loss: 0.501  g_loss: 2.870  (149.66s - 2344.66s remaining)
EPOCH: 13.000  d_loss: 0.599  g_loss: 2.521  (162.38s - 2335.72s remaining)
EPOCH: 14.000  d_loss: 0.61