## Data set creation

In [2]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class ImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        """
        Args:
            image_dir (str): Path to the directory containing images.
            transform (callable, optional): Optional transform to be applied on an image.
        """
        self.image_dir = image_dir
        self.image_files = os.listdir(image_dir)
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_name)

        if self.transform:
            image = self.transform(image)

        return image  # Returning image as tensor



image_dir = "../images"
batch_size = 2

transform = transforms.Compose([
    transforms.Resize((512, 512)),  # Resize images to match model input size
    transforms.ToTensor(),          # Convert image to tensor (scaled between 0 and 1)
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1]
    ])

# Create dataset and dataloader
dataset = ImageDataset(image_dir=image_dir, transform=transform)
dataloader = DataLoader(dataset, shuffle=True)



## Train step

In [10]:
def rescale(x, old_range, new_range, clamp=False):
    old_min, old_max = old_range
    new_min, new_max = new_range
    x -= old_min
    x *= (new_max - new_min) / (old_max - old_min)
    x += new_min
    if clamp:
        x = x.clamp(new_min, new_max)
    return x

def get_time_embedding(timestep):
    # Shape: (160,)
    freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160) 
    # Shape: (1, 160)
    x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None]
    # Shape: (1, 160 * 2)
    return torch.cat([torch.cos(x), torch.sin(x)], dim=-1)

In [12]:
import torch
import numpy as np
from tqdm import tqdm
from ddpm import DDPMSampler

WIDTH = 512
HEIGHT = 512
LATENTS_WIDTH = WIDTH // 8
LATENTS_HEIGHT = HEIGHT // 8

def train_step(
    inputs,
    targets,
    models,
    tokenizer,
    optimizer,
    device,
    clip,
    diffusion,
    sampler,
    strength=0.8,
    do_cfg=True,
    cfg_scale=7.5,
    seed=None,
    n_inference_steps=50,
):
    """
    A single training step where the encoder and decoder are trained while the diffusion model is frozen.
    """
    # Freeze diffusion model (no gradient updates)
    for param in diffusion.parameters():
        param.requires_grad = False

    # Initialize random number generator according to the seed specified
    generator = torch.Generator(device=device)
    if seed is not None:
        generator.manual_seed(seed)
    else:
        generator.seed()

    # Prepare CLIP and inputs
    clip.to(device)
    if do_cfg:
        # Conditional and Unconditional tokens for prompt
        cond_tokens = tokenizer.batch_encode_plus([inputs["prompt"]], padding="max_length", max_length=77).input_ids
        cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
        cond_context = clip(cond_tokens)

        uncond_tokens = tokenizer.batch_encode_plus([inputs["uncond_prompt"]], padding="max_length", max_length=77).input_ids
        uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
        uncond_context = clip(uncond_tokens)

        context = torch.cat([cond_context, uncond_context])  # (2 * Batch_Size, Seq_Len, Dim)
    else:
        tokens = tokenizer.batch_encode_plus([inputs["prompt"]], padding="max_length", max_length=77).input_ids
        tokens = torch.tensor(tokens, dtype=torch.long, device=device)
        context = clip(tokens)  # (Batch_Size, Seq_Len, Dim)

    # Prepare latent space from the input image or random latents
    latents_shape = (1, 4, LATENTS_HEIGHT, LATENTS_WIDTH)

    if inputs.get("input_image"):
        input_image_tensor = inputs["input_image"].resize((WIDTH, HEIGHT))
        input_image_tensor = np.array(input_image_tensor)
        input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32, device=device)
        input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
        input_image_tensor = input_image_tensor.unsqueeze(0).permute(0, 3, 1, 2)

        encoder_noise = torch.randn(latents_shape, generator=generator, device=device)
        latents = models["encoder"](input_image_tensor, encoder_noise)

        # Add noise to latents
        sampler.set_strength(strength=strength)
        latents = sampler.add_noise(latents, sampler.timesteps[0])

    else:
        latents = torch.randn(latents_shape, generator=generator, device=device)

    # Begin training loop for one step
    optimizer.zero_grad()

    timesteps = tqdm(sampler.timesteps)
    for i, timestep in enumerate(timesteps):
        time_embedding = get_time_embedding(timestep).to(device)

        # Prepare model input: latents and context
        model_input = latents
        if do_cfg:
            model_input = model_input.repeat(2, 1, 1, 1)

        # Model output is predicted noise
        model_output = diffusion(model_input, context, time_embedding)

        if do_cfg:
            output_cond, output_uncond = model_output.chunk(2)
            model_output = cfg_scale * (output_cond - output_uncond) + output_uncond

        latents = sampler.step(timestep, latents, model_output)

    # Use decoder to decode latents into final images
    decoder = models["decoder"]
    decoder.to(device)
    images = decoder(latents)

    # Compute the loss for this step (mean squared error example)
    loss_fn = torch.nn.MSELoss()
    loss = loss_fn(images, targets)

    # Backpropagation and optimization
    loss.backward()
    optimizer.step()

    return loss.item()


In [None]:
def train(
    models,
    dataloader,
    tokenizer,
    optimizer,
    device,
    clip,
    diffusion,
    sampler,
    num_epochs=10,
):
    for epoch in range(num_epochs):
        epoch_loss = 0
        for batch_idx, inputs in enumerate(dataloader):
            inputs = inputs.to(device)  # Move inputs to device
            targets = inputs  # In autoencoders, targets are typically the same as inputs

            # Create the input dictionary for training
            input_dict = {
                "prompt": "A description for image generation",  # Replace with actual prompt
                "uncond_prompt": "A negative description",      # Replace with actual negative prompt
                "input_image": inputs  # Actual image batch for training
            }

            # Perform one training step
            loss = train_step(
                input_dict,
                targets,
                models,
                tokenizer,
                optimizer,
                device,
                clip,
                diffusion,
                sampler,
                strength=0.8,  # Set desired strength
                do_cfg=True,   # Conditional or not
                cfg_scale=7.5,
                seed=None,
                n_inference_steps=50
            )
            epoch_loss += loss

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss / len(dataloader)}")


In [14]:
import model_loader
import pipeline
from PIL import Image
from pathlib import Path
from transformers import CLIPTokenizer
import torch

# Example setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load models and tokenizer
tokenizer = CLIPTokenizer("../data/vocab.json", merges_file="../data/merges.txt")
model_file = "../data/v1-5-pruned-emaonly.ckpt"
models = model_loader.preload_models_from_standard_weights(model_file, device)


optimizer = torch.optim.Adam(
    list(models["encoder"].parameters()) + list(models["decoder"].parameters()),
    lr=1e-4
)

# Create the DataLoader
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

for batch_idx, inputs in enumerate(dataloader):
    print(batch_idx)
    print(inputs)
# Train the model



0
tensor([[[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         [[0.9451, 0.9451, 0.9451,  ..., 0.9922, 0.9922, 0.9922],
          [0.9451, 0.9451, 0.9451,  ..., 0.9922, 0.9922, 0.9922],
          [0.9451, 0.9451, 0.9529,  ..., 0.9922, 0.9922, 0.9922],
          ...,
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
          [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],

         [[1.0000, 1.0000, 1.0000,  ..., 0.9686, 0.9686, 0.9686],
          [1.0000, 1.0000, 1.0000,  ..., 0.9686, 0.9686, 0.9686],
          [1.0000, 1.0000, 1.0000,  ...,

In [20]:
train(models, dataloader, tokenizer, optimizer, device, models['clip'], models['diffusion'], "ddpm")

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

## Testing zone

In [4]:
for batch_idx, inputs in enumerate(dataloader):
    print(batch_idx)
    print(inputs.shape) # Move inputs to device
    inputs = inputs.squeeze(0)
    print(inputs.shape) 


0
torch.Size([1, 3, 512, 512])
torch.Size([3, 512, 512])
1
torch.Size([1, 3, 512, 512])
torch.Size([3, 512, 512])
2
torch.Size([1, 3, 512, 512])
torch.Size([3, 512, 512])
3
torch.Size([1, 3, 512, 512])
torch.Size([3, 512, 512])
4
torch.Size([1, 3, 512, 512])
torch.Size([3, 512, 512])
5
torch.Size([1, 3, 512, 512])
torch.Size([3, 512, 512])
6
torch.Size([1, 3, 512, 512])
torch.Size([3, 512, 512])
7
torch.Size([1, 3, 512, 512])
torch.Size([3, 512, 512])
8
torch.Size([1, 3, 512, 512])
torch.Size([3, 512, 512])
9
torch.Size([1, 3, 512, 512])
torch.Size([3, 512, 512])


In [None]:
WIDTH = 512
HEIGHT = 512
LATENTS_WIDTH = WIDTH // 8
LATENTS_HEIGHT = HEIGHT // 8
for batch_idx, inputs in enumerate(dataloader):
    inputs = inputs.squeeze(0)
    inputs = inputs.to(device)  # Move inputs to device
    targets = inputs  # In autoencoders, targets are typically the same as inputs

            # Create the input dictionary for training
    input_dict = {
        "prompt": "A description for image generation",  # Replace with actual prompt
        "uncond_prompt": "A negative description",      # Replace with actual negative prompt
        "input_image": inputs  # Actual image batch for training
    }

    # Freeze diffusion model (no gradient updates)
    diffusion = models["diffusion"]
    for param in diffusion.parameters():
        param.requires_grad = False
    
    # Initialize random number generator according to the seed specified
    seed = 42
    generator = torch.Generator(device=device)
    if seed is not None:
        generator.manual_seed(seed)
    else:
        generator.seed()

    clip = models["clip"]
    do_cfg = True

    clip.to(device)
    if do_cfg:
        # Conditional and Unconditional tokens for prompt
        cond_tokens = tokenizer.batch_encode_plus([input_dict["prompt"]], padding="max_length", max_length=77).input_ids
        cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device)
        cond_context = clip(cond_tokens)

        uncond_tokens = tokenizer.batch_encode_plus([input_dict["uncond_prompt"]], padding="max_length", max_length=77).input_ids
        uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device)
        uncond_context = clip(uncond_tokens)

        context = torch.cat([cond_context, uncond_context])  # (2 * Batch_Size, Seq_Len, Dim)
    
    else:
        tokens = tokenizer.batch_encode_plus([input_dict["prompt"]], padding="max_length", max_length=77).input_ids
        tokens = torch.tensor(tokens, dtype=torch.long, device=device)
        context = clip(tokens)  # (Batch_Size, Seq_Len, Dim)

    # Prepare latent space from the input image or random latents
    latents_shape = (1, 4, LATENTS_HEIGHT, LATENTS_WIDTH)
    print(input_dict.get("input_image"))
    
    input_image = input_dict.get("input_image")

    # Check if input_image exists and has elements
    if input_image is not None and input_image.numel() > 0:
        #input_image_tensor = input_image.resize((WIDTH, HEIGHT))
        input_image_tensor = input_image
        #print(input_image_tensor.shape)
        input_image_tensor = np.array(input_image_tensor)
        #print(input_image_tensor.shape)
        input_image_tensor = torch.tensor(input_image_tensor, dtype=torch.float32, device=device)
        #print(input_image_tensor.shape)
        input_image_tensor = rescale(input_image_tensor, (0, 255), (-1, 1))
        #print(input_image_tensor.shape)
        input_image_tensor = input_image_tensor.unsqueeze(0)
        print(input_image_tensor.shape)
        encoder_noise = torch.randn(latents_shape, generator=generator, device=device)
        latents = models["encoder"](input_image_tensor, encoder_noise)

        sampler = "ddpm"
        if sampler == "ddpm":
            sampler = DDPMSampler(generator)
            sampler.set_inference_timesteps(50)
        else:
            raise ValueError("Unknown sampler value %s. ")
        strength=0.8
        # Add noise to latents
        sampler.set_strength(strength=strength)
        latents = sampler.add_noise(latents, sampler.timesteps[0])

    else:
        latents = torch.randn(latents_shape, generator=generator, device=device)
    ### Working code ###
    optimizer.zero_grad()

    timesteps = tqdm(sampler.timesteps)
    for i, timestep in enumerate(timesteps):
        time_embedding = get_time_embedding(timestep).to(device)

        # Prepare model input: latents and context
        model_input = latents
        if do_cfg:
            model_input = model_input.repeat(2, 1, 1, 1)

        # Model output is predicted noise
        model_output = diffusion(model_input, context, time_embedding)
        cfg_scale = 7.5
        if do_cfg:
            output_cond, output_uncond = model_output.chunk(2)
            model_output = cfg_scale * (output_cond - output_uncond) + output_uncond

        latents = sampler.step(timestep, latents, model_output)

    # Use decoder to decode latents into final images
    decoder = models["decoder"]
    decoder.to(device)
    images = decoder(latents)

    # Compute the loss for this step (mean squared error example)
    loss_fn = torch.nn.MSELoss()
    loss = loss_fn(images, targets)
    print("################")
    print(loss)
    print("################")
    loss.backward()
    optimizer.step()

NameError: name 'dataloader' is not defined