In [1]:
import os

import numpy as np

import torch
import torch.nn as nn

from torch.autograd import Variable

import torchvision.utils

from new_data_loader import get_loader
from make_gif import make_gif

In [2]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.layer_video = nn.Conv3d(in_channels=3, out_channels=32, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
        self.layer_y = nn.Conv3d(in_channels=6, out_channels=32, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
        
        self.discriminator = nn.Sequential(
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            nn.Conv3d(in_channels=64, out_channels=128, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(num_features=128, eps=1e-03),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            nn.Conv3d(in_channels=128, out_channels=256, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(num_features=256, eps=1e-03),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            nn.Conv3d(in_channels=256, out_channels=512, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(num_features=512, eps=1e-03),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            nn.Conv3d(in_channels=512, out_channels=2, kernel_size=(2, 4, 4), stride=(1, 1, 1), padding=(0, 0, 0)),
        )

    
    def forward(self, video, y):
        out_video = self.layer_video(video)
        out_y = self.layer_y(y)
                             
        out_cat = torch.cat([out_video, out_y], 1)
                             
        out = self.discriminator(out_cat)
                             
        return out

In [3]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.layer_3d_video = nn.ConvTranspose3d(in_channels=100, out_channels=256, kernel_size=(2,4,4))
        self.layer_2d_video = nn.ConvTranspose2d(in_channels=100, out_channels=256, kernel_size=4, stride=1, padding=0)

        self.layer_3d_y = nn.ConvTranspose3d(in_channels=6, out_channels=256, kernel_size=(2,4,4))        
        self.layer_2d_y = nn.ConvTranspose2d(in_channels=6, out_channels=256, kernel_size=4, stride=1, padding=0)
        
        self.net_video = nn.Sequential(
            nn.BatchNorm3d(num_features=512),
            nn.ReLU(inplace=True),

            nn.ConvTranspose3d(in_channels=512, out_channels=256, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(num_features=256),
            nn.ReLU(inplace=True),

            nn.ConvTranspose3d(in_channels=256, out_channels=128, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(num_features=128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose3d(in_channels=128, out_channels=64, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(num_features=64),
            nn.ReLU(inplace=True)
        )

        self.gen_net = nn.Sequential(
            nn.ConvTranspose3d(in_channels=64, out_channels=3, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.Tanh()
        )

        self.mask_net = nn.Sequential(
            nn.ConvTranspose3d(in_channels=64, out_channels=1, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.Sigmoid()
        )

        self.static_net = nn.Sequential(
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, z, y):
        
        local_batch_size = z.size()[0]
        
        z_forgeround =  z.view(-1, 100, 1, 1, 1)
        z_background = z.view(-1, 100, 1, 1)
        
        y_foreground =  y.view(-1, 6, 1, 1, 1)
        y_background = y.view(-1, 6, 1, 1)
        
        out_3d_video = self.layer_3d_video(z_forgeround)
        out_2d_video = self.layer_2d_video(z_background)
        
        out_3d_y = self.layer_3d_y(y_foreground)
        out_2d_y = self.layer_2d_y(y_background)

        out_cat_3d = torch.cat([out_3d_video, out_3d_y],1)
        out_cat_2d = torch.cat([out_2d_video, out_2d_y],1)
        
        m_net_video = self.net_video(out_cat_3d)
        
        m_gen_net = self.gen_net(m_net_video)
        m_mask_net = self.mask_net(m_net_video)
        
        m_static_net = self.static_net(out_cat_2d)
        
        foreground = m_gen_net

        mask = m_mask_net.expand(local_batch_size, 3, 32, 64, 64)

        background = m_static_net.view(local_batch_size, 3, 1, 64, 64).expand(local_batch_size, 3, 32, 64, 64)
        
        video = foreground * mask + background * (1 - mask)

        return video


In [4]:
def init_weights(m) :
    name = type(m)

    if name == nn.Conv3d or name == nn.ConvTranspose2d or name == nn.ConvTranspose3d :
        m.weight.data.normal_(0.0, 0.01)
        m.bias.data.fill_(0)
    elif name == nn.BatchNorm2d or name == nn.BatchNorm3d :
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [None]:
pre_train = True

batch_size = 64
video_size = 64
epoch_size = 1000
        
#check GPU
is_gpu = torch.cuda.is_available()
print(is_gpu)

if is_gpu :
    dtype = torch.cuda.FloatTensor
else :
    dtype = torch.FloatTensor

if pre_train :
    D = torch.load('D.ckpt').type(dtype)
    G = torch.load('G.ckpt').type(dtype)
else :
    D = Discriminator()
    D = D.type(dtype)

    G = Generator()
    G = G.type(dtype)

    D.apply(init_weights)
    G.apply(init_weights)

criterion = nn.BCEWithLogitsLoss().type(dtype)

d_optimizer = torch.optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))

True


In [None]:
data_loader = get_loader(data_path='./dataset', image_size=video_size, batch_size=batch_size, num_workers=2)

for epoch in range(1, epoch_size + 1) :

    for iter, (video, y) in enumerate(data_loader) :
        local_batch_size = video.size()[0]
        
        real_labels = Variable(torch.ones(local_batch_size, 2).type(dtype))
        fake_labels = Variable(torch.zeros(local_batch_size, 2).type(dtype))
        
        # 1. Train Discriminator
        video_data = Variable(video).type(dtype)
        y = Variable(y).type(dtype)
        
        y_data = torch.unsqueeze(torch.unsqueeze(torch.unsqueeze(y, -1), -1), -1)
        y_data = y_data.expand(local_batch_size, 6, 32, 64, 64)
       
        
        
        # 1-1. Real Video
        outputs = D(video_data, y_data).view(local_batch_size, 2)
        d_loss_real = criterion(outputs, real_labels)

        
        
        # 1-2. Fake Video
        z = Variable(torch.randn(local_batch_size, 100) * 0.01).type(dtype)
        fake_videos = G(z, y)
        outputs = D(fake_videos, y_data).view(local_batch_size, 2)
        d_loss_fake = criterion(outputs, fake_labels)

        d_loss = d_loss_real + d_loss_fake
        
        D.zero_grad()
        G.zero_grad()
        d_loss.backward()
        d_optimizer.step()





        # 2. Train Generator
        z = Variable(torch.randn(local_batch_size, 100) * 0.01).type(dtype)
        fake_videos = G(z, y)
        outputs = D(fake_videos, y_data).view(local_batch_size, 2)

        g_loss = criterion(outputs, real_labels)
        D.zero_grad()
        G.zero_grad()

        g_loss.backward()
        g_optimizer.step()
          
        print('Epoch [%d/%d], Iter [%d/%d], d_loss: %.4f, g_loss: %.4f' % (epoch, epoch_size, iter, len(data_loader), d_loss.data[0], g_loss.data[0]))
    
    print('Model saving...')
    
    torch.save(D, 'D.ckpt')
    torch.save(G, 'G.ckpt')

Epoch [1/10000], Iter [0/49], d_loss: 0.7521, g_loss: 5.3961
Epoch [1/10000], Iter [1/49], d_loss: 0.9891, g_loss: 3.0680
Epoch [1/10000], Iter [2/49], d_loss: 0.4858, g_loss: 2.0556
Epoch [1/10000], Iter [3/49], d_loss: 0.6345, g_loss: 3.2320
Epoch [1/10000], Iter [4/49], d_loss: 0.4586, g_loss: 2.8988
Epoch [1/10000], Iter [5/49], d_loss: 0.5534, g_loss: 2.9233
Epoch [1/10000], Iter [6/49], d_loss: 0.4146, g_loss: 3.7920
Epoch [1/10000], Iter [7/49], d_loss: 0.5793, g_loss: 2.1055
Epoch [1/10000], Iter [8/49], d_loss: 0.4427, g_loss: 3.6782
Epoch [1/10000], Iter [9/49], d_loss: 0.2598, g_loss: 3.7094
Epoch [1/10000], Iter [10/49], d_loss: 0.3048, g_loss: 3.0607
Epoch [1/10000], Iter [11/49], d_loss: 0.3314, g_loss: 3.3919
Epoch [1/10000], Iter [12/49], d_loss: 0.3834, g_loss: 2.7632
Epoch [1/10000], Iter [13/49], d_loss: 0.8121, g_loss: 1.4459
Epoch [1/10000], Iter [14/49], d_loss: 0.6899, g_loss: 5.4126
Epoch [1/10000], Iter [15/49], d_loss: 0.4852, g_loss: 4.0320
Epoch [1/10000], I

Epoch [3/10000], Iter [35/49], d_loss: 0.2453, g_loss: 4.7086
Epoch [3/10000], Iter [36/49], d_loss: 0.5176, g_loss: 1.6137
Epoch [3/10000], Iter [37/49], d_loss: 1.1267, g_loss: 8.8252
Epoch [3/10000], Iter [38/49], d_loss: 2.2276, g_loss: 3.1307
Epoch [3/10000], Iter [39/49], d_loss: 0.5762, g_loss: 3.3963
Epoch [3/10000], Iter [40/49], d_loss: 0.6711, g_loss: 6.9963
Epoch [3/10000], Iter [41/49], d_loss: 1.4074, g_loss: 1.9071
Epoch [3/10000], Iter [42/49], d_loss: 1.2308, g_loss: 5.3016
Epoch [3/10000], Iter [43/49], d_loss: 0.3131, g_loss: 4.8159
Epoch [3/10000], Iter [44/49], d_loss: 0.3525, g_loss: 3.1993
Epoch [3/10000], Iter [45/49], d_loss: 0.3591, g_loss: 3.1240
Epoch [3/10000], Iter [46/49], d_loss: 0.3516, g_loss: 4.0472
Epoch [3/10000], Iter [47/49], d_loss: 0.2681, g_loss: 3.6088
Epoch [3/10000], Iter [48/49], d_loss: 0.3067, g_loss: 4.0160
Model saving...
Epoch [4/10000], Iter [0/49], d_loss: 0.3087, g_loss: 3.7456
Epoch [4/10000], Iter [1/49], d_loss: 0.3424, g_loss: 3

Epoch [6/10000], Iter [20/49], d_loss: 0.6170, g_loss: 7.8693
Epoch [6/10000], Iter [21/49], d_loss: 1.7399, g_loss: 1.6196
Epoch [6/10000], Iter [22/49], d_loss: 1.1316, g_loss: 6.7049
Epoch [6/10000], Iter [23/49], d_loss: 0.5272, g_loss: 6.2304
Epoch [6/10000], Iter [24/49], d_loss: 0.1276, g_loss: 4.0962
Epoch [6/10000], Iter [25/49], d_loss: 0.4036, g_loss: 4.3632
Epoch [6/10000], Iter [26/49], d_loss: 0.4709, g_loss: 3.3230
Epoch [6/10000], Iter [27/49], d_loss: 0.6419, g_loss: 5.7629
Epoch [6/10000], Iter [28/49], d_loss: 0.5965, g_loss: 2.8098
Epoch [6/10000], Iter [29/49], d_loss: 0.8905, g_loss: 6.5237
Epoch [6/10000], Iter [30/49], d_loss: 1.2170, g_loss: 2.5453
Epoch [6/10000], Iter [31/49], d_loss: 0.7717, g_loss: 4.0208
Epoch [6/10000], Iter [32/49], d_loss: 0.3901, g_loss: 3.2748
Epoch [6/10000], Iter [33/49], d_loss: 0.5028, g_loss: 4.9836
Epoch [6/10000], Iter [34/49], d_loss: 0.3515, g_loss: 4.6002
Epoch [6/10000], Iter [35/49], d_loss: 0.4490, g_loss: 1.9904
Epoch [6

Epoch [9/10000], Iter [5/49], d_loss: 0.5526, g_loss: 3.6429
Epoch [9/10000], Iter [6/49], d_loss: 0.5913, g_loss: 2.4635
Epoch [9/10000], Iter [7/49], d_loss: 0.6773, g_loss: 6.6100
Epoch [9/10000], Iter [8/49], d_loss: 0.3074, g_loss: 5.7124
Epoch [9/10000], Iter [9/49], d_loss: 0.4419, g_loss: 1.7256
Epoch [9/10000], Iter [10/49], d_loss: 1.3677, g_loss: 9.0129
Epoch [9/10000], Iter [11/49], d_loss: 2.4616, g_loss: 3.0459
Epoch [9/10000], Iter [12/49], d_loss: 0.7164, g_loss: 5.0575
Epoch [9/10000], Iter [13/49], d_loss: 1.0295, g_loss: 2.0562
Epoch [9/10000], Iter [14/49], d_loss: 1.3401, g_loss: 8.9758
Epoch [9/10000], Iter [15/49], d_loss: 1.4941, g_loss: 4.3493
Epoch [9/10000], Iter [16/49], d_loss: 0.5564, g_loss: 4.2749
Epoch [9/10000], Iter [17/49], d_loss: 0.1293, g_loss: 4.7963
Epoch [9/10000], Iter [18/49], d_loss: 0.7139, g_loss: 2.8683
Epoch [9/10000], Iter [19/49], d_loss: 0.9977, g_loss: 5.6841
Epoch [9/10000], Iter [20/49], d_loss: 1.0763, g_loss: 2.5701
Epoch [9/1000

Epoch [11/10000], Iter [38/49], d_loss: 0.8476, g_loss: 5.1880
Epoch [11/10000], Iter [39/49], d_loss: 0.3970, g_loss: 4.0160
Epoch [11/10000], Iter [40/49], d_loss: 0.4465, g_loss: 3.5910
Epoch [11/10000], Iter [41/49], d_loss: 0.3609, g_loss: 4.3051
Epoch [11/10000], Iter [42/49], d_loss: 0.4326, g_loss: 3.8293
Epoch [11/10000], Iter [43/49], d_loss: 0.5134, g_loss: 1.6011
Epoch [11/10000], Iter [44/49], d_loss: 1.0232, g_loss: 8.5373
Epoch [11/10000], Iter [45/49], d_loss: 1.3354, g_loss: 4.2694
Epoch [11/10000], Iter [46/49], d_loss: 0.1867, g_loss: 1.6243
Epoch [11/10000], Iter [47/49], d_loss: 0.9568, g_loss: 7.8603
Epoch [11/10000], Iter [48/49], d_loss: 1.2480, g_loss: 3.4000
Model saving...
Epoch [12/10000], Iter [0/49], d_loss: 0.6897, g_loss: 4.1980
Epoch [12/10000], Iter [1/49], d_loss: 0.4198, g_loss: 4.8722
Epoch [12/10000], Iter [2/49], d_loss: 0.4221, g_loss: 2.5688
Epoch [12/10000], Iter [3/49], d_loss: 0.4946, g_loss: 4.5396
Epoch [12/10000], Iter [4/49], d_loss: 0.56

In [None]:
path_dir = './testvideo'

if not os.path.exists(path_dir) :
    os.mkdir(path_dir)

for classes in range(6) :        
    z = Variable(torch.randn(1, 100) * 0.01).type(dtype)
    
    label =  torch.zeros(1, 6)
    label.scatter_(1, torch.LongTensor([[classes]]), 1)
    label.transpose_(1, 0)    
    label = Variable(label).type(dtype)
    
    for i in range(32) :
        fake_video = torch.squeeze(G(z, label))[:,i,:,:]
        torchvision.utils.save_image(tensor=fake_video.data, filename=path_dir + '/test' + str(classes) + '_' + str(i+1) + '.png')
        
    make_gif(root=path_dir, output='output' + str(classes) + '.gif', fps=16)
    
    for i in range(32) :
        os.remove(path_dir + '/test' + str(classes) + "_" + str(i+1) + ".png")