## Improved GAN

GAN for semi-supervised learning: feature-matching technique.

Original code (theano): https://github.com/openai/improved-gan

Paper: https://arxiv.org/abs/1606.03498

**Not implemented or different in comparison with original: **

* No data-based initialization (from 100 training examples for stats' calculation)
* No learning rate decay (from 800 to 1200 epochs)
* Batch size is 25 instead of 100 (mainly for faster updating of EMA network)

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torchvision.utils as vutils
import torch.nn.functional as F
import matplotlib.image as mpimg
import numpy as np

from tqdm import tqdm
from torch.nn.utils import weight_norm
from torch.autograd import Variable
from utils import get_data_loaders, plot_grid
from utils import log_sum_exp

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [4]:
lr = 0.0003
train_size = 400
batch_size = 25
z_dim = 100
ema_coeff = 0.0001
n = 32    # how to resize image
num_channels = 3   # img channels
use_cuda = torch.cuda.is_available()
print('use_cuda: {}'.format(use_cuda))

use_cuda: True


In [5]:
transform = transforms.Compose([transforms.Resize((n, n)), transforms.ToTensor(), 
                              transforms.Lambda(lambda x: x * 2. - 1.)])

In [6]:
train_loader, test_loader = get_data_loaders('cifar10', transform=transform, batch_size=batch_size*2, 
                                             use_cuda=use_cuda)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
train_loader_ssl, _ = get_data_loaders('cifar10', transform=transform, batch_size=batch_size, 
                                             use_cuda=use_cuda, train_size=400)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
def weights_init(model):
    classname = model.__class__.__name__
    if classname.find('Conv2d') != -1:
        model.weight.data.normal_(0.0, 0.05)
    elif classname.find('Linear') != -1:
        model.weight.data.normal_(0.0, 0.05)
        model.bias.data.fill_(0)
    elif classname.find('ConvTranspose2d') != -1:
        model.weight.data.normal_(0.0, 0.05)
        model.bias.data.fill_(0)

In [9]:
def test(net, testloader):
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs, _ = net(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    acc = correct / total
    net.train()
    return acc

In [10]:
def update_avg_nn(model_avg, model, coeff):
    for (na, pa), (n, p) in zip(model_avg.named_parameters(), model.named_parameters()):
        pa.data = pa.data * (1 - coeff) + coeff * p.data

In [11]:
class Dc(nn.Module):
    def __init__(self, num_channels=3, img_size=(32, 32), num_classes=10):
        super(Dc, self).__init__()
        self.conv = nn.Sequential(
            nn.Dropout(0.2),
            # block 1
            weight_norm(nn.Conv2d(in_channels=num_channels, out_channels=96, kernel_size=3, padding=1)),
            nn.LeakyReLU(0.2),
            weight_norm(nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, padding=1)),
            nn.LeakyReLU(0.2),
            weight_norm(nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=2, padding=1)),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5),
            # block 2
            weight_norm(nn.Conv2d(in_channels=96, out_channels=192, kernel_size=3, padding=1)),
            nn.LeakyReLU(0.2),
            weight_norm(nn.Conv2d(in_channels=192, out_channels=192, kernel_size=3, padding=1)),
            nn.LeakyReLU(0.2),
            weight_norm(nn.Conv2d(in_channels=192, out_channels=192, kernel_size=3, stride=2, padding=1)),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.5),
            # block 3
            weight_norm(nn.Conv2d(in_channels=192, out_channels=192, kernel_size=3)),
            nn.LeakyReLU(0.2),
            weight_norm(nn.Conv2d(in_channels=192, out_channels=192, kernel_size=1)),
            nn.LeakyReLU(0.2),
            weight_norm(nn.Conv2d(in_channels=192, out_channels=192, kernel_size=1)),
            nn.LeakyReLU(0.2),
            nn.AvgPool2d(kernel_size=6, stride=1, padding=0), 
        )
        self.logits = weight_norm(nn.Linear(192, num_classes))
        
    def forward(self, x):
        x = self.conv(x)
        features = x.view(-1, x.size(1))
        logits = self.logits(features)
        return logits, features

In [12]:
class Gc(nn.Module):
    def __init__(self, dim_z, dim_features=512, num_channels=3):
        super(Gc, self).__init__()
        self.dim_features = dim_features
        self.num_channels = num_channels
        self.l1 = nn.Sequential(
            nn.Linear(dim_z, self.dim_features * 4 * 4),
            nn.BatchNorm1d(self.dim_features * 4 * 4),
            nn.ReLU(),
        )
        self.l2 = nn.Sequential(
            nn.ConvTranspose2d(self.dim_features, self.dim_features // 2, kernel_size=5, stride=2, 
                               output_padding=1, padding=2),
            nn.BatchNorm2d(self.dim_features // 2),
            nn.ReLU(),
            nn.ConvTranspose2d(self.dim_features // 2, self.dim_features // 4, kernel_size=5, stride=2, 
                                output_padding=1, padding=2),
            nn.BatchNorm2d(self.dim_features // 4),
            nn.ReLU(),
            weight_norm(nn.ConvTranspose2d(self.dim_features // 4, self.num_channels, kernel_size=5, stride=2, 
                                output_padding=1, padding=2)),
            nn.Tanh()
        )
    
    def forward(self, x):
        x = self.l1(x).view(x.size(0), self.dim_features, 4, 4)
        img = self.l2(x)
        return img   

In [13]:
D = Dc()
D_avg = Dc()
G = Gc(z_dim)
D.train(), G.train(), D_avg.train()
print()




In [14]:
D.apply(weights_init)
G.apply(weights_init)
D_avg.load_state_dict(D.state_dict())
print()




In [15]:
softplus = nn.Softplus()
criterion = nn.CrossEntropyLoss()
if use_cuda:
    D.cuda()
    G.cuda()
    D_avg.cuda()

In [None]:
D_opt = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
G_opt = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))

In [None]:
# r = real, f = fake, ul = unlabelled, s = sample
final_acc = []
print()
D.zero_grad()
G.zero_grad()
for ne in range(1000):
    c = 0
    it = iter(train_loader_ssl)
    eloss, gloss = [], []
    for x, _ in train_loader:
        # setting up input data
        x_ul_1 = x[:batch_size]
        x_ul_2 = x[batch_size:]
        c += 1
        if c == len(train_loader_ssl):  # restarting train_loader_ssl
            c = 0
            it = iter(train_loader_ssl)
        z = torch.randn(batch_size, z_dim)
        x_l, y_l = next(it)
        if use_cuda:
            x_l = x_l.cuda()
            y_l = y_l.cuda()
            x_ul_1 = x_ul_1.cuda()
            x_ul_2 = x_ul_2.cuda()
            z = z.cuda()
        # train discriminator
        D_r_l, _ = D(x_l)
        D_r_ul, _ = D(x_ul_1)
        G_s = G(z)
        D_f, _ = D(G_s.detach())
        loss_l = criterion(D_r_l, y_l)
        ul = log_sum_exp(D_r_ul)
        lf = log_sum_exp(D_f)
        loss_ul = -ul.mean() + softplus(ul).mean() + softplus(lf).mean()
        D_loss = loss_l + 0.5 * loss_ul 
        D_loss.backward()
        D_opt.step()
        D.zero_grad()
        G.zero_grad()
        update_avg_nn(D_avg, D, ema_coeff)         
        # train generator
        _, layer_r = D(x_ul_2)
        _, layer_f = D(G_s)
        G_loss = torch.abs(layer_r.detach().mean(dim=0) - layer_f.mean(dim=0)).mean()
        G_loss.backward()
        G_opt.step()
        D.zero_grad()
        G.zero_grad()
        # stats
        eloss.append(D_loss.data.cpu().numpy())
        gloss.append(G_loss.data.cpu().numpy())
    acc = test(D, test_loader)
    acc_avg = test(D_avg, test_loader)
    final_acc.append(acc_avg)
    print('Epoch: {}; D_loss: {:.2f}; G_loss: {:.2f}; Accuracy: {:.2f}%; Accuracy Avg: {:.2f}%;'\
          .format(ne, np.array(eloss).mean(), np.array(gloss).mean(), acc*100., acc_avg*100.))


Epoch: 0; D_loss: 2.58; G_loss: 0.37; Accuracy: 34.47%; Accuracy Avg: 10.00%;
Epoch: 1; D_loss: 2.18; G_loss: 0.24; Accuracy: 41.51%; Accuracy Avg: 10.00%;
Epoch: 2; D_loss: 1.99; G_loss: 0.27; Accuracy: 39.25%; Accuracy Avg: 10.15%;
Epoch: 3; D_loss: 1.82; G_loss: 0.28; Accuracy: 41.49%; Accuracy Avg: 21.46%;
Epoch: 4; D_loss: 1.70; G_loss: 0.27; Accuracy: 41.83%; Accuracy Avg: 25.61%;
Epoch: 5; D_loss: 1.57; G_loss: 0.30; Accuracy: 48.51%; Accuracy Avg: 29.35%;
Epoch: 6; D_loss: 1.45; G_loss: 0.32; Accuracy: 51.67%; Accuracy Avg: 33.03%;
Epoch: 7; D_loss: 1.36; G_loss: 0.32; Accuracy: 46.86%; Accuracy Avg: 36.55%;
Epoch: 8; D_loss: 1.28; G_loss: 0.33; Accuracy: 49.44%; Accuracy Avg: 40.12%;
Epoch: 9; D_loss: 1.19; G_loss: 0.34; Accuracy: 50.31%; Accuracy Avg: 43.32%;
Epoch: 10; D_loss: 1.09; G_loss: 0.34; Accuracy: 51.22%; Accuracy Avg: 45.93%;
Epoch: 11; D_loss: 1.03; G_loss: 0.34; Accuracy: 51.61%; Accuracy Avg: 47.62%;
Epoch: 12; D_loss: 0.95; G_loss: 0.35; Accuracy: 53.77%; Accu

In [None]:
plt.rcParams['figure.figsize'] = [8, 8]
G.training = True
z = torch.randn(64, z_dim).cuda() 
vutils.save_image(G(z).data, 'temp.png')
img = mpimg.imread('temp.png')
plt.imshow(img)
G.training = True

In [None]:
plt.rcParams['figure.figsize'] = [8, 8]
l = next(iter(train_loader))
vutils.save_image(l[0], 'temp2.png')
img = mpimg.imread('temp2.png')
plt.imshow(img)
G.training = True

In [None]:
for n, p in list(D_avg.named_parameters()):
    print(n, p.size())