In [None]:
import torch
import torch.nn as nn
import torchvision

import math
import random
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from scipy.ndimage.interpolation import rotate
import numpy as np

from IPython.display import HTML
from IPython.display import clear_output
import torch.nn as nn
import torch.nn.functional as F
import torch
from functools import partial
from torch.utils.data import DataLoader, TensorDataset

In [None]:
def set_random_seed(seed):
    # Set the random seed for CPU
    torch.manual_seed(seed)
    # Set the random seed for all GPUs
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    # Set the random seed for numpy
    np.random.seed(seed)
    # Set the random seed for Python's built-in random module
    random.seed(seed)
    # Ensure reproducibility by disabling the benchmarking feature in PyTorch
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Example of setting the seed
# set_random_seed(42)

In [None]:
# working
# Basic Convolutional Layer
def conv3x3(in_planes, out_planes, stride=1, dilation=1, bias=False):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=dilation, dilation=dilation, bias=bias)

# Conditional Batch Normalization
class ConditionalBatchNorm2d(nn.Module):
    def __init__(self, num_features, num_classes, bias=True):
        super().__init__()
        self.num_features = num_features
        self.bn = nn.BatchNorm2d(num_features, affine=False)
        self.embed = nn.Embedding(num_classes, num_features * 2 if bias else num_features)
        self.embed.weight.data[:, :num_features].uniform_()
        if bias:
            self.embed.weight.data[:, num_features:].zero_()

    def forward(self, x, y):
        out = self.bn(x)
        gamma, beta = self.embed(y).chunk(2, dim=1) if self.embed.weight.size(1) == self.num_features * 2 else (self.embed(y), None)
        out = gamma.view(-1, self.num_features, 1, 1) * out + (beta.view(-1, self.num_features, 1, 1) if beta is not None else 0)
        return out

# Simple Residual Block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, num_classes, norm_layer, dilation=1):
        super().__init__()
        self.conv1 = conv3x3(in_channels, out_channels, dilation=dilation)
        self.norm1 = norm_layer(out_channels, num_classes)
        self.conv2 = conv3x3(out_channels, out_channels, dilation=dilation)
        self.norm2 = norm_layer(out_channels, num_classes)
        self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False) if in_channels != out_channels else nn.Identity()
        self.act = nn.ELU()

    def forward(self, x, y):
        out = self.act(self.norm1(self.conv1(x), y))
        out = self.norm2(self.conv2(out), y)
        out += self.shortcut(x)
        return self.act(out)

# Simplified Refinement Network
class CondRefineNetDilated(nn.Module):
    def __init__(self, device, num_classes):
        super().__init__()
        self.device = device
        self.norm_layer = ConditionalBatchNorm2d
        self.num_classes = num_classes
        self.ngf = 64
        
        self.begin_conv = nn.Conv2d(4, self.ngf, kernel_size=3, padding=1)
        self.end_conv = nn.Conv2d(self.ngf*2, 4, kernel_size=3, stride=1, padding=1)
        self.act = nn.ELU()

        self.res1 = ResidualBlock(self.ngf, self.ngf, num_classes, self.norm_layer)
        self.res2 = ResidualBlock(self.ngf, 2 * self.ngf, num_classes, self.norm_layer, dilation=2)
        self.res3 = ResidualBlock(2 * self.ngf, 2 * self.ngf, num_classes, self.norm_layer, dilation=4)
        
        self.to(device)

    def forward(self, x, y):
        out = self.act(self.begin_conv(x))
        out = self.res1(out, y)
        out = self.res2(out, y)
        out = self.res3(out, y)
        out = self.act(out)
        out = self.end_conv(out)
        return out

In [None]:
class Model(nn.Module):
    def __init__(self, device, n_steps, sigma_min, sigma_max):
        '''
        Score Network.

        n_steps   : perturbation schedule steps (Langevin Dynamic step)
        sigma_min : sigma min of perturbation schedule
        sigma_min : sigma max of perturbation schedule

        '''
        super().__init__()
        self.device = device
        self.sigmas = torch.exp(torch.linspace(start=math.log(sigma_max), end=math.log(sigma_min), steps = n_steps)).to(device = device)
        self.conv_layer = CondRefineNetDilated(device, n_steps) # Here is the problem.
        self.to(device = device)

    # Loss Function
    def loss_fn(self, x, idx=None):
        '''
        This function performed when only training phase.

        x          : real data if idx==None else perturbation data
        idx        : if None (training phase), we perturbed random index. Else (inference phase), it is recommended that you specify.

        '''
        scores, target, sigma = self.forward(x, idx=idx, get_target=True)
        target = target.view(target.shape[0], -1)
        scores = scores.view(scores.shape[0], -1)        
        losses = torch.square(scores - target).mean(dim=-1) * sigma.squeeze() ** 2
        return losses.mean(dim=0)

    # S(theta, sigma)
    def forward(self, x, idx=None, get_target=False):
        '''
        x          : real data if idx==None else perturbation data
        idx        : if None (training phase), we perturbed random index. Else (inference phase), it is recommended that you specify.
        get_target : if True (training phase), target and sigma is returned with output (score prediction)

        '''

        if idx == None:
            idx = torch.randint(0, len(self.sigmas), (x.size(0), 1)).to(device = self.device)
            used_sigmas = self.sigmas[idx][:, :, None, None]
            noise = torch.randn_like(x)
            x_tilde = x + noise * used_sigmas
            idx = idx.squeeze()
        else:
            idx = torch.Tensor([idx for _ in range(x.size(0))]).to(device = self.device).long()
            x_tilde = x
            
        if get_target:
            target = - 1 / (used_sigmas ) * noise 

            
        output = self.conv_layer(x_tilde, idx)

        return (output, target, used_sigmas) if get_target else output

In [None]:
class AnnealedLangevinDynamic():
    def __init__(self, sigma_min, sigma_max, n_steps, annealed_step, score_fn, device, eps = 1e-1):
        '''
        sigma_min : minimum sigmas of perturbation schedule 
        sigma_max : maximum sigmas of perturbation schedule 
        L         : iteration step of Langevin dynamic
        T         : annelaed step of annealed Langevin dynamic
        score_fn  : trained score network
        eps       : coefficient of step size
        '''
        self.process = torch.exp(torch.linspace(start=math.log(sigma_max), end=math.log(sigma_min), steps = n_steps))
        self.step_size = eps * (self.process / self.process[-1] ) ** 2
        self.score_fn = score_fn
        self.annealed_step = annealed_step
        self.device = device
        
    # One iteration of annealed step
    def _one_annealed_step_iteration(self, x, idx):
        '''
        x   : perturbated data
        idx : step of perturbation schedule
        '''
        self.score_fn.eval()
        z, step_size = torch.randn_like(x).to(device = self.device), self.step_size[idx]
        x = x + 0.5 * step_size * self.score_fn(x, idx) + torch.sqrt(step_size) * z
        return x
        
    # One annealed step
    def _one_annealed_step(self, x, idx):
        '''
        x   : perturbated data
        idx : step of perturbation schedule
        '''
        for _ in range(self.annealed_step):
            x = self._one_annealed_step_iteration(x, idx)
        return x
        
    # One Langevin Step
    def _one_diffusion_step(self, x):
        '''
        x   : sampling of prior distribution
        '''
        for idx in range(len(self.process)):
            x = self._one_annealed_step(x, idx)
            yield x

    @torch.no_grad()
    def sampling(self, sampling_number, only_final=False):
        '''
        only_final : If True, return is an only output of final schedule step 
        '''
        # sample = torch.rand([sampling_number, 1, 14, 14]).to(device = self.device)
        sample = torch.rand([sampling_number, 4, 64, 64]).to(device = self.device)
        sampling_list = []
        
        final = None
        for sample in self._one_diffusion_step(sample):
            final = sample
            if not only_final:
                sampling_list.append(final)
                

        return final if only_final else torch.stack(sampling_list)

In [None]:
class AverageMeter(object):
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
    
    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        
        print('\r' + '\t'.join(entries), end = '')

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

In [None]:
# epsilon of step size
eps = 1.5e-5

# sigma min and max of Langevin dynamic
sigma_min = 0.005
sigma_max = 10

# Langevin step size and Annealed size
n_steps = 20
annealed_step = 50

device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
model = Model(device, n_steps, sigma_min, sigma_max)
optim = torch.optim.Adam(model.parameters(), lr = 0.005)
dynamic = AnnealedLangevinDynamic(sigma_min, sigma_max, n_steps, annealed_step, model, device, eps=eps)

In [None]:
data = np.load("data.npy")

In [None]:
# Normalize the data with min-max scaling

data_min = 0 # data.min()
data_max = data.max()

data = (data - data_min) / (data_max - data_min)

In [None]:
# Split the data into training and validation sets
train_size = int(0.8 * len(data))
val_size = len(data) - train_size

train_data, val_data = torch.utils.data.random_split(data, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)

dataiterator = iter(train_loader)

In [None]:
total_iteration = 20000
current_iteration = 0
display_iteration = 1000
sampling_number = 16
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
only_final = True

losses = AverageMeter('Loss', ':.4f')
progress = ProgressMeter(total_iteration, [losses], prefix='Iteration ')

In [None]:
while current_iteration != total_iteration:
    model.train()
    try:
        data = next(dataiterator)
    except:
        dataiterator = iter(train_loader)
        data = next(dataiterator)
    # data = data[0].to(device = device)
    data = data.to(device = device)
    loss = model.loss_fn(data)
    
    optim.zero_grad()
    loss.backward()
    optim.step()
    
    losses.update(loss.item())
    progress.display(current_iteration)
    current_iteration += 1
    
    if current_iteration % display_iteration == 0:
        sampling_number = 9
        only_final = True
        dynamic = AnnealedLangevinDynamic(sigma_min, sigma_max, n_steps, annealed_step, model, device, eps=eps)
        sample = dynamic.sampling(sampling_number, only_final)
        
        fig, ax = plt.subplots(3, 3, figsize=(8, 8))
        for i in range(3):
            for j in range(3):
                ax[i, j].imshow(sample.cpu().numpy()[i * 3 + j][0,:,:], origin="lower")
                ax[i, j].axis("off")
        plt.show()

In [None]:
eps_list = [7.5e-4, 5e-4, 2.5e-4, 1e-4, 7.5e-5]
sigma_min_list = [0.05]
sigma_max_list = [1]

# Define the testing function
for eps in eps_list:
    for sigma_min in sigma_min_list:
        for sigma_max in sigma_max_list:
            print(f"Testing with eps={eps}, sigma_min={sigma_min}, sigma_max={sigma_max}")

            # Langevin step size and Annealed size
            n_steps = 20
            annealed_step = 50

            device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

            model = Model(device, n_steps, sigma_min, sigma_max)
            optim = torch.optim.Adam(model.parameters(), lr = 0.005)
            dynamic = AnnealedLangevinDynamic(sigma_min, sigma_max, n_steps, annealed_step, model, device, eps=eps)

            total_iteration = 5000
            current_iteration = 0
            display_iteration = 500
            sampling_number = 16
            device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
            only_final = True

            losses = AverageMeter('Loss', ':.4f')
            progress = ProgressMeter(total_iteration, [losses], prefix='Iteration ')

            while current_iteration != total_iteration:
                model.train()
                try:
                    data = next(dataiterator)
                except:
                    dataiterator = iter(train_loader)
                    data = next(dataiterator)
                # data = data[0].to(device = device)
                data = data.to(device = device)
                loss = model.loss_fn(data)

                optim.zero_grad()
                loss.backward()
                optim.step()

                losses.update(loss.item())
                progress.display(current_iteration)
                current_iteration += 1

                if current_iteration % display_iteration == 0:
                    sampling_number = 9
                    only_final = True
                    dynamic = AnnealedLangevinDynamic(sigma_min, sigma_max, n_steps, annealed_step, model, device, eps=eps)
                    sample = dynamic.sampling(sampling_number, only_final)

                    fig, ax = plt.subplots(3, 3, figsize=(8, 8))
                    for i in range(3):
                        for j in range(3):
                            ax[i, j].imshow(sample.cpu().numpy()[i * 3 + j][0,:,:], origin="lower")
                            ax[i, j].axis("off")
                    plt.show()

In [None]:
eps_list = [1.5e-4, 5e-5, 1.5e-5, 1.5e-6]
sigma_min_list = [0.05, 0.01, 0.005, 0.0005]
sigma_max_list = [50, 25, 10, 1]

# Define the testing function
for eps in eps_list:
    for sigma_min in sigma_min_list:
        for sigma_max in sigma_max_list:
            print(f"Testing with eps={eps}, sigma_min={sigma_min}, sigma_max={sigma_max}")

            # Langevin step size and Annealed size
            n_steps = 20
            annealed_step = 50

            device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

            model = Model(device, n_steps, sigma_min, sigma_max)
            optim = torch.optim.Adam(model.parameters(), lr = 0.005)
            dynamic = AnnealedLangevinDynamic(sigma_min, sigma_max, n_steps, annealed_step, model, device, eps=eps)

            total_iteration = 20000
            current_iteration = 0
            display_iteration = 1000
            sampling_number = 16
            device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
            only_final = True

            losses = AverageMeter('Loss', ':.4f')
            progress = ProgressMeter(total_iteration, [losses], prefix='Iteration ')

            while current_iteration != total_iteration:
                model.train()
                try:
                    data = next(dataiterator)
                except:
                    dataiterator = iter(train_loader)
                    data = next(dataiterator)
                # data = data[0].to(device = device)
                data = data.to(device = device)
                loss = model.loss_fn(data)

                optim.zero_grad()
                loss.backward()
                optim.step()

                losses.update(loss.item())
                progress.display(current_iteration)
                current_iteration += 1

                if current_iteration % display_iteration == 0:
                    sampling_number = 9
                    only_final = True
                    dynamic = AnnealedLangevinDynamic(sigma_min, sigma_max, n_steps, annealed_step, model, device, eps=eps)
                    sample = dynamic.sampling(sampling_number, only_final)

                    fig, ax = plt.subplots(3, 3, figsize=(8, 8))
                    for i in range(3):
                        for j in range(3):
                            ax[i, j].imshow(sample.cpu().numpy()[i * 3 + j][0,:,:], origin="lower")
                            ax[i, j].axis("off")
                    plt.show()

In [None]:
sampling_number = 4
only_final = True
n_steps = 25
annealed_step = 200
dynamic = AnnealedLangevinDynamic(sigma_min, sigma_max, n_steps, annealed_step, model, device, eps=eps)
sample = dynamic.sampling(sampling_number, only_final)

fig, ax = plt.subplots(2, 2, figsize=(8, 8))
for i in range(2):
    for j in range(2):
        ax[i, j].imshow(sample.cpu().numpy()[i * 2 + j][0,:,:], origin="lower")
        ax[i, j].axis("off")
plt.show()