In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision import datasets
from torchvision import transforms

In [2]:
# Mnist Train_dataset
transform = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST(root='./mnist_data', train=True, transform=transform, download=True)
train_dataset = datasets.MNIST(root='./mnist_data', train=False, transform=transform, download=True)

BATCH_SIZE = 256
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)

In [3]:
class CVAE(nn.Module):

    def __init__(self, input_dim=784, hidden_dim1=512, hidden_dim2=256, z_dim=2, c_dim=10):
        super(CVAE, self).__init__()

        # encoder
        self.en_fc1 = nn.Linear(input_dim + c_dim, hidden_dim1)
        self.en_fc2 = nn.Linear(hidden_dim1, hidden_dim2)
        self.en_fc3_u = nn.Linear(hidden_dim2, z_dim) # u, mean
        self.en_fc3_var = nn.Linear(hidden_dim2, z_dim) # s, log_var

        # decoder

        self.de_fc1 = nn.Linear(z_dim + c_dim, hidden_dim2)
        self.de_fc2 = nn.Linear(hidden_dim2, hidden_dim1)
        self.de_fc3 = nn.Linear(hidden_dim1, input_dim)

        #
        self.flatten = nn.Flatten()

    def encoder(self, inputs, conditions):
        x = self.flatten(inputs)
        x = torch.concat([x, conditions], dim=-1)


        x = F.relu(self.en_fc1(x))
        x = F.relu(self.en_fc2(x))

        mu = self.en_fc3_u(x)
        log_var = self.en_fc3_var(x)

        return mu, log_var
    
    def reparameterization(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        z = mu + std * eps
        return z
    
    def decoder(self, z, conditions):

        x = torch.concat([z, conditions], dim=-1)
        x = F.relu(self.de_fc1(x))
        x = F.relu(self.de_fc2(x))
        recon_x = torch.sigmoid(self.de_fc3(x)) # 0 ~ 1
        return recon_x
    
    def forward(self, inputs, conditions):
        mu, log_var = self.encoder(inputs, conditions)
        z = self.reparameterization(mu, log_var)
        recon_x = self.decoder(z, conditions)
        return recon_x , mu, z
    
cvae = CVAE()
cvae

CVAE(
  (en_fc1): Linear(in_features=794, out_features=512, bias=True)
  (en_fc2): Linear(in_features=512, out_features=256, bias=True)
  (en_fc3_u): Linear(in_features=256, out_features=2, bias=True)
  (en_fc3_var): Linear(in_features=256, out_features=2, bias=True)
  (de_fc1): Linear(in_features=12, out_features=256, bias=True)
  (de_fc2): Linear(in_features=256, out_features=512, bias=True)
  (de_fc3): Linear(in_features=512, out_features=784, bias=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
)

In [4]:
optimizer = torch.optim.Adam(cvae.parameters())

def loss_function(recon_x, x, mu, log_var):
    flatten = nn.Flatten()
    bce = nn.BCELoss(reduction='sum')

    x = flatten(x)

    BCE_Loss = bce(recon_x, x)
    KLD_Loss = 0.5 * torch.sum(mu**2 + torch.exp(log_var) - log_var -1)
    return BCE_Loss + KLD_Loss

def one_of_k_encoding(x, class_num=10):
    vocab = list(range(class_num))

    output = []
    for i in x:
        encodinig = list(map(lambda s: float(s==i), vocab))
        output.append(encodinig)
    return torch.tensor(output, dtype=torch.float32)

In [5]:
def train(epoch):
    cvae.train()
    train_losses = 0.0
    for data, label in train_dataloader:
        optimizer.zero_grad()
        conditions = one_of_k_encoding(label)

        recon_batch, mu, log_var = cvae(data, conditions)

        loss = loss_function(recon_batch, data, mu, log_var)
        loss.backward()
        optimizer.step()
        train_losses += loss.item()
    print(f'>>>>>EPOCH {epoch} Average_loss = {train_losses / len(train_dataloader.dataset)}')


def test():
    cvae.eval()
    test_losses = 0.0
    with torch.no_grad():
        for data, label in test_dataloader:
            conditions = one_of_k_encoding(label)

            recon_batch, mu, log_var = cvae(data, conditions)
            test_losses += loss_function(recon_batch, data, mu, log_var).item()
    test_losses /= len(test_dataloader.dataset)
    print(f'>>>>> Test set loss {test_losses}')

In [6]:
for epoch in range(1, 21):
    train(epoch)
    test()

>>>>>EPOCH 1 Average_loss = 297.184584375
>>>>> Test set loss 204.51787109375
>>>>>EPOCH 2 Average_loss = 195.15897265625
>>>>> Test set loss 182.2676255126953
>>>>>EPOCH 3 Average_loss = 172.8423453125
>>>>> Test set loss 166.47366264648437
>>>>>EPOCH 4 Average_loss = 161.7261953125
>>>>> Test set loss 157.20613024902343
>>>>>EPOCH 5 Average_loss = 154.07563984375
>>>>> Test set loss 151.39239956054686
>>>>>EPOCH 6 Average_loss = 149.63446875
>>>>> Test set loss 148.14986274414062
>>>>>EPOCH 7 Average_loss = 146.963215234375
>>>>> Test set loss 145.7650412841797
>>>>>EPOCH 8 Average_loss = 144.687115234375
>>>>> Test set loss 143.5314244628906
>>>>>EPOCH 9 Average_loss = 142.438504296875
>>>>> Test set loss 141.52418823242186
>>>>>EPOCH 10 Average_loss = 140.292949609375
>>>>> Test set loss 139.25463903808594
>>>>>EPOCH 11 Average_loss = 138.550091796875
>>>>> Test set loss 137.82870463867187
>>>>>EPOCH 12 Average_loss = 137.3149265625
>>>>> Test set loss 136.57165769042967
>>>>>EPOCH

In [7]:
from torchvision.utils import save_image

with torch.no_grad():
    z = torch.randn(10, 2)
    conditions = torch.eye(10)

    sample = cvae.decoder(z, conditions)
    save_image(sample.view(10, 1, 28, 28), 'tutorial_cvae_result.png')