In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable

import numpy as np
from modules import Generator,Discriminator,Encoder

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
latent_dim=100
batch_size=32
input_size=28
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (1.0,))])

In [3]:
device.type

'cuda'

In [4]:
# MNIST dataset 
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform,download=True)
test_dataset = torchvision.datasets.MNIST(root='./data',train=False,transform=transform)

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)


In [5]:
img_shape=(input_size,input_size)
E = Encoder(latent_dim,img_shape).to(device)
G = Generator(latent_dim,img_shape).to(device)
D = Discriminator(latent_dim,img_shape).to(device)
Tensor = torch.cuda.FloatTensor if device.type== 'cuda' else torch.FloatTensor
adversarial_criterion = torch.nn.BCELoss()

In [6]:
E.model

Sequential(
  (0): Linear(in_features=784, out_features=512, bias=True)
  (1): Linear(in_features=512, out_features=512, bias=True)
  (2): BatchNorm1d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  (3): LeakyReLU(negative_slope=0.2)
  (4): Linear(in_features=512, out_features=256, bias=True)
  (5): BatchNorm1d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  (6): LeakyReLU(negative_slope=0.2)
  (7): Linear(in_features=256, out_features=128, bias=True)
  (8): BatchNorm1d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  (9): LeakyReLU(negative_slope=0.2)
  (10): Linear(in_features=128, out_features=100, bias=True)
  (11): Tanh()
)

In [7]:
G.model

Sequential(
  (0): Linear(in_features=100, out_features=128, bias=True)
  (1): BatchNorm1d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): Linear(in_features=128, out_features=256, bias=True)
  (4): BatchNorm1d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU()
  (6): Linear(in_features=256, out_features=512, bias=True)
  (7): BatchNorm1d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  (8): ReLU()
  (9): Linear(in_features=512, out_features=512, bias=True)
  (10): BatchNorm1d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  (11): ReLU()
  (12): Linear(in_features=512, out_features=784, bias=True)
  (13): Tanh()
)

In [8]:
D.model

Sequential(
  (0): Linear(in_features=884, out_features=1024, bias=True)
  (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): LeakyReLU(negative_slope=0.2)
  (3): Dropout(p=0.4, inplace=False)
  (4): Linear(in_features=1024, out_features=512, bias=True)
  (5): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (6): LeakyReLU(negative_slope=0.2)
  (7): Dropout(p=0.4, inplace=False)
  (8): Linear(in_features=512, out_features=256, bias=True)
  (9): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (10): LeakyReLU(negative_slope=0.2)
  (11): Dropout(p=0.4, inplace=False)
  (12): Linear(in_features=256, out_features=1, bias=True)
)

In [8]:
for epoch in range(1):
    for i, (images, _) in enumerate(train_loader):
        images = images.reshape(-1,np.prod(img_shape)).to(device)
        
        print('original image shape:',images.shape)
        (original_img,enc_img)= E(images)
        print(original_img.shape,enc_img.shape)
        # Sample noise as generator input
        real_noise = Variable(Tensor(np.random.normal(0, 1, (images.shape[0],latent_dim))))
        (gen_img,real_noise)=G(real_noise)
        print(gen_img.shape,real_noise.shape)
        gen_img= gen_img.view((gen_img.size(0),-1))
        print(gen_img.shape)
        fake = Variable(Tensor(images.size(0), 1).fill_(0.0), requires_grad=False)
        predict = D(gen_img,real_noise)
        loss = adversarial_criterion(predict,fake)
        print(loss)
        loss.backward()
        print(loss)
        break

tensor([[[[-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
          [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
          [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
          ...,
          [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
          [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
          [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000]]],


        [[[-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
          [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
          [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
          ...,
          [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
          [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
          [-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000]]],


        [[[-0.5000, -0.5000, -0.5000,  ..., -0.5000, -0.5000, -0.5000],
          [-0.5000, -0.500

TypeError: mean() received an invalid combination of arguments - got (out=NoneType, axis=NoneType, dtype=NoneType, ), but expected one of:
 * (torch.dtype dtype)
 * (tuple of names dim, bool keepdim, torch.dtype dtype)
      didn't match because some of the keywords were incorrect: out, axis
 * (tuple of ints dim, bool keepdim, torch.dtype dtype)
      didn't match because some of the keywords were incorrect: out, axis


In [11]:
predict.shape

torch.Size([32, 1])

In [10]:
gen_img.shape

torch.Size([32, 28, 28])

In [29]:
real_noise.shape

torch.Size([32, 100])

In [35]:
torch.cat((gen_img,real_noise),1).shape

torch.Size([32, 884])

In [32]:
gen_img.view(gen_img.size(0), -1).shape

torch.Size([32, 784])

In [9]:
28*28

784

In [27]:
img_shape=(input_size,input_size)

SyntaxError: can't use starred expression here (<ipython-input-28-2bf152c73045>, line 1)