In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
import torchvision

In [2]:
def nonlinearity(x):
    return F.silu(x)

In [3]:
def normalize(x, temb, name):
    return nn.GroupNorm(num_groups=32, num_channels=x.shape[1], eps=1e-6, affine=True)(x)

In [4]:
def conv2d(x, num_units, kernel_size=3, stride=1, init_scale=1.0):
    conv = weight_norm(nn.Conv2d(x.shape[1], num_units, kernel_size, stride, padding=kernel_size // 2))
    nn.init.kaiming_normal_(conv.weight, a=0, mode='fan_in', nonlinearity='leaky_relu')
    conv.weight.data *= init_scale
    return conv(x)

In [5]:
def upsample(x, with_conv):
    B, C, H, W = x.shape
    x = F.interpolate(x, scale_factor=2, mode='nearest')
    if with_conv:
        x = conv2d(x, num_units=C, kernel_size=3, stride=1)
    return x

In [6]:
def downsample(x, with_conv):
    if with_conv:
        x = conv2d(x, num_units=x.shape[1], kernel_size=3, stride=2)
    else:
        x = F.avg_pool2d(x, kernel_size=2, stride=2, padding=0)
    return x

In [7]:
def nin(x, num_units):
    B, C, H, W = x.shape
    return weight_norm(nn.Conv2d(C, num_units, kernel_size=1, stride=1, padding=0))(x)

In [8]:
def resnet_block(x, temb, out_ch=None, conv_shortcut=False, dropout=0.0):
    B, C, H, W = x.shape
    if out_ch is None:
        out_ch = C

    h = x
    h = nonlinearity(normalize(h, temb, name='norm1'))
    h = conv2d(h, num_units=out_ch)
    h = h + nn.linear(nonlinearity(temb), out_ch)[:, :, None, None]

    h = nonlinearity(normalize(h, temb, name='norm2'))
    h = F.dropout(h, p=dropout, training=True)
    h = conv2d(h, num_units=out_ch, init_scale=0.)

    if C != out_ch:
        if conv_shortcut:
            x = conv2d(x, num_units=out_ch)
        else:
            x = nin(x, out_ch)

    return x + h

In [9]:
def dense(x, num_units):
    return weight_norm(nn.Linear(x.shape[-1], num_units))(x)

In [10]:
def attn_block(x, temb):
    B, C, H, W = x.shape
    h = normalize(x, temb=temb, name='norm')
    q = nin(h, C)
    k = nin(h, C)
    v = nin(h, C)

    w = torch.einsum('bchw,bCHW->bhwHW', q, k) * (C ** -0.5)
    w = w.view(B, H, W, H * W)
    w = F.softmax(w, dim=-1)
    w = w.view(B, H, W, H, W)

    h = torch.einsum('bhwHW,bHWc->bhwc', w, v)
    h = nin(h, C)

    return x + h

In [11]:
def get_timestep_embedding(t, dim):
    half_dim = dim // 2
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -(torch.log(torch.tensor(10000.0)) / half_dim))
    emb = t.float()[:, None] * emb[None, :]
    return torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)


In [12]:
class Model(nn.Module):
    def __init__(self, num_classes, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks=2, attn_resolutions=[], dropout=0.0, resamp_with_conv=True):
        super(Model, self).__init__()
        self.num_classes = num_classes
        self.ch = ch
        self.out_ch = out_ch
        self.ch_mult = ch_mult
        self.num_res_blocks = num_res_blocks
        self.attn_resolutions = attn_resolutions
        self.dropout = dropout
        self.resamp_with_conv = resamp_with_conv

        self.temb_dense_0 = dense
        self.temb_dense_1 = dense

        self.conv_in = weight_norm(nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1))

        self.down = nn.ModuleList()
        self.up = nn.ModuleList()
        num_resolutions = len(ch_mult)

        for i_level in range(num_resolutions):
            for i_block in range(num_res_blocks):
                self.down.append(resnet_block)
                if 2 ** i_level in attn_resolutions:
                    self.down.append(attn_block)
            if i_level != num_resolutions - 1:
                self.down.append(downsample)

        self.mid = nn.ModuleList([
            resnet_block,
            attn_block,
            resnet_block,
        ])

        for i_level in reversed(range(num_resolutions)):
            for i_block in range(num_res_blocks + 1):
                self.up.append(resnet_block)
                if 2 ** i_level in attn_resolutions:
                    self.up.append(attn_block)
            if i_level != 0:
                self.up.append(upsample)

        self.norm_out = normalize
        self.conv_out = weight_norm(nn.Conv2d(ch, out_ch, kernel_size=3, stride=1, padding=1))

    def forward(self, x, t, y=None):
        B, C, H, W = x.shape
        assert y is None, 'not supported'

        temb = get_timestep_embedding(t, self.ch)
        temb = self.temb_dense_0(temb, self.ch * 4)
        temb = self.temb_dense_1(nonlinearity(temb), self.ch * 4)

        h = self.conv_in(x)
        hs = [h]

        for layer in self.down:
            h = layer(h, temb=temb)
            
            hs.append(h)

        for layer in self.mid:
            h = layer(h, temb=temb)

        for layer in self.up:
            h = layer(torch.cat([h, hs.pop()], dim=1), temb=temb)

        h = nonlinearity(self.norm_out(h, temb=temb, name='norm_out'))
        h = self.conv_out(h)
        return h

In [15]:
dataset = torchvision.datasets.CIFAR10(root='../../dataset', train=True, download=False)

In [17]:
import functools

import fire
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from diffusion_pytorch import utils
from diffusion_pytorch.diffusion_utils_2 import get_beta_schedule, GaussianDiffusion2
from diffusion_pytorch.models import unet
from diffusion_pytorch.tpu_utils import tpu_utils, datasets, simple_eval_worker


class Model(nn.Module):
    def __init__(self, *, model_name, betas, model_mean_type: str, model_var_type: str, loss_type: str,
                 num_classes: int, dropout: float, randflip):
        super(Model, self).__init__()
        self.model_name = model_name
        self.diffusion = GaussianDiffusion2(
            betas=betas, model_mean_type=model_mean_type, model_var_type=model_var_type, loss_type=loss_type)
        self.num_classes = num_classes
        self.dropout = dropout
        self.randflip = randflip

    def _denoise(self, x, t, y, dropout):
        B, C, H, W = x.shape
        assert x.dtype == torch.float32
        assert t.shape == (B,) and t.dtype in [torch.int32, torch.int64]
        assert y.shape == (B,) and y.dtype in [torch.int32, torch.int64]
        out_ch = (C * 2) if self.diffusion.model_var_type == 'learned' else C
        y = None
        if self.model_name == 'unet2d16b2':  # 35.7M
            return unet.model(
                x, t=t, y=y, name='model', ch=128, ch_mult=(1, 2, 2, 2), num_res_blocks=2, attn_resolutions=(16,),
                out_ch=out_ch, num_classes=self.num_classes, dropout=dropout
            )
        raise NotImplementedError(self.model_name)

    def train_fn(self, x, y):
        B, C, H, W = x.shape
        if self.randflip:
            x = torch.flip(x, dims=[-1])  # Random horizontal flip
            assert x.shape == (B, C, H, W)
        t = torch.randint(0, self.diffusion.num_timesteps, (B,), dtype=torch.int32, device=x.device)
        losses = self.diffusion.training_losses(
            denoise_fn=functools.partial(self._denoise, y=y, dropout=self.dropout), x_start=x, t=t)
        assert losses.shape == t.shape == (B,)
        return {'loss': losses.mean()}

    def samples_fn(self, dummy_noise, y):
        return {
            'samples': self.diffusion.p_sample_loop(
                denoise_fn=functools.partial(self._denoise, y=y, dropout=0),
                shape=dummy_noise.shape,
                noise_fn=torch.randn
            )
        }

    def progressive_samples_fn(self, dummy_noise, y):
        samples, progressive_samples = self.diffusion.p_sample_loop_progressive(
            denoise_fn=functools.partial(self._denoise, y=y, dropout=0),
            shape=dummy_noise.shape,
            noise_fn=torch.randn
        )
        return {'samples': samples, 'progressive_samples': progressive_samples}

    def bpd_fn(self, x, y):
        total_bpd_b, terms_bpd_bt, prior_bpd_b, mse_bt = self.diffusion.calc_bpd_loop(
            denoise_fn=functools.partial(self._denoise, y=y, dropout=0),
            x_start=x
        )
        return {
            'total_bpd': total_bpd_b,
            'terms_bpd': terms_bpd_bt,
            'prior_bpd': prior_bpd_b,
            'mse': mse_bt
        }


ModuleNotFoundError: No module named 'fire'