In [1]:
import os
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'

In [2]:
#!pip install -q -U einops datasets matplotlib tqdm

import math
from inspect import isfunction
from functools import partial

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange

import torch
from torch import nn, einsum
import torch.nn.functional as F
from pathlib import Path
from torch.optim import AdamW
from PIL import Image
import requests
from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize
import numpy as np
from torch.utils.data import DataLoader
import time
import pandas as pd

import dnnlib
#import tor# Removed num_to_groups
from IPython.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))
from torchvision.utils import save_image

from utils.ema_pytorch import EMA
from torch.cuda.amp import autocast, GradScaler
from cleanfid import fid

#from utils.losses_samples import *
from utils.blocks import *
from utils.elucidating import *

from collections import namedtuple
from einops import reduce


import cv2
import gc
gc.collect()
torch.cuda.empty_cache()

%matplotlib inline

"random seed initialization to guarantee reproducability"
random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)

In [3]:
device='cuda'

### Model

In [4]:
class Unet(nn.Module):
    "Unet Module based on the implementation of https://github.com/lucidrains/denoising-diffusion-pytorch"
    "Extendedy by Skip-SE layers, additional conditional embedding for augmentation based on https://arxiv.org/abs/2206.00364"
    "and patch-based diffusion https://arxiv.org/abs/2207.04316. However, we used 1 patch, as higher number of patches increased training time."
    def __init__(
        self,
        dim,
        init_dim=None,
        out_dim=None,
        dim_mults=(1, 2, 4, 8),
        channels=3,
        with_time_emb=True,
        resnet_block_groups=8,
        use_convnext=True,
        convnext_mult=3,
        num_patches=1
    ):
        super().__init__()

        # determine dimensions, also for patch-based diffusion
        self.channels = channels
        self.patch_size = num_patches
        
        self.init_dim = channels * num_patches**2
        init_dim2 = default(init_dim, dim // 3 * 2)
        self.init_conv = nn.Conv2d(self.init_dim, init_dim2, 7, padding=3)

        dims = [init_dim2, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))
        print("init_dim:\t", init_dim)
        print("dims:\t\t", dims)
        print("in_out:\t\t", in_out)
        
        if use_convnext:
            block_klass = partial(ConvNextBlock, mult=convnext_mult)
        else:
            block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        # time and augmentation embeddings embeddings
        if with_time_emb:
            time_dim = dim * 4
            self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(dim),
                nn.Linear(dim, time_dim),
                nn.GELU(),
                nn.Linear(time_dim, time_dim),
            )
            augm_dim = dim * 4
            self.aug_mlp = nn.Sequential(
                nn.Linear(12, dim),
                nn.Linear(dim, time_dim),
                nn.GELU(),
                nn.Linear(time_dim, time_dim),
            )
        else:
            time_dim = None
            self.time_mlp = None

        # layers with additional SE-layers starting from the second block
        # SE-layers are used upwards and downwards
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)
            
            if ind < 1:
                self.downs.append(
                    nn.ModuleList(
                        [
                            block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                            block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                            Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                            Downsample(dim_out) if not is_last else nn.Identity(),
                            None
                        ]
                    )
                )
            else:
                self.downs.append(
                    nn.ModuleList(
                        [
                            block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                            block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                            Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                            Downsample(dim_out) if not is_last else nn.Identity(),
                            SEBlock(in_out[ind-1][1], dim_out)
                        ]
                    )
                )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)
            
            if ind < 1:
                self.ups.append(
                    nn.ModuleList(
                        [
                            block_klass(dim_out * 2, dim_out, time_emb_dim=time_dim),
                            block_klass(dim_out, dim_in, time_emb_dim=time_dim),
                            Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                            Upsample(dim_in) if not is_last else nn.Identity(),
                            None
                        ]
                    )
                )
            else:
                self.ups.append(
                    nn.ModuleList(
                        [
                            block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                            block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                            Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                            Upsample(dim_in) if not is_last else nn.Identity(),
                            SEBlock(in_out[len(in_out)-ind][1], dim_in)
                        ]
                    )
                )

        out_dim = default(out_dim, channels)
        self.final_conv = nn.Sequential(
            block_klass(dim, dim), nn.Conv2d(dim, self.init_dim, 1)
        )
        
    def convert_image_to_patches(self, x):
        "patch-based dfiffusion function at the beginning of the forward function"
        p = self.patch_size
        B, C, H, W = x.shape
        x = x.permute(0, 2, 3, 1) #BHWC format, bc reshape is done on last 2 axes
        x = x.reshape(B, H, W//p, C*p) #reshape from width axis to channel axis
        x = x.permute(0, 2, 1, 3) #now height & channel should be last 2 axes
        x = x.reshape(B, W//p, H//p, C*p*p) #reshape from height axis to channel axis
        return x.permute(0, 3, 2, 1) #convert to channels-first format
    
    def convert_patches_to_image(self, x):
        "patch-based difffusion fuknction at the end of the forward function"
        p = self.patch_size
        B, C, H, W = x.shape
        x = x.permute(0,3,2,1) #BWHC; from_patches starts w/ height axis, not width
        x.reshape(B, W, H*p, C//p) #reshape from channel axis to height axis
        x = x.permute(0,2,1,3) #now width & channel should be last 2 axes
        x = x.reshape(B, H*p, W*p, C//(p*p)) #reshape from channel axis to width axis
        return x.permute(0, 3, 1, 2) #convert to channels-first format


    def forward(self, x, time, augm):
        x = self.convert_image_to_patches(x)
        x = self.init_conv(x)
        t = self.time_mlp(time) if exists(self.time_mlp) else None
        aug = self.aug_mlp(augm)
        h = []
        se_up = []
        se_down = []

        # downsample
        for block1, block2, attn, downsample, se_layer in self.downs:
            x = block1(x, t, aug)
            x = block2(x, t, aug)
            x = attn(x)
            h.append(x)
            se_down.append(x) 
            x = downsample(x)
            if se_layer is not None:
                x = se_layer(se_down.pop(0), x)
            

        # bottleneck
        x = self.mid_block1(x, t, aug)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t, aug)
        se_up.append(x)
        
        #upsample
        for block1, block2, attn, upsample, se_layer in self.ups:
            
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t, aug)
            x = block2(x, t, aug)
            x = attn(x)
            se_up.append(x)
            x = upsample(x)
            if se_layer is not None:
                x = se_layer(se_up.pop(0), x)

        x = self.final_conv(x)
        x = self.convert_patches_to_image(x)
        return x


### Sample and loss functions

In [5]:
"The following code snippet describes the sampling functions of DDPM, implementation based on https://github.com/lucidrains/denoising-diffusion-pytorch"

def default(val, d):
    if exists(val):
        return val
    return d() if callable(d) else d

# Forward diffusion process, based on the property of the alphas (We do not need to calculate each forward step as for backward diffusion (inference),
# but we can utilized the mathematical property that directly allows ous to apply noise of any timestep (noise schedule step) on the image
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t.long(), x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x_start.shape
    )

    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise


def get_noisy_image(x_start, t):
    # add noise
    x_noisy = q_sample(x_start, t=t)

    # turn back into PIL image
    noisy_image = reverse_transform(x_noisy.squeeze())
    return noisy_image



#SAMPLING_IMGAES
#----------------------------------

@torch.no_grad()
def p_sample(model, x, t, t_index, batch_size):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    # Equation 11 in the paper
    # Use our model (noise predictor) to predict the mean
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t, torch.zeros(batch_size, 12).to(device)) / sqrt_one_minus_alphas_cumprod_t
    )

    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        # Algorithm 2 line 4:
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

##New
def p_mean_variance(x, t, x_self_cond = None, clip_denoised = True):
        preds = model_predictions(x, t, x_self_cond)
        x_start = preds.pred_x_start

        if clip_denoised:
            x_start.clamp_(-1., 1.)

        model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
        return model_mean, posterior_variance, posterior_log_variance, x_start

@torch.no_grad()
def p_sample_new(model, x, t: int, x_self_cond = None, augm=None):
    b, *_, device = *x.shape, x.device
    batched_times = torch.full((b,), t, device = x.device, dtype = torch.long)
    model_mean, _, model_log_variance, x_start = p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True)
    noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
    pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
    return pred_img, x_start

# Algorithm 2 (including returning all images)
@torch.no_grad()
def p_sample_loop(model, shape, batch_size):
    device = next(model.parameters()).device

    b = shape[0]
    # start from pure noise (for each example in the batch)
    img = torch.randn(shape, device=device)
    imgs = []

    for i in reversed(range(0, timesteps)):
        img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i, batch_size)
        #imgs.append(img.cpu().numpy())
        imgs.append(img)
    return imgs

@torch.no_grad()
def p_sample_loop_new(model, shape, return_all_timesteps = False, device='cuda', augm=None, self_condition=False, num_timesteps=1000):
    batch, device = shape[0], device

    img = torch.randn(shape, device = device)
    imgs = [img]

    x_start = None

    for t in tqdm(reversed(range(0, num_timesteps)), desc = 'sampling loop time step', total = num_timesteps):
        self_cond = x_start if self_condition else None
        img, x_start = p_sample_new(model, img, t, augm=augm)
        imgs.append(img)

    ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)

    #ret = self.unnormalize(ret)
    return ret

@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size), batch_size=batch_size)

def p_losses_old(denoise_model, x_start, t, noise=None, loss_type="l1", augm=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    predicted_noise = denoise_model(x_noisy, t, augm)

    if loss_type == 'l1':
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == 'l2':
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == "huber":
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        raise NotImplementedError()

    return loss

def p_losses(model, x_start, t, noise = None, augm=None):
        b, c, h, w = x_start.shape

        noise = default(noise, lambda: torch.randn_like(x_start))

        # offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise


        if offset_noise_strength > 0.:
            offset_noise = torch.randn(x_start.shape[:2], device = device)
            noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')

        # noise sample

        x = q_sample(x_start = x_start, t = t, noise = noise)

        # if doing self-conditioning, 50% of the time, predict x_start from current set of times
        # and condition with unet with that
        # this technique will slow down training by 25%, but seems to lower FID significantly

        x_self_cond = None
        if self_condition and random() < 0.5:
            with torch.no_grad():
                x_self_cond = self.model_predictions(x, t).pred_x_start
                x_self_cond.detach_()

        # predict and take gradient step

        model_out = model(x.half(), t.half(), augm.half())

        if objective == 'pred_noise':
            target = noise
        elif objective == 'pred_x0':
            target = x_start
        elif objective == 'pred_v':
            v = predict_v(x_start, t, noise)
            target = v
        else:
            raise ValueError(f'unknown objective {objective}')

        loss = F.mse_loss(model_out, target, reduction = 'none')
        loss = reduce(loss, 'b ... -> b (...)', 'mean')

        loss = loss * extract(loss_weight, t, loss.shape)
        return loss.mean()

In [6]:


ModelPrediction =  namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])


def predict_start_from_noise(x_t, t, noise):
    return (
        extract(sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
        extract(sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
    )

def predict_noise_from_start(x_t, t, x0):
    return (
        (extract(sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
        extract(sqrt_recipm1_alphas_cumprod, t, x_t.shape)
    )

def predict_v(self, x_start, t, noise):
    return (
        extract(sqrt_alphas_cumprod, t, x_start.shape) * noise -
        extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
    )

def predict_start_from_v(self, x_t, t, v):
    return (
        extract(sqrt_alphas_cumprod, t, x_t.shape) * x_t -
        extract(sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
    )


def model_predictions(model, x, t, x_self_cond=None, clip_x_start=False, rederive_pred_noise=False, objective='pred_noise'):
    model_output = model(x.float(), t.float(), torch.zeros(x.shape[0], 12).to(device).float())
    maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity

    if objective == 'pred_noise':
        pred_noise = model_output
        x_start = predict_start_from_noise(x, t, pred_noise)
        x_start = maybe_clip(x_start)


        if clip_x_start and rederive_pred_noise:
            pred_noise = predict_noise_from_start(x, t, x_start)

    elif objective == 'pred_x0':
        x_start = model_output
        x_start = maybe_clip(x_start)
        pred_noise = predict_noise_from_start(x, t, x_start)

    elif objective == 'pred_v':
        v = model_output
        x_start = predict_start_from_v(x, t, v)
        x_start = maybe_clip(x_start)
        pred_noise = predict_noise_from_start(x, t, x_start)
        
    return ModelPrediction(pred_noise, x_start)

        

def ddim_sample(model, shape, device, total_timesteps, sampling_timesteps, ddim_sampling_eta, return_all_timesteps = False, objective='pred_noise'):
        "Functio for faster sampling based on diffusion denoising implicit models (DDIM), https://arxiv.org/abs/2010.02502"
        model.eval()
        
        batch = shape[0]

        times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1)   # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
        times = list(reversed(times.int().tolist()))
        time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]

        img = torch.randn(shape, device = device)
        imgs = [img]

        x_start = None

        for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
            time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
            #self_cond = x_start if self.self_condition else None
            pred_noise, x_start, *_ = model_predictions(model, img, time_cond, x_self_cond=None, clip_x_start = True, rederive_pred_noise = True, objective=objective)

            if time_next < 0:
                img = x_start
                imgs.append(img)
                continue

            alpha = alphas_cumprod[time]
            alpha_next = alphas_cumprod[time_next]

            sigma = ddim_sampling_eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
            c = (1 - alpha_next - sigma ** 2).sqrt()

            noise = torch.randn_like(img)

            img = x_start * alpha_next.sqrt() + \
                  c * pred_noise + \
                  sigma * noise

            imgs.append(img)

        ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)

        #ret = unnormalize(ret)
        return ret

### Training

In [7]:


def run(dataset="pokemon", batch_size=12, image_size=64, ts=1000, noise_schedule="sigmoid", dim=32, channels=3, dim_mults=(1,2,4,8), aug_factor=0.12, device="cuda", sample_iteration=50000, save_iteration=50000, num_samples=20,
       optim="Adam", lr=1e-4, loss_type="huber", momentum=0.95, max_iter=15000, checkpoint=None, num_patches=1):
    '''This function specifies the whole run, including the batch size, the dataset, the size of the model, the augmentation factor and so on. 
    It defines the alphas for the sampling functions based on the chosen number of timesteps.
    Loads the data and creates a result folder, if it does not exist already.'''
    
    
    #Load the data and generate the directory path
    dataloader = get_data(r"./train_data/"+dataset+"/", batch_size, image_size)
    #directory = os.fsencode(r"../First_Try/"+dataset+"/img/")
    
    #Each training run gets a unique identifier based on the chosen properties of the run
    identifier = dataset+str(image_size)+"_"+noise_schedule+"_"+str(batch_size)+"_"+str(dim)+"_"+str(dim_mults)+"_"+str(ts)+"_"+str(aug_factor)+"_"+str(loss_type)+str(lr)+"_"+str(num_patches)
    
    #
    results_top_folder = Path("./model_weights/"+dataset+"/")
    results_top_folder.mkdir(exist_ok = True)
    
    #Create folder for this run
    results_folder = Path("./model_weights/"+dataset+"/"+identifier+"/")
    results_folder.mkdir(exist_ok = True)
    
    
    #create the model
    model = Unet(
        dim=dim,
        channels=channels,
        dim_mults=dim_mults,
        num_patches=num_patches
    )
    model.to(device)
    
    num_params = sum([p.numel() for p in model.parameters()])   
    
    
    #Defining the learning rate schedule
    if optim == "Adam":
        optimizer = AdamW(model.parameters(), lr=lr, betas=(0.9,0.99))
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)

    #Defining the augmentation pipeline using the chosen augmentation factor
    eluAug = AugmentPipe(aug_factor, xflip=1, yflip=1, rotate_int=1, translate_int=1, scale=1, rotate_frac=1, aniso=1, saturation=1)
    
    global timesteps
    global betas
    global alphas
    global alphas_cumprod
    global alphas_cumprod_prev
    global sqrt_recip_alphas
    global sqrt_recip_alphas_cumprod
    global sqrt_recipm1_alphas_cumprod
    global sqrt_alphas_cumprod
    global sqrt_one_minus_alphas_cumprod
    global posterior_variance
    
    #new
    global snr
    global maybe_clipped_snr
    global loss_weight
    global offset_noise_strength
    global self_condition
    global objective

    #Chosing the noise schedule variant
    #Even though it is reportat that cosine outperforms the linear schedule, we could not observe performance gains
    #We sticked to the linear schedule
    if noise_schedule == "sigmoid":
        betas = sigmoid_beta_schedule(timesteps=ts)
    elif noise_schedule == "linear":
        betas = linear_beta_schedule(timesteps=ts)
    elif noise_schedule == "cosine":
        betas = cosine_beta_schedule(timesteps=ts)
    elif noise_schedule == "squared":
        betas = sigmoid_beta_schedule(timesteps=ts)
    else:
        print("No valid noise schedule picked")
        
    timesteps = ts
    
    #calc_schedules(betas, ts, device)
    # define alphas 
    alphas = 1. - betas
    alphas_cumprod = torch.cumprod(alphas, axis=0)
    alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
    sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
    sqrt_recip_alphas_cumprod = torch.sqrt(1.0/alphas_cumprod)

    # calculations for diffusion q(x_t | x_{t-1}) and others
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
    sqrt_recipm1_alphas_cumprod =  torch.sqrt(1. / alphas_cumprod - 1)

    # calculations for posterior q(x_{t-1} | x_t, x_0)
    posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
    
    #snr_weighting
    #https://arxiv.org/abs/2303.09556
    snr = alphas_cumprod / (1 - alphas_cumprod)
    maybe_clipped_snr = snr.clone()
    min_snr_gamma = 5
    maybe_clipped_snr.clamp_(max = min_snr_gamma)
    
    #offset noise
    offset_noise_strength = 0.1
    
    #self-conditioning
    self_condition = False
    
    #Objective
    #We find that any other than predicting the noise leads to much worse convergence times
    objective = 'pred_noise'
    
    if objective == 'pred_noise':
        loss_weight = maybe_clipped_snr / snr
    elif objective == 'pred_x0':
        loss_weight = maybe_clipped_snr
    elif objective == 'pred_v':
        loss_weight = maybe_clipped_snr / (snr+1)


    
    print("Run_Properties: ", identifier)
    print("Number of parameters: ",num_params)
    
    #Generate the checkpoint, if chosen to continue training the model
    if checkpoint is not None:
        cp = str(results_folder)+"/"+checkpoint
    else:
        cp = None
    
    #Call the train function
    train(dataset, max_iter=max_iter, batch_size=batch_size, image_size=image_size, model=model, dataloader=dataloader, optimizer=optimizer, device=device, augPlan=eluAug, folders=[results_folder], save_iteration=save_iteration,
          sample_iteration=sample_iteration, num_samples=num_samples, loss_type=loss_type, checkpoint=cp)

In [12]:
def train(dataset, max_iter=150000, batch_size=12, image_size=64, model=None, dataloader=None, checkpoint=None, optimizer=None, device='cuda', augPlan=None, folders=None, loss_type="huber",
         sample_iteration=500000, save_iteration=50000, num_samples=50):
    '''Final training loop function that is called by the run function. '''

    
    iteration = 0
    losses = list()
    duration = list()
    ep = list()
    full_tens = list()
    iterations = list() 
    kid_mu = list()
    kid_sigma = list()
    #fid = list()
    batch_lpips = list()
    avg_single_lpips = list()
    min_single_lpips = list()
    
    loss_dict = {'epoch': [],
                'iteration': [],
                'loss': [],
                'sliding_loss': []
                }
    score_dict = {'epoch': [],
                'iteration': [],
                'min_lpips': [],
                'avg_lpips': [],
                'fid': [],
                'kid': [],
                }
    
    #Directly load the images into memory (was a little faster due to small size of dataset)
    full_tens = list()
    for batch in dataloader:
        batch = batch[0].to(device)
        full_tens.append(batch)
    
    
    print("FULLTENS: ", torch.cat(full_tens,dim=0).shape)
    full_tens = torch.cat(full_tens,dim=0)
    print("RANDTENS: ", full_tens[torch.randint(0, full_tens.size(0), (16,))].shape)
    
    starting_epoch=0
    kid_mean = -1
    kid_std = -1
    lpips_val = -1
    avg_lpips = -1
    min_lpips = -1
    
    #Mixed precision
    scaler = GradScaler()
    
    #Load checkpoint, if wanted
    if checkpoint is not None:
        model.load_state_dict(torch.load(checkpoint))
        iteration = int(checkpoint.split("-")[-1][:-4])+1
        
        cp_optim = torch.load(checkpoint.replace('model','optimizer'))
        optimizer.load_state_dict(torch.load(checkpoint.replace('model','optimizer')))
    
    #Exponentially moving average
    ema = EMA(
        model,
        beta = 0.999,                 # exponential moving average factor
        update_after_step = 20000,    # only after this number of .update() calls will it start updating
        update_every = 10,            # how often to actually update, to save on compute (updates every 10th .update() call)
    )
    if checkpoint is not None:
        ema.load_state_dict(torch.load(checkpoint.replace('model','ema'),map_location=torch.device('cuda')))

    sliding_loss = []
    for epoch in range(starting_epoch, 10000000):
        model.train()
        if iteration > max_iter:
            break
            
        #10 iterations per epoch, however, epochs can be neglected, as dataset size is so small
        pbar = tqdm(range(10))
        total_loss = 0

        for step in pbar:
            batch = full_tens[torch.randint(0, full_tens.size(0), (batch_size,))]
            iteration += 1
            optimizer.zero_grad()
            
            #augment the images based on the chosen augmentation plan
            augm, labels = augPlan(batch)

            # Algorithm 1 line 3: sample t uniformally for every example in the batch
            t = torch.randint(0, timesteps, (batch_size,), device=device).long()
            
            #Autocast and scaler for mixed-precision training
            with autocast():
                loss = p_losses(model, augm, t.float(), augm=labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            ema.update()
            
            total_loss += loss.item()
            pbar.set_postfix({"Loss":total_loss/(step+1),"epoch":epoch,"iteration":iteration})
            
            #Sliding loss contains the loss of the last 50 iteration to account for outliers and give a more accurate loss representation
            if len(sliding_loss) < 50:
                sliding_loss.insert(0, loss.item())
            else:
                sliding_loss.insert(0, loss.item())
                sliding_loss.pop()
                
                
            #Save the loss dictionary as csv file
            if iteration % 1000 == 1 or iteration == max_iter:
                loss_dict['epoch'].append(epoch)
                loss_dict['iteration'].append(iteration)
                loss_dict['loss'].append(loss.item())
                loss_dict['sliding_loss'].append(sum(sliding_loss)/len(sliding_loss))
                result_df = pd.DataFrame(loss_dict)
                result_df.to_csv(str(folders[0] / f'loss_scores.csv'))  

            # Sample images and save
            if iteration % sample_iteration == 0:

                with torch.no_grad():
                    sample_image_ema = ddim_sample(ema, (4,3,image_size,image_size), device='cuda', total_timesteps=1000, sampling_timesteps=50, ddim_sampling_eta=0, objective=objective)
                    sample_image_model = ddim_sample(model, (4,3,image_size,image_size), device='cuda', total_timesteps=1000, sampling_timesteps=50, ddim_sampling_eta=0, objective=objective)

                print("\nSAMPLED_IMAGES")
                sample_image_ema = map_interval(sample_image_ema)
                sample_image_model = map_interval(sample_image_model)
                plot_images(sample_image_ema)
                plot_images(sample_image_model)


                #SAVE IMAGES
                save_image((sample_image_ema+1)*0.5, str(folders[0] / f'final_ema-{iteration}.png'), nrow = 15)
                save_image((sample_image_model+1)*0.5, str(folders[0] / f'final_model-{iteration}.png'), nrow = 15)

            #save checpoints
            #if (iteration % save_iteration == 1) or (iteration == max_iter):
            #    torch.save(model.state_dict(), str(folders[0]/ f'model-{iteration}.tar'))
            #    torch.save(ema.state_dict(), str(folders[0]/ f'ema-{iteration}.tar'))
            #    torch.save(optimizer.state_dict(), str(folders[0]/ f'optimizer-{iteration}.tar'))
                
            return None
        
        

In [15]:
#Call the run function and train the model
l1 = [(1,2),(1,2,2),(1,2,4)]
l2 = [0.25, 0.5, 0.5]
for dataset in os.listdir('./test_data/'):
    for i in range (len(l1)):
        run(dataset=dataset, dim=64, dim_mults=l1[i], loss_type='l2', num_samples=20, max_iter=650000, sample_iteration=10000, save_iteration=10000, aug_factor=l2[i], batch_size=16, image_size=96,
                checkpoint=None, lr=3e-4, noise_schedule='linear', ts=1000)

init_dim:	 None
dims:		 [42, 64, 128]
in_out:		 [(42, 64), (64, 128)]
Run_Properties:  LSUNBed96_linear_16_64_(1, 2)_1000_0.25_l20.0003_1
Number of parameters:  6879527
FULLTENS:  torch.Size([100, 3, 96, 96])
RANDTENS:  torch.Size([16, 3, 96, 96])


  0%|          | 0/10 [00:00<?, ?it/s]

init_dim:	 None
dims:		 [42, 64, 128, 128]
in_out:		 [(42, 64), (64, 128), (128, 128)]
Run_Properties:  LSUNBed96_linear_16_64_(1, 2, 2)_1000_0.5_l20.0003_1
Number of parameters:  11434727
FULLTENS:  torch.Size([100, 3, 96, 96])
RANDTENS:  torch.Size([16, 3, 96, 96])


  0%|          | 0/10 [00:00<?, ?it/s]

init_dim:	 None
dims:		 [42, 64, 128, 256]
in_out:		 [(42, 64), (64, 128), (128, 256)]
Run_Properties:  LSUNBed96_linear_16_64_(1, 2, 4)_1000_0.5_l20.0003_1
Number of parameters:  26803303
FULLTENS:  torch.Size([100, 3, 96, 96])
RANDTENS:  torch.Size([16, 3, 96, 96])


  0%|          | 0/10 [00:00<?, ?it/s]

FileNotFoundError: Found no valid file for the classes Obama96_linear_16_64_(1, 2)_1000_0.25_l20.0003_1. Supported extensions are: .jpg, .jpeg, .png, .ppm, .bmp, .pgm, .tif, .tiff, .webp