In [None]:
# Modified from https://github.com/Harrypotterrrr/DVD-GAN
# Be careful when running with your computer with less than 25GB of RAM, it will crash

In [2]:
!pip install torchdiffeq

Collecting torchdiffeq
  Downloading torchdiffeq-0.2.2-py3-none-any.whl (31 kB)
Installing collected packages: torchdiffeq
Successfully installed torchdiffeq-0.2.2


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter

from torchdiffeq import odeint_adjoint, odeint

# Normalization

In [4]:
class SpectralNorm(nn.Module):
    def __init__(self, module, name='weight', power_iterations=1):
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.power_iterations = power_iterations
        if not self._made_params():
            self._make_params()

    def _update_u_v(self):
        u = getattr(self.module.layer, self.name + "_u")
        v = getattr(self.module.layer, self.name + "_v")
        w = getattr(self.module.layer, self.name + "_bar")

        height = w.data.shape[0]
        for _ in range(self.power_iterations):
            v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
            u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))

        # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
        sigma = u.dot(w.view(height, -1).mv(v))
        setattr(self.module.layer, self.name, w / sigma.expand_as(w))

    def _made_params(self):
        try:
            u = getattr(self.module.layer, self.name + "_u")
            v = getattr(self.module.layer, self.name + "_v")
            w = getattr(self.module.layer, self.name + "_bar")
            return True
        except AttributeError:
            return False


    def _make_params(self):
        w = getattr(self.module.layer, self.name)

        height = w.data.shape[0]
        width = w.view(height, -1).data.shape[1]

        u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
        v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
        u.data = l2normalize(u.data)
        v.data = l2normalize(v.data)
        w_bar = Parameter(w.data)

        del self.module.layer._parameters[self.name]

        self.module.layer.register_parameter(self.name + "_u", u)
        self.module.layer.register_parameter(self.name + "_v", v)
        self.module.layer.register_parameter(self.name + "_bar", w_bar)


    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)

def l2normalize(v, eps=1e-12):
    return v / (v.norm() + eps)


class ConditionalNorm(nn.Module):

    def __init__(self, in_channel, n_condition=96):
        super().__init__()

        self.in_channel = in_channel
        self.bn = nn.BatchNorm2d(self.in_channel, affine=False)

        self.embed = nn.Linear(n_condition, self.in_channel * 2)
        self.embed.weight.data[:, :self.in_channel].normal_(1, 0.02)
        self.embed.weight.data[:, self.in_channel:].zero_()

    def forward(self, x, class_id):
        out = self.bn(x)
        embed = self.embed(class_id)
        gamma, beta = embed.chunk(2, 1)
        # gamma = gamma.unsqueeze(2).unsqueeze(3)
        # beta = beta.unsqueeze(2).unsqueeze(3)
        gamma = gamma.view(-1, self.in_channel, 1, 1)
        beta = beta.view(-1, self.in_channel, 1, 1)
        out = gamma * out + beta

        return out

# ODE fucntion

In [5]:
class Conv2dODE(nn.Module):
    def __init__(self, in_channel, out_channel, ksize=3, stride=1, 
                 padding=0, bias=True):
        super().__init__()
        # for augmented
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.layer = nn.Conv2d(out_channel,out_channel,ksize,stride,padding,bias = bias)

    def forward(self, t, x):
        BT, C, W, H = x.size()
        # zeros augmented
        if self.in_channel < self.out_channel:
            zeros_aug = torch.zeros([BT, self.out_channel - self.in_channel, W, H])
            x = torch.cat((x,zeros_aug),1)
        x = x * t
        return self.layer(x)


In [6]:
class ODEFunc(nn.Module):
    #                   3           10          [3,3]
    def __init__(self, in_channel, out_channel, kernel_size=None,
                 padding=1, stride=1, n_class=96, bn=True,
                 activation=F.relu, upsample_factor=2, downsample_factor=1):
        super().__init__()

        self.upsample_factor = upsample_factor if downsample_factor is 1 else 1
        self.downsample_factor = downsample_factor
        self.activation = activation
        self.bn = bn if downsample_factor is 1 else False
        self.in_channel = in_channel
        self.out_channel = out_channel
        self.n_class = n_class

        self.nfe = 0

        if kernel_size is None:
            kernel_size = [3, 3]
        
        self.conv0 = SpectralNorm(Conv2dODE(out_channel, out_channel, 
                                            kernel_size, stride, padding, 
                                            bias = True))

        self.conv1 = SpectralNorm(Conv2dODE(out_channel, out_channel, 
                                            kernel_size, stride, padding,
                                            bias = True))
        

        if bn:
        #     self.CBNorm1 = ConditionalNorm(in_channel, n_class) # TODO 2 x noise.size[1]
            self.CBNorm2 = ConditionalNorm(out_channel, n_class)
        
    def forward(self, t, x, condition):
        self.nfe += 1
        BT, C, W, H = x.size()
        out = x
        out = self.conv0(t,out)
        if self.bn:
            out = self.CBNorm2(out, condition)
        out = self.activation(out)
        out = self.conv1(t,out)
        if self.downsample_factor != 1:
            out = F.avg_pool2d(out, self.downsample_factor)
        return out


# ODE Block

In [7]:
class ODEBlock(nn.Module):

    def __init__(self, odefunc):
        super().__init__()
        self.odefunc = odefunc
        self.integration_time = torch.tensor([0,1]).float()

        if self.odefunc.bn:
            self.CBNorm1 = ConditionalNorm(self.odefunc.in_channel, self.odefunc.n_class) # TODO 2 x noise.size[1]
            # self.CBNorm2 = ConditionalNorm(out_channel, n_class)

    def forward(self, x, condition):
        out = x
        if self.odefunc.bn:
            out = self.CBNorm1(out,condition)
    
        out = self.odefunc.activation(out)
        # print(out.size())
        if self.odefunc.upsample_factor != 1:
            out = F.interpolate(out, scale_factor=self.odefunc.upsample_factor)
        # print(out.size())
        BT, C, W, H = out.size()
        # zeros augmented
        if self.odefunc.in_channel < self.odefunc.out_channel:
            zeros_aug = torch.zeros([BT, self.odefunc.out_channel - self.odefunc.in_channel, W, H])
            out = torch.cat((out,zeros_aug),1)
        self.integration_time = self.integration_time.type_as(x)
        # print('out',out.size())
        func = lambda t,x: self.odefunc(t,x,condition)
        out = odeint(func, out, self.integration_time)
        return out[1]

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value

In [None]:
n_class = 96
batch_size = 4
n_frames = 20

gResBlock = ODEFunc(3, 100, [3, 3])
odeGResBlock = ODEBlock(gResBlock)
x = torch.rand([batch_size * n_frames, 3, 64, 64])
condition = torch.rand([batch_size, n_class])
condition = condition.repeat(n_frames, 1)
print(x.size())
y = odeGResBlock(x,condition)
print(x.size())
print(y.size())

torch.Size([80, 3, 64, 64])
torch.Size([80, 3, 128, 128])
out torch.Size([80, 100, 128, 128])
