In [1]:
import os

import numpy as np

import torch
import torch.nn as nn

from torch.autograd import Variable

import torchvision.utils

from new_data_loader import get_loader
from make_gif import make_gif

from DiscrimiatorModule import Discriminator
from GeneratorModule import Generator

In [2]:
def init_weights(m) :
    name = type(m)

    if name == nn.Conv3d or name == nn.ConvTranspose2d or name == nn.ConvTranspose3d :
        m.weight.data.normal_(0.0, 0.01)
        m.bias.data.fill_(0)
    elif name == nn.BatchNorm2d or name == nn.BatchNorm3d :
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [None]:
pre_train = True

batch_size = 64
video_size = 64
epoch_size = 1000
        
#check GPU
is_gpu = torch.cuda.is_available()
print(is_gpu)

if is_gpu :
    dtype = torch.cuda.FloatTensor
else :
    dtype = torch.FloatTensor

if pre_train :
    D = torch.load('D.ckpt').type(dtype)
    G = torch.load('G.ckpt').type(dtype)
else :
    D = Discriminator()
    D = D.type(dtype)

    G = Generator()
    G = G.type(dtype)

    D.apply(init_weights)
    G.apply(init_weights)

criterion = nn.BCEWithLogitsLoss().type(dtype)

d_optimizer = torch.optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))

True


In [None]:
data_loader = get_loader(data_path='./dataset', image_size=video_size, batch_size=batch_size, num_workers=2)

for epoch in range(384, epoch_size + 1) :
    for iter, (video, y) in enumerate(data_loader) :
        local_batch_size = video.size()[0]
        
        real_labels = Variable(torch.ones(local_batch_size, 2).type(dtype))
        fake_labels = Variable(torch.zeros(local_batch_size, 2).type(dtype))
        
        # 1. Train Discriminator
        video_data = Variable(video).type(dtype)
        y = Variable(y).type(dtype)
        
        y_data = torch.unsqueeze(torch.unsqueeze(torch.unsqueeze(y, -1), -1), -1)
        y_data = y_data.expand(local_batch_size, 6, 32, 64, 64)
       
        
        
        # 1-1. Real Video
        outputs = D(video_data, y_data).view(local_batch_size, 2)
        d_loss_real = criterion(outputs, real_labels)

        
        
        # 1-2. Fake Video
        z = Variable(torch.randn(local_batch_size, 100) * 0.01).type(dtype)
        fake_videos = G(z, y)
        outputs = D(fake_videos, y_data).view(local_batch_size, 2)
        d_loss_fake = criterion(outputs, fake_labels)

        d_loss = d_loss_real + d_loss_fake
        
        D.zero_grad()
        G.zero_grad()
        d_loss.backward()
        d_optimizer.step()





        # 2. Train Generator
        z = Variable(torch.randn(local_batch_size, 100) * 0.01).type(dtype)
        fake_videos = G(z, y)
        outputs = D(fake_videos, y_data).view(local_batch_size, 2)

        g_loss = criterion(outputs, real_labels)
        D.zero_grad()
        G.zero_grad()

        g_loss.backward()
        g_optimizer.step()
          
        print('Epoch [%d/%d], Iter [%d/%d], d_loss: %.4f, g_loss: %.4f' % (epoch, epoch_size, iter, len(data_loader), d_loss.data[0], g_loss.data[0]))
    
    print('Model saving...')
    
    torch.save(D, 'D.ckpt')
    torch.save(G, 'G.ckpt')

Epoch [384/1000], Iter [0/49], d_loss: 0.2224, g_loss: 2.0415
Epoch [384/1000], Iter [1/49], d_loss: 0.1755, g_loss: 6.3966
Epoch [384/1000], Iter [2/49], d_loss: 0.0084, g_loss: 9.1431
Epoch [384/1000], Iter [3/49], d_loss: 0.0066, g_loss: 8.1735
Epoch [384/1000], Iter [4/49], d_loss: 0.0121, g_loss: 7.2475
Epoch [384/1000], Iter [5/49], d_loss: 0.0454, g_loss: 8.9441
Epoch [384/1000], Iter [6/49], d_loss: 0.0042, g_loss: 8.0692
Epoch [384/1000], Iter [7/49], d_loss: 0.0257, g_loss: 8.9309
Epoch [384/1000], Iter [8/49], d_loss: 0.0479, g_loss: 6.7944
Epoch [384/1000], Iter [9/49], d_loss: 0.0291, g_loss: 6.1598
Epoch [384/1000], Iter [10/49], d_loss: 0.0217, g_loss: 8.0372
Epoch [384/1000], Iter [11/49], d_loss: 0.0011, g_loss: 8.1110
Epoch [384/1000], Iter [12/49], d_loss: 0.0019, g_loss: 6.4925
Epoch [384/1000], Iter [13/49], d_loss: 0.0156, g_loss: 6.9862
Epoch [384/1000], Iter [14/49], d_loss: 0.0098, g_loss: 5.9769
Epoch [384/1000], Iter [15/49], d_loss: 0.0238, g_loss: 7.0272
Ep