In [4]:
import os

import numpy as np

import torch
import torch.nn as nn

from torch.autograd import Variable

import torchvision.utils

from new_data_loader import get_loader
from make_gif import make_gif

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.layer_video = nn.Conv3d(in_channels=3, out_channels=32, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
        self.layer_y = nn.Conv3d(in_channels=6, out_channels=32, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
        
        self.discriminator = nn.Sequential(
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            nn.Conv3d(in_channels=64, out_channels=128, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(num_features=128, eps=1e-03),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            nn.Conv3d(in_channels=128, out_channels=256, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(num_features=256, eps=1e-03),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            nn.Conv3d(in_channels=256, out_channels=512, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(num_features=512, eps=1e-03),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),

            nn.Conv3d(in_channels=512, out_channels=2, kernel_size=(2, 4, 4), stride=(1, 1, 1), padding=(0, 0, 0)),
        )

    
    def forward(self, video, y):
        out_video = self.layer_video(video)
        out_y = self.layer_y(y)
                             
        out_cat = torch.cat([out_video, out_y], 1)
                             
        out = self.discriminator(out_cat)
                             
        return out

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.layer_3d_video = nn.ConvTranspose3d(in_channels=100, out_channels=256, kernel_size=(2,4,4))
        self.layer_2d_video = nn.ConvTranspose2d(in_channels=100, out_channels=256, kernel_size=4, stride=1, padding=0)

        self.layer_3d_y = nn.ConvTranspose3d(in_channels=6, out_channels=256, kernel_size=(2,4,4))        
        self.layer_2d_y = nn.ConvTranspose2d(in_channels=6, out_channels=256, kernel_size=4, stride=1, padding=0)
        
        self.net_video = nn.Sequential(
            nn.BatchNorm3d(num_features=512),
            nn.ReLU(inplace=True),

            nn.ConvTranspose3d(in_channels=512, out_channels=256, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(num_features=256),
            nn.ReLU(inplace=True),

            nn.ConvTranspose3d(in_channels=256, out_channels=128, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(num_features=128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose3d(in_channels=128, out_channels=64, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.BatchNorm3d(num_features=64),
            nn.ReLU(inplace=True)
        )

        self.gen_net = nn.Sequential(
            nn.ConvTranspose3d(in_channels=64, out_channels=3, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.Tanh()
        )

        self.mask_net = nn.Sequential(
            nn.ConvTranspose3d(in_channels=64, out_channels=1, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1)),
            nn.Sigmoid()
        )

        self.static_net = nn.Sequential(
            nn.BatchNorm2d(num_features=512),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=256),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=128),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, z, y):
        
        z_forward =  z.view(-1, 100, 1, 1, 1)
        z_backward = z.view(-1, 100, 1, 1)
        
        y_forward =  y.view(-1, 6, 1, 1, 1)
        y_backward = y.view(-1, 6, 1, 1)
        
        out_3d_video = self.layer_3d_video(z_forward)
        out_2d_video = self.layer_2d_video(z_backward)
        
        out_3d_y = self.layer_3d_y(y_forward)
        out_2d_y = self.layer_2d_y(y_backward)

        out_cat_3d = torch.cat([out_3d_video, out_3d_y],1)
        out_cat_2d = torch.cat([out_2d_video, out_2d_y],1)
        
        m_net_video = self.net_video(out_cat_3d)
        
        m_gen_net = self.gen_net(m_net_video)
        m_mask_net = self.mask_net(m_net_video)
        
        m_static_net = self.static_net(out_cat_2d)

        
        foreground = m_gen_net

        mask = m_mask_net.expand(-1, 3, 32, 64, 64)

        background = m_static_net.view(-1, 3, 1, 64, 64).expand(-1, 3, 32, 64, 64)
        
        video = foreground * mask + background * (1 - mask)

        return video


In [None]:
def init_weights(m) :
    name = type(m)

    if name == nn.Conv3d or name == nn.ConvTranspose2d or name == nn.ConvTranspose3d :
        m.weight.data.normal_(0.0, 0.01)
        m.bias.data.fill_(0)
    elif name == nn.BatchNorm2d or name == nn.BatchNorm3d :
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [None]:
pre_train = True

batch_size = 64
video_size = 64
epoch_size = 10
        
#check GPU
is_gpu = torch.cuda.is_available()
print(is_gpu)

if is_gpu :
    dtype = torch.cuda.FloatTensor
else :
    dtype = torch.FloatTensor

if pre_train :
    D = torch.load('D.ckpt').type(dtype)
    G = torch.load('G.ckpt').type(dtype)
else :
    D = Discriminator()
    D = D.type(dtype)

    G = Generator()
    G = G.type(dtype)

    D.apply(init_weights)
    G.apply(init_weights)

    criterion = nn.BCEWithLogitsLoss().type(dtype)

    d_optimizer = torch.optim.Adam(D.parameters(), lr=2e-3, betas=(0.5, 0.999))
    g_optimizer = torch.optim.Adam(G.parameters(), lr=2e-3, betas=(0.5, 0.999))

In [None]:
data_loader = get_loader(data_path='./dataset', image_size=video_size, batch_size=batch_size, num_workers=2)

for epoch in range(1, epoch_size + 1) :

    for iter, (video, y) in enumerate(data_loader) :
        local_batch_size = video.size()[0]
        
        if is_gpu :
            real_labels = Variable(torch.ones(local_batch_size, 2).type(dtype))
            fake_labels = Variable(torch.zeros(local_batch_size, 2).type(dtype))
        else :
            real_labels = Variable(torch.ones(local_batch_size, 2).type(dtype))
            fake_labels = Variable(torch.zeros(local_batch_size, 2).type(dtype))
        
        # 1. Train Discriminator
        video_data = Variable(video).type(dtype)
        y = Variable(y).type(dtype)
        
        y_data = torch.unsqueeze(torch.unsqueeze(torch.unsqueeze(y, -1), -1), -1)
        y_data = y_data.expand(local_batch_size, 6, 32, 64, 64)
       
        
        
        # 1-1. Real Video
        outputs = D(video_data, y_data).view(local_batch_size, 2)
        d_loss_real = criterion(outputs, real_labels)

        
        
        # 1-2. Fake Video
        z = Variable(torch.randn(local_batch_size, 100) * 0.01).type(dtype)
        fake_videos = G(z, y)
        outputs = D(fake_videos, y_data).view(local_batch_size, 2)
        d_loss_fake = criterion(outputs, fake_labels)

        d_loss = d_loss_real + d_loss_fake
        
        D.zero_grad()
        G.zero_grad()
        d_loss.backward()
        d_optimizer.step()





        # 2. Train Generator
        z = Variable(torch.randn(local_batch_size, 100) * 0.01).type(dtype)
        fake_videos = G(z, y)
        outputs = D(fake_videos, y_data).view(local_batch_size, 2)

        g_loss = criterion(outputs, real_labels)
        D.zero_grad()
        G.zero_grad()

        g_loss.backward()
        g_optimizer.step()
        
        if iter % 300 == 0 :       
            print('Epoch [%d/%d], Iter [%d/%d], d_loss: %.4f, g_loss: %.4f' % (epoch, epoch_size, iter, len(data_loader), d_loss.data[0], g_loss.data[0]))
            
            torch.save(D, 'D.ckpt')
            torch.save(G, 'G.ckpt')

In [40]:
path_dir = './testvideo'

if not os.path.exists(path_dir) :
    os.mkdirs(path_dir)

for classes in range(6) :        
    z = Variable(torch.randn(100) * 0.01).type(dtype)
    
    label =  torch.zeros(1, 6)
    label.scatter_(1, torch.LongTensor([[i]]), 1)
    label.transpose_(1, 0)    
    label = Variable(label).type(dtype)
    
    for i in range(32) :
        fake_video = torch.squeeze(G(label, z))[:,i,:,:]
        torchvision.utils.save_image(tensor=fake_video.data, filename=path_dir + '/test' + str(classes) + "_" + str(i+1) + ".png")
        
    make_gif(root=path_dir, output='output' + str(classes) + '.gif', fps=16)
    
    os.remove(path_dir + '/*.png')    

SyntaxError: invalid syntax (<ipython-input-40-47096b9e7ec5>, line 3)