In [6]:
import os

import numpy as np

import torch
import torch.nn as nn

from torch.autograd import Variable

import torchvision.utils

from new_data_loader import get_loader
from make_gif import make_gif

from DiscrimiatorModule import Discriminator
from GeneratorModule import Generator

In [4]:
def init_weights(m) :
    name = type(m)

    if name == nn.Conv3d or name == nn.ConvTranspose2d or name == nn.ConvTranspose3d :
        m.weight.data.normal_(0.0, 0.01)
        m.bias.data.fill_(0)
    elif name == nn.BatchNorm2d or name == nn.BatchNorm3d :
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [5]:
pre_train = True

batch_size = 64
video_size = 64
epoch_size = 1000
        
#check GPU
is_gpu = torch.cuda.is_available()
print(is_gpu)

if is_gpu :
    dtype = torch.cuda.FloatTensor
else :
    dtype = torch.FloatTensor

if pre_train :
    D = torch.load('D.ckpt').type(dtype)
    G = torch.load('G.ckpt').type(dtype)
else :
    D = Discriminator()
    D = D.type(dtype)

    G = Generator()
    G = G.type(dtype)

    D.apply(init_weights)
    G.apply(init_weights)

criterion = nn.BCEWithLogitsLoss().type(dtype)

d_optimizer = torch.optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))

True


In [6]:
data_loader = get_loader(data_path='./dataset', image_size=video_size, batch_size=batch_size, num_workers=2)

for epoch in range(384, epoch_size + 1) :
    for iter, (video, y) in enumerate(data_loader) :
        local_batch_size = video.size()[0]
        
        real_labels = Variable(torch.ones(local_batch_size, 2).type(dtype))
        fake_labels = Variable(torch.zeros(local_batch_size, 2).type(dtype))
        
        # 1. Train Discriminator
        video_data = Variable(video).type(dtype)
        y = Variable(y).type(dtype)
        
        y_data = torch.unsqueeze(torch.unsqueeze(torch.unsqueeze(y, -1), -1), -1)
        y_data = y_data.expand(local_batch_size, 6, 32, 64, 64)
       
        
        
        # 1-1. Real Video
        outputs = D(video_data, y_data).view(local_batch_size, 2)
        d_loss_real = criterion(outputs, real_labels)

        
        
        # 1-2. Fake Video
        z = Variable(torch.randn(local_batch_size, 100) * 0.01).type(dtype)
        fake_videos = G(z, y)
        outputs = D(fake_videos, y_data).view(local_batch_size, 2)
        d_loss_fake = criterion(outputs, fake_labels)

        d_loss = d_loss_real + d_loss_fake
        
        D.zero_grad()
        G.zero_grad()
        d_loss.backward()
        d_optimizer.step()





        # 2. Train Generator
        z = Variable(torch.randn(local_batch_size, 100) * 0.01).type(dtype)
        fake_videos = G(z, y)
        outputs = D(fake_videos, y_data).view(local_batch_size, 2)

        g_loss = criterion(outputs, real_labels)
        D.zero_grad()
        G.zero_grad()

        g_loss.backward()
        g_optimizer.step()
          
        print('Epoch [%d/%d], Iter [%d/%d], d_loss: %.4f, g_loss: %.4f' % (epoch, epoch_size, iter, len(data_loader), d_loss.data[0], g_loss.data[0]))
    
    print('Model saving...')
    
    torch.save(D, 'D.ckpt')
    torch.save(G, 'G.ckpt')

Epoch [264/1000], Iter [0/49], d_loss: 0.4987, g_loss: 7.7083
Epoch [264/1000], Iter [1/49], d_loss: 0.0159, g_loss: 9.3092
Epoch [264/1000], Iter [2/49], d_loss: 0.5341, g_loss: 5.5646
Epoch [264/1000], Iter [3/49], d_loss: 0.0230, g_loss: 3.6369
Epoch [264/1000], Iter [4/49], d_loss: 0.1487, g_loss: 7.6690
Epoch [264/1000], Iter [5/49], d_loss: 0.7244, g_loss: 2.4736
Epoch [264/1000], Iter [6/49], d_loss: 0.1684, g_loss: 3.2955
Epoch [264/1000], Iter [7/49], d_loss: 0.2525, g_loss: 8.5487
Epoch [264/1000], Iter [8/49], d_loss: 0.0043, g_loss: 8.7568
Epoch [264/1000], Iter [9/49], d_loss: 0.3591, g_loss: 5.4538
Epoch [264/1000], Iter [10/49], d_loss: 0.0150, g_loss: 4.7448
Epoch [264/1000], Iter [11/49], d_loss: 0.2033, g_loss: 5.9525
Epoch [264/1000], Iter [12/49], d_loss: 0.0252, g_loss: 5.7419
Epoch [264/1000], Iter [13/49], d_loss: 0.0515, g_loss: 4.3458
Epoch [264/1000], Iter [14/49], d_loss: 0.0060, g_loss: 8.1969
Epoch [264/1000], Iter [15/49], d_loss: 0.0219, g_loss: 6.1315
Ep

Epoch [266/1000], Iter [32/49], d_loss: 0.0456, g_loss: 9.2748
Epoch [266/1000], Iter [33/49], d_loss: 0.1080, g_loss: 7.9031
Epoch [266/1000], Iter [34/49], d_loss: 0.0341, g_loss: 8.0723
Epoch [266/1000], Iter [35/49], d_loss: 0.0206, g_loss: 6.2563
Epoch [266/1000], Iter [36/49], d_loss: 0.1554, g_loss: 5.5564
Epoch [266/1000], Iter [37/49], d_loss: 0.1977, g_loss: 7.8398
Epoch [266/1000], Iter [38/49], d_loss: 0.0295, g_loss: 8.8001
Epoch [266/1000], Iter [39/49], d_loss: 0.1030, g_loss: 8.5424
Epoch [266/1000], Iter [40/49], d_loss: 0.0268, g_loss: 6.7837
Epoch [266/1000], Iter [41/49], d_loss: 0.0600, g_loss: 7.0566
Epoch [266/1000], Iter [42/49], d_loss: 0.0441, g_loss: 8.5073
Epoch [266/1000], Iter [43/49], d_loss: 0.0944, g_loss: 6.2165
Epoch [266/1000], Iter [44/49], d_loss: 0.1189, g_loss: 5.1503
Epoch [266/1000], Iter [45/49], d_loss: 0.0591, g_loss: 4.5497
Epoch [266/1000], Iter [46/49], d_loss: 0.2123, g_loss: 5.0143
Epoch [266/1000], Iter [47/49], d_loss: 0.0679, g_loss:

Process Process-12:
Process Process-11:
Traceback (most recent call last):
  File "/home/dskym0/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/dskym0/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/dskym0/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 34, in _worker_loop
    r = index_queue.get()
  File "/home/dskym0/anaconda3/lib/python3.6/multiprocessing/queues.py", line 335, in get
    res = self._reader.recv_bytes()
  File "/home/dskym0/anaconda3/lib/python3.6/multiprocessing/connection.py", line 216, in recv_bytes
    buf = self._recv_bytes(maxlength)
Traceback (most recent call last):
  File "/home/dskym0/anaconda3/lib/python3.6/multiprocessing/connection.py", line 407, in _recv_bytes
    buf = self._recv(4)
  File "/home/dskym0/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.

KeyboardInterrupt: 