In [1]:
# prerequisites
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

In [2]:
bs = 100
seed = 3407
torch.manual_seed(seed)
# MNIST Dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.5,), std=(0.5,))])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [None]:
def try_gpu(i=0):
    if torch.cuda.device_count() >= i + 1:
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')

def save_model(model, path, optimizer=None):
    """
    Saves the state_dict of a torch model and optional optimizer to 'path'
    Returns: None
    """
    state = {"model": model.state_dict()}
    if optimizer is not None:
        state["optimizer"] = optimizer.state_dict()
    torch.save(state, path)


def load_model(model, path, optimizer=None):
    """
    Loads the state_dict of a torch model and optional optimizer from 'path'
    Returns: None
    """
    state = torch.load(path)
    model.load_state_dict(state["model"])
    if optimizer is not None:
        optimizer.load_state_dict(state["optimizer"])

In [3]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()       
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, g_output_dim)

    def forward(self, x): 
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        return torch.tanh(self.fc3(x))
    
class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 512)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 64)
        self.fc5 = nn.Linear(64, 1)
    
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc4(x), 0.2)
        x = F.dropout(x, 0.3)        
        return torch.sigmoid(self.fc5(x))

In [4]:
# build network
z_dim = 100
mnist_dim = train_dataset.data.size(1) * train_dataset.data.size(2)

G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device=try_gpu())
D = Discriminator(784).to(device=try_gpu())
os.makedirs("./samples", exist_ok=True)



In [6]:
criterion = nn.BCELoss() 
lr = 0.0005
G_optimizer = optim.Adam(G.parameters(), lr = lr)
D_optimizer = optim.Adam(D.parameters(), lr = lr)

In [7]:
def D_train(x):
    # train discriminator on real
    x_real, y_real = x.view(-1, mnist_dim), torch.ones(bs, 1)
    x_real, y_real = Variable(x_real.to(device=try_gpu())), Variable(y_real.to(device=try_gpu()))

    D_output = D(x_real)
    D_real_loss = criterion(D_output, y_real)

    # train discriminator on facke
    z = Variable(torch.randn(bs, z_dim).to(device=try_gpu()))
    x_fake, y_fake = G(z), Variable(torch.zeros(bs, 1).to(device=try_gpu()))

    D_output = D(x_fake)
    D_fake_loss = criterion(D_output, y_fake)

    # gradient backprop & optimize ONLY D's parameters
    D_loss = D_real_loss + D_fake_loss
    D_optimizer.zero_grad()
    D_loss.backward()
    D_optimizer.step()
        
    return  D_loss.data.item()

In [8]:
def G_train(x):
    #=======================Train the generator=======================#
    z = Variable(torch.randn(bs, z_dim).to(device=try_gpu()))
    y = Variable(torch.ones(bs, 1).to(device=try_gpu()))

    G_output = G(z)
    D_output = D(G_output)
    G_loss = criterion(D_output, y)

    # gradient backprop & optimize ONLY G's parameters
    G_optimizer.zero_grad()
    G_loss.backward()
    G_optimizer.step()
        
    return G_loss.data.item()

In [9]:
n_epoch = 100
D_losses, G_losses = [], []
for epoch in range(1, n_epoch+1):
    ed_loss, eg_loss = 0, 0        
    for batch_idx, (x, _) in enumerate(train_loader):
        ed_loss += D_train(x)
        eg_loss += G_train(x)
    with torch.no_grad():
        test_z = Variable(torch.randn(bs, z_dim).to(device=try_gpu()))
        generated = G(test_z)
        save_image(generated.view(generated.size(0), 1, 28, 28), f'./samples/sample_{epoch}.png')

    D_losses.append(ed_loss / batch_idx)
    G_losses.append(eg_loss / batch_idx)
    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), n_epoch, ed_loss / batch_idx, eg_loss / batch_idx))
    if (epoch) % 50 == 0:
        torch.save(D.state_dict(), f"./gan_D_epoch_{epoch}.ckpt")

np.save("gan_discriminator_loss", D_losses)
np.save("gan_gen_loss", G_losses)

In [10]:
class GanClassifier(nn.Module):
    def __init__(self, d_input_dim):
        super(GanClassifier, self).__init__()

        self.fc1 = nn.Linear(d_input_dim, 512)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        
        self.current = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(True),
            nn.Linear(128, 10)
        )
    
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = self.current(x)
        return x

In [12]:
classifier = GanClassifier(d_input_dim=784)
print(classifier)
pre_trained = Discriminator(mnist_dim)
pre_trained.to(device=try_gpu())
classifier.to(device=try_gpu())
pre_trained.load_state_dict(torch.load(f"./gan_D_epoch_{n_epoch}.ckpt"))
classifier_dict = classifier.state_dict()
pre = {k: v for k, v in pre_trained.state_dict().items() if k in classifier_dict}
classifier_dict.update(pre)
classifier.load_state_dict(classifier_dict)

GanClassifier(
  (fc1): Linear(in_features=784, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (current): Sequential(
    (0): Linear(in_features=256, out_features=128, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=128, out_features=10, bias=True)
  )
)


<All keys matched successfully>

In [13]:
def train_classifier(net, train_iter, criterion, optimizer):
    if isinstance(net, nn.Module):
        net.train()
    loss_record = 0
    total = 0
    for _, data in enumerate (train_iter):
        img, labels = data
        img = img.view(-1, 784)
        img = Variable(img)
        label = Variable(labels)
        img = img.to(device=try_gpu())
        label = label.to(device=try_gpu())
        # ===================forward=====================
        output = net(img)
        loss = criterion(output, label)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_record += loss
        total += 1
    # ===================log========================
    return loss_record / total

def eval_accuracy(net, data_iter):
    if isinstance(net, nn.Module):
        net.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for data, target in data_iter:
            eval_imgs = Variable(data.view(-1, 784)).to(device=try_gpu())
            target = Variable(target).to(device=try_gpu())
            output = net(eval_imgs)
            pred = output.max(dim=1)[1]
            correct += (pred == target).sum().item()
            total += target.size(0)
        return correct / total

In [14]:
for name, param in classifier.named_parameters():
    if param.requires_grad and ("fc" ) in name:
        param.requires_grad = False

classifier.to(device=try_gpu())
for name, param in classifier.named_parameters():
    print(name, param.requires_grad)

num_epochs = 10
classifier_criterion = nn.CrossEntropyLoss().to(device=try_gpu())
classifier_optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001, weight_decay=1e-5)

train_loss = []
val_acc =[]
for epoch in range(num_epochs):
    epoch_train_loss = train_classifier(classifier, train_loader, classifier_criterion, classifier_optimizer)
    epoch_val_acc = eval_accuracy(classifier, test_loader, classifier_criterion)
    print('epoch [{}/{}], train_loss:{:.4f}'.format(epoch+1, num_epochs, epoch_train_loss.item()))    
    print('epoch [{}/{}], val_acc:{:.4f}'.format(epoch+1, num_epochs, epoch_val_acc))
    train_loss.append(epoch_train_loss.item())
    val_acc.append(epoch_val_acc)

np.save("ganclassvalaccfixed.npy", val_acc)

fc1.weight False
fc1.bias False
fc2.weight False
fc2.bias False
current.0.weight True
current.0.bias True
current.2.weight True
current.2.bias True
epoch [1/10], train_loss:0.3209
epoch [1/10], val_acc:0.9267
epoch [2/10], train_loss:0.2261
epoch [2/10], val_acc:0.9335
epoch [3/10], train_loss:0.2020
epoch [3/10], val_acc:0.9429
epoch [4/10], train_loss:0.1840
epoch [4/10], val_acc:0.9424
epoch [5/10], train_loss:0.1789
epoch [5/10], val_acc:0.9479
epoch [6/10], train_loss:0.1711
epoch [6/10], val_acc:0.9434
epoch [7/10], train_loss:0.1696
epoch [7/10], val_acc:0.9461
epoch [8/10], train_loss:0.1630
epoch [8/10], val_acc:0.9477
epoch [9/10], train_loss:0.1641
epoch [9/10], val_acc:0.9487
epoch [10/10], train_loss:0.1602
epoch [10/10], val_acc:0.9446


: 