In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchsummary import summary
from torchvision.datasets import MNIST

  warn(f"Failed to load image Python extension: {e}")


In [2]:
plt.style.use("ggplot")
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
CHANNELS, IMG_ROWS, IMG_COLS = 1, 28, 28
IMG_SHAPE = (CHANNELS, IMG_ROWS, IMG_COLS)
Z_DIM = 100
NUM_CLASSES = 10

In [3]:
class Dataset:
    def __init__(self, num_labeled):
        self.num_labeled = num_labeled
        download_root = "./MNIST_DATASET"
        train_dataset = MNIST(download_root, train=True, download=True)
        test_dataset = MNIST(download_root, train=False, download=True)
        self.x_train, self.y_train = train_dataset.data.to(DEVICE), train_dataset.targets.to(DEVICE)
        self.x_test, self.y_test = test_dataset.data.to(DEVICE), test_dataset.targets.to(DEVICE)

        self.x_train, self.x_test = self.x_train / 127.5 - 1.0, self.x_test / 127.5 - 1.0
        self.x_train, self.x_test = torch.unsqueeze(self.x_train, 1), torch.unsqueeze(self.x_test, 1)

    def batch_labeled(self, batch_size):
        idx = np.random.randint(low=0, high=self.num_labeled, size=batch_size)
        imgs = self.x_train[idx]
        labels = self.y_train[idx]
        return imgs, labels

    def batch_unlabeled(self, batch_size):
        idx = np.random.randint(low=self.num_labeled, high=self.x_train.shape[0], size=batch_size)
        imgs = self.x_train[idx]
        return imgs

    def training_set(self):
        x_train = self.x_train[range(self.num_labeled)]
        y_train = self.y_train[range(self.num_labeled)]
        return x_train, y_train

    def test_set(self):
        return self.x_test, self.y_test


NUM_LABELED = 1000
dataset = Dataset(num_labeled=NUM_LABELED)

In [4]:
class Generator(nn.Module):
    def __init__(self, z_dim):
        super(Generator, self).__init__()
        self.fc = nn.Linear(z_dim, 256 * 7 * 7)
        self.model = nn.Sequential(
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.reshape(-1, 256, 7, 7)
        output = self.model(x)
        return output


generator = Generator(z_dim=Z_DIM).to(DEVICE)
summary(generator, (Z_DIM,))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                [-1, 12544]       1,266,944
   ConvTranspose2d-2          [-1, 128, 14, 14]         295,040
       BatchNorm2d-3          [-1, 128, 14, 14]             256
         LeakyReLU-4          [-1, 128, 14, 14]               0
   ConvTranspose2d-5           [-1, 64, 14, 14]          73,792
       BatchNorm2d-6           [-1, 64, 14, 14]             128
         LeakyReLU-7           [-1, 64, 14, 14]               0
   ConvTranspose2d-8            [-1, 1, 28, 28]             577
              Tanh-9            [-1, 1, 28, 28]               0
Total params: 1,636,737
Trainable params: 1,636,737
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.97
Params size (MB): 6.24
Estimated Total Size (MB): 7.21
---------------------------------------

In [5]:
class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=img_shape[0], out_channels=32, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=0),
            nn.LeakyReLU(negative_slope=0.01, inplace=True),
            nn.Dropout(p=0.5),
            nn.Flatten(),
            nn.Linear(3 * 3 * 128, 10),
        )

    def forward(self, x):
        x = self.model(x)
        return x


discriminator_semi = Discriminator(img_shape=IMG_SHAPE).to(DEVICE)
summary(discriminator_semi, IMG_SHAPE)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 14, 14]             320
         LeakyReLU-2           [-1, 32, 14, 14]               0
            Conv2d-3             [-1, 64, 7, 7]          18,496
         LeakyReLU-4             [-1, 64, 7, 7]               0
            Conv2d-5            [-1, 128, 3, 3]          73,856
         LeakyReLU-6            [-1, 128, 3, 3]               0
           Dropout-7            [-1, 128, 3, 3]               0
           Flatten-8                 [-1, 1152]               0
            Linear-9                   [-1, 10]          11,530
Total params: 104,202
Trainable params: 104,202
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.18
Params size (MB): 0.40
Estimated Total Size (MB): 0.58
-------------------------------------------

In [6]:
optimizer_d_sl = optim.Adam(discriminator_semi.parameters(), lr=0.0002)
optimizer_d_ul = optim.Adam(discriminator_semi.parameters(), lr=0.0002)
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002)
criterion4sl = nn.CrossEntropyLoss()
criterion4ul = nn.BCELoss()

In [8]:
supervised_losses, iteration_checkpoints = [], []
iterations = 8000
batch_size = 128
sample_interval = 800

real = torch.ones(batch_size, 1).to(DEVICE)
fake = torch.zeros(batch_size, 1).to(DEVICE)

for iteration in range(iterations):
    imgs, labels = dataset.batch_labeled(batch_size)
    imgs_unlabeled = dataset.batch_unlabeled(batch_size)

    z = torch.randn(batch_size, Z_DIM).to(DEVICE)
    gen_imgs = generator(z).detach()
    
    optimizer_d_sl.zero_grad()
    d_pred_sl = discriminator_semi(imgs)
    d_loss_sl = criterion4sl(d_pred_sl, labels)
    d_loss_sl.backward()
    optimizer_d_sl.step()

    optimizer_d_ul.zero_grad()
    d_pred_real, d_pred_fake = discriminator_semi(imgs_unlabeled), discriminator_semi(gen_imgs)
    d_pred_real = 1.0 - (1.0 / (torch.sum(torch.exp(d_pred_real), dim=-1, keepdim=True) + 1.0))
    d_pred_fake = 1.0 - (1.0 / (torch.sum(torch.exp(d_pred_fake), dim=-1, keepdim=True) + 1.0))
    
    # d_pred_real, d_pred_fake = d_pred_real.detach(), d_pred_fake.detach()
    
    d_loss_real, d_loss_fake = criterion4ul(d_pred_real, real), criterion4ul(d_pred_fake, fake)
    d_loss_ul = (d_loss_real + d_loss_fake) * 0.5
    d_loss_ul.backward()
    optimizer_d_ul.step()

    z = torch.randn(batch_size, Z_DIM).to(DEVICE)
    gen_imgs = generator(z).detach()

    optimizer_g.zero_grad()
    d_pred_fake = discriminator_semi(gen_imgs)
    d_pred_fake = 1.0 - (1.0 / (torch.sum(torch.exp(d_pred_fake), dim=-1, keepdim=True) + 1.0))
    g_loss = criterion4ul(d_pred_fake, real)
    g_loss.backward()
    optimizer_g.step()

    if (iteration + 1) % sample_interval == 0:
        supervised_losses.append(d_loss_sl.item())
        iteration_checkpoints.append(iteration + 1)
        print(f"{iteration + 1} [D(SL) loss: {d_loss_sl.item():.4f}] "
              f"[D(UL) loss: {d_loss_ul.item():.4f}] [G loss: {g_loss.item():.4f}]")

800 [D(SL) loss: 0.0794] [D(UL) loss: 0.0014] [G loss: 56.2097]
1600 [D(SL) loss: 0.0217] [D(UL) loss: 0.0000] [G loss: 90.6599]
2400 [D(SL) loss: 0.0190] [D(UL) loss: 0.0000] [G loss: 94.5943]
3200 [D(SL) loss: 0.0049] [D(UL) loss: 0.0000] [G loss: 99.3379]
4000 [D(SL) loss: 0.0031] [D(UL) loss: 0.0000] [G loss: 100.0000]
4800 [D(SL) loss: 0.0003] [D(UL) loss: 0.0000] [G loss: 100.0000]
5600 [D(SL) loss: 0.0047] [D(UL) loss: 0.0000] [G loss: 100.0000]
6400 [D(SL) loss: 0.0034] [D(UL) loss: 0.0000] [G loss: 100.0000]
7200 [D(SL) loss: 0.0155] [D(UL) loss: 0.0000] [G loss: 100.0000]
8000 [D(SL) loss: 0.0002] [D(UL) loss: 0.0000] [G loss: 99.2524]


In [10]:
with torch.no_grad():
    X, y = dataset.test_set()
    discriminator_semi.eval()
    prediction = discriminator_semi(X)
    correct = torch.argmax(prediction, 1) == y
    accuracy = correct.float().mean()
    print(accuracy.item())

0.9386000037193298
