In [3]:
# !pip install piq

In [10]:
import torch
import yaml
import random
import numpy as np
from piq import FID
from unet import UNet
from ddpm import DDPM
import matplotlib.pyplot as plt


In [7]:
#read yaml file
with open('config.yaml') as file:
  config = yaml.safe_load(file)

In [11]:
# Setting reproducibility
SEED = config['seed']
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7fb92e2f8130>

In [13]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [14]:
model = UNet(in_ch=config['in_ch'], 
             out_ch=config['out_ch'], 
             time_dim=config['time_dim'],
             device=device)
model.to(device)

UNet(
  (conv): ResidualConv(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (norm1): GroupNorm(1, 32, eps=1e-05, affine=True)
    (activation1): GELU(approximate='none')
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (activation2): GroupNorm(1, 64, eps=1e-05, affine=True)
  )
  (down1): Down(
    (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (resconv): ResidualConv(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (norm1): GroupNorm(1, 64, eps=1e-05, affine=True)
      (activation1): GELU(approximate='none')
      (conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (activation2): GroupNorm(1, 128, eps=1e-05, affine=True)
    )
    (time_emb): Sequential(
      (0): SiLU()
      (1): Linear(in_features=256, out_features=128, bias=True)
    )
  )
  (attn1): Se

In [15]:
PATH = '/home/woody/iwso/iwso089h/diffusion/checkpoints/unet.pth'
model.load_state_dict(torch.load(PATH))

<All keys matched successfully>

In [16]:
ddpm = DDPM(beta_start=config['beta_start'], 
            beta_end=config['beta_end'], 
            steps=config['steps'], 
            color_channels=config['color_channels'], 
            image_size=config['image_size'], 
            device=device)

In [60]:
batch_size = 256
sampled_images = torch.zeros((4096*5, 3, 16, 16)).to(device)

for i in range(16*5):
    sampled_images[i*batch_size:(i+1)*batch_size] = ddpm.backward_diffusion(model, batch_size)
    print(f'Generated Batch No: {i+1}')

Generated Batch No: 1
Generated Batch No: 2
Generated Batch No: 3
Generated Batch No: 4
Generated Batch No: 5
Generated Batch No: 6
Generated Batch No: 7
Generated Batch No: 8
Generated Batch No: 9
Generated Batch No: 10
Generated Batch No: 11
Generated Batch No: 12
Generated Batch No: 13
Generated Batch No: 14
Generated Batch No: 15
Generated Batch No: 16
Generated Batch No: 17
Generated Batch No: 18
Generated Batch No: 19
Generated Batch No: 20
Generated Batch No: 21
Generated Batch No: 22
Generated Batch No: 23
Generated Batch No: 24
Generated Batch No: 25
Generated Batch No: 26
Generated Batch No: 27
Generated Batch No: 28
Generated Batch No: 29
Generated Batch No: 30
Generated Batch No: 31
Generated Batch No: 32
Generated Batch No: 33
Generated Batch No: 34
Generated Batch No: 35
Generated Batch No: 36
Generated Batch No: 37
Generated Batch No: 38
Generated Batch No: 39
Generated Batch No: 40
Generated Batch No: 41
Generated Batch No: 42
Generated Batch No: 43
Generated Batch No: 

In [61]:
sprites = torch.tensor(np.load('sprites_1788_16x16.npy')).to(device)
sprites = sprites.permute(0, 3, 2, 1)/255.
sprites.shape

torch.Size([89400, 3, 16, 16])

In [62]:
sprites.max()

tensor(1., device='cuda:0')

In [63]:
random_indices = torch.randint(low=0, high=sprites.shape[0], size=(sampled_images.shape[0],))
random_indices.shape

torch.Size([20480])

In [64]:
sampled_sprites = sprites[random_indices]
sampled_sprites.shape

torch.Size([20480, 3, 16, 16])

In [66]:
fid = FID()
fid(sampled_images.reshape(4096*5, 16*16*3), sampled_sprites.reshape(4096*5, 16*16*3))

tensor(38.6496, device='cuda:0', dtype=torch.float64)