In [None]:
import torch
from models.DiffusionModel import DiffusionUNet
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torchvision import datasets
import torchvision.transforms as transforms
import gc

In [None]:
torch.mps.empty_cache()
gc.collect()

In [None]:
out_res = (16,16)
device = "mps"
learning_rate = 3E-4
num_epochs = 70
batch_size = 64
in_channels = 1
torch.manual_seed(42)

In [None]:
transform = transforms.Compose([
			transforms.Resize(out_res),
			transforms.CenterCrop(out_res),
			transforms.ToTensor(),
			#transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
			])

#train_set = datasets.ImageFolder("Data/celeb", transform=transform)
train_set = datasets.MNIST(root=".", download=True, transform=transform)

### Cosine diffusion schedule

$$ x_t = cos( \frac{t}{T} \frac{\pi}{2} ) x_0 + sin( \frac{t}{T} \frac{\pi}{2} ) \epsilon $$

Where $ \epsilon $ is a noise distributed as normal gaussian.

In [None]:
def cosine_diffusion_schedule(diffusion_times):
    signal_rates = torch.cos(diffusion_times * torch.pi / 2)
    noise_rates = torch.sin(diffusion_times * torch.pi / 2)
    return signal_rates, noise_rates

def offset_cosine_diffusion_schedule(diffusion_times):
    min_signal_rate = torch.Tensor( [0.02] )
    max_signal_rate = torch.Tensor( [0.95] )
    start_angle = torch.acos(max_signal_rate)
    end_angle = torch.acos(min_signal_rate)

    diffusion_angles = start_angle + diffusion_times * (end_angle-start_angle)
    signal_rates = torch.cos(diffusion_angles)
    noise_rates = torch.sin(diffusion_angles)
    return signal_rates, noise_rates


In [None]:
T = 1000
diffusion_times = torch.Tensor( [t/T for t in range(T)] )
s_rates_cos, n_rates_cos = cosine_diffusion_schedule(diffusion_times)
s_rates_off, n_rates_off = offset_cosine_diffusion_schedule(diffusion_times)
fig, axs = plt.subplots(1, 2, figsize=(10,3))
axs[0].plot(diffusion_times, s_rates_cos, label="Cosine")
axs[0].plot(diffusion_times, s_rates_off, label="Offset cos")
axs[0].set_ylabel("Signal rate")
axs[0].set_xlabel("Diffusion time")
axs[0].legend()

axs[1].plot(diffusion_times, n_rates_cos, label="Cosine")
axs[1].plot(diffusion_times, n_rates_off, label="Offset cos")
axs[1].set_ylabel("Noise rate")
axs[1].set_xlabel("Diffusion time")
axs[1].legend()
plt.show()

In [None]:
img, _ = train_set[0]

n_samples = 10
noises = torch.randn_like(img)
ptr, rng = 0, int(T/n_samples)

fig, axs = plt.subplots(1, n_samples, figsize=(18,2))
fig.suptitle("Cosine diffusion schedule")
for i in range(n_samples):
    noise_img = s_rates_cos[ptr] * img + n_rates_cos[ptr] * noises
    noise_img = torch.clip( noise_img, min=0, max=1 )
    axs[i].imshow(noise_img.permute(1,2,0))
    axs[i].set_xticks([])
    axs[i].set_yticks([])
    ptr += rng

In [None]:
model = DiffusionUNet(in_channels=in_channels, resolution=out_res[0], attn_resolutions=[8], ch_mult=(1,2,2), channels=16, time_steps=T).to(device)

In [None]:
from torch.optim import AdamW

loss_fn = nn.MSELoss()
optim = AdamW(model.parameters(), lr=learning_rate)

In [None]:
def epoch(imgs, schedule=offset_cosine_diffusion_schedule):
    optim.zero_grad()
    
    B = imgs.shape[0]

    times = torch.randint(0, T,(batch_size,))
    diff_times = (times/T)[:, None, None, None]
    
    noises = torch.randn(size=imgs.shape)
    s_rates, n_rates = schedule(diff_times)
    # add noise to current image
    noisy_imgs = s_rates * imgs + n_rates * noises

    pred_noises = model(noisy_imgs.to(device), (times))

    loss = loss_fn(pred_noises.cpu(), noises)
    loss.backward()
    optim.step()

    del noisy_imgs
    return loss.item()

In [None]:
gen_noise = torch.randn(size=(n_samples,in_channels,*out_res))
schedule = cosine_diffusion_schedule
data_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=0)
n_it = len(train_set)//batch_size +1

for e in range(num_epochs):
    bar = tqdm(range(len(data_loader)))
    epoch_loss = 0
    model.train()
    for i in bar:
        imgs, _ = next(iter(data_loader))
        loss = epoch(imgs)
        epoch_loss += loss
        torch.mps.synchronize()
    epoch_loss = epoch_loss/len(data_loader)
    print(f"epoch {e+1:3.0f} \t  loss {epoch_loss:4.4f}")
    if e%10 == 0:
        model.eval()
        plt.clf()
        fig, axs = plt.subplots(1, n_samples, figsize=(18,2))
        out_imgs, denoising = model.reverse_diffusion(gen_noise.to(device), diffusion_steps=100, schedule=schedule) 
        #model.sample_images(gen_noise.to(device), diffusion_steps=100, schedule=schedule).cpu()
        for i in range(n_samples):
            axs[i].imshow(out_imgs[i].permute(1,2,0))
            axs[i].set_xticks([])
            axs[i].set_yticks([])
        plt.show()
        #plt.savefig(os.path.join(check_point_dir , 'img_schedule_%i'%(n_schedule)))
        gen_noise.to('cpu')
