In [2]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [11]:
latent_size = 64
hidden_size = 256
image_size = 784
epochs = 200
batch_size = 100
sample_dir = 'samples'

In [12]:
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

In [13]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

mnist = torchvision.datasets.MNIST(root='../data/',
                                   train=True,
                                   transform=transform,
                                   download=True)

data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size,
                                          shuffle=True)

In [14]:
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, hidden_size),
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid()
)

G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh()
)

In [15]:
D = D.to(device)
G = G.to(device)

In [16]:
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

In [17]:
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)


def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()

In [18]:
total_step = len(data_loader)
for epoch in range(epochs):
    for i, (images, _) in enumerate(data_loader):
        images = images.reshape(batch_size, -1).to(device)

        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs

        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs

        d_loss = d_loss_real + d_loss_fake
        reset_grad()
        d_loss.backward()
        d_optimizer.step()

        z = torch.randn(batch_size, latent_size).to(device)
        fake_images = G(z)
        outputs = D(fake_images)

        g_loss = criterion(outputs, real_labels)

        reset_grad()
        g_loss.backward()
        g_optimizer.step()

        if(i + 1) % 200 == 0:
            print(epoch, epochs, i+1, total_step, d_loss.item(), g_loss.item(),
                  real_score.mean().item(), fake_score.mean().item())

    if(epoch + 1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))

    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), os.path.join(
        sample_dir, 'fake_images-{}.png'.format(epoch+1)))

torch.save(G.state_dict(), 'G.ckpt')
torch.save(D.state_dict(), 'D.ckpt')

0 200 200 600 0.059341445565223694 4.109435558319092 0.9890259504318237 0.04678564518690109
0 200 400 600 0.042105838656425476 5.756846904754639 0.9871971011161804 0.027017049491405487
0 200 600 600 0.03686777874827385 5.6250691413879395 0.9893426299095154 0.025321951135993004
1 200 200 600 0.12100980430841446 4.588721752166748 0.9390741586685181 0.03446531295776367
1 200 400 600 0.08876954019069672 5.021081447601318 0.9662787914276123 0.03780444711446762
1 200 600 600 0.14142343401908875 4.035740852355957 0.9595312476158142 0.08282560855150223
2 200 200 600 0.14945241808891296 4.590148448944092 0.9379166960716248 0.03772294521331787
2 200 400 600 0.4698045551776886 2.5062475204467773 0.8366687297821045 0.14511968195438385
2 200 600 600 0.16813039779663086 5.337667465209961 0.9421612620353699 0.039320170879364014
3 200 200 600 0.48595303297042847 3.6646909713745117 0.8544695973396301 0.1589338779449463
3 200 400 600 0.5057373642921448 4.333642482757568 0.8871980309486389 0.212190583348

30 200 200 600 0.18228980898857117 4.733157157897949 0.9311180114746094 0.05302654951810837
30 200 400 600 0.5174337029457092 3.652967929840088 0.8248844742774963 0.0910499170422554
30 200 600 600 0.4020730257034302 3.590503454208374 0.8731974363327026 0.07364524900913239
31 200 200 600 0.3296677768230438 4.42137336730957 0.877406895160675 0.037781428545713425
31 200 400 600 0.4397886395454407 3.9752755165100098 0.8315312266349792 0.0681825652718544
31 200 600 600 0.2618735730648041 3.4853813648223877 0.9189217686653137 0.09897331893444061
32 200 200 600 0.4391173720359802 3.8182687759399414 0.8661733865737915 0.10481429845094681
32 200 400 600 0.3261282444000244 4.712432861328125 0.9335322976112366 0.15373462438583374
32 200 600 600 0.3549455404281616 4.306268215179443 0.8541050553321838 0.0915822759270668
33 200 200 600 0.45095279812812805 4.233772277832031 0.9122800827026367 0.17620258033275604
33 200 400 600 0.5479370951652527 2.7826931476593018 0.9064078330993652 0.213873788714408

60 200 200 600 0.41765639185905457 3.301039695739746 0.8845176696777344 0.17305056750774384
60 200 400 600 0.4538530111312866 3.715430974960327 0.8454084396362305 0.16543518006801605
60 200 600 600 0.35122162103652954 2.5153167247772217 0.8667677044868469 0.12375468760728836
61 200 200 600 0.5887960195541382 2.251800775527954 0.8411731719970703 0.21013770997524261
61 200 400 600 0.42723312973976135 3.758854866027832 0.869932234287262 0.17835241556167603
61 200 600 600 0.8793547749519348 2.110034465789795 0.7456060647964478 0.23229442536830902
62 200 200 600 0.6246877312660217 2.263483762741089 0.9074590802192688 0.31346800923347473
62 200 400 600 0.44499754905700684 3.571059465408325 0.8463990688323975 0.12222184985876083
62 200 600 600 0.47203880548477173 2.7066586017608643 0.8368905186653137 0.15410026907920837
63 200 200 600 0.6079087257385254 2.644644260406494 0.7667187452316284 0.15944838523864746
63 200 400 600 0.4442415237426758 2.3310956954956055 0.8135270476341248 0.1198848411

90 200 400 600 0.669735312461853 1.9789228439331055 0.7055895328521729 0.15411829948425293
90 200 600 600 0.8229697942733765 1.425915241241455 0.680898904800415 0.21355636417865753
91 200 200 600 0.7348275184631348 1.3934494256973267 0.7929279208183289 0.3034500181674957
91 200 400 600 0.8593859672546387 1.760433554649353 0.6713584065437317 0.19780784845352173
91 200 600 600 0.6675841808319092 2.2012507915496826 0.7531459927558899 0.1890282928943634
92 200 200 600 0.7077271938323975 1.5364959239959717 0.7639878392219543 0.2091187983751297
92 200 400 600 0.8478935956954956 2.344122886657715 0.7463641166687012 0.2807302474975586
92 200 600 600 0.8352941274642944 2.4041075706481934 0.7322332859039307 0.2509329915046692
93 200 200 600 0.6216518878936768 2.866236686706543 0.7581455707550049 0.17030860483646393
93 200 400 600 0.7828507423400879 2.2062408924102783 0.9113222360610962 0.38373246788978577
93 200 600 600 0.767910897731781 2.707284927368164 0.6655935645103455 0.14267222583293915
9

120 200 400 600 0.7930848598480225 1.9271759986877441 0.7052201628684998 0.23313495516777039
120 200 600 600 0.6931952238082886 1.6495623588562012 0.7368989586830139 0.23634590208530426
121 200 200 600 0.7318360805511475 2.2262253761291504 0.7072795033454895 0.19512759149074554
121 200 400 600 0.9577162861824036 2.0001251697540283 0.6827889680862427 0.2780825197696686
121 200 600 600 0.8111404776573181 1.8273075819015503 0.7561988830566406 0.28473109006881714
122 200 200 600 0.9379547834396362 1.838623046875 0.6881136894226074 0.29489070177078247
122 200 400 600 0.9514013528823853 1.521095633506775 0.6937612891197205 0.3165075182914734
122 200 600 600 1.0364303588867188 1.4438958168029785 0.6565397381782532 0.3023958206176758
123 200 200 600 0.7911785244941711 1.7292609214782715 0.756639301776886 0.2829208970069885
123 200 400 600 0.7174006104469299 1.6532968282699585 0.8324599862098694 0.3207569718360901
123 200 600 600 0.834651529788971 1.1601238250732422 0.7998014688491821 0.3288826

150 200 400 600 0.9411622285842896 1.837670087814331 0.6351266503334045 0.26700541377067566
150 200 600 600 0.7993757724761963 1.678339958190918 0.7911180853843689 0.33114904165267944
151 200 200 600 0.9912204146385193 1.3445875644683838 0.718887984752655 0.3849577307701111
151 200 400 600 1.0099523067474365 1.4193458557128906 0.7150618433952332 0.347490519285202
151 200 600 600 1.090887427330017 1.1734235286712646 0.7177258133888245 0.4020940363407135
152 200 200 600 0.8542239665985107 1.6474957466125488 0.6789631247520447 0.25520774722099304
152 200 400 600 0.9199838638305664 1.6093770265579224 0.6843539476394653 0.2679237127304077
152 200 600 600 0.9865818619728088 1.3369921445846558 0.753254234790802 0.3969036042690277
153 200 200 600 1.1207208633422852 1.5303139686584473 0.5604654550552368 0.23217135667800903
153 200 400 600 1.1126536130905151 1.6397864818572998 0.6201470494270325 0.2966502010822296
153 200 600 600 0.8569012880325317 1.6187514066696167 0.6768829226493835 0.2686634

180 200 400 600 1.0112946033477783 1.486090064048767 0.626899242401123 0.29892203211784363
180 200 600 600 0.9590332508087158 1.9522829055786133 0.6817440390586853 0.2851439118385315
181 200 200 600 1.1348915100097656 1.675836443901062 0.5713383555412292 0.23763762414455414
181 200 400 600 0.886161208152771 1.5686756372451782 0.6503607034683228 0.23135071992874146
181 200 600 600 0.8456092476844788 1.4933390617370605 0.763116717338562 0.32561659812927246
182 200 200 600 0.9952126145362854 1.2415238618850708 0.6812238693237305 0.3283674418926239
182 200 400 600 1.059198021888733 1.5205689668655396 0.6314235329627991 0.3095804452896118
182 200 600 600 0.935788631439209 1.7481553554534912 0.62775057554245 0.24876049160957336
183 200 200 600 1.048467993736267 1.525883674621582 0.702978253364563 0.38618430495262146
183 200 400 600 0.8729828596115112 1.4563980102539062 0.7283393740653992 0.3276992738246918
183 200 600 600 0.9079015254974365 1.438686728477478 0.7361348271369934 0.332357108592