In [1]:
#importing dependencies

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import sampler
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import save_image

#import _pickle as cpickle #to store model histories in a file
import os
import imageio
from PIL import Image

use_cuda = False
device   = torch.device('cpu')
if torch.cuda.is_available():
    use_cuda = True
    device   = torch.device('cuda')
print(use_cuda)    
    
%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots"

True


In [2]:
#specifying parameters
image_size = 64 #HACK to use MNIST architecture
G_input_dim = 100
G_output_dim = 3
D_input_dim = 3
D_output_dim = 1
num_filters = [1024, 512, 256, 128]

learning_rate = 0.0002
betas = (0.5, 0.999)
batch_size = 128
num_epochs = 1000

data_dir = './Train_data'
save_dir = './DCGAN_results/'

In [3]:
print(os.getcwd())

/home/shreyashpandey/PoseGuided


In [4]:
#loading data

transform = transforms.Compose([
                                 #Hack to make MNIST code work
                                transforms.ToTensor(),
                                #transforms.Normalize(mean=(214.0466981, 206.55220904, 203.99178198), std=(54.34939265, 55.62690195, 58.85794001))
                               ])
                                

df_data = dsets.ImageFolder(data_dir, transform = transform)

data_loader = torch.utils.data.DataLoader(dataset=df_data,
                                          batch_size=batch_size,
                                          shuffle=True)


## VAE Model

In [5]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(64*64*3, 4096)
        self.fc2 = nn.Linear(4096,1000)
        self.fc3 = nn.Linear(1000,400)
        self.fc31 = nn.Linear(400, 60)
        self.fc32 = nn.Linear(400, 60)
        self.fc4 = nn.Linear(60, 400)
        self.fc5 = nn.Linear(400,1000)
        self.fc6 = nn.Linear(1000,4096)
        self.fc7 = nn.Linear(4096, 64*64*3)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = F.relu(self.fc2(h1))
        h3 = F.relu(self.fc3(h2))
        return self.fc31(h3), self.fc32(h3)

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5*logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, z):
        h4 = F.relu(self.fc4(z))
        h5 = F.relu(self.fc5(h4))
        h6 = F.relu(self.fc6(h5))
        #out = F.sigmoid(self.fc7(h6))
        out = self.fc7(h6)
        return out

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 64*64*3))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

##  Loss Function

In [6]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.mse_loss(recon_x, x.view(-1, 64*64*3), size_average=False)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

## Set up training and testing functions

In [7]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(data_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model.forward(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(data_loader.dataset),
                100. * batch_idx / len(data_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(data_loader.dataset)))


def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(data_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0 and epoch%1 == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(-1, 3, 64, 64)[:n]])
                save_image(comparison.cpu(),
                         'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(data_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))


# Train and Test

In [8]:
for epoch in range(1, num_epochs+1):
    train(epoch)
    test(epoch)
    
    with torch.no_grad():
        if epoch%1 == 0:
            sample = torch.randn(64, 60).to(device)
            sample = model.decode(sample).cpu()
            save_image(sample.view(64, 3, 64, 64),'results/sample_' + str(epoch) + '.png')


====> Epoch: 1 Average loss: 7969.7299
====> Test set loss: 491.9377
====> Epoch: 2 Average loss: 419.7873
====> Test set loss: 447.6332
====> Epoch: 3 Average loss: 385.3027
====> Test set loss: 385.8177
====> Epoch: 4 Average loss: 365.8480
====> Test set loss: 370.3429
====> Epoch: 5 Average loss: 342.7465
====> Test set loss: 331.3101


====> Epoch: 6 Average loss: 330.9819
====> Test set loss: 344.5012
====> Epoch: 7 Average loss: 319.7732
====> Test set loss: 326.9131
====> Epoch: 8 Average loss: 312.3764
====> Test set loss: 304.6423
====> Epoch: 9 Average loss: 307.1249
====> Test set loss: 312.8598
====> Epoch: 10 Average loss: 304.1908
====> Test set loss: 304.3378
====> Epoch: 11 Average loss: 301.5076
====> Test set loss: 304.5958


====> Epoch: 12 Average loss: 300.2108
====> Test set loss: 297.3886
====> Epoch: 13 Average loss: 293.1638
====> Test set loss: 290.0884
====> Epoch: 14 Average loss: 289.3482
====> Test set loss: 292.4807
====> Epoch: 16 Average loss: 284.7624
====> Test set loss: 286.5779
====> Epoch: 17 Average loss: 283.3137
====> Test set loss: 282.3890
====> Epoch: 18 Average loss: 280.5664
====> Test set loss: 276.3532


====> Epoch: 19 Average loss: 276.9815
====> Test set loss: 279.6854
====> Epoch: 20 Average loss: 274.6611
====> Test set loss: 276.9785
====> Epoch: 21 Average loss: 273.2161
====> Test set loss: 271.3798
====> Epoch: 22 Average loss: 271.2797
====> Test set loss: 268.2496
====> Epoch: 23 Average loss: 271.9356
====> Test set loss: 272.5535


====> Epoch: 24 Average loss: 269.1417
====> Test set loss: 280.7639
====> Epoch: 25 Average loss: 268.4054
====> Test set loss: 268.6512
====> Epoch: 26 Average loss: 266.3833
====> Test set loss: 264.2683
====> Epoch: 27 Average loss: 271.1645
====> Test set loss: 263.8919
====> Epoch: 28 Average loss: 267.1011
====> Test set loss: 264.4180
====> Epoch: 29 Average loss: 262.5237
====> Test set loss: 261.7360


====> Epoch: 30 Average loss: 261.8509
====> Test set loss: 263.3722
====> Epoch: 31 Average loss: 259.7849
====> Test set loss: 258.5203
====> Epoch: 32 Average loss: 260.8051
====> Test set loss: 263.8849
====> Epoch: 33 Average loss: 258.3441
====> Test set loss: 255.5665
====> Epoch: 35 Average loss: 254.8988
====> Test set loss: 258.6737


====> Epoch: 36 Average loss: 254.9925
====> Test set loss: 251.0610
====> Epoch: 37 Average loss: 253.7522
====> Test set loss: 255.2121
====> Epoch: 38 Average loss: 254.7533
====> Test set loss: 250.2197
====> Test set loss: 202.6127
====> Epoch: 232 Average loss: 205.4375
====> Test set loss: 198.0059
====> Epoch: 233 Average loss: 204.6793
====> Test set loss: 197.7527
====> Epoch: 234 Average loss: 204.9758
====> Test set loss: 198.9602


====> Epoch: 235 Average loss: 204.9401
====> Test set loss: 200.8411
====> Epoch: 236 Average loss: 204.9614
====> Test set loss: 197.5637
====> Epoch: 237 Average loss: 204.8006
====> Test set loss: 199.3333
====> Epoch: 238 Average loss: 204.5066
====> Test set loss: 200.4367
====> Epoch: 239 Average loss: 204.5875
====> Test set loss: 197.5621
====> Epoch: 240 Average loss: 204.6156
====> Test set loss: 198.1703


====> Epoch: 241 Average loss: 204.5773
====> Test set loss: 199.3499
====> Epoch: 242 Average loss: 204.7251
====> Test set loss: 200.3085
====> Epoch: 243 Average loss: 204.5880
====> Test set loss: 197.4510
====> Epoch: 244 Average loss: 204.3175
====> Test set loss: 200.1189
====> Epoch: 245 Average loss: 204.7997
====> Test set loss: 199.4697


====> Epoch: 246 Average loss: 204.5109
====> Test set loss: 202.6573
====> Epoch: 247 Average loss: 204.1944
====> Test set loss: 198.1093
====> Epoch: 248 Average loss: 204.1681
====> Test set loss: 198.5303
====> Epoch: 249 Average loss: 204.4490
====> Test set loss: 197.7406
====> Epoch: 250 Average loss: 204.2915
====> Test set loss: 198.8831
====> Epoch: 251 Average loss: 204.1595
====> Test set loss: 197.9614


====> Epoch: 252 Average loss: 204.0850
====> Test set loss: 199.1685
====> Epoch: 253 Average loss: 204.3019
====> Test set loss: 199.0998
====> Epoch: 254 Average loss: 204.7704
====> Test set loss: 197.3820
====> Epoch: 255 Average loss: 203.9992
====> Test set loss: 198.1915
====> Epoch: 256 Average loss: 204.2348
====> Test set loss: 200.6352
====> Epoch: 257 Average loss: 204.2155


====> Test set loss: 198.9185
====> Epoch: 258 Average loss: 203.9610
====> Test set loss: 197.3048
====> Epoch: 259 Average loss: 203.8997
====> Test set loss: 200.9721
====> Epoch: 260 Average loss: 204.1559
====> Test set loss: 197.4846
====> Epoch: 261 Average loss: 204.0957
====> Test set loss: 197.6391
====> Epoch: 262 Average loss: 203.8165
====> Test set loss: 197.6161


====> Epoch: 263 Average loss: 203.9195
====> Test set loss: 198.2164
====> Epoch: 264 Average loss: 204.0805
====> Test set loss: 200.5243
====> Epoch: 265 Average loss: 203.8274
====> Test set loss: 197.6702
====> Epoch: 266 Average loss: 204.3149
====> Test set loss: 196.3673
====> Epoch: 267 Average loss: 203.7004
====> Test set loss: 198.0727
====> Epoch: 268 Average loss: 203.6882
====> Test set loss: 199.2057


====> Epoch: 269 Average loss: 204.0039
====> Test set loss: 198.5693
====> Epoch: 270 Average loss: 203.8005
====> Test set loss: 197.7100
====> Epoch: 271 Average loss: 204.0320
====> Test set loss: 199.3712
====> Epoch: 272 Average loss: 203.5954
====> Test set loss: 196.6521
====> Epoch: 273 Average loss: 203.6749
====> Test set loss: 197.6862


====> Epoch: 274 Average loss: 203.7529
====> Test set loss: 199.9646
====> Epoch: 275 Average loss: 203.8136
====> Test set loss: 196.9176
====> Epoch: 276 Average loss: 203.8999
====> Test set loss: 197.7784
====> Epoch: 277 Average loss: 203.9208
====> Test set loss: 197.1767
====> Epoch: 278 Average loss: 203.3364
====> Test set loss: 198.8774
====> Epoch: 279 Average loss: 203.2784
====> Test set loss: 198.7787


====> Epoch: 280 Average loss: 203.5376
====> Test set loss: 199.8368
====> Epoch: 281 Average loss: 203.7724
====> Test set loss: 196.2309
====> Epoch: 282 Average loss: 203.1395
====> Test set loss: 198.3470
====> Epoch: 283 Average loss: 203.5736
====> Test set loss: 197.6509
====> Epoch: 284 Average loss: 203.4182
====> Test set loss: 197.3774
====> Epoch: 285 Average loss: 203.2602
====> Test set loss: 196.8402


====> Epoch: 286 Average loss: 203.5302
====> Test set loss: 196.9290
====> Epoch: 287 Average loss: 203.4544
====> Test set loss: 198.3167
====> Epoch: 288 Average loss: 203.6065
====> Test set loss: 196.5107
====> Epoch: 289 Average loss: 203.4126
====> Test set loss: 200.6701
====> Epoch: 290 Average loss: 203.5543
====> Test set loss: 198.4745


====> Epoch: 291 Average loss: 203.0458
====> Test set loss: 199.0848
====> Epoch: 292 Average loss: 203.3259
====> Test set loss: 198.1781
====> Epoch: 293 Average loss: 203.3666
====> Test set loss: 197.1352
====> Epoch: 294 Average loss: 203.5073
====> Test set loss: 197.8107
====> Epoch: 295 Average loss: 203.2222
====> Test set loss: 197.0547
====> Epoch: 296 Average loss: 203.1878
====> Test set loss: 198.4695


====> Epoch: 297 Average loss: 203.0388
====> Test set loss: 196.6735
====> Epoch: 298 Average loss: 202.9547
====> Test set loss: 196.9222
====> Epoch: 299 Average loss: 203.1005
====> Test set loss: 198.0530
====> Epoch: 300 Average loss: 203.1673
====> Test set loss: 201.2162
====> Epoch: 301 Average loss: 202.7548
====> Test set loss: 196.5969
====> Epoch: 302 Average loss: 202.9621
====> Test set loss: 194.9736


====> Epoch: 303 Average loss: 202.5770
====> Test set loss: 199.5712
====> Epoch: 304 Average loss: 203.3870
====> Test set loss: 196.1829
====> Epoch: 305 Average loss: 202.8210
====> Test set loss: 196.8751
====> Epoch: 306 Average loss: 203.1806
====> Test set loss: 196.1509
====> Epoch: 307 Average loss: 202.5711
====> Test set loss: 197.5439


====> Epoch: 308 Average loss: 202.8203
====> Test set loss: 198.6254
====> Epoch: 309 Average loss: 203.0935
====> Test set loss: 199.4469
====> Epoch: 310 Average loss: 202.7018
====> Test set loss: 195.7507
====> Epoch: 311 Average loss: 203.0906
====> Test set loss: 199.0162
====> Epoch: 312 Average loss: 202.7601
====> Test set loss: 198.6202
====> Epoch: 313 Average loss: 202.6122
====> Test set loss: 197.9669


====> Epoch: 314 Average loss: 202.8200
====> Test set loss: 201.0357
====> Epoch: 315 Average loss: 202.8447
====> Test set loss: 196.7811
====> Epoch: 316 Average loss: 202.6301
====> Test set loss: 196.9065
====> Epoch: 317 Average loss: 202.4970
====> Test set loss: 196.9852
====> Epoch: 318 Average loss: 202.8495
====> Test set loss: 195.7923


====> Epoch: 319 Average loss: 202.6939
====> Test set loss: 198.5510
====> Epoch: 320 Average loss: 202.3139
====> Test set loss: 195.6029
====> Epoch: 321 Average loss: 202.5506
====> Test set loss: 197.7145
====> Epoch: 322 Average loss: 202.6477
====> Test set loss: 198.5246
====> Epoch: 323 Average loss: 202.5648
====> Test set loss: 196.2910
====> Epoch: 324 Average loss: 202.4297
====> Test set loss: 194.5816


====> Epoch: 325 Average loss: 202.9940
====> Test set loss: 196.5480
====> Epoch: 326 Average loss: 202.5484
====> Test set loss: 198.8304
====> Epoch: 327 Average loss: 202.2378
====> Test set loss: 194.8911
====> Epoch: 328 Average loss: 202.4386
====> Test set loss: 197.0796
====> Epoch: 329 Average loss: 202.2293
====> Test set loss: 196.0821
====> Epoch: 330 Average loss: 202.2537
====> Test set loss: 195.9781


====> Epoch: 331 Average loss: 202.5389
====> Test set loss: 197.8848
====> Epoch: 332 Average loss: 202.2311
====> Test set loss: 197.1828
====> Epoch: 333 Average loss: 202.6309
====> Test set loss: 195.6359
====> Epoch: 334 Average loss: 202.4229
====> Test set loss: 196.8601
====> Epoch: 335 Average loss: 202.0768
====> Test set loss: 196.8953


====> Epoch: 336 Average loss: 202.6811
====> Test set loss: 196.1925
====> Epoch: 337 Average loss: 202.2517
====> Test set loss: 196.0592
====> Epoch: 338 Average loss: 201.9840
====> Test set loss: 195.2859
====> Epoch: 339 Average loss: 202.3600
====> Test set loss: 195.2525
====> Epoch: 340 Average loss: 202.4517
====> Test set loss: 194.6867
====> Epoch: 341 Average loss: 202.0606
====> Test set loss: 196.9772


====> Epoch: 342 Average loss: 202.1822
====> Test set loss: 196.6707


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



====> Epoch: 417 Average loss: 200.6700
====> Test set loss: 195.1790
====> Epoch: 418 Average loss: 200.6613
====> Test set loss: 196.2136
====> Epoch: 419 Average loss: 200.6839
====> Test set loss: 193.6550
====> Epoch: 420 Average loss: 201.2488
====> Test set loss: 197.7774
====> Epoch: 421 Average loss: 201.0406
====> Test set loss: 193.7495
====> Epoch: 422 Average loss: 200.7593
====> Test set loss: 195.9652


====> Epoch: 423 Average loss: 201.0499
====> Test set loss: 197.0907
====> Epoch: 424 Average loss: 201.0616
====> Test set loss: 193.0602
====> Epoch: 425 Average loss: 200.6570
====> Test set loss: 196.4750
====> Epoch: 426 Average loss: 200.7326
====> Test set loss: 192.6299
====> Epoch: 427 Average loss: 200.5420
====> Test set loss: 195.0120


====> Epoch: 428 Average loss: 201.5154
====> Test set loss: 194.0026
====> Epoch: 429 Average loss: 200.5837
====> Test set loss: 192.8054
====> Epoch: 430 Average loss: 201.0561
====> Test set loss: 193.2554
====> Epoch: 431 Average loss: 201.0456
====> Test set loss: 195.1693
====> Epoch: 432 Average loss: 200.4225
====> Test set loss: 194.5179
====> Epoch: 433 Average loss: 200.4027
====> Test set loss: 195.4787


====> Epoch: 434 Average loss: 201.0636


KeyboardInterrupt: 