In [1]:
%load_ext autoreload
%autoreload 2

import sys
import os 

os.chdir("../")

from einops import rearrange
import einops 
import numpy as np
import torch

In [23]:
q = np.random.uniform(size=[32, 8, 128, 10])
k = np.random.uniform(size=[32, 8, 128, 3])

In [27]:
sim2 = np.einsum("bhdi, bhdj -> bhdji", q, k)

In [72]:
from datasets import load_dataset

dataset = load_dataset("cifar10")
IMAGE_SIZE = 28
CHANNELS = 3
BATCH_SIZE = 64
TIMESTEPS = 300

UPSCALE_IMAGE_SIZE = 32

device = "cuda" if torch.cuda.is_available() else "cpu"

Found cached dataset cifar10 (/home/johannes/.cache/huggingface/datasets/cifar10/plain_text/1.0.0/447d6ec4733dddd1ce3bb577c7166b986eaa4c538dcd9e805ba61f35674a9de4)


  0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
from diffusion import Diffusion

diffusion = Diffusion(
    timesteps=TIMESTEPS,
)

In [4]:
from unet import UnetConditional

unet_conditional = UnetConditional(
    channels=[32, 64, 128, 256],
    in_channels=CHANNELS,
    resnet_block_groups=8,
    use_convnext=False,
    convnext_mult=2,
    init_channel_mult=32,
    nclasses=10,
)
unet_conditional

UnetConditional(
  (_init_conv): Conv2d(3, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
  (_time_mlp): Sequential(
    (0): SinusoidalPositionEmbeddings()
    (1): Linear(in_features=96, out_features=384, bias=True)
    (2): GELU(approximate=none)
    (3): Linear(in_features=384, out_features=384, bias=True)
  )
  (_down_sampling_layers): ModuleList(
    (0): ModuleList(
      (0): ResnetBlock(
        (_time_embedder): Sequential(
          (0): SiLU()
          (1): Linear(in_features=384, out_features=32, bias=True)
        )
        (_block1): ConvBlock(
          (_conv): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (_gnorm): GroupNorm(8, 32, eps=1e-05, affine=True)
          (_actfunc): SiLU()
        )
        (_block2): ConvBlock(
          (_conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (_gnorm): GroupNorm(8, 32, eps=1e-05, affine=True)
          (_actfunc): SiLU()
        )
        (_res_conv): Conv

In [17]:
from unet import Unet

unet = Unet(
    channels=[32, 64, 128, 256],
    in_channels=1,
    resnet_block_groups=8,
    use_convnext=False,
    convnext_mult=2,
    init_channel_mult=32,
)



functools.partial(<class 'utils.ResnetBlock'>, norm_groups=8)


In [5]:
from torchvision import transforms
from torch.utils.data import DataLoader

transform = transforms.Compose(
    [   
        transforms.Resize(UPSCALE_IMAGE_SIZE),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Lambda(lambda t: (t * 2) - 1)
    ]
)

def normalize(examples):
    
    examples["x"] = [transform(image) for image in examples["img"]]
    del examples["img"]

    return examples

transformed_dataset = dataset.with_transform(normalize)

# create dataloader
dataloader = DataLoader(transformed_dataset["train"], batch_size=BATCH_SIZE, shuffle=True)

In [73]:
from typing import List

from torchvision.transforms import Compose, ToTensor, Lambda

reverse_transform = transforms.Compose([
     # transforms.Lambda(lambda t: t[0, :, :, :]),
     transforms.Lambda(lambda t: (t + 1) / 2),
     transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
     transforms.Lambda(lambda t: t * 255.),
     transforms.Lambda(lambda t: t.to("cpu"),),
     transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
     # transforms.ToPILImage(),
])

reverse_transform2 = transforms.Compose([
     # transforms.Lambda(lambda t: t[0, :, :, :]),
     transforms.Resize(128),
     transforms.Lambda(lambda t: (t + 1) / 2),
     # transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
     transforms.Lambda(lambda t: t * 255.),
     transforms.Lambda(lambda t: t.to("cpu"),),
     # transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
     # transforms.ToPILImage(),
])


def _generate_image(
    model: torch.nn.Module,
    diffusion: Diffusion,
    timesteps: int,
    shape: List[int],
    device: str,
    klass: int,
):
    
    klass = torch.Tensor([klass]).to(device)
    
    with torch.no_grad():

        x = torch.randn(shape, device=device)
        timesteps_iter = list(range(timesteps))
        timesteps_iter.reverse()
        
        Y = [x]
        
        for t in timesteps_iter:

            t = torch.Tensor([t]).long().to(device)

            predicted_noise = model.forward(x, t, klass)

            x = diffusion.backward(x, predicted_noise, t)

            Y.append(x)
            
        return torch.cat(Y, dim=0)

In [98]:
y = _generate_image(
    model=unet,
    diffusion=diffusion,
    timesteps=TIMESTEPS,
    shape=[1, 1, UPSCALE_IMAGE_SIZE, UPSCALE_IMAGE_SIZE],
    device=device,
)

[299, 298, 297, 296, 295, 294, 293, 292, 291, 290, 289, 288, 287, 286, 285, 284, 283, 282, 281, 280, 279, 278, 277, 276, 275, 274, 273, 272, 271, 270, 269, 268, 267, 266, 265, 264, 263, 262, 261, 260, 259, 258, 257, 256, 255, 254, 253, 252, 251, 250, 249, 248, 247, 246, 245, 244, 243, 242, 241, 240, 239, 238, 237, 236, 235, 234, 233, 232, 231, 230, 229, 228, 227, 226, 225, 224, 223, 222, 221, 220, 219, 218, 217, 216, 215, 214, 213, 212, 211, 210, 209, 208, 207, 206, 205, 204, 203, 202, 201, 200, 199, 198, 197, 196, 195, 194, 193, 192, 191, 190, 189, 188, 187, 186, 185, 184, 183, 182, 181, 180, 179, 178, 177, 176, 175, 174, 173, 172, 171, 170, 169, 168, 167, 166, 165, 164, 163, 162, 161, 160, 159, 158, 157, 156, 155, 154, 153, 152, 151, 150, 149, 148, 147, 146, 145, 144, 143, 142, 141, 140, 139, 138, 137, 136, 135, 134, 133, 132, 131, 130, 129, 128, 127, 126, 125, 124, 123, 122, 121, 120, 119, 118, 117, 116, 115, 114, 113, 112, 111, 110, 109, 108, 107, 106, 105, 104, 103, 102, 101, 100,

tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 1, 0, 0, 0, 0, 0]])

In [52]:
image = reverse_transform(y[-1, :, :, :])

image.save("tmp.png")

In [77]:
x_noised, noise, time = diffusion.forward(batch["x"])

y = unet_conditional.forward(
    x=x_noised.to(device),
    time=time.to(device),
    classes=batch["label"].to(device),
)

In [53]:
def plot_process(
    model: torch.nn.Module,
    x: torch.Tensor,
    label: torch.Tensor,
    diffusion: Diffusion,
    device: str,
):
    
    model.to(device)

    assert len(x.shape) == 3
    
    X, Y = [], []
    
    with torch.no_grad():

        for timestep in range(diffusion._timesteps):

            t = torch.Tensor([timestep]).long()

            x_noised, noise, _ = diffusion.forward(
                x=x,
                t=t,
            )

            y = model.forward(
                x=x_noised.to(device),
                time=t.to(device),
                classes=label.to(device),
            )

            Y.append(y.to("cpu")[0, :, :, :])
            X.append(x_noised.to("cpu"))

    
    return X, Y 

NameError: name 'device' is not defined

In [60]:
x_noised, noise, time = diffusion.forward(x)

In [None]:
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms

writer = SummaryWriter("runs/test3")

device = "cuda" if torch.cuda.is_available() else "cpu"

unet_conditional.to(device)

optimizer = torch.optim.Adam(unet_conditional.parameters(), lr=1e-3)

iteration = 0

for e in range(500):

    for step, batch in enumerate(dataloader):

        iteration += 1

        x = batch["x"].to(device)
        l = batch["label"].to(device)

        x_noised, noise, time = diffusion.forward(x)

        noise_predicted = unet_conditional.forward(
            x=x_noised.to(device),
            time=time.to(device),
            classes=l.to(device),
        )

        loss = diffusion.loss(noise, noise_predicted)
        
        unet_conditional.zero_grad()
        loss.backward()
        optimizer.step()
        
        writer.add_scalar("loss", loss, iteration)
    
    print(e)
    
    y = _generate_image(
        model=unet_conditional,
        diffusion=diffusion,
        timesteps=TIMESTEPS,
        shape=[1, CHANNELS, UPSCALE_IMAGE_SIZE, UPSCALE_IMAGE_SIZE],
        device=device,
        klass=0,
    )
        
    writer.add_image(f'images',np.asarray(reverse_transform(y[-1, :, :, :])), e, dataformats='HWC')

0
1
2
3
4
5
6
7
8
9


In [74]:
batch = next(iter(dataloader))


idx = 0

X, Y = get_process(
    model=unet_conditional,
    x=batch["x"][idx, :, :, :],
    label=batch["label"][idx],
    diffusion=diffusion,
    device="cuda",
)

import matplotlib.pyplot as plt

from matplotlib.pyplot import imshow

%matplotlib notebook

indicies = np.linspace(0, len(X) -1 , 10, dtype=int)

fig, axs = plt.subplots(2, len(indicies), gridspec_kw={'height_ratios': [1, 2]}, figsize=(10,3))

i = 0

for i, process_i in enumerate(indicies):
    
    x = reverse_transform(X[process_i])
    y = reverse_transform(Y[process_i])
    
    axs[0][i].imshow(x)
    
    axs[1][i].imshow(y)
    
    axs[0][i].axes.get_xaxis().set_visible(False)
    axs[0][i].axes.get_yaxis().set_visible(False)

    axs[1][i].axes.get_xaxis().set_visible(False)
    axs[1][i].axes.get_yaxis().set_visible(False)
    
    i +=1

In [71]:
batch

{'label': tensor([8, 9, 2, 8, 3, 2, 8, 0, 1, 9, 7, 2, 3, 2, 2, 8, 1, 1, 8, 9, 3, 8, 1, 8,
         5, 4, 1, 6, 2, 9, 7, 2, 5, 6, 8, 7, 6, 7, 4, 9, 9, 6, 0, 4, 0, 7, 2, 2,
         9, 2, 1, 5, 5, 8, 1, 8, 7, 2, 2, 7, 7, 4, 2, 9]),
 'x': tensor([[[[ 0.1059,  0.1137,  0.1294,  ...,  0.0824,  0.0588,  0.0588],
           [ 0.1373,  0.1294,  0.1294,  ...,  0.1216,  0.0824,  0.0275],
           [ 0.1059,  0.1059,  0.1059,  ...,  0.0980,  0.0745,  0.0667],
           ...,
           [-0.4196, -0.4039, -0.3804,  ..., -0.8118, -0.8431, -0.8824],
           [-0.4275, -0.4039, -0.3647,  ..., -0.7569, -0.7725, -0.7961],
           [-0.4039, -0.3804, -0.3569,  ..., -0.7255, -0.7255, -0.7412]],
 
          [[ 0.2000,  0.2078,  0.2235,  ...,  0.1373,  0.1137,  0.1137],
           [ 0.2314,  0.2314,  0.2235,  ...,  0.1765,  0.1373,  0.0824],
           [ 0.2000,  0.2000,  0.2000,  ...,  0.1529,  0.1294,  0.1216],
           ...,
           [-0.3176, -0.3020, -0.2784,  ..., -0.7725, -0.8118, -0.8431],
