# 🔥 Diffusion Models

In [1]:
working_dir = "/home/mary/work/repos/generative_deep_Learning_2nd_edition_pytorch"
exp_dir = working_dir + "/notebooks/08_diffusion/01_ddm/"

In [2]:
%load_ext autoreload
%autoreload 2

import sys
import os
import copy

# Add the path to the notebooks folder
notebooks_path = os.path.abspath(working_dir)
if notebooks_path not in sys.path:
    sys.path.append(notebooks_path)

utils_path = os.path.abspath(exp_dir)
if utils_path not in sys.path:
    sys.path.append(utils_path)

In [None]:
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets
import torch
import torch.nn as nn
from torch.nn import Module
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary

import matplotlib.pyplot as plt
from notebooks.utils import display

## 0. Parameters <a name="parameters"></a>

In [85]:
IMAGE_SIZE = 64
BATCH_SIZE = 64
DATASET_REPETITIONS = 5
LOAD_MODEL = False

NOISE_EMBEDDING_SIZE = 32
PLOT_DIFFUSION_STEPS = 20

# optimization
EMA = 0.999
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4
EPOCHS = 100

## 1. Prepare the Data

In [86]:
data_dir = working_dir + "/data"
dataset_dir = data_dir + "/pytorch-challange-flower-dataset"

In [87]:
# define a custom dataset to repeat the data
class RepeatedDataset(Dataset):
    def __init__(self, dataset, num_repeats):
        super().__init__()
        self.dataset = dataset
        self.num_repeats = num_repeats

    def __len__(self):
        return len(self.dataset) * self.num_repeats
    
    def __getitem__(self, index):
        orginal_idx = index % len(self.dataset)
        return self.dataset[orginal_idx]

In [None]:
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor()
])

dataset = datasets.ImageFolder(dataset_dir, transform=transform)

repeated_dataset = RepeatedDataset(dataset=dataset, num_repeats=DATASET_REPETITIONS)

train_dataset = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

print(f"Size of orginal dataset = {len(dataset)}")
print(f"Size of repeated dataset = {len(repeated_dataset)}")

In [89]:
data_iter = iter(train_dataset)
sample_images, _ = next(data_iter)

In [None]:
display(sample_images)

In [None]:
print(torch.min(sample_images))
print(torch.max(sample_images))
images_mean = torch.mean(sample_images)
print(images_mean)
images_var = torch.var(sample_images)
print(images_var)
images_std = torch.std(sample_images)
print(images_std)

In [None]:
# check that when we normalize the images we get the input images x0 with mean = 0 and var = 1
normalized_images = (sample_images - images_mean) / images_std
print("normalized images mean = ", torch.mean(normalized_images))
print("normalized images var = ", torch.var(normalized_images))

### 1.1 Diffusion schedules <a name="diffusion_schedules"></a>

In [93]:
def linear_diffusion_schduler(diffusion_times):
    min_rate = 0.0001
    max_rate = 0.02
    beta = min_rate + (diffusion_times * (max_rate - min_rate))
    alfa = 1 - beta
    alfa_bar = torch.cumprod(alfa, dim=0)
    signal_rate = torch.sqrt(alfa_bar)
    noise_rate = torch.sqrt((1 - alfa_bar))

    return noise_rate, signal_rate

In [94]:
def cosine_schduler(diffusion_times):
    signal_rate = torch.cos(diffusion_times * (torch.pi / 2))
    noise_rate = torch.sin(diffusion_times * (torch.pi / 2))

    return noise_rate, signal_rate

In [95]:
def offset_cosine_schduler(diffusion_times):
    min_signal_rate = torch.as_tensor(0.02)
    max_signal_rate = torch.as_tensor(0.95)
    start_angle = torch.acos(max_signal_rate)
    end_angle = torch.acos(min_signal_rate)

    theta = start_angle + (diffusion_times * (end_angle - start_angle))

    signal_rate = torch.cos(theta)
    noise_rate = torch.sin(theta)

    return noise_rate, signal_rate

In [None]:
T = 1000

diffusion_times = torch.as_tensor([t/T for t in range(T)])
print(diffusion_times.shape)

In [97]:
linear_noise_rates, linear_signal_rates  = linear_diffusion_schduler(diffusion_times)

cosine_noise_rates, cosine_signal_rates = cosine_schduler(diffusion_times)

offset_cosine_noise_rates, offset_cosine_signal_rates = offset_cosine_schduler(diffusion_times)

In [None]:
plt.plot(diffusion_times, linear_signal_rates**2, linewidth=1.5, label="linear")
plt.plot( diffusion_times, cosine_signal_rates**2, linewidth=1.5, label="cosine")
plt.plot(diffusion_times, offset_cosine_signal_rates**2, linewidth=1.5, label="offset_cosine")

plt.xlabel("t/T")
plt.ylabel(r"$\bar{\alpha_t}$ (signal)")
plt.legend()
plt.show()

In [None]:
plt.plot(diffusion_times, linear_noise_rates**2, linewidth=1.2, label="linear")
plt.plot(diffusion_times, cosine_noise_rates**2, linewidth=1.2, label="cosine")
plt.plot(diffusion_times, offset_cosine_noise_rates**2, linewidth=1.2, label="offset cosine")

plt.xlabel("t/T")
plt.ylabel(r"1 - $\bar{\alpha_t}$ (noise)")
plt.legend()
plt.show()


## 2. Build the model <a name="build"></a>

In [100]:
class SinusoidalEmbedding(Module):
    def __init__(self, l):
        super().__init__()
        self.l = l
    
    def forward(self, x):
        embedding = []
        f = torch.log(torch.as_tensor(1000)) / (self.l - 1)
        mult = torch.arange(0, self.l, 1)
        freqs = mult * f
        
        embedding.extend(torch.sin(2* torch.pi * torch.exp(freqs) * x))
        embedding.extend(torch.cos(2* torch.pi * torch.exp(freqs) * x))
        
        return torch.as_tensor(torch.stack(embedding))

In [101]:
def sinusoidal_embedding(x, l, device=torch.device("cpu")):
    embedding = []
    f = (torch.log(torch.as_tensor(1000)) / (l - 1)).to(device)
    mult = torch.arange(0, l, 1).to(device)
    freqs = mult * f

    freqs = freqs.unsqueeze(0) # (1, l)
    x = x.unsqueeze(-1) # (B, 1, 1)
    
    sin_embedding = torch.sin(2* torch.pi * torch.exp(freqs) * x)
    cos_embedding = torch.cos(2* torch.pi * torch.exp(freqs) * x)

    embedding = torch.cat([sin_embedding, cos_embedding], dim=-1)
  
    
    return embedding.squeeze(1)

In [None]:
noise = torch.ones((10, 1, 1, 1)) * 0.1
embedding = sinusoidal_embedding(noise, 16)
print(embedding.shape)


In [None]:
for x in torch.arange(0, 1, 0.1):
    embedding_x = sinusoidal_embedding(x, 16)[0]
    plt.plot(embedding_x, label=str(f"{x.item():.1f}"))

plt.legend()
plt.xlabel("embedding dimension")
plt.ylabel("embedding value")
plt.show

In [104]:
def sinusoidal_embedding_log(x, l):
    embedding = []
    freqs = torch.exp(torch.linspace(torch.log(torch.as_tensor(1.0)), 
                                     torch.log(torch.as_tensor(1000.0)),
                                     l))
    
    freqs = freqs.unsqueeze(0)
    x = x.unsqueeze(-1)
    
    sin_embeddings = torch.sin(2* torch.pi * freqs * x)
    cos_embeddings =torch.cos(2* torch.pi * freqs * x)

    embedding = torch.cat([sin_embeddings, cos_embeddings], dim=-1)
    
    return embedding.squeeze(1)


In [None]:
for x in torch.arange(0, 1, 0.1):
    embedding_x = sinusoidal_embedding_log(x, 16)[0]
    plt.plot(embedding_x, label=str(f"{x.item():.1f}"))

plt.legend()
plt.xlabel("embedding dimension")
plt.ylabel("embedding value")
plt.show

In [106]:
embedding_list = []
for y in torch.arange(0, 1, 0.01):
    embedding_list.append(sinusoidal_embedding(y, NOISE_EMBEDDING_SIZE/2)[0])


In [None]:
embedding_array = torch.as_tensor((torch.stack(embedding_list).transpose(0, 1)))
labels = [f"{value.item():.1f}" for value in torch.arange(0.0, 1.0, 0.1)]
fig, ax = plt.subplots()
ax.set_xticks(
    torch.arange(0, 100, 10), labels=labels
)
ax.set_ylabel("embedding dimension", fontsize=8)
ax.set_xlabel("noise variance", fontsize=8)
plt.pcolor(embedding_array, cmap="coolwarm")
plt.colorbar(orientation="horizontal", label="embedding value")
ax.imshow(embedding_array, interpolation="nearest", origin="lower")
plt.show()

In [108]:
def swish(x):
    return x * F.sigmoid(x)

In [None]:
sample_input = torch.arange(-4, 4, 0.1)

plt.plot(sample_input, F.sigmoid(sample_input), label="sigmoid")
plt.plot(sample_input, swish(sample_input), linewidth=3, label="swish")
plt.plot(sample_input, F.silu(sample_input), label="silu")
plt.plot(sample_input, F.relu(sample_input), label="relu")
plt.xlabel("input value")
plt.ylabel("activation function output")
plt.legend()
plt.show()

Building blocks

In [110]:
class ResdualBlock(Module):
    def __init__(self, input_channels, channels):
        super().__init__()
        self.channels = channels
        self.input_channels = input_channels

        # define the layers
        self.residual_conv = nn.Conv2d(self.input_channels, self.channels, 
                                       kernel_size=1, stride=1, padding="same")
        
        self.bn = nn.BatchNorm2d(self.input_channels, affine=False)
        self.conv1 = nn.Conv2d(self.input_channels, self.channels, kernel_size=3, 
                               stride=1, padding="same")
        self.conv2 = nn.Conv2d(self.channels, self.channels, kernel_size=3, 
                               stride=1, padding="same")
        
    def forward(self, x):
        c = x.shape[1]

        assert c == self.input_channels

        # check if we need to increase the number of channels
        if self.channels == c:
            resduial = x
        else:
            resduial = self.residual_conv(x)

        x = self.bn(x)
        x = self.conv1(x)
        x = F.selu(x)
        x = self.conv2(x)
        x = x + resduial

        return x
    
class DownBlock(Module):
    def __init__(self, input_channels, channels, block_depth):
        super().__init__()
        self.channels = channels
        self.input_channels = input_channels
        self.block_depth = block_depth
        self.residual_blocks = nn.ModuleList()
        self.avrg_pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
        for i in range(self.block_depth):
            # just for the first one we will use the input channel size
            if i == 0:
                res_block = ResdualBlock(self.input_channels, self.channels)
            else:
                res_block = ResdualBlock(self.channels, self.channels)
            self.residual_blocks.append(res_block)

    def forward(self, x):
        x, skips = x

        for res_block in self.residual_blocks:
            x = res_block(x)
            # store the output in the skip
            skips.append(x)
        x = self.avrg_pool(x)
        return x, skips

class upBlock(Module):
    def __init__(self, input_channels, channels, resdual_channels, block_depth):
        super().__init__()
        self.channels = channels
        self.block_depth = block_depth
        self.resdual_channels = resdual_channels
        self.input_channels = input_channels
        self.residual_blocks = nn.ModuleList()

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        for i in range(self.block_depth):
            if i == 0:
                res_block = ResdualBlock(self.input_channels + self.resdual_channels, self.channels)
            else:
                res_block = ResdualBlock(self.channels + self.resdual_channels, self.channels)
            self.residual_blocks.append(res_block)
    
    def forward(self, x):
        x, skips = x
        # print(f"upblock x: {x.shape}")
        x = self.upsample(x)
        # print(f"upblock x2: {x.shape}")

        for res_block in self.residual_blocks:
            x = torch.cat([x, skips.pop()], dim=1)
            # print(f"upblock cat: {x.shape}")
            x = res_block(x)

        return x, skips

The U-Net implementation

In [111]:
class UNet(Module):
    def __init__(self, device, noise_embedding_size=32):
        super().__init__()
        self.noise_embedding_size = noise_embedding_size
        self.device = device
        
        self.noise_upsampling = nn.Upsample(size = IMAGE_SIZE)
        self.conv_input = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=1, 
                                    stride=1, padding="same")

        self.down_block_1 = DownBlock(input_channels=64, channels=32, block_depth=2)
        self.down_block_2 = DownBlock(input_channels=32, channels=64, block_depth=2)
        self.down_block_3 = DownBlock(input_channels=64, channels=96, block_depth=2)

        self.resdual_block_1 = ResdualBlock(input_channels=96, channels=128)
        self.resdual_block_2 = ResdualBlock(input_channels=128, channels=128)

        self.up_block_1 = upBlock(input_channels=128, channels=96, resdual_channels=96, block_depth=2)
        self.up_block_2 = upBlock(input_channels=96, channels=64, resdual_channels=64, block_depth=2)
        self.up_block_3 = upBlock(input_channels=64, channels=32, resdual_channels=32, block_depth=2)

        self.conv_output = nn.Conv2d(in_channels=32, out_channels=3, kernel_size=1, 
                                     stride=1, padding="same")
        
        init.zeros_(self.conv_output.weight)
    
    def forward(self, noise_var, noisy_image):
        # noise_var, noisy_image = x

        # print(f"noise_var: {noise_var.shape}")
        B = noise_var.shape[0]
        noise_emb = sinusoidal_embedding(noise_var, self.noise_embedding_size / 2, device=self.device)

        # print(f"noise_emb: {noise_emb.shape}")
        noise_channels = noise_emb.shape[1]

        noise_emb = noise_emb.unsqueeze(-1).unsqueeze(-1)
  
        
        # print(f"noise_emb 1: {noise_emb.shape}")
        noise_emb = self.noise_upsampling(noise_emb)
        # print(f"noise_emb 2: {noise_emb.shape}")

        # print(f"noisy_image 1: {noisy_image.shape}")

        noisy_image = self.conv_input(noisy_image)

        # print(f"noisy_image 2: {noisy_image.shape}")

        # x = torch.cat([noise_emb, noisy_image], dim=1)
        x = torch.cat([noisy_image, noise_emb], dim=1)

        # print(f"x: {x.shape}")

        skips = []
        x, skips = self.down_block_1((x, skips))
        x, skips = self.down_block_2((x, skips))
        x, skips = self.down_block_3((x, skips))

        x = self.resdual_block_1(x)
        x = self.resdual_block_2(x)

        x, skips = self.up_block_1((x, skips))
        x, skips = self.up_block_2((x, skips))
        x, skips = self.up_block_3((x, skips))

        x = self.conv_output(x)

        return x

In [112]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [113]:
unet = UNet(noise_embedding_size=NOISE_EMBEDDING_SIZE, device=device).to(device)

In [None]:
summary(unet, [(1,1), (1, 3, IMAGE_SIZE, IMAGE_SIZE)])

In [115]:
B = sample_images.shape[0]
noise_var = (torch.ones((B, 1)) * 0.1).to(device)
pred_noise = unet(noise_var, sample_images.to(device))

In [None]:
print(pred_noise.shape)
display(pred_noise)

Diffusion Model

In [117]:
class DiffusionModel(Module):
    def __init__(self, diffusion_schdule, device, 
                 noise_embedding_size=32, ema_momentum=0.999, log_dir="./"):
        super().__init__()

        self.diffusion_schdule = diffusion_schdule
        self. device = device
        self.noise_embedding_size = noise_embedding_size
        self.ema_momentum = ema_momentum

        self.writer = SummaryWriter(log_dir)

        self.model = UNet(noise_embedding_size=self.noise_embedding_size, device=self.device)
        
        # self.ema_model = UNet(noise_embedding_size=self.noise_embedding_size, device=self.device)
        # Initialize all the parameters similar to the model
        # self.ema_model.load_state_dict(self.model.state_dict())
        self.ema_model = copy.deepcopy(self.model)
        
        torch.manual_seed(42)
    
    def normalize_images(self, images):
        mean = torch.mean(images)
        std = torch.std(images)
        images_norm = (images - mean) / std
        return images_norm
    
    def de_normalize_images(self, images):
        denorm_images = (images * self.data_std) + self.data_mean
        denorm_images = torch.clamp(denorm_images, min=0, max=1)

        return denorm_images
    
    def compute_global_stat(self, data_loader):
        self.data_mean = 0
        self.data_std = 0

        dataset = torch.cat([batch for batch, _ in data_loader])
        self.data_mean = torch.mean(dataset)
        self.data_std = torch.std(dataset)
        # for data, _ in data_loader:
        #     self.data_mean += torch.mean(data)
        #     self.data_std += torch.std(data)
        
        # self.data_mean /= len(data_loader)
        # self.data_std /= len(data_loader)

    # to be used for loading checkpoints
    def set_global_stat(self, data_mean, data_std):
        self.data_mean = data_mean
        self.data_std = data_std
        
    def denoise(self, noisy_images, noise_rates, signal_rates, training=True):
        
        if training:
            pred_noise = self.model(noise_rates.squeeze(1).squeeze(1)**2, noisy_images)
        else:
            pred_noise = self.ema_model(noise_rates.squeeze(1).squeeze(1)**2, noisy_images)
        
        pred_images = (noisy_images - (pred_noise * noise_rates)) / signal_rates

        return pred_noise, pred_images
    
    def train_step(self, train_images):
        
        train_images = train_images.to(device)
        self.optimizer.zero_grad()
        # normalize the images
        norm_images = self.normalize_images(train_images)
        # generate random noise
        noise = torch.randn_like(norm_images).to(self.device)
        # generate random diffusion times
        B = train_images.shape[0]
        diffusion_times = torch.rand((B, 1, 1, 1)).to(self.device)
        # calculate the noise and signal rates
        noise_rates, signal_rates = self.diffusion_schdule(diffusion_times)
        # calculate noisey images
        # print(f"norm_images: {norm_images.shape}")
        # print(f"signal_rates: {signal_rates.shape}")
        # print(f"noise: {noise.shape}")
        # print(f"noise_rates: {noise_rates.shape}")
        
        noisy_images = (norm_images * signal_rates) + (noise * noise_rates)
        # predict the noise and get the corretced image
        pred_noise, pred_images = self.denoise(noisy_images, noise_rates, signal_rates, training=True)

        # calculate the loss
        loss = self.loss_fucntion(pred_noise, noise)
        # calculate the gradients
        loss.backward()
        #update the weights for model
        self.optimizer.step()

        # update the weights for EMA model
        for ema_parameter, model_parameter in zip (self.ema_model.parameters(), self.model.parameters()):
            # we will use the inplace operation for the tensor to avoid allocating a new tensor for the results
            ema_parameter.data.mul_(self.ema_momentum).add_((1 - self.ema_momentum) * model_parameter.data)
            # The following would work too but will allocate a tensor for the results
            # ema_parameter.data = (self.ema_momentum * ema_parameter.data) + ((1 - self.ema_momentum) * model_parameter.data)

        for ema_buffer, buffer in zip(self.ema_model.buffers(), self.model.buffers()):
            data_type = ema_buffer.dtype
            if data_type == torch.float32:
                ema_buffer.data.mul_(self.ema_momentum).add_((1 - self.ema_momentum) * buffer.data).to(data_type)
            else:
                ema_buffer.copy_(buffer)
        
        return loss.item()


    def fit(self, train_dataset, optimizer, loss_function, epochs, callbacks=[]):

        self.optimizer = optimizer
        self.loss_fucntion = loss_function

        # compute the global stat to be used later for image generation
        self.compute_global_stat(train_dataset)

        for epoch in range(epochs):

            acc_loss = 0

            self.model.train()
            self.ema_model.train()

            for train_images, _ in train_dataset:
                loss = self.train_step(train_images)

                acc_loss += loss 
            
            acc_loss /= len(train_dataset)

            # print epoch progress
            print(
                f"Epoch {epoch}/{epochs}: Training: loss: {acc_loss :.4f} "
            )

            self.writer.add_scalar("loss", acc_loss, global_step=epoch)

            if callbacks:
                logs = {"epoch": epoch,
                       "model_state_dict":self.model.state_dict(),
                       "ema_model_state_dict": self.ema_model.state_dict(),
                       "data_mean": self.data_mean,
                       "data_std": self.data_std,
                       "loss": acc_loss,
                       "model": self}
                for callback in callbacks:
                    callback.on_epoch_end(epoch, logs)

    def reverse_diffusion(self, initial_noise, diffusion_steps, use_ema_model=True):
        step_size = 1.0 / diffusion_steps
        current_images = initial_noise.to(self.device)
        B = initial_noise.shape[0]
        if use_ema_model:
            training_flag = False
            self.ema_model.eval()
        else:
            training_flag = True
            self.model.eval()

        with torch.no_grad():
            for step in range(diffusion_steps):
                curr_diffusion_time = torch.ones((B, 1, 1, 1)).to(self.device) * (1 - (step * step_size))
                curr_noise_rate, curr_signal_rate = self.diffusion_schdule(curr_diffusion_time)
                
                pred_noise, pred_image = self.denoise(current_images, curr_noise_rate, 
                                                    curr_signal_rate, training=training_flag)
                
                next_diffusion_time = curr_diffusion_time - step_size
                next_noise_rate, next_signal_rate = self.diffusion_schdule(next_diffusion_time)

                current_images = (pred_image * next_signal_rate) + (pred_noise * next_noise_rate)
            
        return pred_image
    
    def generate_images(self, image_num, diffusion_steps, use_ema_model=True):
        noise = torch.randn((image_num, 3, IMAGE_SIZE, IMAGE_SIZE)).to(self.device)
        gen_images = self.reverse_diffusion(noise, diffusion_steps, use_ema_model=use_ema_model)
        gen_images = self.de_normalize_images(gen_images)

        return gen_images


create the required callbacks

In [118]:
class Callback:
    def on_epoch_end(self, epoch, logs=None):
        pass

In [119]:
class SaveCheckpoint(Callback):
    def __init__(self, save_dir, save_every=10):
        super().__init__()
        self.save_dir = save_dir
        self.save_every = save_every
    def on_epoch_end(self, epoch, logs=None):
        
        if (epoch % self.save_every) == 0:
            checkpoint = {"epoch":epoch,
                          "model_state_dict":logs["model_state_dict"],
                          "ema_model_state_dict":logs["ema_model_state_dict"],
                          "data_mean": logs["data_mean"],
                          "data_std": logs["data_std"],
                          "loss":logs["loss"]
                        }
            checkpoint_file = self.save_dir + f"/checkpoint_{epoch}.pth"

            torch.save(checkpoint, checkpoint_file)

In [120]:
class ImageGenerator(Callback):
    def __init__(self, save_dir, num_images, diff_steps, use_ema_model=True):
        super().__init__()
        self.save_dir = save_dir
        self.num_images = num_images
        self.diff_steps = diff_steps
        self.use_ema_model = use_ema_model

    def on_epoch_end(self, epoch, logs=None):
        diff_model = logs["model"]

        gen_images = diff_model.generate_images(self.num_images, self.diff_steps, self.use_ema_model)

        display(gen_images, save_to=self.save_dir + f"/epoch_{epoch}.png")

Prepare for training

In [121]:
log_dir =  exp_dir + "/log"
os.makedirs(log_dir, exist_ok=True)

sample_dir =  exp_dir + "/sample_gen"
os.makedirs(sample_dir, exist_ok=True)

checkpoint_dir =  exp_dir + "/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

In [122]:
callbacks = [SaveCheckpoint(save_dir=checkpoint_dir, save_every=10),
             ImageGenerator(save_dir=sample_dir, num_images=10, diff_steps=30, use_ema_model=True)]

Create the model

In [123]:
diff_model = DiffusionModel(offset_cosine_schduler, device, log_dir=log_dir).to(device)

In [124]:
# check if we have checkpoint to load
if LOAD_MODEL:
    checkpoint_file = checkpoint_dir + "/checkpoint_600.pth"
    checkpoint = torch.load(checkpoint_file)
    diff_model.model.load_state_dict(checkpoint["model_state_dict"])
    diff_model.ema_model.load_state_dict(checkpoint["ema_model_state_dict"])
    diff_model.set_global_stat(checkpoint["data_mean"], checkpoint["data_std"])

Create the model, optimizer and the loss function

In [125]:
# we will use mean abslout error
loss_fn = nn.L1Loss()
optimizer = optim.AdamW(diff_model.model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

In [None]:
for buffer in diff_model.ema_model.buffers():
    print(type(buffer))
    print(buffer.shape)
    print(buffer.dtype)
    print(buffer.data.dtype)
    print(buffer)
    print(buffer.data)

In [None]:
diff_model.fit(train_dataset, optimizer=optimizer, loss_function=loss_fn, epochs=EPOCHS, callbacks=callbacks)