In [1]:
import os
import time
import torch
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

In [2]:
batch_size = 100
epochs = 5
seed = 1
log_interval = 10
torch.manual_seed(seed)

<torch._C.Generator at 0x20b8f5ebb10>

In [3]:
train_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder(r"C:\Users\shiva\Kuzushiji-Kanji\kanjivgmain", transform=transforms.ToTensor()), batch_size=batch_size)

In [4]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False), 
            nn.LeakyReLU(0.2, inplace=True),    
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True), 
            nn.Conv2d(128, 256, 3, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 1024, 4, 1, 0, bias=False), 
            nn.LeakyReLU(0.2, inplace=True),
            )

        self.decoder = nn.Sequential( 
            nn.ConvTranspose2d(1024, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512), 
            nn.ReLU(True),  
            nn.ConvTranspose2d(512, 256, 3, 2, 1, bias=False), 
            nn.BatchNorm2d(256), 
            nn.ReLU(True),  
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128), 
            nn.ReLU(True),  
            nn.ConvTranspose2d(128, 3, 4, 2, 1, bias=False), 
            nn.Sigmoid()
            )

        self.fc1 = nn.Linear(1024, 512)
        self.fc21 = nn.Linear(512, 64)
        self.fc22 = nn.Linear(512, 64)
        self.fc3 = nn.Linear(64, 512)
        self.fc4 = nn.Linear(512, 1024)
        self.lrelu = nn.LeakyReLU()
        self.relu = nn.ReLU()

    def encode(self, x):
        conv = self.encoder(x); 
        h1 = self.fc1(conv.view(-1, 1024)) 
        return self.fc21(h1), self.fc22(h1)
            
    def decode(self, z):
        h3 = self.relu(self.fc3(z)) 
        deconv_input = self.fc4(h3)  
        deconv_input = deconv_input.view(-1, 1024, 1, 1) 
        return self.decoder(deconv_input)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
            
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


model = VAE()
optimizer = optim.Adam(model.parameters(), lr=1e-3)


# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x.view(-1, 1, 28, 28), x.view(-1, 1, 28, 28), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD


def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
epoch, train_loss / len(train_loader.dataset)))

In [None]:
for epoch in range(1, epochs + 1):
    train(epoch)
    with torch.no_grad():
        sample = torch.randn(50, 64)
        sample = model.decode(sample)
        save_image(sample.view(50, 3, 28, 28), r'C:\Users\shiva\Kuzushiji-Kanji\_' + str(epoch) + '.png')

