In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from torchvision.utils import save_image
import numpy as np
from tqdm import tqdm

In [2]:
### 1. Noise schedule and forward diffusion process

def cosine_beta_schedule(timesteps):
    """Generates a noise schedule."""
    return torch.linspace(1e-4, 0.02, timesteps)

def forward_diffusion(x_0, t, noise_schedule):
    """
    Simulates the forward diffusion process by adding noise to the clean image.
    x_0: the clean image
    t: current timestep
    noise_schedule: pre-calculated schedule for the noise
    """
    noise = torch.randn_like(x_0)
    return torch.sqrt(noise_schedule[t]) * x_0 + torch.sqrt(1 - noise_schedule[t]) * noise


In [3]:
### 2. U-Net model modified to accept embeddings

class EmbeddingConditionalUNet(nn.Module):
    def __init__(self, c_in=3, c_out=3, embedding_dim=512):
        super(EmbeddingConditionalUNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(c_in, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.bottleneck = nn.Sequential(
            nn.Conv2d(128 + embedding_dim, 128, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, c_out, kernel_size=3, padding=1),
            nn.Tanh()
        )

    def forward(self, x, embedding):
        # Encode the image
        x = self.encoder(x)

        # Expand and concatenate the embedding with image features
        embedding = embedding.unsqueeze(-1).unsqueeze(-1)
        embedding = embedding.expand(-1, -1, x.shape[-2], x.shape[-1])
        x = torch.cat([x, embedding], dim=1)
        
        # Process through the bottleneck and decoder
        x = self.bottleneck(x)
        x = self.decoder(x)
        return x

In [4]:
### 3. Training loop with pre-generated embeddings

def train_with_custom_embeddings(model, optimizer, timesteps, num_epochs, dataloader, device):
    noise_schedule = cosine_beta_schedule(timesteps).to(device)
    mse_loss = nn.MSELoss()

    for epoch in range(num_epochs):
        for images, embeddings in tqdm(dataloader):
            images = images.to(device)
            embeddings = embeddings.to(device)

            # Sample a random timestep
            t = torch.randint(0, timesteps, (1,)).long().to(device)

            # Apply forward diffusion to get a noisy image
            noisy_images = forward_diffusion(images, t, noise_schedule)

            # Pass the noisy image and the pre-generated embeddings to the model
            noise_pred = model(noisy_images, embeddings)

            # Compute the loss and update the model
            loss = mse_loss(noise_pred, noisy_images)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"Epoch {epoch} - Loss: {loss.item()}")

In [5]:
### 4. Sampling process to generate images from embeddings

def sample_with_embeddings(model, embeddings, timesteps, device):
    noise_schedule = cosine_beta_schedule(timesteps)
    with torch.no_grad():
        # Start from pure noise
        x_t = torch.randn(1, 3, 64, 64).to(device)

        for t in reversed(range(timesteps)):
            # Predict the noise, conditioned on the embeddings
            noise_pred = model(x_t, embeddings)

            # Reverse the diffusion process step-by-step
            x_t = (x_t - np.sqrt(1 - noise_schedule[t]) * noise_pred) / np.sqrt(noise_schedule[t])

        return x_t

In [6]:
import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# 1. Define a custom dataset that loads both images and embeddings
class ImageEmbeddingDataset(Dataset):
    def __init__(self, image_folder, embedding_folder, transform=None):
        self.image_folder = image_folder
        self.embedding_folder = embedding_folder
        self.transform = transform
        
        # List all image files (.jpg) and only keep ones with corresponding .npy files
        self.image_files = []
        self.embedding_files = []
        
        # Loop through all image files and check if the corresponding embedding exists
        for image_file in tqdm(os.listdir(image_folder)):
            if image_file.endswith('.jpg'):
                embedding_file = image_file.replace('.jpg', '.npy')
                if os.path.exists(os.path.join(embedding_folder, embedding_file)):
                    self.image_files.append(image_file)
                    self.embedding_files.append(embedding_file)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Get the image and embedding filenames
        image_file = self.image_files[idx]
        embedding_file = self.embedding_files[idx]
        
        # Load the image
        image_path = os.path.join(self.image_folder, image_file)
        image = Image.open(image_path).convert('RGB')

        # Apply transformations to the image
        if self.transform:
            image = self.transform(image)

        # Load the embedding (as a numpy array)
        embedding_path = os.path.join(self.embedding_folder, embedding_file)
        embedding = np.load(embedding_path)

        # Convert embedding to a tensor
        embedding = torch.tensor(embedding, dtype=torch.float32)

        return image, embedding

# 2. Define the image transformations (resize, normalize, etc.)
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # Resize images to 64x64 if needed
    transforms.ToTensor(),        # Convert PIL image to Tensor
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])

# 3. Create the dataset
image_folder = 'data/img_align_celeba'
embedding_folder = 'data/embeddings'

dataset = ImageEmbeddingDataset(image_folder, embedding_folder, transform=transform)

# 4. Create the DataLoader
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Now you can use the dataloader in your training loop


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 202599/202599 [00:01<00:00, 167661.40it/s]


In [8]:
### 5. Example usage

# Hyperparameters
timesteps = 1000
num_epochs = 5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize the model, optimizer, and dataset
embedding_dim = 512  # Adjust this to the size of your embeddings
model = EmbeddingConditionalUNet(embedding_dim=embedding_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Assume you have a dataloader that yields (images, embeddings)
# For example: dataloader = [(image_batch, embedding_batch), ...]

# Train the model
train_with_custom_embeddings(model, optimizer, timesteps, num_epochs, dataloader, device)

# Sample from the model using a specific embedding
sampled_image = sample_with_embeddings(model, [dataloader[0][1]], timesteps, device)

# Save or display the generated image
save_image(sampled_image, 'generated_image.png')

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6330/6330 [05:04<00:00, 20.81it/s]


Epoch 0 - Loss: 0.154750794172287


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6330/6330 [05:05<00:00, 20.74it/s]


Epoch 1 - Loss: 0.15147365629673004


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6330/6330 [05:11<00:00, 20.29it/s]


Epoch 2 - Loss: 0.15259747207164764


  2%|█▉                                                                                                                    | 105/6330 [00:05<04:59, 20.80it/s]


KeyboardInterrupt: 