In [3]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import imageio
import time
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
from tqdm import tqdm

from model import *
from data import *

# Initialize the model
t_range = 100  # Number of steps
image_size = (1, 3, 32, 32)  # Example input image size (batch, channels, height, width)
img_depth = 3  # Number of channels in the image
dataset_choice = "Cifar-10"
batch_size=256
device = 'cuda'
# Example image (normalized between 0 and 1)

train_loader = get_dataloader(dataset_name=dataset_choice, batch_size=batch_size)
validation_loader = get_dataloader(dataset_name=dataset_choice, batch_size=batch_size, split='validation')

model = DiffusionModel(in_size=32 * 32, t_range=t_range, img_depth=img_depth, device=device).to(device)
optimizer = optim.Adam(model.parameters(), lr=2e-4)
num_epochs = 5  # Number of training epochs

Files already downloaded and verified
Files already downloaded and verified


In [8]:
print(sum([p.numel() for p in model.parameters()]))

12920323


In [6]:
model.load_state_dict(torch.load(f'../model_{t_range}_float16_precision.pth'))

gif_shape = [8, 8]
sample_batch_size = gif_shape[0] * gif_shape[1]
n_hold_final = 10

# Generate samples from denoising process
gen_samples = []
x = torch.randn((sample_batch_size, img_depth, 32, 32)).to(device)
sample_steps = torch.arange(model.t_range-1, 0, -1).to(device)
for t in sample_steps:
    with torch.autocast(device_type=device, dtype=torch.float16):
        x = model.denoise_sample(x, t)
    if t % 50 == 0:
        gen_samples.append(x)
for _ in range(n_hold_final):
    gen_samples.append(x)
gen_samples = torch.stack(gen_samples, dim=0).moveaxis(2, 4).squeeze(-1)
gen_samples = (gen_samples.clamp(-1, 1) + 1) / 2

In [7]:
# Process samples and save as gif
gen_samples = gen_samples.cpu()
gen_samples = (gen_samples * 255).type(torch.uint8)
gen_samples = gen_samples.reshape(-1, gif_shape[0], gif_shape[1], 32, 32, img_depth)

def stack_samples(gen_samples, stack_dim):
    gen_samples = list(torch.split(gen_samples, 1, dim=1))
    for i in range(len(gen_samples)):
        gen_samples[i] = gen_samples[i].squeeze(1)
    return torch.cat(gen_samples, dim=stack_dim)

gen_samples = stack_samples(gen_samples, 2)
gen_samples = stack_samples(gen_samples, 2)

imageio.mimsave(
    f"pred_{t_range}_float16.gif",
    list(gen_samples),
    fps=5,
)

In [None]:
# Process samples and save as gif
gen_samples = gen_samples.cpu()
gen_samples = (gen_samples * 255).type(torch.uint8)
gen_samples = gen_samples.reshape(-1, gif_shape[0], gif_shape[1], 32, 32, img_depth)

def stack_samples(gen_samples, stack_dim):
    gen_samples = list(torch.split(gen_samples, 1, dim=1))
    for i in range(len(gen_samples)):
        gen_samples[i] = gen_samples[i].squeeze(1)
    return torch.cat(gen_samples, dim=stack_dim)

gen_samples = stack_samples(gen_samples, 2)
gen_samples = stack_samples(gen_samples, 2)

imageio.mimsave(
    f"pred_{t_range}_float16.gif",
    list(gen_samples),
    fps=5,
)

In [2]:
fid = FrechetInceptionDistance(feature=2048).to(device)  # or feature=64 based on the choice
inception = InceptionScore().to(device)
model.load_state_dict(torch.load(f'../model_{t_range}_bfloat16_precision.pth'))

# Update FID metric with all real training images
iteration = 1
for batch, _ in tqdm(validation_loader):  # Loop through the entire training dataset
    real_images = inverse_transform(batch).byte().to(device)
    fid.update(real_images, real=True)
    if iteration == 20:
        break
    iteration +=1 

# Generate the same number of images as the total training dataset
generated_images = []
start = time.time()
with torch.no_grad():
    for i in tqdm(range(20)):
        noise = torch.randn(batch_size, 3, 32, 32).to(device)  # Start with noise
        gen_images = noise
        sample_steps = torch.arange(model.t_range-1, 0, -1).to(device)
        for t in sample_steps:
            with torch.autocast(device_type=device, dtype=torch.bfloat16):
                gen_images = model.denoise_sample(gen_images, torch.tensor([t]).to(device))
        gen_images = inverse_transform(gen_images).byte()
        fid.update(gen_images, real=False)
        inception.update(gen_images)
end = time.time()
print(f"Time elapsed {end - start}")

# Compute the FID score
fid_score = fid.compute()
print(f"FID Score: {fid_score}")
inception_score = inception.compute()
print(f"{inception_score=}")

 48%|████▊     | 19/40 [00:24<00:27,  1.32s/it]
100%|██████████| 20/20 [52:06<00:00, 156.32s/it]


Time elapsed 3126.3549897670746
FID Score: 25.88897705078125
inception_score=(tensor(7.5745, device='cuda:0'), tensor(0.3724, device='cuda:0'))


In [None]:
sum([p.numel() for p in model.parameters()])

In [None]:
print(len(generated_images))
generated_images

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

#model.load_state_dict(torch.load('model.pth'))

img, _ = next(iter(train_loader))
img = img[0].squeeze()
img = inverse_transform(img).byte()
print(img.dtype)
print(img.shape)

with torch.no_grad():
    noise = torch.randn(1, 3, 32, 32).to(device)  # Start with noise
    gen_images = noise
    sample_steps = torch.arange(t_range-1, 0, -1).to(device)
    for t in sample_steps:
        gen_images = model.denoise_sample(gen_images, torch.tensor([t]).to(device))
    img = gen_images.squeeze(0).cpu()
    img = inverse_transform(img).byte()
    plt.imshow(img.permute(1,2,0))
