In [1]:
import os

import numpy as np

import torch
import torch.nn as nn

from torch.autograd import Variable

import torchvision.utils

from new_data_loader import get_loader
from make_gif import make_gif

In [2]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.layer_video = nn.Conv3d(in_channels=3, out_channels=32, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
        self.layer_y = nn.Conv3d(in_channels=6, out_channels=32, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
        
        self.discriminator = nn.Sequential(
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            nn.Conv3d(in_channels=64, out_channels=128, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(num_features=128, eps=1e-03),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            nn.Conv3d(in_channels=128, out_channels=256, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(num_features=256, eps=1e-03),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            nn.Conv3d(in_channels=256, out_channels=512, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(num_features=512, eps=1e-03),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            nn.Conv3d(in_channels=512, out_channels=2, kernel_size=(2, 4, 4), stride=(1, 1, 1), padding=(0, 0, 0)),
        )

    
    def forward(self, video, y):
        out_video = self.layer_video(video)
        out_y = self.layer_y(y)
                             
        out_cat = torch.cat([out_video, out_y], 1)
                             
        out = self.discriminator(out_cat)
                             
        return out

In [3]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.layer_3d_video = nn.ConvTranspose3d(in_channels=100, out_channels=256, kernel_size=(2,4,4))
        self.layer_2d_video = nn.ConvTranspose2d(in_channels=100, out_channels=256, kernel_size=4, stride=1, padding=0)

        self.layer_3d_y = nn.ConvTranspose3d(in_channels=6, out_channels=256, kernel_size=(2,4,4))        
        self.layer_2d_y = nn.ConvTranspose2d(in_channels=6, out_channels=256, kernel_size=4, stride=1, padding=0)
        
        self.net_video = nn.Sequential(
            nn.BatchNorm3d(num_features=512),
            nn.ReLU(inplace=True),

            nn.ConvTranspose3d(in_channels=512, out_channels=256, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(num_features=256),
            nn.ReLU(inplace=True),

            nn.ConvTranspose3d(in_channels=256, out_channels=128, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(num_features=128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose3d(in_channels=128, out_channels=64, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(num_features=64),
            nn.ReLU(inplace=True)
        )

        self.gen_net = nn.Sequential(
            nn.ConvTranspose3d(in_channels=64, out_channels=3, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.Tanh()
        )

        self.mask_net = nn.Sequential(
            nn.ConvTranspose3d(in_channels=64, out_channels=1, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.Sigmoid()
        )

        self.static_net = nn.Sequential(
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, z, y):
        
        local_batch_size = z.size()[0]
        
        z_forgeround =  z.view(-1, 100, 1, 1, 1)
        z_background = z.view(-1, 100, 1, 1)
        
        y_foreground =  y.view(-1, 6, 1, 1, 1)
        y_background = y.view(-1, 6, 1, 1)
        
        out_3d_video = self.layer_3d_video(z_forgeround)
        out_2d_video = self.layer_2d_video(z_background)
        
        out_3d_y = self.layer_3d_y(y_foreground)
        out_2d_y = self.layer_2d_y(y_background)

        out_cat_3d = torch.cat([out_3d_video, out_3d_y],1)
        out_cat_2d = torch.cat([out_2d_video, out_2d_y],1)
        
        m_net_video = self.net_video(out_cat_3d)
        
        m_gen_net = self.gen_net(m_net_video)
        m_mask_net = self.mask_net(m_net_video)
        
        m_static_net = self.static_net(out_cat_2d)
        
        foreground = m_gen_net

        mask = m_mask_net.expand(local_batch_size, 3, 32, 64, 64)

        background = m_static_net.view(local_batch_size, 3, 1, 64, 64).expand(local_batch_size, 3, 32, 64, 64)
        
        video = foreground * mask + background * (1 - mask)

        return video


In [4]:
def init_weights(m) :
    name = type(m)

    if name == nn.Conv3d or name == nn.ConvTranspose2d or name == nn.ConvTranspose3d :
        m.weight.data.normal_(0.0, 0.01)
        m.bias.data.fill_(0)
    elif name == nn.BatchNorm2d or name == nn.BatchNorm3d :
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [5]:
pre_train = True

batch_size = 64
video_size = 64
epoch_size = 1000
        
#check GPU
is_gpu = torch.cuda.is_available()
print(is_gpu)

if is_gpu :
    dtype = torch.cuda.FloatTensor
else :
    dtype = torch.FloatTensor

if pre_train :
    D = torch.load('D.ckpt').type(dtype)
    G = torch.load('G.ckpt').type(dtype)
else :
    D = Discriminator()
    D = D.type(dtype)

    G = Generator()
    G = G.type(dtype)

    D.apply(init_weights)
    G.apply(init_weights)

criterion = nn.BCEWithLogitsLoss().type(dtype)

d_optimizer = torch.optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))

True


In [6]:
data_loader = get_loader(data_path='./dataset', image_size=video_size, batch_size=batch_size, num_workers=2)

for epoch in range(1, epoch_size + 1) :

    for iter, (video, y) in enumerate(data_loader) :
        local_batch_size = video.size()[0]
        
        real_labels = Variable(torch.ones(local_batch_size, 2).type(dtype))
        fake_labels = Variable(torch.zeros(local_batch_size, 2).type(dtype))
        
        # 1. Train Discriminator
        video_data = Variable(video).type(dtype)
        y = Variable(y).type(dtype)
        
        y_data = torch.unsqueeze(torch.unsqueeze(torch.unsqueeze(y, -1), -1), -1)
        y_data = y_data.expand(local_batch_size, 6, 32, 64, 64)
       
        
        
        # 1-1. Real Video
        outputs = D(video_data, y_data).view(local_batch_size, 2)
        d_loss_real = criterion(outputs, real_labels)

        
        
        # 1-2. Fake Video
        z = Variable(torch.randn(local_batch_size, 100) * 0.01).type(dtype)
        fake_videos = G(z, y)
        outputs = D(fake_videos, y_data).view(local_batch_size, 2)
        d_loss_fake = criterion(outputs, fake_labels)

        d_loss = d_loss_real + d_loss_fake
        
        D.zero_grad()
        G.zero_grad()
        d_loss.backward()
        d_optimizer.step()





        # 2. Train Generator
        z = Variable(torch.randn(local_batch_size, 100) * 0.01).type(dtype)
        fake_videos = G(z, y)
        outputs = D(fake_videos, y_data).view(local_batch_size, 2)

        g_loss = criterion(outputs, real_labels)
        D.zero_grad()
        G.zero_grad()

        g_loss.backward()
        g_optimizer.step()
          
        print('Epoch [%d/%d], Iter [%d/%d], d_loss: %.4f, g_loss: %.4f' % (epoch, epoch_size, iter, len(data_loader), d_loss.data[0], g_loss.data[0]))
    
    print('Model saving...')
    
    torch.save(D, 'D.ckpt')
    torch.save(G, 'G.ckpt')

Epoch [1/1000], Iter [0/49], d_loss: 0.3324, g_loss: 7.5953
Epoch [1/1000], Iter [1/49], d_loss: 0.9974, g_loss: 3.4806
Epoch [1/1000], Iter [2/49], d_loss: 0.2490, g_loss: 3.4503
Epoch [1/1000], Iter [3/49], d_loss: 0.2122, g_loss: 3.7952
Epoch [1/1000], Iter [4/49], d_loss: 0.2244, g_loss: 6.1627
Epoch [1/1000], Iter [5/49], d_loss: 0.4924, g_loss: 3.3667
Epoch [1/1000], Iter [6/49], d_loss: 0.3299, g_loss: 2.8424
Epoch [1/1000], Iter [7/49], d_loss: 0.1645, g_loss: 4.1315
Epoch [1/1000], Iter [8/49], d_loss: 0.1966, g_loss: 4.1595
Epoch [1/1000], Iter [9/49], d_loss: 0.3706, g_loss: 3.8966
Epoch [1/1000], Iter [10/49], d_loss: 0.3196, g_loss: 3.8792
Epoch [1/1000], Iter [11/49], d_loss: 0.3414, g_loss: 4.6363
Epoch [1/1000], Iter [12/49], d_loss: 0.2286, g_loss: 4.3408
Epoch [1/1000], Iter [13/49], d_loss: 0.2796, g_loss: 3.3681
Epoch [1/1000], Iter [14/49], d_loss: 0.1306, g_loss: 5.0401
Epoch [1/1000], Iter [15/49], d_loss: 0.1118, g_loss: 3.7180
Epoch [1/1000], Iter [16/49], d_lo

Epoch [3/1000], Iter [37/49], d_loss: 0.1701, g_loss: 2.9388
Epoch [3/1000], Iter [38/49], d_loss: 0.1932, g_loss: 4.7882
Epoch [3/1000], Iter [39/49], d_loss: 0.1683, g_loss: 4.3391
Epoch [3/1000], Iter [40/49], d_loss: 0.1779, g_loss: 2.0534
Epoch [3/1000], Iter [41/49], d_loss: 0.2938, g_loss: 4.5435
Epoch [3/1000], Iter [42/49], d_loss: 0.1124, g_loss: 5.3844
Epoch [3/1000], Iter [43/49], d_loss: 0.2747, g_loss: 4.7704
Epoch [3/1000], Iter [44/49], d_loss: 0.1851, g_loss: 3.0400
Epoch [3/1000], Iter [45/49], d_loss: 0.3377, g_loss: 3.0620
Epoch [3/1000], Iter [46/49], d_loss: 0.4870, g_loss: 7.1752
Epoch [3/1000], Iter [47/49], d_loss: 0.3256, g_loss: 5.3355
Epoch [3/1000], Iter [48/49], d_loss: 0.1739, g_loss: 2.7685
Model saving...
Epoch [4/1000], Iter [0/49], d_loss: 0.5309, g_loss: 5.8257
Epoch [4/1000], Iter [1/49], d_loss: 0.3759, g_loss: 5.5559
Epoch [4/1000], Iter [2/49], d_loss: 0.3002, g_loss: 2.1874
Epoch [4/1000], Iter [3/49], d_loss: 0.4776, g_loss: 5.0874
Epoch [4/100

Epoch [6/1000], Iter [25/49], d_loss: 0.6563, g_loss: 3.5575
Epoch [6/1000], Iter [26/49], d_loss: 0.6302, g_loss: 6.0390
Epoch [6/1000], Iter [27/49], d_loss: 0.1765, g_loss: 4.6470
Epoch [6/1000], Iter [28/49], d_loss: 0.1855, g_loss: 5.6205
Epoch [6/1000], Iter [29/49], d_loss: 0.3321, g_loss: 4.3181
Epoch [6/1000], Iter [30/49], d_loss: 0.3010, g_loss: 4.9766
Epoch [6/1000], Iter [31/49], d_loss: 0.1300, g_loss: 4.7936
Epoch [6/1000], Iter [32/49], d_loss: 0.1810, g_loss: 4.9556
Epoch [6/1000], Iter [33/49], d_loss: 0.3279, g_loss: 4.6397
Epoch [6/1000], Iter [34/49], d_loss: 0.2410, g_loss: 6.2115
Epoch [6/1000], Iter [35/49], d_loss: 0.1875, g_loss: 3.7305
Epoch [6/1000], Iter [36/49], d_loss: 0.1714, g_loss: 2.7874
Epoch [6/1000], Iter [37/49], d_loss: 0.3025, g_loss: 5.6873
Epoch [6/1000], Iter [38/49], d_loss: 0.2505, g_loss: 3.7471
Epoch [6/1000], Iter [39/49], d_loss: 0.1404, g_loss: 3.7561
Epoch [6/1000], Iter [40/49], d_loss: 0.6376, g_loss: 8.5308
Epoch [6/1000], Iter [41

Epoch [9/1000], Iter [12/49], d_loss: 0.2579, g_loss: 3.7418
Epoch [9/1000], Iter [13/49], d_loss: 0.2514, g_loss: 6.3923
Epoch [9/1000], Iter [14/49], d_loss: 0.2795, g_loss: 4.7162
Epoch [9/1000], Iter [15/49], d_loss: 0.1334, g_loss: 2.2911
Epoch [9/1000], Iter [16/49], d_loss: 0.3086, g_loss: 4.6227
Epoch [9/1000], Iter [17/49], d_loss: 0.3040, g_loss: 3.5036
Epoch [9/1000], Iter [18/49], d_loss: 0.2594, g_loss: 5.2553
Epoch [9/1000], Iter [19/49], d_loss: 0.0721, g_loss: 5.3640
Epoch [9/1000], Iter [20/49], d_loss: 0.5004, g_loss: 2.2456
Epoch [9/1000], Iter [21/49], d_loss: 0.3415, g_loss: 4.8381
Epoch [9/1000], Iter [22/49], d_loss: 0.1825, g_loss: 5.2392
Epoch [9/1000], Iter [23/49], d_loss: 0.1878, g_loss: 4.5068
Epoch [9/1000], Iter [24/49], d_loss: 0.2709, g_loss: 1.4494
Epoch [9/1000], Iter [25/49], d_loss: 0.5328, g_loss: 6.7852
Epoch [9/1000], Iter [26/49], d_loss: 0.2140, g_loss: 6.2857
Epoch [9/1000], Iter [27/49], d_loss: 0.3578, g_loss: 2.8789
Epoch [9/1000], Iter [28

Epoch [11/1000], Iter [47/49], d_loss: 0.4866, g_loss: 6.6515
Epoch [11/1000], Iter [48/49], d_loss: 0.3038, g_loss: 5.1028
Model saving...
Epoch [12/1000], Iter [0/49], d_loss: 0.5216, g_loss: 1.3031
Epoch [12/1000], Iter [1/49], d_loss: 1.1178, g_loss: 7.5680
Epoch [12/1000], Iter [2/49], d_loss: 0.5925, g_loss: 4.1209
Epoch [12/1000], Iter [3/49], d_loss: 0.7594, g_loss: 5.3333
Epoch [12/1000], Iter [4/49], d_loss: 0.1792, g_loss: 5.1187
Epoch [12/1000], Iter [5/49], d_loss: 0.3113, g_loss: 3.4078
Epoch [12/1000], Iter [6/49], d_loss: 0.3250, g_loss: 3.9455
Epoch [12/1000], Iter [7/49], d_loss: 0.2310, g_loss: 3.5190
Epoch [12/1000], Iter [8/49], d_loss: 0.3451, g_loss: 4.8889
Epoch [12/1000], Iter [9/49], d_loss: 0.2915, g_loss: 4.4340
Epoch [12/1000], Iter [10/49], d_loss: 0.1763, g_loss: 3.4459
Epoch [12/1000], Iter [11/49], d_loss: 0.2442, g_loss: 3.6859
Epoch [12/1000], Iter [12/49], d_loss: 0.0721, g_loss: 4.2912
Epoch [12/1000], Iter [13/49], d_loss: 0.1766, g_loss: 4.4499
Ep

Epoch [14/1000], Iter [32/49], d_loss: 0.1581, g_loss: 7.3571
Epoch [14/1000], Iter [33/49], d_loss: 0.8471, g_loss: 1.5320
Epoch [14/1000], Iter [34/49], d_loss: 1.2886, g_loss: 8.0954
Epoch [14/1000], Iter [35/49], d_loss: 0.8428, g_loss: 2.9800
Epoch [14/1000], Iter [36/49], d_loss: 0.5970, g_loss: 5.4226
Epoch [14/1000], Iter [37/49], d_loss: 0.2734, g_loss: 6.2087
Epoch [14/1000], Iter [38/49], d_loss: 0.4667, g_loss: 2.1808
Epoch [14/1000], Iter [39/49], d_loss: 0.2269, g_loss: 3.8969
Epoch [14/1000], Iter [40/49], d_loss: 0.0758, g_loss: 3.5775
Epoch [14/1000], Iter [41/49], d_loss: 0.2732, g_loss: 6.2928
Epoch [14/1000], Iter [42/49], d_loss: 0.1613, g_loss: 4.9160
Epoch [14/1000], Iter [43/49], d_loss: 0.1669, g_loss: 3.9290
Epoch [14/1000], Iter [44/49], d_loss: 0.1989, g_loss: 3.2115
Epoch [14/1000], Iter [45/49], d_loss: 0.2378, g_loss: 5.1395
Epoch [14/1000], Iter [46/49], d_loss: 0.3806, g_loss: 2.6181
Epoch [14/1000], Iter [47/49], d_loss: 0.3271, g_loss: 5.5943
Epoch [1

Epoch [17/1000], Iter [17/49], d_loss: 0.4002, g_loss: 4.9237
Epoch [17/1000], Iter [18/49], d_loss: 0.1965, g_loss: 4.7761
Epoch [17/1000], Iter [19/49], d_loss: 0.3655, g_loss: 3.5323
Epoch [17/1000], Iter [20/49], d_loss: 0.6066, g_loss: 6.0849
Epoch [17/1000], Iter [21/49], d_loss: 0.2491, g_loss: 5.6785
Epoch [17/1000], Iter [22/49], d_loss: 0.1034, g_loss: 3.8759
Epoch [17/1000], Iter [23/49], d_loss: 0.2123, g_loss: 2.7488
Epoch [17/1000], Iter [24/49], d_loss: 0.2065, g_loss: 3.7154
Epoch [17/1000], Iter [25/49], d_loss: 0.1912, g_loss: 5.2869
Epoch [17/1000], Iter [26/49], d_loss: 0.1938, g_loss: 4.8288
Epoch [17/1000], Iter [27/49], d_loss: 0.3042, g_loss: 2.4342
Epoch [17/1000], Iter [28/49], d_loss: 0.5446, g_loss: 7.3325
Epoch [17/1000], Iter [29/49], d_loss: 0.7338, g_loss: 4.0359
Epoch [17/1000], Iter [30/49], d_loss: 0.1211, g_loss: 4.3292
Epoch [17/1000], Iter [31/49], d_loss: 0.7235, g_loss: 6.5645
Epoch [17/1000], Iter [32/49], d_loss: 0.7047, g_loss: 3.5946
Epoch [1

Process Process-36:
Traceback (most recent call last):
  File "/home/dskym0/anaconda3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/home/dskym0/anaconda3/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/dskym0/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 40, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/dskym0/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 40, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/dskym0/KUDL_TermProject/new_data_loader.py", line 60, in __getitem__
    image = self.transform(image)
  File "/home/dskym0/anaconda3/lib/python3.6/site-packages/torchvision-0.1.9-py3.6.egg/torchvision/transforms.py", line 34, in __call__
    img = t(img)
  File "/home/dskym0/anaconda3/lib/python3.6/site-packages/torchvision

KeyboardInterrupt: 

In [None]:
path_dir = './testvideo'

if not os.path.exists(path_dir) :
    os.mkdir(path_dir)

for classes in range(6) :        
    z = Variable(torch.randn(1, 100) * 0.01).type(dtype)
    
    label =  torch.zeros(1, 6)
    label.scatter_(1, torch.LongTensor([[classes]]), 1)
    label.transpose_(1, 0)    
    label = Variable(label).type(dtype)
    
    for i in range(32) :
        fake_video = torch.squeeze(G(z, label))[:,i,:,:]
        torchvision.utils.save_image(tensor=fake_video.data, filename=path_dir + '/test' + str(classes) + '_' + str(i+1) + '.png')
        
    make_gif(root=path_dir, output='output' + str(classes) + '.gif', fps=16)
    
    for i in range(32) :
        os.remove(path_dir + '/test' + str(classes) + "_" + str(i+1) + ".png")