In [None]:
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from datasets import load_dataset
from torch import nn
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from datasets import load_dataset
import matplotlib.pyplot as plt
from unet_architecture import UNetWithTimeEmbedding
from vae_architecture import SpatialVAE

# loading data

In [None]:
dataset = load_dataset("Skiittoo/cartoon-faces", split="train")

# preprocessing
transform = transforms.Compose([
    transforms.Resize((64, 64)), # resize to 64x64 for vae
    transforms.ToTensor(), # moves image in [0,1] range
])

class CartoonFacesDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        img, caption = item['image'], item['text']
        
        if img.mode == 'RGBA': # 4d to 3d image (sometimes used in hugging face)
            img = img.convert('RGB')
        
        if self.transform:
            img = self.transform(img)
        
        return img, caption

cartoon_dataset = CartoonFacesDataset(dataset, transform=transform)
dataloader = DataLoader(cartoon_dataset, batch_size=32, shuffle=True) # load into dataloader for training

# vae architecture and training

In [None]:
def vae_loss(recon, x, z_mean, z_log_var):
    recon_loss = nn.functional.binary_cross_entropy(recon, x, reduction='sum')
    kl_loss = -0.5 * torch.mean(1 + z_log_var - z_mean.pow(2) - z_log_var.exp())  
    return (recon_loss + kl_loss) / x.size(0)


def show_reconstructions(vae, imgs, epoch, num=5, device='cuda'):
    vae.eval()
    imgs = imgs.to(device)
    with torch.no_grad():
        reconstructed, z_mean, z_log_var = vae(imgs)

    imgs = imgs.detach().cpu()
    reconstructed = reconstructed.detach().cpu()

    fig, axs = plt.subplots(2, num, figsize=(2 * num, 4))
    for i in range(num):
        axs[0, i].imshow(imgs[i].permute(1, 2, 0).clip(0, 1))
        axs[0, i].set_title("Original")
        axs[0, i].axis("off")

        axs[1, i].imshow(reconstructed[i].permute(1, 2, 0).clip(0, 1))
        axs[1, i].set_title("Reconstructed")
        axs[1, i].axis("off")

    plt.suptitle(f"Reconstructions after epoch {epoch}")
    plt.show()

def train(vae, optimizer, epoch, loader, device='cuda'):
    vae.train()
    total_loss = 0

    imgs_to_show = None
    
    for imgs in loader:
        imgs = imgs.to(device)
        optimizer.zero_grad()
        reconstructed, z_mean, z_log_var = vae(imgs)
        loss = vae_loss(reconstructed, imgs, z_mean, z_log_var)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        if imgs_to_show is None:
            imgs_to_show = imgs.detach().cpu()

    avg_loss = total_loss / len(loader.dataset)
    print(f'Epoch {epoch} Loss {avg_loss}')

    show_reconstructions(vae, imgs_to_show, epoch, num=5, device=device)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vae = SpatialVAE()
vae.to(device)
# vae.load_state_dict(torch.load("/kaggle/working/vae_epoch30.pth", map_location=device))
optimizer = optim.Adam(vae.parameters(), lr=1e-4)

epochs = 50
# go from 1,...,epochs+1
for epoch in range(1, epochs+1):
    train(vae, optimizer, epoch, dataloader, device)

    # Save the model
    torch.save(vae.state_dict(), f"spatial_vae_epoch{epoch}.pth")

print("Training finished.")

In [None]:
import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F

def show_reconstructions(originals, reconstructed, n=5):
    originals = originals.detach().cpu().permute(0, 2, 3, 1)
    reconstructed = reconstructed.detach().cpu().permute(0, 2, 3, 1)
    fig, axs = plt.subplots(2, n, figsize=(2*n, 4))
    for i in range(n):
        axs[0, i].imshow(originals[i].clip(0, 1))
        axs[0, i].axis("off")
        axs[1, i].imshow(reconstructed[i].clip(0, 1))
        axs[1, i].axis("off")
    plt.suptitle("Top row: Original, Bottom row: Reconstructed")
    plt.show()


# Assuming you already have your VAE and DataLoader set up:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = SpatialVAE()
vae.to(device)
vae.load_state_dict(torch.load("/kaggle/input/stable_vae/pytorch/default/1/spatial_vae_epoch50.pth", map_location=device))
vae.eval()

# Get the first batch from your DataLoader
with torch.no_grad():
    imgs = next(iter(dataloader))  # This gets first batch
    imgs = imgs.to(device)
    reconstructed, z_mean, z_log_var = vae(imgs)

# Display side by side
show_reconstructions(imgs, reconstructed)


# stable diffusion pipeline

In [None]:
def linear_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02):
    return torch.linspace(beta_start, beta_end, timesteps)

def forward_diffusion_process(x0, t, beta_schedule):
    beta = beta_schedule.to(x0.device)
    alpha = 1 - beta
    alpha_cumprod = torch.cumprod(alpha, dim=0)

    a_cumprod_t = alpha_cumprod[t].view(-1, 1, 1, 1)
    a_cumprod_t = torch.clamp(a_cumprod_t, min=1e-8, max=1.0)
    a_sqrts = torch.sqrt(a_cumprod_t)

    one_minus_cumprod_t = 1 - a_cumprod_t
    one_minus_cumprod_t = torch.clamp(one_minus_cumprod_t, min=1e-8, max=1.0)
    one_minus_sqrts = torch.sqrt(one_minus_cumprod_t)

    epsilon = torch.randn_like(x0)

    noisy = a_sqrts * x0 + one_minus_sqrts * epsilon
    return noisy, epsilon

In [None]:
import torch 
from transformers import CLIPProcessor, CLIPModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = SpatialVAE()
vae.to(device)
vae.load_state_dict(torch.load("/kaggle/input/stable_vae/pytorch/default/1/spatial_vae_epoch50.pth", map_location=device))
vae.eval()

clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model.to(device)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = 1000
epochs = 15
lr = 2e-4
latent_channels = 4  # your VAE latent channel dimension

beta_schedule = linear_beta_schedule(timesteps).to(device)
alpha = 1 - beta_schedule
alpha_cumprod = torch.cumprod(alpha, dim=0)
alpha_cumprod_prev = torch.cat([torch.ones(1, device=device), alpha_cumprod[:-1]])

model = UNetWithTimeEmbedding(latent_channels=latent_channels).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()

for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    
    for images, captions in dataloader:  
        images = images.to(device)
        batch_size = images.shape[0]

        inputs = clip_processor(
            images=images,
            text=captions,
            return_tensors='pt',
            padding=True,
            truncation=True,
            do_rescale=False  # set depending on your image dtype
        )
        
        # move all inputs to the correct device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # extract only the text inputs needed for clip_model.text_model()
        text_inputs = {
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"]
        }
        
        with torch.no_grad():
            text_outputs = clip_model.text_model(**text_inputs)
            text_features = text_outputs.pooler_output

        # encode images with VAE (no grad)
        with torch.no_grad():
            mu, log_var = vae.encoder(images)
            var = torch.exp(log_var).clamp(min=1e-6)
            std = var.sqrt()
            latents = mu + std * torch.randn_like(std)

        # sample diffusion timestep for each sample in batch
        t = torch.randint(0, timesteps, (batch_size,), device=device)

        # forward diffusion 
        noisy_latents, noise = forward_diffusion_process(latents, t, beta_schedule)

        # predict noise with UNet conditioned on text embeddings
        noise_pred = model(noisy_latents, t, text=text_features)

        loss = criterion(noise_pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss / len(dataloader):.6f}")

torch.save(model.state_dict(), "unet_weights.pth")

KeyboardInterrupt: 