In [15]:
# prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [16]:
from collections import defaultdict

import torch
from torch.optim.optimizer import Optimizer


class Lookahead(Optimizer):
    r"""PyTorch implementation of the lookahead wrapper.
    Lookahead Optimizer: https://arxiv.org/abs/1907.08610
    """

    def __init__(self, optimizer, la_steps=5, la_alpha=0.8, pullback_momentum="none"):
        """optimizer: inner optimizer
        la_steps (int): number of lookahead steps
        la_alpha (float): linear interpolation factor. 1.0 recovers the inner optimizer.
        pullback_momentum (str): change to inner optimizer momentum on interpolation update
        """
        self.optimizer = optimizer
        self._la_step = 0  # counter for inner optimizer
        self.la_alpha = la_alpha
        self._total_la_steps = la_steps
        pullback_momentum = pullback_momentum.lower()
        assert pullback_momentum in ["reset", "pullback", "none"]
        self.pullback_momentum = pullback_momentum

        self.state = defaultdict(dict)

        # Cache the current optimizer parameters
        for group in optimizer.param_groups:
            for p in group['params']:
                param_state = self.state[p]
                param_state['cached_params'] = torch.zeros_like(p.data)
                param_state['cached_params'].copy_(p.data)
                if self.pullback_momentum == "pullback":
                    param_state['cached_mom'] = torch.zeros_like(p.data)

    def __getstate__(self):
        return {
            'state': self.state,
            'optimizer': self.optimizer,
            'la_alpha': self.la_alpha,
            '_la_step': self._la_step,
            '_total_la_steps': self._total_la_steps,
            'pullback_momentum': self.pullback_momentum
        }

    def zero_grad(self):
        self.optimizer.zero_grad()

    def get_la_step(self):
        return self._la_step

    def state_dict(self):
        return self.optimizer.state_dict()

    def load_state_dict(self, state_dict):
        self.optimizer.load_state_dict(state_dict)

    def _backup_and_load_cache(self):
        """Useful for performing evaluation on the slow weights (which typically generalize better)
        """
        for group in self.optimizer.param_groups:
            for p in group['params']:
                param_state = self.state[p]
                param_state['backup_params'] = torch.zeros_like(p.data)
                param_state['backup_params'].copy_(p.data)
                p.data.copy_(param_state['cached_params'])

    def _clear_and_load_backup(self):
        for group in self.optimizer.param_groups:
            for p in group['params']:
                param_state = self.state[p]
                p.data.copy_(param_state['backup_params'])
                del param_state['backup_params']

    @property
    def param_groups(self):
        return self.optimizer.param_groups

    def step(self, closure=None):
        """Performs a single Lookahead optimization step.
        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = self.optimizer.step(closure)
        self._la_step += 1

        if self._la_step >= self._total_la_steps:
            self._la_step = 0
            # Lookahead and cache the current optimizer parameters
            for group in self.optimizer.param_groups:
                for p in group['params']:
                    param_state = self.state[p]
                    p.data.mul_(self.la_alpha).add_(param_state['cached_params'], alpha=1.0 - self.la_alpha)  # crucial line
                    param_state['cached_params'].copy_(p.data)
                    if self.pullback_momentum == "pullback":
                        internal_momentum = self.optimizer.state[p]["momentum_buffer"]
                        self.optimizer.state[p]["momentum_buffer"] = internal_momentum.mul_(self.la_alpha).add_(
                            1.0 - self.la_alpha, param_state["cached_mom"])
                        param_state["cached_mom"] = self.optimizer.state[p]["momentum_buffer"]
                    elif self.pullback_momentum == "reset":
                        self.optimizer.state[p]["momentum_buffer"] = torch.zeros_like(p.data)

        return loss

In [17]:
bs = 256

# MNIST Dataset
transform = transforms.Compose([transforms.ToTensor(),
  transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transform, download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [18]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()       
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)
    
    # forward method
    def forward(self, x): 
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))
    
class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)
    
    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

In [19]:
# build network
z_dim = 100
mnist_dim = train_dataset.data.size(1) * train_dataset.data.size(2)

G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
D = Discriminator(mnist_dim).to(device)

In [20]:
G

Generator(
  (fc1): Linear(in_features=100, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=1024, bias=True)
  (fc4): Linear(in_features=1024, out_features=784, bias=True)
)

In [21]:
D

Discriminator(
  (fc1): Linear(in_features=784, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=1, bias=True)
)

In [22]:
# loss
criterion = nn.BCELoss() 

# optimizer
lr = 0.0002 
G_optimizer = optim.Adam(G.parameters(), lr = lr)
D_optimizer = optim.Adam(D.parameters(), lr = lr)

In [23]:
def D_train(x):
    #=======================Train the discriminator=======================#
    D.zero_grad()

    # train discriminator on real
    x_real, y_real = x.view(-1, mnist_dim), torch.ones(bs, 1)
    x_real, y_real = x_real.to(device), y_real.to(device)

    D_output = D(x_real)
    D_real_loss = criterion(D_output, y_real)
    D_real_score = D_output

    # train discriminator on facke
    z = torch.randn(bs, z_dim).to(device)
    x_fake, y_fake = G(z), torch.zeros(bs, 1).to(device)

    D_output = D(x_fake)
    D_fake_loss = criterion(D_output, y_fake)
    D_fake_score = D_output

    # gradient backprop & optimize ONLY D's parameters
    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    D_optimizer.step()
        
    return  D_loss.data.item()

In [24]:
def G_train(x):
    #=======================Train the generator=======================#
    G.zero_grad()

    z = torch.randn(bs, z_dim).to(device)
    y = torch.ones(bs, 1).to(device)

    G_output = G(z)
    D_output = D(G_output)
    G_loss = criterion(D_output, y)

    # gradient backprop & optimize ONLY G's parameters
    G_loss.backward()
    G_optimizer.step()
        
    return G_loss.data.item()

In [10]:
n_epoch = 200
for epoch in range(1, n_epoch+1):           
    D_losses, G_losses = [], []
    for batch_idx, (x, _) in enumerate(train_loader):
        D_losses.append(D_train(x))
        G_losses.append(G_train(x))

    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), n_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))

[1/200]: loss_d: 0.969, loss_g: 2.694
[2/200]: loss_d: 1.111, loss_g: 1.382
[3/200]: loss_d: 0.946, loss_g: 2.163
[4/200]: loss_d: 0.759, loss_g: 2.223
[5/200]: loss_d: 0.447, loss_g: 2.711
[6/200]: loss_d: 0.571, loss_g: 2.620
[7/200]: loss_d: 0.673, loss_g: 2.251
[8/200]: loss_d: 0.642, loss_g: 2.466
[9/200]: loss_d: 0.596, loss_g: 2.382
[10/200]: loss_d: 0.649, loss_g: 2.318
[11/200]: loss_d: 0.658, loss_g: 2.186
[12/200]: loss_d: 0.701, loss_g: 2.140
[13/200]: loss_d: 0.693, loss_g: 2.130
[14/200]: loss_d: 0.779, loss_g: 1.972
[15/200]: loss_d: 0.777, loss_g: 1.988
[16/200]: loss_d: 0.760, loss_g: 1.982
[17/200]: loss_d: 0.842, loss_g: 1.757
[18/200]: loss_d: 0.848, loss_g: 1.807
[19/200]: loss_d: 0.849, loss_g: 1.775
[20/200]: loss_d: 0.837, loss_g: 1.775
[21/200]: loss_d: 0.890, loss_g: 1.692
[22/200]: loss_d: 0.888, loss_g: 1.645
[23/200]: loss_d: 0.940, loss_g: 1.522
[24/200]: loss_d: 0.965, loss_g: 1.471
[25/200]: loss_d: 0.951, loss_g: 1.516
[26/200]: loss_d: 0.971, loss_g: 1

In [11]:
with torch.no_grad():
    test_z = torch.randn(bs, z_dim).to(device)
    generated = G(test_z)

    save_image(generated.view(generated.size(0), 1, 28, 28), './samples/sample_' + '.png')

In [25]:
G_loss_per_epoch = []
D_loss_per_epoch = []

D_real_acc_per_epoch = []
D_fake_acc_per_epoch = []
G_acc_per_epoch = []

n_epoch = 200
for epoch in range(1, n_epoch+1):           
    D_losses, G_losses = [], []
    D_real_accs, D_fake_accs, G_accs = [], [], []
    for batch_idx, (x, _) in enumerate(train_loader):
        D_current_loss, D_batch_real_acc, D_batch_fake_acc = D_train(x)
        G_current_loss, G_batch_acc = G_train(x)

        D_real_accs.append(D_batch_real_acc)
        D_fake_accs.append(D_batch_fake_acc)
        G_accs.append(G_batch_acc)

        D_losses.append(D_current_loss)
        G_losses.append(G_current_loss)

    G_loss = torch.mean(torch.FloatTensor(G_losses))
    D_loss = torch.mean(torch.FloatTensor(D_losses))

    G_loss_per_epoch.append(G_loss.item())
    D_loss_per_epoch.append(D_loss.item())

    D_real_acc_per_epoch.append(torch.mean(torch.FloatTensor(D_real_accs)).item())
    D_fake_acc_per_epoch.append(torch.mean(torch.FloatTensor(D_fake_accs)).item())
    G_acc_per_epoch.append(torch.mean(torch.FloatTensor(G_accs)).item())


    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), n_epoch, D_loss, G_loss))      
    
    # This is where I'm recording to Tensorboard
    with train_summary_writer.as_default():
        for tag, parm in G.named_parameters():
            tf.summary.histogram('G_{}'.format(tag), parm.grad.data.cpu().numpy(), step=epoch)

        for tag, parm in D.named_parameters():
            tf.summary.histogram('D_{}'.format(tag), parm.grad.data.cpu().numpy(), step=epoch)

        # tf.summary.image(name, data, step=None, max_outputs=3, description=None)
        
        tf.summary.scalar("D Loss/train", D_loss.item(), step=epoch)
        tf.summary.scalar("G Loss/train", G_loss.item(), step=epoch)

        tf.summary.scalar("D real acc/train", torch.mean(torch.FloatTensor(D_real_accs)).item(), step=epoch)
        tf.summary.scalar("D fake acc/train", torch.mean(torch.FloatTensor(D_fake_accs)).item(), step=epoch)
        tf.summary.scalar("G acc/train", torch.mean(torch.FloatTensor(G_accs)).item(), epoch)

        with torch.no_grad():
            test_z = Variable(torch.randn(bs, z_dim).to(device))
            generated = G(test_z)
            img = generated.view(generated.size(0), 1, 28, 28).permute(0, 2, 3, 1)
            tf.summary.image('img_epoch_{}.png'.format(epoch), data=img, step=epoch, max_outputs=3)

    

    # # Save checkpoint at each epoch for G and D.
    # G_checkpoint = {'epoch': epoch,
    #                 'model': G,
    #                 'model_state_dict': G.state_dict(),
    #                 'optimizer_state_dict' : G_optimizer.state_dict(),
    #                 'loss': G_loss}

    # D_checkpoint = {'epoch': epoch,
    #                 'model': D,
    #                 'model_state_dict': D.state_dict(),
    #                 'optimizer_state_dict' : D_optimizer.state_dict(),
    #                 'loss': D_loss}

    # torch.save(G_checkpoint, checkpoint_path + 'G_checkpoint_epoch_{}.pth'.format(epoch))
    # torch.save(D_checkpoint, checkpoint_path + 'D_checkpoint_epoch_{}.pth'.format(epoch))

    samples_per_epoch = 1
    for sample_nbr in range(1, samples_per_epoch+1):
        with torch.no_grad():
            test_z = torch.randn(bs, z_dim).to(device)
            generated = G(test_z)
            save_image(generated.view(generated.size(0), 1, 28, 28), path + '/samples/epoch_{}_sample_{}'.format(epoch, sample_nbr) + '.png')

TypeError: cannot unpack non-iterable float object