In [1]:
import os

import numpy as np

import torch
import torch.nn as nn

from torch.autograd import Variable

import torchvision.utils

from new_data_loader import get_loader
from make_gif import make_gif

from DiscrimiatorModule import Discriminator
from GeneratorModule import Generator

In [2]:
def init_weights(m) :
    name = type(m)

    if name == nn.Conv3d or name == nn.ConvTranspose2d or name == nn.ConvTranspose3d :
        m.weight.data.normal_(0.0, 0.01)
        m.bias.data.fill_(0)
    elif name == nn.BatchNorm2d or name == nn.BatchNorm3d :
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [None]:
pre_train = True

batch_size = 64
video_size = 64
epoch_size = 1000
        
#check GPU
is_gpu = torch.cuda.is_available()
print(is_gpu)

if is_gpu :
    dtype = torch.cuda.FloatTensor
else :
    dtype = torch.FloatTensor

if pre_train :
    D = torch.load('D2.ckpt').type(dtype)
    G = torch.load('G2.ckpt').type(dtype)
else :
    D = Discriminator()
    D = D.type(dtype)

    G = Generator()
    G = G.type(dtype)

    D.apply(init_weights)
    G.apply(init_weights)

criterion = nn.BCEWithLogitsLoss().type(dtype)

d_optimizer = torch.optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))

True


In [None]:
data_loader = get_loader(data_path='./dataset', image_size=video_size, batch_size=batch_size, num_workers=2)

for epoch in range(258, epoch_size + 1) :
    for iter, (video, y) in enumerate(data_loader) :
        local_batch_size = video.size()[0]
        
        real_labels = Variable(torch.ones(local_batch_size, 2).type(dtype))
        fake_labels = Variable(torch.zeros(local_batch_size, 2).type(dtype))
        
        # 1. Train Discriminator
        video_data = Variable(video).type(dtype)
        y = Variable(y).type(dtype)
        
        y_data = torch.unsqueeze(torch.unsqueeze(torch.unsqueeze(y, -1), -1), -1)
        y_data = y_data.expand(local_batch_size, 6, 32, 64, 64)
       
        
        
        # 1-1. Real Video
        outputs = D(video_data, y_data).view(local_batch_size, 2)
        d_loss_real = criterion(outputs, real_labels)

        
        
        # 1-2. Fake Video
        z = Variable(torch.randn(local_batch_size, 100) * 0.01).type(dtype)
        fake_videos = G(z, y)
        outputs = D(fake_videos, y_data).view(local_batch_size, 2)
        d_loss_fake = criterion(outputs, fake_labels)

        d_loss = d_loss_real + d_loss_fake
        
        D.zero_grad()
        G.zero_grad()
        d_loss.backward()
        d_optimizer.step()





        # 2. Train Generator
        z = Variable(torch.randn(local_batch_size, 100) * 0.01).type(dtype)
        fake_videos = G(z, y)
        outputs = D(fake_videos, y_data).view(local_batch_size, 2)

        g_loss = criterion(outputs, real_labels)
        D.zero_grad()
        G.zero_grad()

        g_loss.backward()
        g_optimizer.step()
          
        print('Epoch [%d/%d], Iter [%d/%d], d_loss: %.4f, g_loss: %.4f' % (epoch, epoch_size, iter, len(data_loader), d_loss.data[0], g_loss.data[0]))
    
    print('Model saving...')
    
    if epoch % 100 == 0 :
        torch.save(D, 'D_' + str(epoch) + '.ckpt')
        torch.save(G, 'G_' + str(epoch) + '.ckpt')
    else :  
        torch.save(D, 'D2.ckpt')
        torch.save(G, 'G2.ckpt')

Epoch [258/1000], Iter [0/49], d_loss: 0.0196, g_loss: 2.9811
Epoch [258/1000], Iter [1/49], d_loss: 0.6080, g_loss: 7.1604
Epoch [258/1000], Iter [2/49], d_loss: 0.0280, g_loss: 9.4297
Epoch [258/1000], Iter [3/49], d_loss: 0.0635, g_loss: 9.9971
Epoch [258/1000], Iter [4/49], d_loss: 1.0498, g_loss: 3.4247
Epoch [258/1000], Iter [5/49], d_loss: 0.0452, g_loss: 2.6584
Epoch [258/1000], Iter [6/49], d_loss: 0.4451, g_loss: 5.3827
Epoch [258/1000], Iter [7/49], d_loss: 0.0660, g_loss: 8.3923
Epoch [258/1000], Iter [8/49], d_loss: 0.0175, g_loss: 5.4445
Epoch [258/1000], Iter [9/49], d_loss: 0.0449, g_loss: 7.7921
Epoch [258/1000], Iter [10/49], d_loss: 0.0320, g_loss: 5.6586
Epoch [258/1000], Iter [11/49], d_loss: 0.0188, g_loss: 6.5327
Epoch [258/1000], Iter [12/49], d_loss: 0.2696, g_loss: 4.9968
Epoch [258/1000], Iter [13/49], d_loss: 0.0160, g_loss: 4.5912
Epoch [258/1000], Iter [14/49], d_loss: 0.2482, g_loss: 7.2264
Epoch [258/1000], Iter [15/49], d_loss: 0.0908, g_loss: 6.2738
Ep

Epoch [260/1000], Iter [32/49], d_loss: 0.0530, g_loss: 8.5669
Epoch [260/1000], Iter [33/49], d_loss: 0.5933, g_loss: 3.1082
Epoch [260/1000], Iter [34/49], d_loss: 0.3357, g_loss: 7.3612
Epoch [260/1000], Iter [35/49], d_loss: 0.0073, g_loss: 10.4734
Epoch [260/1000], Iter [36/49], d_loss: 0.1463, g_loss: 9.9986
Epoch [260/1000], Iter [37/49], d_loss: 0.0370, g_loss: 7.3659
Epoch [260/1000], Iter [38/49], d_loss: 0.0140, g_loss: 6.7709
Epoch [260/1000], Iter [39/49], d_loss: 0.1209, g_loss: 8.1265
Epoch [260/1000], Iter [40/49], d_loss: 0.0058, g_loss: 8.7497
Epoch [260/1000], Iter [41/49], d_loss: 0.1296, g_loss: 7.3109
Epoch [260/1000], Iter [42/49], d_loss: 0.0153, g_loss: 6.6893
Epoch [260/1000], Iter [43/49], d_loss: 0.0220, g_loss: 6.5138
Epoch [260/1000], Iter [44/49], d_loss: 0.0170, g_loss: 4.5035
Epoch [260/1000], Iter [45/49], d_loss: 0.0111, g_loss: 6.1438
Epoch [260/1000], Iter [46/49], d_loss: 0.0091, g_loss: 5.9378
Epoch [260/1000], Iter [47/49], d_loss: 0.0180, g_loss

Epoch [263/1000], Iter [15/49], d_loss: 0.0390, g_loss: 6.9227
Epoch [263/1000], Iter [16/49], d_loss: 0.0079, g_loss: 6.9530
Epoch [263/1000], Iter [17/49], d_loss: 0.0713, g_loss: 8.1117
Epoch [263/1000], Iter [18/49], d_loss: 0.0086, g_loss: 6.9184
Epoch [263/1000], Iter [19/49], d_loss: 0.1083, g_loss: 6.0773
Epoch [263/1000], Iter [20/49], d_loss: 0.0479, g_loss: 4.7783
Epoch [263/1000], Iter [21/49], d_loss: 0.1523, g_loss: 4.9918
Epoch [263/1000], Iter [22/49], d_loss: 0.0962, g_loss: 6.0984
Epoch [263/1000], Iter [23/49], d_loss: 0.0111, g_loss: 5.6742
Epoch [263/1000], Iter [24/49], d_loss: 0.0616, g_loss: 5.4713
Epoch [263/1000], Iter [25/49], d_loss: 0.0218, g_loss: 6.1650
Epoch [263/1000], Iter [26/49], d_loss: 0.0034, g_loss: 5.8384
Epoch [263/1000], Iter [27/49], d_loss: 0.0095, g_loss: 5.7115
Epoch [263/1000], Iter [28/49], d_loss: 0.0163, g_loss: 5.2961
Epoch [263/1000], Iter [29/49], d_loss: 0.0153, g_loss: 5.3504
Epoch [263/1000], Iter [30/49], d_loss: 0.0389, g_loss:

Epoch [265/1000], Iter [47/49], d_loss: 0.0172, g_loss: 6.2966
Epoch [265/1000], Iter [48/49], d_loss: 0.0067, g_loss: 4.5649
Model saving...
Epoch [266/1000], Iter [0/49], d_loss: 0.0160, g_loss: 5.4025
Epoch [266/1000], Iter [1/49], d_loss: 0.0093, g_loss: 5.3753
Epoch [266/1000], Iter [2/49], d_loss: 0.0527, g_loss: 5.6143
Epoch [266/1000], Iter [3/49], d_loss: 0.0227, g_loss: 5.1602
Epoch [266/1000], Iter [4/49], d_loss: 0.0891, g_loss: 7.3392
Epoch [266/1000], Iter [5/49], d_loss: 0.0076, g_loss: 8.0974
Epoch [266/1000], Iter [6/49], d_loss: 1.7748, g_loss: 0.0192
Epoch [266/1000], Iter [7/49], d_loss: 4.5615, g_loss: 17.1963
Epoch [266/1000], Iter [8/49], d_loss: 2.3972, g_loss: 11.1998
Epoch [266/1000], Iter [9/49], d_loss: 0.0205, g_loss: 6.8033
Epoch [266/1000], Iter [10/49], d_loss: 0.1529, g_loss: 7.2698
Epoch [266/1000], Iter [11/49], d_loss: 0.1253, g_loss: 6.5965
Epoch [266/1000], Iter [12/49], d_loss: 0.3091, g_loss: 6.0900
Epoch [266/1000], Iter [13/49], d_loss: 0.0931,