### Imports and kaggle setup

In [None]:
!pip install POT -q
!pip install wandb -q

from torchvision import datasets
import torchvision.datasets.utils as dataset_utils
import torchvision

from ot.bregman import sinkhorn
from ot.lp import emd

from ignite.metrics import SSIM, PSNR

import yaml
import sys
import os

from PIL import Image
import matplotlib.pyplot as plt

import cv2
import numpy as np
import os
import random
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset, DataLoader

from functools import partial

import torch
import torch.nn as nn

from abc import abstractmethod
import math
import torch.nn.functional as F
import torch.optim as optim

from tqdm import tqdm
import copy

import wandb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(1234)
random.seed(1234)
np.random.seed(1234)
torch.cuda.manual_seed(1234)

torch.backends.cudnn.deterministic=True

import warnings
warnings.filterwarnings("ignore")

#for metrics
from ignite.engine import *
from ignite.handlers import *
from ignite.metrics import *
from ignite.utils import *
from ignite.contrib.metrics.regression import *
from ignite.contrib.metrics import *

def eval_step(engine, batch):
    return batch

default_evaluator = Engine(eval_step)

In [None]:
ls

In [None]:
%%writefile kaggle.json
{"username":"kaggle3223","key":"de631c18f08a192190f69129affaaa2e"}

In [None]:
!mkdir ~/.kaggle/
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json

In [None]:
device

In [None]:
# !wandb login

### Create config

In [None]:
%%writefile conf.yml

MODE : 2   # 1 Train, 2 Metrics
IMAGE_SIZE : [32, 32] #[224,224]
CHANNEL_X : 3
CHANNEL_Y : 3
TIMESTEPS : 1000 #2000
MODEL_CHANNELS : 128
NUM_RESBLOCKS : 4
ATTENTION_RESOLUTIONS : [2,4,8]
DROPOUT : 0
CHANNEL_MULT : [1,2,4,8]
CONV_RESAMPLE : 'True'
USE_CHECKPOINT : 'False'
USE_FP16 : 'False'
NUM_HEADS : 1
NUM_HEAD_CHANNELS : 64
NUM_HEAD_UPSAMPLE : -1
USE_SCALE_SHIFT_NORM : 'False'
RESBLOCK_UPDOWN : 'False'
USE_NEW_ATTENTION_ORDER : 'False'
PATH_COLOR : "./ab/ab/ab1.npy" #'color'
PATH_GREY : "./l/gray_scale.npy"  #'grayscale'
BATCH_SIZE : 64
BATCH_SIZE_VAL : 8
ITERATION_MAX : 20000
LR : 0.0001
LOSS : 'L2'
VALIDATION_EVERY : 1000
EMA_EVERY : 100
START_EMA : 2000
SAVE_MODEL_EVERY : 5000 #10000
PLOT_EVERY : 50
EXP_NAME: "shoes_bags_fixed_unregularized"

In [None]:
class Config(dict):
    def __init__(self, config_path):
        with open(config_path, 'r') as f:
            self._yaml = f.read()
            self._dict = yaml.safe_load(self._yaml)
            self._dict['PATH'] = os.path.dirname(config_path)

    def __getattr__(self, name):
        if self._dict.get(name) is not None:
            return self._dict[name]
        return None

    def print(self):
        print('Model configurations:')
        print('---------------------------------')
        print(self._yaml)
        print('')
        print('---------------------------------')
        print('')


def load_config(path):
    
    config_path = path
    config = Config(config_path)
    return(config)

### Load data, create custom dataset with OT

In [None]:
# for shoes handbags
!FILEID='1i1F462P45I2w3lIFL-u8gwmcPm3iEAGb' && \
FILENAME='shoes_bags_data.zip' && \
FILEDEST="https://docs.google.com/uc?export=download&id=${FILEID}" && \
wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate ${FILEDEST} -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=${FILEID}" -O $FILENAME && rm -rf /tmp/cookies.txt

!unzip shoes_bags_data.zip

In [None]:
def load_shoes_bags(data_path="./shoes_tensor_32.torch",total_size=11000):
    data = torch.load(data_path)
    
    random_idx = np.random.choice(np.arange(len(data)), size=total_size, replace=False)
    data = data[random_idx]
    
    np.random.shuffle(random_idx)
    train_data, test_data = data[:10000, :], data[10000:, :]
    assert len(train_data) == total_size - 1000
    return train_data, test_data


class ShoeBagDataset(Dataset):
    def __init__(self, data_from, data_to):
        self.data_from = data_from
        self.data_to = data_to

    def __len__(self):
        return len(self.data_from)

    def __getitem__(self, idx: int):
        """
            returns single sample
        """
        return self.data_from[idx], self.data_to[idx]

In [None]:
def collate_fn(batch) -> tuple:
    batch_size = len(batch)

    data_from = torch.Tensor([b[0].tolist() for b in batch])
    data_to = torch.Tensor([b[1].tolist() for b in batch])

    distance = (torch.cdist(data_from.reshape(batch_size, -1), data_to.reshape(batch_size, -1))**2)/2
    distance = distance.numpy()

    epsilon = 10

    distance = distance/(3*32*32)
    epsilon = epsilon/(3*32*32)

    a = np.ones(batch_size)/batch_size
    b = np.ones(batch_size)/batch_size
    map_probs = emd(a, b, distance)

    to = []
    for b in range(batch_size):
        # tranposrt_idx = np.random.choice(np.arange(batch_size), size=batch_size, replace=True, p=map_probs[b+1]/(map_probs[b+1].sum()))
        # img = data_to[tranposrt_idx[0]]

        tranposrt_idx = np.random.choice(np.arange(batch_size), size=1, replace=True, p=map_probs[b]/(map_probs[b].sum()))
        img = data_to[tranposrt_idx[0]]
        to.append(img.tolist())
    
    return data_from, torch.Tensor(to)


class ColoredDataset(Dataset):
    def __init__(self, data_from, data_to):
        # self.image_ids = torch.tensor(image_ids, dtype=torch.int64)
        self.data_from = data_from
        self.data_to = data_to
        # self.num_classes = 9

    def __len__(self):
        return len(self.data_from)

    def __getitem__(self, idx: int):
        """
            returns single sample
        """
        return self.data_from[idx], self.data_to[idx]

In [None]:
def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

class GaussianDiffusion(nn.Module):
    def __init__(           #Remplacer par fichier config
        self,
        image_size = (224,224),
        channel_y = 3,
        channel_x = 1,
        timesteps = 2000
        ):
        
        super().__init__()
        
        self.image_size = image_size
        self.channel_y = channel_y
        self.channel_x = channel_x
        self.timesteps = timesteps
        
        betas = np.linspace(1e-6,0.01,timesteps)
        alphas = 1. - betas
        gammas = np.cumprod(alphas,axis=0)
        
        to_torch = partial(torch.tensor, dtype=torch.float32)
        
        #calculation for q(y_t|y_{t-1})
        self.register_buffer('gammas',to_torch(gammas))
        self.register_buffer('sqrt_one_minus_gammas',to_torch(np.sqrt(1-gammas)))
        self.register_buffer('sqrt_gammas',to_torch(np.sqrt(gammas)))
    
    def noisy_image(self,t,y):
        ''' Compute y_noisy according to (6) p15 of [2]'''
        noise = torch.randn_like(y)
        y_noisy = extract(self.gammas,t,y.shape)*y + extract(self.sqrt_one_minus_gammas,t,noise.shape)*noise
        return y_noisy, noise
        
    def noise_prediction(self,denoise_fn,y_noisy,x,t):
        ''' Use the NN to predict the noise added between y_{t-1} and y_t'''
        noise_pred = denoise_fn(y_noisy,x,t)
        return(noise_pred)

### Model

In [None]:
class GroupNorm32(nn.GroupNorm):
    def forward(self, x):
        return super().forward(x.float()).type(x.dtype)

def normalization(channels):
    """
    Make a standard normalization layer.
    :param channels: number of input channels.
    :return: an nn.Module for normalization.
    """
    return GroupNorm32(32, channels)

def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module

def checkpoint(func, inputs, params, flag):
    """
    Evaluate a function without caching intermediate activations, allowing for
    reduced memory at the expense of extra compute in the backward pass.
    :param func: the function to evaluate.
    :param inputs: the argument sequence to pass to `func`.
    :param params: a sequence of parameters `func` depends on but does not
                   explicitly take as arguments.
    :param flag: if False, disable gradient checkpointing.
    """
    if flag:
        args = tuple(inputs) + tuple(params)
        return CheckpointFunction.apply(func, len(inputs), *args)
    else:
        return func(*inputs)

def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.
    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding


class CheckpointFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, run_function, length, *args):
        ctx.run_function = run_function
        ctx.input_tensors = list(args[:length])
        ctx.input_params = list(args[length:])
        with torch.no_grad():
            output_tensors = ctx.run_function(*ctx.input_tensors)
        return output_tensors

    @staticmethod
    def backward(ctx, *output_grads):
        ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
        with torch.enable_grad():
            # Fixes a bug where the first op in run_function modifies the
            # Tensor storage in place, which is not allowed for detach()'d
            # Tensors.
            shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
            output_tensors = ctx.run_function(*shallow_copies)
        input_grads = torch.autograd.grad(
            output_tensors,
            ctx.input_tensors + ctx.input_params,
            output_grads,
            allow_unused=True,
        )
        del ctx.input_tensors
        del ctx.input_params
        del output_tensors
        return (None, None) + input_grads

class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

class TimestepBlock(nn.Module):
    """
    Any module where forward() takes timestep embeddings as a second argument.
    """

    @abstractmethod
    def forward(self, x, emb):
        """
        Apply the module to `x` given `emb` timestep embeddings.
        """


class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    """
    A sequential module that passes timestep embeddings to the children that
    support it as an extra input.
    """

    def forward(self, x, emb):
        for layer in self:
            if isinstance(layer, TimestepBlock):
                x = layer(x, emb)
            else:
                x = layer(x)
        return x

class Upsample(nn.Module):
    """
    An upsampling layer with an optional convolution.
    :param channels: channels in the inputs and outputs.
    :param use_conv: a bool determining if a convolution is applied.
    """

    def __init__(self, channels, use_conv, out_channels=None):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        if use_conv:
            self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1)

    def forward(self, x):
        assert x.shape[1] == self.channels
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        if self.use_conv:
            x = self.conv(x)
        return x

class Downsample(nn.Module):
    """
    A downsampling layer with an optional convolution.
    :param channels: channels in the inputs and outputs.
    :param use_conv: a bool determining if a convolution is applied.
    """

    def __init__(self, channels, use_conv, out_channels=None):
        super().__init__()
        self.channels = channels
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        stride = 2
        if use_conv:
            self.op = nn.Conv2d(
                self.channels, self.out_channels, 3, stride=stride, padding=1
            )
        else:
            assert self.channels == self.out_channels
            self.op = nn.AvgPool2d(kernel_size=stride, stride=stride)

    def forward(self, x):
        assert x.shape[1] == self.channels
        return self.op(x)


class ResBlock(TimestepBlock):
    """
    A residual block that can optionally change the number of channels.
    :param channels: the number of input channels.
    :param emb_channels: the number of timestep embedding channels.
    :param dropout: the rate of dropout.
    :param out_channels: if specified, the number of out channels.
    :param use_conv: if True and out_channels is specified, use a spatial
        convolution instead of a smaller 1x1 convolution to change the
        channels in the skip connection.
    :param use_checkpoint: if True, use gradient checkpointing on this module.
    :param up: if True, use this block for upsampling.
    :param down: if True, use this block for downsampling.
    """

    def __init__(
        self,
        channels,
        emb_channels,
        dropout,
        out_channels=None,
        use_conv=False,
        use_scale_shift_norm=False,
        use_checkpoint=False,
        up=False,
        down=False,
    ):
        super().__init__()
        self.channels = channels
        self.emb_channels = emb_channels
        self.dropout = dropout
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_checkpoint = use_checkpoint
        self.use_scale_shift_norm = use_scale_shift_norm

        self.in_layers = nn.Sequential(
            normalization(channels),
            nn.SiLU(),
            nn.Conv2d(channels, self.out_channels, 3, padding=1),
        )

        self.updown = up or down

        if up:
            self.h_upd = Upsample(channels, False)
            self.x_upd = Upsample(channels, False)
        elif down:
            self.h_upd = Downsample(channels, False)
            self.x_upd = Downsample(channels, False)
        else:
            self.h_upd = self.x_upd = nn.Identity()

        self.emb_layers = nn.Sequential(
            nn.SiLU(),
            nn.Linear(
                emb_channels,
                2 * self.out_channels if use_scale_shift_norm else self.out_channels,
            ),
        )
        self.out_layers = nn.Sequential(
            normalization(self.out_channels),
            nn.SiLU(),
            nn.Dropout(p=dropout),
            zero_module(
                nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)
            ),
        )

        if self.out_channels == channels:
            self.skip_connection = nn.Identity()
        elif use_conv:
            self.skip_connection = nn.Conv2d(
                channels, self.out_channels, 3, padding=1
            )
        else:
            self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)

    def forward(self, x, emb):
        """
        Apply the block to a Tensor, conditioned on a timestep embedding.
        :param x: an [N x C x ...] Tensor of features.
        :param emb: an [N x emb_channels] Tensor of timestep embeddings.
        :return: an [N x C x ...] Tensor of outputs.
        """
        return checkpoint(
            self._forward, (x, emb), self.parameters(), self.use_checkpoint
        )

    def _forward(self, x, emb):
        if self.updown:
            in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
            h = in_rest(x)
            h = self.h_upd(h)
            x = self.x_upd(x)
            h = in_conv(h)
        else:
            h = self.in_layers(x)
        emb_out = self.emb_layers(emb).type(h.dtype)
        while len(emb_out.shape) < len(h.shape):
            emb_out = emb_out[..., None]
        if self.use_scale_shift_norm:
            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
            scale, shift = torch.chunk(emb_out, 2, dim=1)
            h = out_norm(h) * (1 + scale) + shift
            h = out_rest(h)
        else:
            h = h + emb_out
            h = self.out_layers(h)
        return self.skip_connection(x) + h

class AttentionBlock(nn.Module):
    """
    An attention block that allows spatial positions to attend to each other.
    Originally ported from here, but adapted to the N-d case.
    https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
    """

    def __init__(
        self,
        channels,
        num_heads=1,
        num_head_channels=-1,
        use_checkpoint=False,
        use_new_attention_order=False,
    ):
        super().__init__()
        self.channels = channels
        if num_head_channels == -1:
            self.num_heads = num_heads
        else:
            assert (
                channels % num_head_channels == 0
            ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
            self.num_heads = channels // num_head_channels
        self.use_checkpoint = use_checkpoint
        self.norm = normalization(channels)
        self.qkv = nn.Conv1d(channels, channels * 3, 1)
        if use_new_attention_order:
            # split qkv before split heads
            self.attention = QKVAttention(self.num_heads)
        else:
            # split heads before split qkv
            self.attention = QKVAttentionLegacy(self.num_heads)

        self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))

    def forward(self, x):
        return checkpoint(self._forward, (x,), self.parameters(), True)

    def _forward(self, x):
        b, c, *spatial = x.shape
        x = x.reshape(b, c, -1)
        qkv = self.qkv(self.norm(x))
        h = self.attention(qkv)
        h = self.proj_out(h)
        return (x + h).reshape(b, c, *spatial)

def count_flops_attn(model, _x, y):
    """
    A counter for the `thop` package to count the operations in an
    attention operation.
    Meant to be used like:
        macs, params = thop.profile(
            model,
            inputs=(inputs, timestamps),
            custom_ops={QKVAttention: QKVAttention.count_flops},
        )
    """
    b, c, *spatial = y[0].shape
    num_spatial = int(np.prod(spatial))
    # We perform two matmuls with the same number of ops.
    # The first computes the weight matrix, the second computes
    # the combination of the value vectors.
    matmul_ops = 2 * b * (num_spatial ** 2) * c
    model.total_ops += torch.DoubleTensor([matmul_ops])

class QKVAttentionLegacy(nn.Module):
    """
    A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
    """

    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads

    def forward(self, qkv):
        """
        Apply QKV attention.
        :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
        :return: an [N x (H * C) x T] tensor after attention.
        """
        bs, width, length = qkv.shape
        assert width % (3 * self.n_heads) == 0
        ch = width // (3 * self.n_heads)
        q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        weight = torch.einsum(
            "bct,bcs->bts", q * scale, k * scale
        )  # More stable with f16 than dividing afterwards
        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
        a = torch.einsum("bts,bcs->bct", weight, v)
        return a.reshape(bs, -1, length)

    @staticmethod
    def count_flops(model, _x, y):
        return count_flops_attn(model, _x, y)


class QKVAttention(nn.Module):
    """
    A module which performs QKV attention and splits in a different order.
    """

    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads

    def forward(self, qkv):
        """
        Apply QKV attention.
        :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
        :return: an [N x (H * C) x T] tensor after attention.
        """
        bs, width, length = qkv.shape
        assert width % (3 * self.n_heads) == 0
        ch = width // (3 * self.n_heads)
        q, k, v = qkv.chunk(3, dim=1)
        scale = 1 / math.sqrt(math.sqrt(ch))
        weight = torch.einsum(
            "bct,bcs->bts",
            (q * scale).view(bs * self.n_heads, ch, length),
            (k * scale).view(bs * self.n_heads, ch, length),
        )  # More stable with f16 than dividing afterwards
        weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
        a = torch.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
        return a.reshape(bs, -1, length)

    @staticmethod
    def count_flops(model, _x, y):
        return count_flops_attn(model, _x, y)

class UNetModel(nn.Module):
    """
    The full UNet model with attention and timestep embedding.
    :param in_channels: channels in the input Tensor, for image colorization : Y_channels + X_channels .
    :param model_channels: base channel count for the model.
    :param out_channels: channels in the output Tensor.
    :param num_res_blocks: number of residual blocks per downsample.
    :param attention_resolutions: a collection of downsample rates at which
        attention will take place. May be a set, list, or tuple.
        For example, if this contains 4, then at 4x downsampling, attention
        will be used.
    :param dropout: the dropout probability.
    :param channel_mult: channel multiplier for each level of the UNet.
    :param conv_resample: if True, use learned convolutions for upsampling and
        downsampling.
    :param use_checkpoint: use gradient checkpointing to reduce memory usage.
    :param num_heads: the number of attention heads in each attention layer.
    :param num_heads_channels: if specified, ignore num_heads and instead use
                               a fixed channel width per attention head.
    :param num_heads_upsample: works with num_heads to set a different number
                               of heads for upsampling. Deprecated.
    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
    :param resblock_updown: use residual blocks for up/downsampling.
    :param use_new_attention_order: use a different attention pattern for potentially
                                    increased efficiency.
    """

    def __init__(
        self,
        image_size,
        in_channels,
        model_channels,
        out_channels,
        num_res_blocks,
        attention_resolutions,
        dropout=0,
        channel_mult=(1, 2, 4, 8),
        conv_resample=True,
        use_checkpoint=False,
        use_fp16=False,
        num_heads=1,
        num_head_channels=-1,
        num_heads_upsample=-1,
        use_scale_shift_norm=False,
        resblock_updown=False,
        use_new_attention_order=False,
    ):
        super().__init__()

        if num_heads_upsample == -1:
            num_heads_upsample = num_heads

        self.image_size = image_size
        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.num_res_blocks = num_res_blocks
        self.attention_resolutions = attention_resolutions
        self.dropout = dropout
        self.channel_mult = channel_mult
        self.conv_resample = conv_resample
        self.use_checkpoint = use_checkpoint
        self.dtype = torch.float16 if use_fp16 else torch.float32
        self.num_heads = num_heads
        self.num_head_channels = num_head_channels
        self.num_heads_upsample = num_heads_upsample

        time_embed_dim = model_channels * 4
        self.time_embed = nn.Sequential(
            nn.Linear(model_channels, time_embed_dim),
            nn.SiLU(),
            nn.Linear(time_embed_dim, time_embed_dim),
        )

        ch = input_ch = int(channel_mult[0] * model_channels)
        self.input_blocks = nn.ModuleList(
            [TimestepEmbedSequential(nn.Conv2d(in_channels, ch, 3, padding=1))]
        )
        self._feature_size = ch
        input_block_chans = [ch]
        ds = 1
        for level, mult in enumerate(channel_mult):
            for _ in range(num_res_blocks):
                layers = [
                    ResBlock(
                        ch,
                        time_embed_dim,
                        dropout,
                        out_channels=int(mult * model_channels),
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = int(mult * model_channels)
                if ds in attention_resolutions:
                    layers.append(
                        AttentionBlock(
                            ch,
                            use_checkpoint=use_checkpoint,
                            num_heads=num_heads,
                            num_head_channels=num_head_channels,
                            use_new_attention_order=use_new_attention_order,
                        )
                    )
                self.input_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch
                input_block_chans.append(ch)
            if level != len(channel_mult) - 1:
                out_ch = ch
                self.input_blocks.append(
                    TimestepEmbedSequential(
                        ResBlock(
                            ch,
                            time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            down=True,
                        )
                        if resblock_updown
                        else Downsample(
                            ch, conv_resample, out_channels=out_ch
                        )
                    )
                )
                ch = out_ch
                input_block_chans.append(ch)
                ds *= 2
                self._feature_size += ch

        self.middle_block = TimestepEmbedSequential(
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
            AttentionBlock(
                ch,
                use_checkpoint=use_checkpoint,
                num_heads=num_heads,
                num_head_channels=num_head_channels,
                use_new_attention_order=use_new_attention_order,
            ),
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
        )
        self._feature_size += ch

        self.output_blocks = nn.ModuleList([])
        for level, mult in list(enumerate(channel_mult))[::-1]:
            for i in range(num_res_blocks + 1):
                ich = input_block_chans.pop()
                layers = [
                    ResBlock(
                        ch + ich,
                        time_embed_dim,
                        dropout,
                        out_channels=int(model_channels * mult),
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = int(model_channels * mult)
                if ds in attention_resolutions:
                    layers.append(
                        AttentionBlock(
                            ch,
                            use_checkpoint=use_checkpoint,
                            num_heads=num_heads_upsample,
                            num_head_channels=num_head_channels,
                            use_new_attention_order=use_new_attention_order,
                        )
                    )
                if level and i == num_res_blocks:
                    out_ch = ch
                    layers.append(
                        ResBlock(
                            ch,
                            time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            up=True,
                        )
                        if resblock_updown
                        else Upsample(ch, conv_resample, out_channels=out_ch)
                    )
                    ds //= 2
                self.output_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch

        self.out = nn.Sequential(
            normalization(ch),
            nn.SiLU(),
            zero_module(nn.Conv2d(input_ch, out_channels, 3, padding=1)),
        )

    def forward(self,y, x, timesteps):
        """
        Apply the model to an input batch.
        :param y: a [N x 3 x ...] Tensor of noisy colored images 
        :param x: an [N x 1 x ...] Tensor of inputs (B&W)
        :param timesteps: a 1-D batch of timesteps.
        :return: an [N x C x ...] Tensor of outputs.
        """

        z = torch.cat([x,y],dim = 1)

        hs = []
        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

        h = z.type(torch.float32)
        for module in self.input_blocks:
            h = module(h, emb)
            hs.append(h)

        h = self.middle_block(h, emb)

        for module in self.output_blocks:
            h = torch.cat([h, hs.pop()], dim=1)
            h = module(h, emb)
            
        h = h.type(z.dtype)
        return self.out(h)

### Plottings

In [None]:
def fig2data ( fig ):
    """
    @brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it
    @param fig a matplotlib figure
    @return a numpy 3D array of RGBA values
    """
    # draw the renderer
    fig.canvas.draw ( )
 
    # Get the RGBA buffer from the figure
    w,h = fig.canvas.get_width_height()
    buf = np.fromstring ( fig.canvas.tostring_argb(), dtype=np.uint8 )
    buf.shape = ( w, h,4 )
 
    # canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode
    buf = np.roll ( buf, 3, axis = 2 )
    return buf

def fig2img ( fig ):
    buf = fig2data ( fig )
    w, h, d = buf.shape
    return Image.frombytes( "RGBA", ( w ,h ), buf.tostring( ) )


def plot_y(Grey, Y, Y_ema, Color):

    imgs = torch.cat([Grey, Y, Y_ema, Color]).permute(0,2,3,1).mul(0.5).add(0.5).numpy().clip(0,1)

    fig, axes = plt.subplots(4, 8, figsize=(15, 4.5), dpi=150)
    for i, ax in enumerate(axes.flatten()):
        ax.imshow(imgs[i])
        ax.get_xaxis().set_visible(False)
        ax.set_yticks([])
        
    axes[0, 0].set_ylabel('Grey', fontsize=24)
    axes[1, 0].set_ylabel('Y', fontsize=24)
    axes[2, 0].set_ylabel('Y_ema', fontsize=24)
    axes[3, 0].set_ylabel('Color', fontsize=24)
    
    fig.tight_layout(pad=0.001)
    return fig, axes


def plot_noise(ref, pred):
    imgs = torch.cat([ref[:10], pred[:10]]).permute(0,2,3,1).mul(0.5).add(0.5).numpy().clip(0,1)

    fig, axes = plt.subplots(2, 10, figsize=(15, 4.5), dpi=150)
    for i, ax in enumerate(axes.flatten()):
        ax.imshow(imgs[i])
        ax.get_xaxis().set_visible(False)
        ax.set_yticks([])
        
    axes[0, 0].set_ylabel('Reference', fontsize=24)
    axes[1, 0].set_ylabel('Prediction', fontsize=24)
    
    fig.tight_layout(pad=0.001)
    return fig, axes

In [None]:
def print_stat(st):
    print(f"Min: {st.min()}, Max: {st.max()}, Mean: {st.mean()}, Std: {st.std()}")

### Training

In [None]:
class Trainer():
    def __init__(self,config):
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.diffusion = GaussianDiffusion(config.IMAGE_SIZE,config.CHANNEL_X,config.CHANNEL_Y,config.TIMESTEPS)
        in_channels = config.CHANNEL_X + config.CHANNEL_Y
        out_channels = config.CHANNEL_Y
        self.network = UNetModel(
            config.IMAGE_SIZE,
            in_channels,
            config.MODEL_CHANNELS,
            out_channels,
            config.NUM_RESBLOCKS,
            config.ATTENTION_RESOLUTIONS,
            config.DROPOUT,
            config.CHANNEL_MULT,
            config.CONV_RESAMPLE,
            config.USE_CHECKPOINT,
            config.USE_FP16,
            config.NUM_HEADS,
            config.NUM_HEAD_CHANNELS,
            config.NUM_HEAD_UPSAMPLE,
            config.USE_SCALE_SHIFT_NORM,
            config.RESBLOCK_UPDOWN,
            config.USE_NEW_ATTENTION_ORDER,
            ).to(self.device)

        train_set_bags, test_set_bags = load_shoes_bags(data_path="./shoes_tensor_32.torch")
        train_set_shoes, test_set_shoes = load_shoes_bags(data_path="./handbag_tensor_32.torch")
        
        train_dataset = ShoeBagDataset(train_set_bags, train_set_shoes)
        test_dataset = ShoeBagDataset(test_set_bags, test_set_shoes)
        
        self.batch_size = config.BATCH_SIZE
        self.batch_size_val = config.BATCH_SIZE_VAL
        
        self.dataloader_train = torch.utils.data.DataLoader(
                                    train_dataset,
                                    batch_size=self.batch_size,
                                    shuffle=True,
                                    num_workers=4,
                                    collate_fn=collate_fn,
                                    drop_last=True
                                )

        self.dataloader_validation = torch.utils.data.DataLoader(
                                    test_dataset,
                                    batch_size=self.batch_size_val,
                                    shuffle=False,
                                    num_workers=4,
                                    collate_fn=collate_fn,
                                    drop_last=True
                                )
        
        self.iteration_max = config.ITERATION_MAX
        self.EMA = EMA(0.9999)
        self.LR = config.LR
        if config.LOSS == 'L1':
            self.loss = nn.L1Loss()
        if config.LOSS == 'L2':
            self.loss = nn.MSELoss()
        else :
            print('Loss not implemented, setting the loss to L2 (default one)')
        self.num_timesteps = config.TIMESTEPS
        self.validation_every = config.VALIDATION_EVERY
        self.ema_every = config.EMA_EVERY
        self.start_ema = config.START_EMA
        self.save_model_every = config.SAVE_MODEL_EVERY
        self.ema_model = copy.deepcopy(self.network).to(self.device)
        self.plot_every = config.PLOT_EVERY
        self.optimizer = optim.Adam(self.network.parameters(),lr=self.LR)
        self.iteration = 0
    def save_model(self,name,EMA=False):
        if not EMA:
            torch.save({"iteration": self.iteration,
                        "model": self.network.state_dict(),
                        "optimizer": self.optimizer.state_dict()
                       }, name)
        else:
            torch.save({"iteration": self.iteration,
                        "model": self.ema_model.state_dict(),
                        "optimizer": self.optimizer.state_dict()
                       }, name)

    def train(self):

            to_torch = partial(torch.tensor, dtype=torch.float32)
            
            print('Starting Training')
            
            wandb_step = 0

            while self.iteration < self.iteration_max:

                print(f"Start of interation no. {self.iteration+1}")

                tq = tqdm(self.dataloader_train)
                
                for step, (grey, color) in enumerate(tq):
                    tq.set_description(f'Iteration {self.iteration} / {self.iteration_max}')
                    self.network.train()
                    self.optimizer.zero_grad()

                    t = torch.randint(0, self.num_timesteps, (self.batch_size,)).long()
                    
                    noisy_image,noise_ref = self.diffusion.noisy_image(t,color)
                    noise_pred = self.diffusion.noise_prediction(self.network,noisy_image.to(self.device),grey.to(self.device),t.to(self.device))
                    loss = self.loss(noise_ref.to(self.device),noise_pred)
                    loss.backward()
                    self.optimizer.step()
                    tq.set_postfix(loss = loss.item())

                    wandb.log({f'Noise_loss_train' : loss.item()}, step=wandb_step)
                    
                    self.iteration+=1

                    if self.iteration%self.ema_every == 0 and self.iteration>self.start_ema:
                        print('EMA update')
                        self.EMA.update_model_average(self.ema_model,self.network)

                    if self.iteration%self.plot_every == 0:
                        fig, ax = plot_noise(noise_ref.cpu().detach(), noise_pred.cpu().detach())
                        wandb.log({'Noise': [wandb.Image(fig2img(fig))]}, step=wandb_step)


                    if self.iteration%self.save_model_every == 0:
                        print('Saving models')
                        if not os.path.exists('models/'):
                            os.makedirs('models')
                        self.save_model(f'models/model_best.pth')
                        self.save_model(f'models/model_ema_best.pth',EMA=True)


                    wandb_step += 1

                    if self.iteration%self.validation_every == 0:
                        tq_val = tqdm(self.dataloader_validation)
                        with torch.no_grad():
                            self.network.eval()
                            
                            y_mean_norms = []
                            for val_step, (grey, color) in enumerate(tq_val):
                                tq_val.set_description(f'Iteration {self.iteration} / {self.iteration_max}')
                                T = 1000
                                
                                betas = np.linspace(1e-6,0.01,T)
                                alphas = 1. - betas
                                gammas = to_torch(np.cumprod(alphas,axis=0))
                                alphas = to_torch(alphas)
                                betas = to_torch(betas)
                                
                                y = torch.randn_like(color)
                                y_norm = []
                                y_norm_ema = []
                                for t in reversed(range(T)):
                                    if t == 0 :
                                        z = torch.zeros_like(color)
                                    else:
                                        z = torch.randn_like(color)

                                    time = (torch.ones((self.batch_size_val,)) * t).long()
                                    y = extract(to_torch(np.sqrt(1/alphas)),time,y.shape)*(y-(extract(to_torch((1-alphas)/np.sqrt(1-gammas)),time,y.shape))*self.network(y.to(self.device),grey.to(self.device),time.to(self.device)).detach().cpu()) + extract(to_torch(np.sqrt(1-alphas)),time,z.shape)*z
                                    y_ema = extract(to_torch(np.sqrt(1/alphas)),time,y.shape)*(y-(extract(to_torch((1-alphas)/np.sqrt(1-gammas)),time,y.shape))*self.ema_model(y.to(self.device),grey.to(self.device),time.to(self.device)).detach().cpu()) + extract(to_torch(np.sqrt(1-alphas)),time,z.shape)*z
                                    

                                    y_mean_norms.append(torch.norm(y.reshape(y.shape[0], -1), dim=-1, p=2).mean().cpu().item())
                
                                data = [[x_coord, y_coord] for (x_coord, y_coord) in zip(range(T+1), y_mean_norms)]
                                table = wandb.Table(data=data, columns = ["x", "y"])
                                wandb.log({"my_custom_plot_id" : wandb.plot.line(table, "x", "y",
                                           title="Custom Y vs X Line Plot")})
                
                                
                                fig, ax = plot_y(grey.cpu().detach(), y.cpu().detach(), y_ema.cpu().detach(), color.cpu().detach())
                                wandb.log({'Ys': [wandb.Image(fig2img(fig))]}, step=wandb_step)
                                    

                                loss = self.loss(color,y)
                                loss_ema = self.loss(color,y_ema)
                                tq_val.set_postfix({'loss': loss.item(),'loss ema': loss_ema.item()})

                                wandb.log({f'y_loss_valid': loss.item()}, step=wandb_step) 
                                wandb.log({f'ema_loss_valid': loss_ema.item()}, step=wandb_step)
                                
                                #because validation takes too long, we simply use 1 batch
                                break

### Metric calculation

In [None]:
#get model weights
!kaggle kernels output kaggle3223/notebook04a4ac77c4-metrics -p ./

In [None]:
def metrics(config):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    diffusion = GaussianDiffusion(config.IMAGE_SIZE,config.CHANNEL_X,config.CHANNEL_Y,config.TIMESTEPS)
    in_channels = config.CHANNEL_X + config.CHANNEL_Y
    out_channels = config.CHANNEL_Y
    network = UNetModel(
        config.IMAGE_SIZE,
        in_channels,
        config.MODEL_CHANNELS,
        out_channels,
        config.NUM_RESBLOCKS,
        config.ATTENTION_RESOLUTIONS,
        config.DROPOUT,
        config.CHANNEL_MULT,
        config.CONV_RESAMPLE,
        config.USE_CHECKPOINT,
        config.USE_FP16,
        config.NUM_HEADS,
        config.NUM_HEAD_CHANNELS,
        config.NUM_HEAD_UPSAMPLE,
        config.USE_SCALE_SHIFT_NORM,
        config.RESBLOCK_UPDOWN,
        config.USE_NEW_ATTENTION_ORDER,
        ).to(device)


    to_torch = partial(torch.tensor, dtype=torch.float32)
    batch_size_val = 64

    _, test_set_bags = load_shoes_bags(data_path="./shoes_tensor_32.torch")
    _, test_set_shoes = load_shoes_bags(data_path="./handbag_tensor_32.torch")

    test_dataset = ShoeBagDataset(test_set_bags, test_set_shoes)

    dataloader_validation = torch.utils.data.DataLoader(
                                test_dataset,
                                batch_size=batch_size_val,
                                shuffle=False,
                                num_workers=4,
                                collate_fn=collate_fn,
                                drop_last=True
                            )

    LR = config.LR

    num_timesteps = config.TIMESTEPS

    ema_model = copy.deepcopy(network).to(device)


    network.load_state_dict(torch.load("models/model_best.pth")["model"])
    ema_model.load_state_dict(torch.load("models/model_ema_best.pth")["model"])

    tq_val = tqdm(dataloader_validation)

    with torch.no_grad():
        network.eval()
        
        psnrs = []
        psnrs_ema = []
        ssims = []
        ssims_ema = []
        
        i = 0
        for val_step, (grey, color) in enumerate(tq_val):

            T = 1000

            betas = np.linspace(1e-6,0.01,T)
            alphas = 1. - betas
            gammas = to_torch(np.cumprod(alphas,axis=0))
            alphas = to_torch(alphas)
            betas = to_torch(betas)

            y = torch.randn_like(color)
            y_norm = []
            y_norm_ema = []
            for t in reversed(range(T)):
                if t == 0 :
                    z = torch.zeros_like(color)
                else:
                    z = torch.randn_like(color)


                time = (torch.ones((batch_size_val,)) * t).long()
                y = extract(to_torch(np.sqrt(1/alphas)),time,y.shape)*(y-(extract(to_torch((1-alphas)/np.sqrt(1-gammas)),time,y.shape))*network(y.to(device),grey.to(device),time.to(device)).detach().cpu()) + extract(to_torch(np.sqrt(1-alphas)),time,z.shape)*z
                y_ema = extract(to_torch(np.sqrt(1/alphas)),time,y.shape)*(y-(extract(to_torch((1-alphas)/np.sqrt(1-gammas)),time,y.shape))*ema_model(y.to(device),grey.to(device),time.to(device)).detach().cpu()) + extract(to_torch(np.sqrt(1-alphas)),time,z.shape)*z
                
                
            psnr = PSNR(data_range=1.0)
            psnr.attach(default_evaluator, 'psnr')
            state = default_evaluator.run([[y.mul(0.5).add(0.5), color.mul(0.5).add(0.5)]])
            psnrs.append(state.metrics['psnr'])
            print("PSNR: ", psnrs[-1])
            
            state = default_evaluator.run([[y_ema.mul(0.5).add(0.5), color.mul(0.5).add(0.5)]])
            psnrs_ema.append(state.metrics['psnr'])
            print("PSNR ema: ", psnrs_ema[-1])
            
            metric = SSIM(data_range=1.0)
            metric.attach(default_evaluator, 'ssim')
            state = default_evaluator.run([[y.mul(0.5).add(0.5), color.mul(0.5).add(0.5)]])
            ssims.append(state.metrics['ssim'])
            print("SSIM: ", ssims[-1])
            
            state = default_evaluator.run([[y_ema.mul(0.5).add(0.5), color.mul(0.5).add(0.5)]])
            ssims_ema.append(state.metrics['ssim'])
            print("SSIM ema: ", ssims_ema[-1])
            
            
            i += 1
            if i == 15:
                break
                
    print("Total PSNR: ", np.mean(psnrs))
    print("Total PSNR ema: ", np.mean(psnrs_ema))
    print("Total SSIM: ", np.mean(ssims))
    print("Total SSIM ema: ", np.mean(ssims_ema))

### Training/Metrics

In [None]:
import argparse

def train(config):
    wandb.init(name=config.EXP_NAME, project='Diffusion_Colorization', config=config)
    trainer = Trainer(config)
    trainer.train()
    print('training complete')


config = load_config("./conf.yml")
print('Config loaded')
mode = config.MODE
if mode == 1:
    train(config)
else:
    print("performing metric calculations")
    metrics(config)