In [1]:
from typing import *
from typing_extensions import *

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm, trange

import hcrot
from hcrot import layers

In [None]:
# 설정
batch_size = 512
num_epochs = 5
timesteps = 1000
lr = 1e-3

# 노이즈 스케줄
betas = np.linspace(1e-4, 0.02, timesteps)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)

# q(x_t | x_0)
def q_sample(x_start, t):
    noise = np.random.randn(*x_start.shape)
    sqrt_alphas_cumprod_t = np.sqrt(alphas_cumprod[t]).reshape(-1, 1, 1, 1)
    sqrt_one_minus_alphas_cumprod_t = np.sqrt(1 - alphas_cumprod[t]).reshape(-1, 1, 1, 1)
    noisy_img = sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
    return noisy_img, noise

# reverse diffusion
def reverse_diffusion(model, x, shape=(1, 1, 28, 28), record_steps=[99, 75, 50, 25, 0]):
    images = []
    for t in reversed(range(timesteps)):
        t_tensor = np.full((shape[0],), t, dtype=np.int16)
        noise_pred = model(x, t_tensor)
        alpha = alphas[t]
        alpha_bar = alphas_cumprod[t]
        beta = betas[t]

        noise = np.random.randn(*x.shape) if t > 0 else np.zeros_like(x)
        x = (1 / np.sqrt(alpha)) * (x - (1 - alpha) / np.sqrt(1 - alpha_bar) * noise_pred) + np.sqrt(beta) * noise

        if t in record_steps:
            images.append(x.copy())
    return images

def average_pooling(img, pool_size=2):
    B, C, H, W = img.shape
    new_H, new_W = H // pool_size, W // pool_size
    img = img[:,:,:new_H * pool_size, :new_W * pool_size]
    img_reshaped = img.reshape(B, C, new_H, pool_size, new_W, pool_size)
    downsampled = img_reshaped.mean(axis=(3, 5))
    return downsampled

# MNIST 데이터 로드
df = pd.read_csv('./datasets/mnist_test.csv')
label = df['7'].to_numpy()
df = df.drop('7',axis=1)
dat = df.to_numpy()

mnist = dat[:batch_size * 10]
train_label = label[:batch_size * 10]
mnist = mnist.reshape(-1,1,28,28).astype(np.float32)
mnist = (mnist / 255.) * 2. - 1.
mnist = average_pooling(mnist, 2) # resize
dataloader = hcrot.dataset.Dataloader(mnist, train_label, batch_size=batch_size, shuffle=True)

# 모델 정의
class Model(layers.Module):
    def __init__(self):
        super().__init__()
        self.unet = layers.UNetModel(
            sample_size=14,
            in_channels=1,
            out_channels=1,
            block_out_channels=(32,64,32),
            num_class_embeds=10,
        )
    
    def __call__(self, *args, **kwargs):
        return self.forward(*args, *kwargs)

    def forward(self, x_noisy, t, labels):
        noise_pred = self.unet(x_noisy, t, labels)
        return noise_pred

# Model, Optimizer, Loss
model = Model()
optimizer = hcrot.optim.AdamW(model, lr_rate=lr)
criterion = layers.MSELoss()

# 훈련 루프
# x_vis, _ = next(iter(hcrot.dataset.Dataloader(mnist, train_label, batch_size=1, shuffle=True)))
# _, label = next(iter(dataloader))
# print('label:', label)

# noisy_imgs = [q_sample(x_vis[0], np.array([t])) for t in steps]
# show_images(f"Forward Diffusion @ Epoch {num_epochs}", noisy_imgs, steps)

pbar = trange(num_epochs)
for epoch in pbar:
    total_loss = 0
    i = 0
    for i, (x, label) in enumerate(dataloader):
        t = np.random.randint(0, timesteps, (x.shape[0],))
        x_noisy, noise = q_sample(x, t)
        noise_pred = model(x_noisy, t, label)
        loss = criterion(noise_pred, noise)
        total_loss += loss.item()
        dz = criterion.backward()
        optimizer.update(dz)
    pbar.set_postfix(loss=total_loss/(i+1))

    # if (epoch+1) % 1 == 0:
    #     print(f"\n🖼 Visualizing at epoch {epoch+1}...")
    #     restored_imgs = reverse_diffusion(model, noisy_imgs[-1], shape=(1, 1, 28, 28), record_steps=steps)
    #     show_images(f"Reverse Diffusion @ Epoch {epoch+1}", restored_imgs, list(reversed(steps)))
    # break

100%|██████████| 5/5 [01:11<00:00, 14.37s/it, loss=0.129]


In [3]:
# hcrot.utils.save(model.parameters, 'datasets/artifact.pkl')
hcrot.utils.save(model.parameters, 'artifact.pkl')

In [5]:
def sample(model, label, shape=(1, 1, 14, 14), record_steps=[99, 75, 50, 25, 0], timesteps=50):
    images = []
    x = np.random.randn(*shape)
    for t in reversed(trange(timesteps)):
        t_tensor = np.array([t],dtype=np.int16)
        noise_pred = model(x, t_tensor, label)
        alpha = alphas[t]
        alpha_bar = alphas_cumprod[t]
        beta = betas[t]

        noise = np.random.randn(*x.shape) if t > 0 else np.zeros_like(x)
        x = (1 / np.sqrt(alpha)) * (x - (1 - alpha) / np.sqrt(1 - alpha_bar) * noise_pred) + np.sqrt(beta) * noise
        if t in record_steps:
            images.append(x.copy())
        
    return images

def show_images(title, images, steps):
    fig, axs = plt.subplots(1, len(images), figsize=(len(steps), 3))
    for i, image in enumerate(images):
        image = np.clip((image / 2 + 0.5), 0, 1)
        image = np.squeeze(image)
        image = np.round(image * 255.).astype(np.int8)
        axs[i].imshow(image, cmap='gray')
        axs[i].axis("off")
        axs[i].set_title(f"{steps[i]}")
    plt.tight_layout()
    plt.show()

steps = [0, 10, 20, 30, 40, 49]
steps = [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 99]
# steps = [0, 50, 100, 150, 200, 250, 300, 350, 399]
# restored_imgs = []
# for label in range(10):
#     restored_imgs = sample(model, label, record_steps=steps, timesteps=100)
#     # restored_imgs.append(sample(model, label, timesteps=100))
#     show_images(f"", restored_imgs, steps)
#     # break

steps = [0, 25, 50, 75, 100, 125, 150, 175, 200, 225, 249]
restored_imgs = sample(model, 0, record_steps=steps, timesteps=250)
show_images(f"", restored_imgs, steps)

  0%|          | 0/400 [00:00<?, ?it/s]

 49%|████▉     | 195/400 [00:13<00:13, 14.95it/s]


KeyboardInterrupt: 