In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import random
import copy
import math
import os
import argparse


def update_ema(ema_model, model, ema_rate):
    for tar, src in zip(ema_model.parameters(), model.parameters()):
        tar.data.mul_(ema_rate).add_(src.data * (1 - ema_rate))


class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.cnt = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.cnt += n
        self.avg = self.sum / self.cnt

In [2]:
train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()])
test_transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.CIFAR10(root='../data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root='../data', train=False, download=True, transform=test_transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ../data/cifar-10-python.tar.gz to ../data
Files already downloaded and verified


In [3]:
class GaussianFourierProjection(nn.Module):
    def __init__(self, embed_dim, init_scale=1., learnable=False):
        super().__init__()
        self.emb = init_scale * torch.randn(embed_dim // 2)
        self.emb = nn.Parameter(self.emb, requires_grad=learnable)
    
    def forward(self, t):
        t = 2. * np.pi * t[:, None] * self.emb[None, :]
        return torch.cat([torch.sin(t), torch.cos(t)], dim=-1)


class PositionalEmbedding(nn.Module):
    def __init__(self, embed_dim, max_positions=10000):
        super().__init__()
        self.embed_dim = embed_dim
        self.max_positions = max_positions 
        self.half_dim = embed_dim // 2
        self.emb = math.log(max_positions) / (self.half_dim - 1)
        self.emb = torch.exp(torch.arange(self.half_dim).float() * -self.emb)
        self.emb = nn.Parameter(self.emb, requires_grad=False)

    def forward(self, t):
        t = 0.9 * self.max_positions * t[:, None] * self.emb[None, :]
        return torch.cat([torch.sin(t), torch.cos(t)], dim=-1)

In [4]:
class Upsize(nn.Module):
    def __init__(self, cin, cout=None):
        super().__init__()
        cout = cin if cout is None else cout
        self.conv = nn.Conv2d(cin, cout, 3, 1, 1)

    def forward(self, x):
        x = F.interpolate(x, scale_factor=2.0, mode="nearest")
        x = self.conv(x)
        return x


class Downsize(nn.Module):
    def __init__(self, cin, cout=None):
        super().__init__()
        cout = cin if cout is None else cout
        self.conv = nn.Conv2d(cin, cout, 3, 2, 0)

    def forward(self, x):
        x = F.pad(x, (0, 1, 0, 1), mode="constant", value=0)
        x = self.conv(x)
        return x

In [5]:
def get_signal_scale(log_snr):
    return torch.sigmoid(log_snr).sqrt()

def get_noise_scale(log_snr):
    return torch.sigmoid(-log_snr).sqrt()


class LinearLogSnrDiffuser(object):
    def __init__(self):
        self.min_time = 1e-5
        self.log_snr0 = np.log(1 / (np.exp(1e-4) - 1))
        self.log_snr1 = np.log(1 / (np.exp(1e-4 + 10) - 1))

    def sample_time(self, batch_size):
        return self.min_time + torch.rand(batch_size) * (1. - self.min_time)

    def log_snr(self, t):
        return self.log_snr0 + (self.log_snr1 - self.log_snr0) * t
    
    def forward_process(self, x, t, noise=None):
        noise = torch.randn_like(x) if noise is None else noise
        log_snr = self.log_snr(t).unflatten(-1, (-1, *[1 for _ in x.size()[1:]]))
        alpha = get_signal_scale(log_snr)
        sigma = get_noise_scale(log_snr)
        return alpha * x + sigma * noise, noise

    @torch.no_grad()
    def reverse_process(self, z, model, num_steps):
        time = torch.ones(z.size()[0]).to(z.device)
        reversed_time = torch.linspace(1., 0., num_steps + 1)
        dt = reversed_time[0] - reversed_time[1]
        for i in range(num_steps): # t -> s
            t = reversed_time[i]        
            s = reversed_time[i + 1]
            log_snr_s = self.log_snr(s)
            log_snr_t = self.log_snr(t)
            alpha_s = get_signal_scale(log_snr_s)
            sigma_s = get_noise_scale(log_snr_s)
            alpha_t = get_signal_scale(log_snr_t)
            sigma_t = get_noise_scale(log_snr_t)
            alpha_s_to_t = alpha_t / alpha_s
            sigma_s_to_t = torch.sqrt(sigma_t.square() - alpha_s_to_t.square() * sigma_s.square())
            sigma_t_to_s = torch.sqrt(sigma_s.square() * sigma_s_to_t.square() / sigma_t.square())
            prev_coeff = 1 / alpha_s_to_t
            noise_coeff = - sigma_s_to_t.square() / (alpha_s_to_t * sigma_t)

            time.fill_(t)
            noise_pred = model(z, time)
            z = prev_coeff * z + noise_coeff * noise_pred + torch.randn_like(z) * sigma_t_to_s
            x_pred = (z - sigma_s * noise_pred) / alpha_s

        return x_pred

In [6]:
class Resnet(nn.Module):
    def __init__(self, in_ch, out_ch, num_groups=8, time_dim=None, dropout_ratio=0.05, bottleneck_ratio=0.5):
        super().__init__()
        cmid = int(bottleneck_ratio * in_ch)
        self.res_in = nn.Sequential(nn.GroupNorm(num_groups, in_ch),
                                    nn.SiLU(inplace=True),
                                    nn.Conv2d(in_ch, cmid, 3, 1, 1))

        self.time = nn.Linear(time_dim, cmid) if time_dim is not None else None

        self.res_out = nn.Sequential(nn.GroupNorm(num_groups, num_channels=cmid),
                                     nn.SiLU(inplace=True),
                                     nn.Dropout(p=dropout_ratio),
                                     nn.Conv2d(cmid, out_ch, 3, 1, 1))

        self.skip = None if in_ch == out_ch else nn.Conv2d(in_ch, out_ch, 1, 1, 0)

    def forward(self, x, t=None):
        res = self.res_in(x)
        if t is not None:
            time = self.time(t).unflatten(-1, (-1, 1, 1))
            res = res + time
        
        res = self.res_out(res)
        return x + res if self.skip is None else self.skip(x) + res

In [7]:
class Attention(nn.Module):
    def __init__(self, channels, num_heads=4, num_groups=8, bottleneck_ratio=0.5):
        super().__init__()
        self.hidden_size = int(bottleneck_ratio * channels)
        self.channels, self.num_heads = channels, num_heads
        self.group_norm = nn.GroupNorm(num_groups, channels)
        self.proj = nn.Conv2d(channels, 3 * self.hidden_size, 1, 1, 0, bias=False)
        self.out = nn.Conv2d(self.hidden_size, channels, 1, 1, 0)
        self.scale = (self.hidden_size // self.num_heads) ** -0.5

    def forward(self, x):
        h = self.group_norm(x)
        q, k, v = self.proj(h).flatten(-2, -1).unflatten(1, (self.num_heads * 3, -1)).chunk(3, dim=1)
        w = torch.matmul(q.transpose(-2, -1), k * self.scale)
        h = torch.matmul(v, w.softmax(dim=-2)).unflatten(-1, x.size()[2:]).flatten(1, 2)
        return x + self.out(h)

In [8]:
class UNet(nn.Module):
    def __init__(self, args):
        super().__init__()

        self.channels = args.arch.channels
        self.time_dim = args.arch.time_dim

        self.input_conv = [nn.Conv2d(args.arch.input_channels, self.channels, 3, 1, padding='same'),
                           nn.SiLU()]

        self.input_conv = nn.Sequential(*self.input_conv)
        
        self.time_embed = [PositionalEmbedding(self.channels),
                           nn.Linear(self.channels, self.time_dim), 
                           nn.SiLU(inplace=True), 
                           nn.Linear(self.time_dim, self.time_dim),
                           nn.SiLU()]

        self.time_embed = nn.Sequential(*self.time_embed)

        self.enc_tower = []

        for l, n_layers in enumerate(args.arch.n_layers):
            in_ch = self.channels if l == 0 else args.arch.ch_mults[l - 1] * self.channels
            ch = args.arch.ch_mults[l] * self.channels
            block = Downsize(in_ch, ch) if args.arch.resample[l] else nn.Conv2d(in_ch, ch, 1, 1, 0)
            self.enc_tower += [block]

            for _ in range(n_layers):
                self.enc_tower += [Resnet(ch, ch, time_dim=self.time_dim)]
                self.enc_tower += [Attention(ch)] if args.arch.use_attn[l] else []

        self.enc_tower = nn.ModuleList(self.enc_tower)
        mid_ch = args.arch.ch_mults[-1] * self.channels
        self.mid1 = Resnet(mid_ch, mid_ch, time_dim=self.time_dim)
        self.mid2 = Attention(mid_ch, bottleneck_ratio=0.25)
        self.mid3 = Resnet(mid_ch, mid_ch, time_dim=self.time_dim)

        self.dec_tower = []
        for l, mult in list(enumerate(args.arch.n_layers))[::-1]:
            
            ch = args.arch.ch_mults[l] * self.channels
            out_ch = self.channels if l == 0 else args.arch.ch_mults[l - 1] * self.channels

            for _ in range(n_layers):
                self.dec_tower += [Attention(ch)] if args.arch.use_attn[l] else []
                self.dec_tower += [Resnet(2 * ch, ch, time_dim=self.time_dim)]

            block = Upsize(ch, out_ch) if args.arch.resample[l] else nn.Conv2d(ch, out_ch, 1, 1, 0)
            self.dec_tower += [block]
        
        self.dec_tower = nn.ModuleList(self.dec_tower)
        self.output_conv = [nn.SiLU(),
                            nn.Conv2d(self.channels, args.arch.output_channels, 1, 1, 0)]

        self.output_conv = nn.Sequential(*self.output_conv)


    def forward(self, x, t):
        x = self.input_conv(x)
        t = self.time_embed(t)
        u = []
        for enc in self.enc_tower:
            if isinstance(enc, Resnet):
                x = enc(x, t)
                u.append(x)
            else:
                x = enc(x)

        x = self.mid1(x, t)
        x = self.mid2(x)
        x = self.mid3(x, t)

        for dec in self.dec_tower:
            if isinstance(dec, Resnet):
                x = torch.cat([x, u.pop()], dim=1)
                x = dec(x, t)
            else:
                x = dec(x)

        x = self.output_conv(x)
        return x

In [9]:
args = argparse.Namespace()
args.arch = argparse.Namespace()
args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

args.batch_size = 64
args.learning_rate = 1e-4
args.num_epochs = 100
args.ema_rate = 0.9999

args.arch.input_channels = 3
args.arch.output_channels = 3
args.arch.channels = 128
args.arch.time_dim = args.arch.channels * 4
args.arch.ch_mults = [1, 2, 4, 8]
args.arch.n_layers = [4, 4, 4, 4]
args.arch.use_attn = [False, False,  True, False]
args.arch.resample = [False,  True,  True,  True]

model = UNet(args).to(args.device)
ema_model = copy.deepcopy(model)

ema_model.requires_grad_(False)
diffuser = LinearLogSnrDiffuser()
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
loss_meter = AverageMeter()

In [10]:
state_dict = torch.load("/content/drive/My Drive/" + "vdm_cifar10.pth")
model.load_state_dict(state_dict['model'])
optimizer.load_state_dict(state_dict['optimizer'])
ema_model.load_state_dict(state_dict['ema_model'])

<All keys matched successfully>

In [11]:
@torch.no_grad()

def evaluate_bpd(data, model, num_steps, scheduler):
    time = torch.ones(data.size()[0]).to(data.device)
    reversed_time = torch.linspace(1., 0., num_steps + 1)
    dt = reversed_time[0] - reversed_time[1]


    elbo = 0.

    # z = torch.randn_like(data)

    for i in range(num_steps): # t -> s
        t = reversed_time[i]        
        s = reversed_time[i + 1]
        log_snr_s = scheduler.log_snr(s)
        log_snr_t = scheduler.log_snr(t)
        alpha_s = get_signal_scale(log_snr_s)
        sigma_s = get_noise_scale(log_snr_s)
        alpha_t = get_signal_scale(log_snr_t)
        sigma_t = get_noise_scale(log_snr_t)
        alpha_s_to_t = alpha_t / alpha_s
        sigma_s_to_t = torch.sqrt(sigma_t.square() - alpha_s_to_t.square() * sigma_s.square())
        sigma_t_to_s = torch.sqrt(sigma_s.square() * sigma_s_to_t.square() / sigma_t.square())
        
        z = alpha_t * data + sigma_t * torch.randn_like(data)
        true_mu = ((alpha_s_to_t / sigma_s_to_t.square()) * z + (alpha_s / sigma_s.square()) * data) * sigma_t_to_s.square()

        time.fill_(t)
        noise_pred = model(z, time)
        pred_x0 = x_pred = (z - sigma_t * noise_pred) / alpha_t
        pred_mu = ((alpha_s_to_t / sigma_s_to_t.square()) * z + (alpha_s / sigma_s.square()) * pred_x0) * sigma_t_to_s.square()

        if i != num_steps - 1:
            kld = 0.5 * (true_mu - pred_mu).square().mean(dim=0).sum() / sigma_t_to_s.square()
            elbo += kld
            # z = true_mu + sigma_t_to_s * torch.randn_like(data)
        
        else:
            log_prob = - 0.5 * (data - pred_mu).square().mean(dim=0).sum() / sigma_t_to_s.square()
            log_prob = log_prob - 0.5 * torch.log(np.sqrt(2. * np.pi) * sigma_t_to_s)
            elbo -= log_prob  

    bpd = elbo / (3 * 32 * 32 * np.log(2))
    return bpd


tensor(3.3233, device='cuda:0')


In [13]:
DataLoader(test_dataset, batch_size=100, shuffle=False)

ema_model.eval()
avg_bpd = AverageMeter()
for x, _ in loader:
    x = x.to(args.device)
    x = 2. * x - 1.
    bpd = evaluate_bpd(x, ema_model, 25, diffuser)
    avg_bpd.update(bpd.item())

print(avg_bpd.avg)

3.2841094566296927
