In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as utils 

import cv2
import glob

from PIL import Image
from torchvision import transforms, utils
import random

In [4]:
mnist_test_input = np.load("CIFAR10/testImages.npy")

mnist_train_input = np.load("CIFAR10/trainImages.npy")




train_input = torch.Tensor(mnist_train_input.reshape(mnist_train_input.shape[0], 3, 32, 32)/128.0 - 1.0)#transform_to_tensor( transform, mnist_train_input)
test_input  = torch.Tensor(mnist_test_input.reshape(mnist_test_input.shape[0], 3, 32, 32)/128.0 - 1.0)#transform_to_tensor( transform, mnist_test_input)


In [5]:
# motivated by https://github.com/pytorch/examples/blob/master/dcgan/main.py
class Generator(nn.Module): 
    # input_size -> 28 x 28
    def __init__(self, input_size, channels):
        super(Generator, self).__init__()
        
        
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(     input_size, channels * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(channels * 8),
            nn.ReLU(True),
            # state size. (channels*8) x 4 x 4
            nn.ConvTranspose2d(channels * 8, channels * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(channels * 4),
            nn.ReLU(True),
            # state size. (channels*4) x 8 x 8
            nn.ConvTranspose2d(channels * 4, channels * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(channels * 2),
            nn.ReLU(True),
            # state size. (channels*2) x 16 x 16
            nn.ConvTranspose2d(channels * 2, 3, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. 1 x 32 x 32
        )
    
    def forward(self, x):       
        return self.main(x)
    

class Discriminator(nn.Module):
    def __init__(self, channels):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # input is 1 x 32 x 32
            nn.Conv2d(3, channels * 2, 4, 2, 1, bias=False),
            #nn.BatchNorm2d(channels * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (channels*2) x 16 x 16
            nn.Conv2d(channels * 2, channels * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(channels * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (channels*4) x 8 x 8
            nn.Conv2d(channels * 4, channels * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(channels * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (channels*8) x 4 x 4
            nn.Conv2d(channels * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )
        
    def forward(self, x):       
        return self.main(x)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")    

chanNum = 32
nz = 100
    
netG = Generator(nz, chanNum).to(device)
netD = Discriminator(chanNum).to(device)
batch_size = 500
epoch = 300

trlr = 0.0001

optimizerD = optim.Adam(netD.parameters(), lr=trlr)
optimizerG = optim.Adam(netG.parameters(), lr=trlr)

criterion = nn.BCELoss()#F.binary_cross_entropy_with_logits

counter = 0


fn = torch.randn(1, nz, 1, 1, device=device)
cpu = torch.device("cpu")


for lll in range(epoch):
    print("epoch :" + str(lll + 1) )
    
    a = train_input.to(cpu).numpy()
    np.random.shuffle(a)
    train_input = torch.Tensor(a).to(device)
    
    for i in range(0, train_input.shape[0], batch_size):
        
    
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        
        data = train_input[i:i + batch_size]
        
        netD.zero_grad()
        real_cpu = data.to(device)
        
        label = torch.full((batch_size,), 1, device=device)
        label = label.view(-1, 1, 1, 1)
        
        output = netD(real_cpu)
        
        #output = output.view(-1)
        
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        # train with fake
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(0)
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(1)  # fake labels are real for generator cost
        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()
        
        if i%(batch_size*10) == 0:
            print(errD.item(), errG.item())
            
    
    output_sample = netG(fn).to(cpu).detach().numpy().reshape(3, 32, 32)   
        
    output_sample_R = output_sample[0].reshape(32, 32, 1)
    output_sample_G = output_sample[1].reshape(32, 32, 1)
    output_sample_B = output_sample[2].reshape(32, 32, 1)
        
    output_sample = np.concatenate([output_sample_R, output_sample_G, output_sample_B], axis = 2)
    
    
    
    print(np.max(output_sample))
    print(np.min(output_sample))
    cv2.imwrite("output/" + str(counter) + ".jpg",(output_sample + 1) * 128)
    counter = counter + 1
    


epoch :1
1.3721816539764404 0.8305067420005798
1.0921283960342407 1.0568556785583496
0.9667689800262451 1.3049755096435547
0.7540966272354126 1.806807518005371
0.4109864830970764 2.6540119647979736
0.5040451884269714 2.4388279914855957
0.5147438049316406 2.7157437801361084
0.2919873595237732 3.4333441257476807
0.24003596603870392 3.458007335662842
0.19190123677253723 3.587157964706421
0.985793
-0.9838035
epoch :2
0.1811167597770691 3.915347099304199
0.242074653506279 3.8507332801818848
0.1619994342327118 4.013177871704102
0.11705789715051651 4.43682336807251
0.10431542992591858 4.175917148590088
0.11075007915496826 4.0650153160095215
0.07478498667478561 4.668224811553955
0.0669635757803917 4.22370719909668
0.08933380246162415 3.8646533489227295
0.08205150067806244 4.11275053024292
0.95546234
-0.9907515
epoch :3
0.08684545755386353 4.244515419006348
0.07696449011564255 4.228545665740967
0.06835443526506424 4.246151447296143
0.05851692706346512 4.552626132965088
0.0496147982776165 4.4319

0.0319831408560276 6.538575649261475
0.019833644852042198 8.137663841247559
0.0227101631462574 7.699166297912598
0.02938196063041687 7.229325294494629
0.012531884014606476 7.365367889404297
0.024419451132416725 7.017804145812988
0.02667050063610077 6.834583282470703
0.9996765
-0.99785954
epoch :41
0.021881671622395515 6.146587371826172
0.027798902243375778 6.722303867340088
0.02083549275994301 7.361138820648193
0.031261444091796875 6.580589294433594
0.030887391418218613 6.908348560333252
0.02799331396818161 7.094851493835449
0.02574852481484413 7.680047988891602
0.02452632412314415 6.673507213592529
0.013848491944372654 6.727615833282471
0.018680747598409653 7.032238483428955
0.9988257
-0.99701524
epoch :42
0.014651878736913204 7.1739277839660645
0.01628519594669342 6.9967875480651855
0.03158891201019287 7.116988658905029
0.025028564035892487 6.953043460845947
0.0167701318860054 7.075355529785156
0.021570954471826553 6.826303482055664
0.012630992569029331 7.397534370422363
0.0114118270

0.021743331104516983 6.655756950378418
0.04247589036822319 9.165924072265625
0.009117946028709412 7.915401458740234
0.010354312136769295 7.672308444976807
0.9993817
-0.9999289
epoch :80
0.04197125881910324 8.953377723693848
0.010246324352920055 7.603000164031982
0.04106676205992699 7.232892990112305
0.03167461231350899 6.434295177459717
0.04240650683641434 6.775207996368408
0.027831990271806717 7.123608589172363
0.012758694589138031 7.838242053985596
0.02104315161705017 6.381778717041016
0.011186454445123672 7.399298667907715
0.01632087305188179 7.9620490074157715
0.9998498
-0.9999664
epoch :81
0.06058500334620476 7.4756035804748535
0.05070018395781517 7.074698448181152
0.022147998213768005 7.23087215423584
0.017065610736608505 7.063760757446289
0.020776528865098953 7.083844184875488
0.022606322541832924 6.635052680969238
0.020973503589630127 7.24269962310791
0.029234640300273895 6.322280406951904
0.037816256284713745 7.265191078186035
0.03108658641576767 6.63020133972168
0.99920803
-0

0.045381106436252594 6.902756214141846
0.99999124
-0.999773
epoch :119
0.06344807147979736 6.049901008605957
0.02910863235592842 7.90648889541626
0.07409214973449707 5.882226943969727
0.04301435500383377 7.248344898223877
0.041028596460819244 7.161391735076904
0.0639963299036026 6.21159553527832
0.05145709216594696 6.666985034942627
0.06157347559928894 5.94896125793457
0.04119129106402397 6.770534515380859
0.054476942867040634 6.017340183258057
0.9998563
-0.99999756
epoch :120
0.057458944618701935 6.018069744110107
0.0735093355178833 8.111467361450195
0.04768306389451027 5.3199992179870605
0.10030637681484222 6.019060134887695
0.09315904974937439 6.160342693328857
0.08809076994657516 5.947210788726807
0.06253018975257874 6.619600772857666
0.04945506155490875 6.562094688415527
0.03470974415540695 8.374733924865723
0.07363662868738174 6.723249912261963
0.9999819
-0.99936813
epoch :121
0.05549377575516701 6.965423107147217
0.038801588118076324 7.468525409698486
0.16862916946411133 7.00990

0.05827245116233826 5.720331192016602
0.060866475105285645 6.164531230926514
0.06558012962341309 5.40336799621582
0.031780194491147995 6.608888626098633
0.06345383077859879 4.020276069641113
0.9994936
-0.9980756
epoch :159
0.11001457273960114 5.047728538513184
0.08133739233016968 5.734658241271973
0.04658080264925957 5.48207950592041
0.02922132797539234 6.087210655212402
0.18873873353004456 4.985194683074951
0.04432370886206627 6.423791408538818
0.07813719660043716 6.2093186378479
0.08540699630975723 6.157628059387207
0.04564676433801651 5.1150617599487305
0.04271090403199196 6.677136421203613
0.9999875
-0.9996556
epoch :160
0.1511111855506897 5.1605544090271
0.05178075283765793 7.144436836242676
0.0855308547616005 7.817714691162109
0.051585033535957336 6.128313064575195
0.12978911399841309 5.6204400062561035
0.022555438801646233 7.054366111755371
0.06958325952291489 5.103635787963867
0.07962334901094437 4.825456142425537
0.07318615913391113 6.019903182983398
0.05764653533697128 5.9996

0.02745141088962555 6.011499881744385
0.03302355483174324 6.1224470138549805
0.03185373544692993 6.411496162414551
0.0631304681301117 5.580774784088135
0.03242143243551254 6.491003036499023
0.05067933723330498 5.427751541137695
0.043736476451158524 5.205402851104736
0.04676260054111481 5.978537559509277
0.08322256803512573 5.353850841522217
0.9999922
-0.99842435
epoch :199
0.10459408909082413 8.017783164978027
0.044949859380722046 5.18049430847168
0.03656242415308952 6.614071846008301
0.0751219093799591 5.856424808502197
0.049453556537628174 5.945441722869873
0.08652905374765396 4.721554279327393
0.16945227980613708 5.6758575439453125
0.0567963607609272 6.187199115753174
0.10943682491779327 5.789163112640381
0.0576409175992012 6.391719818115234
0.9999808
-0.9908975
epoch :200
0.03773210942745209 6.196399211883545
0.04385971650481224 6.198055744171143
0.07771310210227966 5.321039199829102
0.13080160319805145 5.824838638305664
0.06756193935871124 7.190119743347168
0.06038479506969452 4.9

0.054704781621694565 5.79271936416626
0.09154470264911652 4.938061714172363
0.99996555
-0.9969893
epoch :238
0.06547819077968597 5.427984237670898
0.04438888281583786 6.240123271942139
0.09334685653448105 5.575257778167725
0.048607274889945984 6.3024187088012695
0.04928700625896454 6.588996887207031
0.09103616327047348 5.369673252105713
0.09437458217144012 4.874366283416748
0.07050972431898117 5.216878414154053
0.05198992043733597 6.403525352478027
0.05321358144283295 5.242717742919922
0.99969673
-0.99920416
epoch :239
0.12620776891708374 4.536661624908447
0.05090548098087311 5.710284233093262
0.05287438631057739 6.482214450836182
0.11382569372653961 4.7259392738342285
0.07005788385868073 5.296439170837402
0.0669974610209465 4.947232246398926
0.04067530483007431 5.633033752441406
0.049750205129384995 5.617534637451172
0.1785895824432373 4.451351642608643
0.08180249482393265 5.3009538650512695
0.99986786
-0.9992569
epoch :240
0.052886154502630234 6.640350818634033
0.07674521207809448 4.

0.10790103673934937 5.037833213806152
0.1687641143798828 5.023283004760742
0.07816406339406967 4.475741863250732
0.08143773674964905 7.495866298675537
0.3060069978237152 6.4788079261779785
0.41365551948547363 8.306564331054688
0.9996396
-0.9925043
epoch :278
0.11284049600362778 7.6622724533081055
0.05427724868059158 6.000535011291504
0.21933889389038086 3.565422534942627
0.17736080288887024 4.696997165679932
0.08540942519903183 3.9376368522644043
0.2063502073287964 5.626060485839844
0.029254067689180374 7.525140762329102
0.16144193708896637 5.258558750152588
0.12364242970943451 4.839536190032959
0.16789346933364868 4.984260559082031
0.9999745
-0.996267
epoch :279
0.13113997876644135 4.36065149307251
0.17076213657855988 4.73642110824585
0.4367619752883911 6.254117488861084
0.02810649946331978 6.175607204437256
0.0830635279417038 5.580404758453369
0.1367352157831192 4.955729007720947
0.10495270788669586 5.252305030822754
0.07963666319847107 5.418091773986816
0.14718255400657654 3.9205868

In [12]:
cpu = torch.device( "cpu")    

noise = torch.randn(1, nz, 1, 1, device=device)


fake = netG(noise)

fake = fake.view(32, 32)

fakenp = fake.to(cpu).detach().numpy()



cv2.imwrite( "example.jpg", (fakenp + 1.0) * 128 )

True

In [14]:
torch.save(netG.state_dict(), "netG_CFIAR10_100")

