In [1]:

import numpy as np

import matplotlib.pyplot as plt
import os

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import Subset
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel, AutoencoderKL
from diffusers.optimization import get_cosine_schedule_with_warmup

from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DATA_PATH = 'data/data0/lsun/bedroom'
BATCH_SIZE = 8

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
class VAE:
    vae_url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/vae-ft-mse-840000-ema-pruned.safetensors"

    def __init__(self, device):
        self.model = AutoencoderKL.from_single_file(self.vae_url).to(device)
        self.device = device

    def to_latent(self, input):
        with torch.no_grad():
            latent = self.model.encode(input.to(self.device))
        return latent.latent_dist.sample()

    def to_image(self, encoded):
        with torch.no_grad():
            output_img = self.model.decode(encoded)
        return output_img.sample

In [5]:
vae = VAE(device)

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: vae.to_latent(x.unsqueeze(0)).squeeze(0))  
])

image_dataset = ImageFolder(root=DATA_PATH, transform=transform)
image_dataset = Subset(image_dataset, torch.randperm(len(image_dataset))[:1000])
train_dataloader = DataLoader(image_dataset, batch_size=BATCH_SIZE, shuffle=True)

for idx, batch in enumerate(train_dataloader):
    break




In [6]:
def pos_encoding(n, d_model):
    """Positional encoding."""
    assert d_model % 2 == 0, 'd_model must be divisible by 2'
    wk = torch.tensor([1 / 10_000 ** (2 * i / d_model) for i in range(d_model // 2)])
    wk = wk.reshape((1, d_model // 2))
    t = torch.arange(n).reshape((n, 1))
    encoding = torch.zeros(n, d_model)
    encoding[:,::2] = torch.sin(t * wk)
    encoding[:, 1::2] = torch.cos(t * wk)
    return encoding

In [7]:
class EmbeddingBlock(nn.Module):
    """Embedding block for UNet."""

    def __init__(self, n_steps, d_model):
        super(EmbeddingBlock, self).__init__()
        self.t_embed = self.init_pos_encoding(n_steps, d_model)
        self.l1 = nn.Linear(d_model, d_model)
        self.l2 = nn.Linear(d_model, d_model)
        self.silu = nn.SiLU()

    def init_pos_encoding(self, n_steps, d_model):
        t_embed = nn.Embedding(n_steps, d_model)
        t_embed.weight.data = pos_encoding(n_steps, d_model)
        t_embed.requires_grad = False
        return t_embed

    def forward(self, t):
        t = self.t_embed(t)
        t = self.l1(t)
        t = self.silu(t)
        t = self.l2(t)
        return t

class ConvBlock(nn.Module):
    """Convolutional block for UNet."""

    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding="same")
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding="same")
        self.relu = nn.ReLU()
        self.bnorm = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bnorm(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bnorm(x)
        x = self.relu(x)
        return x

class EncoderBlock(nn.Module):
    """Encoder block for UNet."""

    def __init__(self, in_channels, out_channels):
        super(EncoderBlock, self).__init__()
        self.conv = ConvBlock(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2)

    def forward(self, x):
        x = self.conv(x)
        pool = self.pool(x)
        return x, pool

class DecoderBlock(nn.Module):
    """Decoder block for UNet."""

    def __init__(self, in_channels, out_channels):
        super(DecoderBlock, self).__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0)
        self.conv = ConvBlock(2*out_channels, out_channels)

    def forward(self, x, down_tensor):
        x = self.upconv(x)
        x = torch.cat((x, down_tensor), dim=1)
        print(x.shape)
        x = self.conv(x)
        return x

class UNet(nn.Module):
    """UNet model for diffusion."""

    def __init__(self, batch_size, n_steps, input_size=32, in_channels=4, first_layer_channels=64):
        super(UNet, self).__init__()

        self.batch_size = batch_size

        # input size
        self.s1 = input_size
        self.s2 = self.s1 // 2
        self.s3 = self.s2 // 2

        # number of channels
        self.ch0 = in_channels
        self.ch1 = first_layer_channels
        self.ch2 = self.ch1 * 2
        self.ch3 = self.ch2 * 2

        # positional encoding
        # self.t1 = self.init_pos_encoding(n_steps, d_model = in_channels * self.s1 * self.s1)
        # self.t2 = self.init_pos_encoding(n_steps, d_model = self.ch1 * self.s2 * self.s2)
        # self.t3 = self.init_pos_encoding(n_steps, d_model = self.ch2 * self.s3 * self.s3)

        # embedding blocks
        self.em1 = EmbeddingBlock(n_steps, in_channels * self.s1 * self.s1)
        self.em2 = EmbeddingBlock(n_steps, self.ch1 * self.s2 * self.s2)
        self.em3 = EmbeddingBlock(n_steps, self.ch2 * self.s3 * self.s3)
        self.em4 = EmbeddingBlock(n_steps, self.ch2 * self.s2 * self.s2)
        self.em5 = EmbeddingBlock(n_steps, self.ch1 * self.s1 * self.s1)

        # encoder blocks
        self.e1 = EncoderBlock(self.ch0, self.ch1)
        self.e2 = EncoderBlock(self.ch1, self.ch2)

        # decoder blocks
        self.d1 = DecoderBlock(self.ch3, self.ch2)
        self.d2 = DecoderBlock(self.ch2, self.ch1)

        # middle conv block
        self.middle = ConvBlock(self.ch2, self.ch3)

        # output layer
        self.out = nn.Conv2d(self.ch1, self.ch0, kernel_size=1, padding="same")


    # def init_pos_encoding(self, n_steps, d_model):
    #     t_embed = nn.Embedding(n_steps, d_model)
    #     t_embed.weight.data = pos_encoding(n_steps, d_model)
    #     t_embed.requires_grad = False
    #     return t_embed

    def forward(self, x, t):

        t1 = self.em1(t).view(-1, self.ch0, self.s1, self.s1)
        t2 = self.em2(t).view(-1, self.ch1, self.s2, self.s2)
        t3 = self.em3(t).view(-1, self.ch2, self.s3, self.s3)
        t4 = self.em4(t).view(-1, self.ch2, self.s2, self.s2)
        t5 = self.em5(t).view(-1, self.ch1, self.s1, self.s1)

        x1, pool1 = self.e1(x + t1)
        x2, pool2 = self.e2(pool1 + t2)

        x = self.middle(pool2 + t3)

        x = self.d1(x + t4, x2)
        x = self.d2(x + t5, x1)

        x = self.out(x)

        return x


In [8]:
# from torchinfo import summary
# model = UNet(BATCH_SIZE, 1000)
# summary(model, input_size=[(BATCH_SIZE, 4, 32, 32), (BATCH_SIZE,1)])

In [9]:
unet = UNet(BATCH_SIZE, 100).to(device)
unet(batch[0], torch.arange(BATCH_SIZE).to(device))

RuntimeError: [enforce fail at alloc_cpu.cpp:117] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 17179869184 bytes. Error code 12 (Cannot allocate memory)