In [None]:
!pip3 install kaggle

In [None]:
!kaggle datasets download -d ebrahimelgazar/pixel-art
!unzip -p pixel-art.zip sprites.npy > sprites.npy
!unzip -p pixel-art.zip sprites_labels.npy > sprites_labels.npy

In [87]:
import numpy as np
import torch
from torch import nn, optim, Tensor
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader

from PIL import Image

import matplotlib.pyplot as plt

from tqdm.auto import tqdm

In [2]:
labels = Tensor(np.load('sprites_labels.npy').argmax(axis=1)).type(torch.long)
labels.shape

torch.Size([89400])

In [3]:
image = np.load('sprites.npy').astype(np.float32) / 255
image = Tensor(image.swapaxes(-1, 1))
image = image * 2 - 1

image.shape

torch.Size([89400, 3, 16, 16])

In [115]:
import math, random
from m_model import Model
from m_diffusion import Diffusion, cosine_schedule, linear_shedule

In [119]:
def genereate_image(model:nn.Module, diffusion:Diffusion, device='cpu'):
  n = 4
  model.eval()
  with torch.inference_mode():
    x = torch.randn((n, 3, 16, 16), device=device)
    pbar = tqdm(total=1000, desc=f'Sampling', position=0, colour='yellow')
    pbar.update(1)

    for i in reversed(range(1, 1000)):
      t = (torch.ones(n) * i).long()
      predicted:Tensor = model(x, t)
      alpha = diffusion.alpha[t][:, None, None, None]
      alpha_hat = diffusion.alpha_hat[t][:, None, None, None]
      beta = diffusion.beta[t][:, None, None, None]
      if i > 1:
        noise = torch.randn_like(x, device=device)
      else:
        noise = torch.zeros_like(x, device=device)
      x = 1 / torch.sqrt(alpha) * (x - ((1-alpha)/(torch.sqrt(1 - alpha_hat)))*predicted) + torch.sqrt(beta)*noise
      pbar.update(1)
    pbar.close()

  x = (x.clamp(-1, 1) + 1) / 2
  x = (x * 255).type(torch.uint8)
  
  for i in range(4):
    plt.subplot(2, 2, 1+i)
    plt.imshow(x.detach().cpu()[i].swapaxes(0, -1).numpy())
  plt.show()

In [120]:
c_in = 3
c_out = 3
embedding = 256
steps = 1000
categories=5
w_size = 16
h_size = 16

In [116]:
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
diffusion = Diffusion(steps, cosine_schedule(steps), device=DEVICE)

In [117]:
model = Model(c_in, c_out, w_size, h_size, steps, categories, embedding, device=DEVICE)

In [118]:
dataset = TensorDataset(image, labels)
dl = DataLoader(dataset, batch_size=8, shuffle=True, drop_last=True)

optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-8)
loss_fn = nn.MSELoss()

In [110]:
model.train()
for epoch in range(10):
  loss_total = 0

  pbar = tqdm(total=len(dl), desc=f'Train ({epoch+1}/{10})', position=0, colour='yellow')
  for X, Y in dl:
    X = X.to(DEVICE)
    Y = Y.to(DEVICE)

    t = torch.randint(low=0, high=steps, size=(X.shape[0],))
    x_t, noise = diffusion.noise_image(X, t)
    predicted = model(x_t, t, Y if random.random() < 0.1 else None)

    loss:Tensor = loss_fn(noise, predicted)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    loss_total += loss.item() * X.shape[0]

    pbar.set_postfix(MSE=loss.item())
    pbar.update(1)

  loss_total /= len(dl.dataset)

  pbar.set_postfix(MSE=loss_total)
  pbar.close()

  torch.save(model.state_dict(), f'model-{epoch+1}.pt')

Train (1/10):   0%|[33m          [0m| 0/11175 [00:00<?, ?it/s]

Train (1/10):  18%|[33m█▊        [0m| 1971/11175 [20:07<1:38:14,  1.56it/s, MSE=0.177] 

KeyboardInterrupt: 