In [1]:
from typing import Optional, List

import numpy as np
import torch
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


In [2]:
# data
dataset_name = "oxford_flowers102"
dataset_repetitions = 5
num_epochs = 1  # train for at least 50 epochs for good results
image_size = 64
# KID = Kernel Inception Distance, see related section
kid_image_size = 75
kid_diffusion_steps = 5
plot_diffusion_steps = 20

# sampling
min_signal_rate = 0.02
max_signal_rate = 0.95

# architecture
embedding_dims = 32
embedding_max_frequency = 1000.0
widths = [32, 64, 96, 128]
block_depth = 2

# optimization
batch_size = 64
ema = 0.999
learning_rate = 1e-3
weight_decay = 1e-4


In [3]:
class SinusoidalEmbedding(nn.Module):
    def __init__(self):
        super(SinusoidalEmbedding, self).__init__()
        self.embedding_min_frequency = 1.0 #it is defined in the keras
        self.embedding_dims = 32 # it is defined as half of the image's dimension, stays as hyperparameter 
        self.embedding_max_frequency = 1000.0

    def forward(self, x):
        frequencies = torch.exp(
            torch.linspace( 
                torch.log(torch.tensor(self.embedding_min_frequency)),
                torch.log(torch.tensor(self.embedding_max_frequency)),
                self.embedding_dims // 2,
            )
        )
        angular_speeds = 2.0 * math.pi * frequencies
         
        angular= angular_speeds * x
        freq_shapes = angular.shape
        angular = angular.view(freq_shapes[0],freq_shapes[-1],freq_shapes[1],freq_shapes[2])
        embeddings = torch.cat(
            (torch.sin(angular), torch.cos(angular)), dim=1)
         
        return embeddings

In [4]:
class ResidualBlock(nn.Module):
    def __init__(self, width,fi):
        super(ResidualBlock, self).__init__()
   
        self.width = width
        self.activation = nn.SiLU()
        self.conv2 = nn.Conv2d(width, width, kernel_size=3, padding=1)
        self.batch = nn.BatchNorm2d(fi, affine=False)
        self.conv3 = nn.Conv2d(fi, width, kernel_size=3, padding=1)
        self.conv0 =  nn.Conv2d(fi,  width, kernel_size=3, padding=1)
        
    def forward(self, x):
        input_width = x.shape[1]
        if input_width == self.width:
            residual = x
        else:
            residual =  self.conv0(x)

        x = self.batch(x)
        x = self.conv3(x)
        x = self.activation(x)
        x = self.conv2(x)
        x += residual
        
        return x

In [5]:
class DownBlock(nn.Module):
  def __init__(self,width,fi):
      super(DownBlock, self).__init__()
      self.avgpool = nn.AvgPool2d(kernel_size=2)
    
      self.resBlock = ResidualBlock( width,fi)
      self.resBlock2 = ResidualBlock( width,width)
        
  def forward(self, x):
      x, skips = x
      x = self.resBlock(x)
      skips.append(x)
      x = self.resBlock2(x)
      skips.append(x)
      x = self.avgpool(x)
      return x


In [6]:
class UpBlock(nn.Module):
  def __init__(self, width,fi):
      super(UpBlock, self).__init__()
      self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
       
      self.resBlock = ResidualBlock( width,fi + width)
      self.resBlock2 = ResidualBlock(width,width*2)
  def forward(self, x):
      x, skips = x
      x= self.upsample(x) 
      x = torch.cat([x, skips.pop()], dim=1)
     
      x =self.resBlock(x)
      x = torch.cat([x, skips.pop()], dim=1)
      
      x =self.resBlock2(x)
      return x

In [7]:

class ResidualUNet(nn.Module):
    def __init__(self, image_size, widths, block_depth):
        super(ResidualUNet, self).__init__()
        self.conv = nn.Conv2d(widths[0],widths[0], kernel_size=1, bias=False)
        self.conv2 = nn.Conv2d(widths[0],3, kernel_size=1, bias=False)
        
        self.downBlock0 = DownBlock(32,64)
        self.downBlock1 = DownBlock(64,32)
        self.downBlock2 = DownBlock(96,64)
   


        self.upBlock0 = UpBlock(96,128) 
        self.upBlock1 = UpBlock(64,96) 
        self.upBlock2 = UpBlock(32,64) 
        
        
     
        
        self.resBlock0 = ResidualBlock(128,96)
        self.resBlock1 = ResidualBlock(128,128)

        self.sin = SinusoidalEmbedding()
    def forward(self, noisy_images, noise_variances):

 
        e =  self.sin(noise_variances)
        
        e = F.interpolate(e, size=image_size, mode='nearest')
         
        x = self.conv(e)
        x = torch.cat([x, e], dim=1)
        skips = []

        x = self.downBlock0([x, skips])
      
        x = self.downBlock1([x, skips])
    
        x = self.downBlock2([x, skips])
 
     
        
        x = self.resBlock0(x)
         
        x = self.resBlock1(x)
      
  
        x = self.upBlock0([x, skips])
        x = self.upBlock1([x, skips])
        x = self.upBlock2([x, skips])
            
        x = self.conv2(x)
        
        return x

In [8]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from PIL import Image
from torch.utils.data import Subset, DataLoader
from torch.utils.data import  ConcatDataset


from  torch.utils.data.sampler import SubsetRandomSampler  
transform = transforms.Compose([
    transforms.CenterCrop(500),
    transforms.Resize(size=(image_size, image_size), interpolation=Image.LANCZOS),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: torch.clamp(x/255.0, min=0.0, max=1.0))
     
])
# Define the data transformation pi
# Load the dataset
dataset = datasets.Flowers102(root='./data', split='test', download=True, transform=transform)
dataset1 = datasets.Flowers102(root='./data', split='train', download=True, transform=transform)
dataset2 = datasets.Flowers102(root='./data', split='val', download=True, transform=transform)

concat_dataset = ConcatDataset([dataset,dataset1,dataset2,dataset,dataset1,dataset2,dataset,dataset1,dataset2,dataset,dataset1,dataset2,dataset,dataset1,dataset2] ) 


data_loader = DataLoader(concat_dataset, batch_size=batch_size  ,      shuffle=True, drop_last=True )




In [9]:
len(concat_dataset)

40945

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.utils as utils

class DiffusionModel(nn.Module):
    def __init__(self, image_size, widths, block_depth):
        super().__init__()
        self.network = ResidualUNet( image_size, widths, block_depth)
        self.ema_network = ResidualUNet( image_size, widths, block_depth)
        self.ema_decay = 0.99

 
    def denormalize(self, images):
        # convert the pixel values back to 0-1 range
        images = nn.BatchNorm2d(num_features=images.shape[1])(images).mean() + images *  torch.var(nn.BatchNorm2d(num_features=images.shape[1])(images))**0.5
        return torch.clamp(images, 0.0, 1.0)
    
    def diffusion_schedule(self, diffusion_times):
        # diffusion times -> angles
        start_angle = torch.acos(torch.tensor(max_signal_rate))
        end_angle = torch.acos(torch.tensor(min_signal_rate))
        diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle)
        # angles -> signal and noise rates
        signal_rates = torch.cos(diffusion_angles)
        noise_rates = torch.sin(diffusion_angles)
        # note that their squared sum is always: sin^2(x) + cos^2(x) = 1
        return noise_rates, signal_rates

    def denoise(self, noisy_images, noise_rates, signal_rates, training):
        # the exponential moving average weights are used at evaluation
        if training:
            network = self.network
        else:
            network = self.ema_network


        # predict noise component and calculate the image component using it
        pred_noises = network(noisy_images, noise_rates.pow(2))
        pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates
        return pred_noises, pred_images


    def reverse_diffusion(self, initial_noise, diffusion_steps):
        # reverse diffusion = sampling
        num_images = initial_noise.shape[0]
        step_size = 1.0 / diffusion_steps

        # important line:
        # at the first sampling step, the "noisy image" is pure noise
        # but its signal rate is assumed to be nonzero (min_signal_rate)
        next_noisy_images = initial_noise
        for step in range(diffusion_steps):
            noisy_images = next_noisy_images

            # separate the current noisy image to its components
            diffusion_times = torch.ones((num_images, 1, 1, 1)) - step * step_size
            noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
            pred_noises, pred_images = self.denoise(
                noisy_images, noise_rates, signal_rates, training=True
            )
            # network used in eval mode

            # remix the predicted components using the next signal and noise rates
            next_diffusion_times = diffusion_times - step_size
            next_noise_rates, next_signal_rates = self.diffusion_schedule(
                next_diffusion_times
            )
            next_noisy_images = (                next_signal_rates * pred_images + next_noise_rates * pred_noises )
            # this new noisy image will be used in the next step

        return pred_images

    def generate(self, num_images, diffusion_steps):
        # noise -> images -> denormalized images
        initial_noise = torch.randn(num_images, 3, image_size, image_size)
        generated_images = self.reverse_diffusion(initial_noise, diffusion_steps)
        generated_images = self.denormalize(generated_images)
        return generated_images

    def ema_para(self,network_parameters):
      for param, ema_param in zip(network_parameters,self.ema_network.parameters()):
        ema_param.data = self.ema_decay * ema_param.data + (1 -  self.ema_decay) * param.data

    def forward(self, images):
        # normalize images to have standard deviation of 1, like the noises
        images = nn.BatchNorm2d(num_features=images.shape[1])(images)
        noises = torch.randn(images.shape[0],3,image_size, image_size)

        # sample uniform random diffusion times
        diffusion_times = torch.rand(images.shape[0], 1, 1, 1)
        noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)       
        noisy_images = signal_rates * images + noise_rates * noises

        pred_noises, pred_images = self.denoise( noisy_images, noise_rates, signal_rates, training=True)
        #self.ema_para(self.network.parameters())
        return pred_noises, noises

    def test_step(self, images ):
        images = nn.BatchNorm2d(num_features=images.shape[1])(images)
        noises = torch.randn(images.shape[0],3,image_size, image_size)

        # sample uniform random diffusion times
        diffusion_times = torch.rand(images.shape[0], 1, 1, 1)
        noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)       
        noisy_images = signal_rates * images + noise_rates * noises

        pred_noises, pred_images = self.denoise( noisy_images, noise_rates, signal_rates, training=True)

        images = self.denormalize(images)
        generated_images = self.generate( num_images=batch_size, diffusion_steps=kid_diffusion_steps    )
        return generated_images

    def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6):
            # plot random generated images for visual evaluation of generation quality
        generated_images = self.generate(   num_images=num_rows * num_cols,    diffusion_steps=plot_diffusion_steps     )

        plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0))
        for row in range(num_rows):
            for col in range(num_cols):
                index = row * num_cols + col
                plt.subplot(num_rows, num_cols, index + 1)
                plt.imshow(generated_images[index].permute(1, 2, 0).detach().numpy())
                plt.axis("off")
        plt.tight_layout()
        plt.show()
        plt.close()
      

In [13]:
diffusion_model.parameters

<bound method Module.parameters of DiffusionModel(
  (network): ResidualUNet(
    (conv): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (conv2): Conv2d(32, 3, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (downBlock0): DownBlock(
      (avgpool): AvgPool2d(kernel_size=2, stride=2, padding=0)
      (resBlock): ResidualBlock(
        (activation): SiLU()
        (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (batch): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
        (conv3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (resBlock2): ResidualBlock(
        (activation): SiLU()
        (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (batch): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
        (conv3): Conv

In [12]:
diffusion_model  = DiffusionModel(image_size, widths, block_depth)

 

optimizer = torch.optim.AdamW(
    params=diffusion_model.parameters(),
    lr=learning_rate,
    weight_decay=weight_decay
)

loss_fn = nn.L1Loss(size_average=None, reduce=None, reduction='mean')

def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(data_loader):
        # Every data instance is an input + label pair
        images, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        pred_noises, noises = diffusion_model(images)
 
        # Compute the loss and its gradients
        loss = loss_fn(pred_noises, noises)
        print(i,loss)
        loss.backward()

        # Adjust learning weights
        optimizer.step()
         
        

        # Gather data and report
        running_loss += loss.item()
        if i % 16 == 15:
            last_loss = running_loss / 16 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(data_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
             

    return last_loss

In [None]:
# Initializing in a separate cell so we can easily add more epochs to the same run
# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 3

best_vloss = 1_000_000.

diffusion_model.train(True)

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    
    avg_loss = train_one_epoch(epoch_number, writer)
    


    epoch_number += 1

EPOCH 1:
0 tensor(0.8052, grad_fn=<MeanBackward0>)
1 tensor(0.8165, grad_fn=<MeanBackward0>)
2 tensor(0.8131, grad_fn=<MeanBackward0>)
3 tensor(0.8201, grad_fn=<MeanBackward0>)
4 tensor(0.8795, grad_fn=<MeanBackward0>)
5 tensor(0.8309, grad_fn=<MeanBackward0>)
6 tensor(0.8483, grad_fn=<MeanBackward0>)
7 tensor(0.8022, grad_fn=<MeanBackward0>)
8 tensor(0.8035, grad_fn=<MeanBackward0>)
9 tensor(0.8028, grad_fn=<MeanBackward0>)
10 tensor(0.8083, grad_fn=<MeanBackward0>)
11 tensor(0.8094, grad_fn=<MeanBackward0>)
12 tensor(0.8023, grad_fn=<MeanBackward0>)
13 tensor(0.8022, grad_fn=<MeanBackward0>)
14 tensor(0.8011, grad_fn=<MeanBackward0>)
15 tensor(0.8002, grad_fn=<MeanBackward0>)
  batch 16 loss: 0.8153457418084145
16 tensor(0.8022, grad_fn=<MeanBackward0>)
17 tensor(0.8096, grad_fn=<MeanBackward0>)
18 tensor(0.8181, grad_fn=<MeanBackward0>)
19 tensor(0.8242, grad_fn=<MeanBackward0>)
20 tensor(0.8028, grad_fn=<MeanBackward0>)
21 tensor(0.8031, grad_fn=<MeanBackward0>)
22 tensor(0.8315, g

In [None]:
import matplotlib.pyplot as plt
diffusion_model.plot_images()

In [None]:
for data in data_loader:
  img, lbl = data
  print(img)

In [None]:
for data in data_loader:
  img, lbl = data
  print(img)