Skip to content

johndpope/LazyCIPSGenerator

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 

Repository files navigation

LazyCIPSGenerator

mash up of CIPS generator and lazy diffusion

The new LazyCIPSGenerator class offers several benefits:

  1. Efficiency: It only generates content for specified regions, reducing computation for partial updates.
  2. Context-awareness: It maintains global context while focusing on local generation.
  3. Flexibility: It can be used for various tasks like inpainting, local editing, or progressive image generation.

To illustrate these benefits, let's create a simple training example. We'll set up a scenario where we're training the model to perform inpainting tasks.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

class InpaintingDataset(Dataset):
    def __init__(self, image_paths, mask_size=(64, 64), transform=None):
        self.image_paths = image_paths
        self.mask_size = mask_size
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        
        # Create a random mask
        mask = torch.zeros(1, img.shape[1], img.shape[2])
        x = np.random.randint(0, img.shape[1] - self.mask_size[0])
        y = np.random.randint(0, img.shape[2] - self.mask_size[1])
        mask[:, x:x+self.mask_size[0], y:y+self.mask_size[1]] = 1
        
        return img, mask

def train_lazy_cips(model, dataloader, num_epochs=10, device='cuda'):
    optimizer = optim.Adam(model.parameters(), lr=0.0002, betas=(0.5, 0.999))
    criterion = nn.MSELoss()
    
    for epoch in range(num_epochs):
        for i, (images, masks) in enumerate(dataloader):
            images = images.to(device)
            masks = masks.to(device)
            
            # Create masked images
            masked_images = images * (1 - masks)
            
            # Generate random latent vectors
            latents = torch.randn(images.shape[0], model.input_dim).to(device)
            
            # Generate inpainted images
            generated_images = model(latents, masked_images, masks, images.shape[2])
            
            # Compute loss
            loss = criterion(generated_images * masks, images * masks)
            
            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if i % 100 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}")
        
        # Visualize results
        if (epoch + 1) % 5 == 0:
            visualize_results(model, images, masks, masked_images, generated_images)

def visualize_results(model, images, masks, masked_images, generated_images):
    fig, axs = plt.subplots(4, 4, figsize=(20, 20))
    for i in range(4):
        axs[i, 0].imshow(images[i].cpu().permute(1, 2, 0).detach().numpy())
        axs[i, 0].set_title("Original")
        axs[i, 1].imshow(masked_images[i].cpu().permute(1, 2, 0).detach().numpy())
        axs[i, 1].set_title("Masked")
        axs[i, 2].imshow(generated_images[i].cpu().permute(1, 2, 0).detach().numpy())
        axs[i, 2].set_title("Generated")
        axs[i, 3].imshow(masks[i].cpu().squeeze().detach().numpy(), cmap='gray')
        axs[i, 3].set_title("Mask")
    plt.tight_layout()
    plt.show()

# Usage example
if __name__ == "__main__":
    # Assume we have a list of image paths
    image_paths = ["path/to/image1.jpg", "path/to/image2.jpg", ...]
    
    # Create dataset and dataloader
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = InpaintingDataset(image_paths, transform=transform)
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4)
    
    # Initialize model
    model = LazyCIPSGenerator(input_dim=256, max_resolution=256).to('cuda')
    
    # Train model
    train_lazy_cips(model, dataloader)

This training example demonstrates the benefits of the LazyCIPSGenerator:

  1. Efficiency: The model only generates content for the masked regions, which is typically a small portion of the image. This is more computationally efficient than generating the entire image, especially for tasks like inpainting or local editing.

  2. Context-awareness: The context encoder processes the entire image and mask, allowing the generator to maintain awareness of the global context even when only generating a small region. This helps in creating coherent and contextually appropriate inpaintings.

  3. Flexibility: The same model can be used for various tasks. In this example, we're using it for inpainting with random masks, but it could easily be adapted for user-guided editing or progressive image generation by changing how we create and apply the masks.

Key components of the training example:

  1. InpaintingDataset: A custom dataset that loads images and creates random masks for inpainting tasks.

  2. train_lazy_cips function: The main training loop that:

    • Generates masked images
    • Creates random latent vectors
    • Uses the LazyCIPSGenerator to inpaint the masked regions
    • Computes the loss only on the generated (masked) regions
    • Performs backpropagation and optimization
  3. visualize_results function: Periodically visualizes the original images, masked images, generated (inpainted) images, and masks to track the model's progress.

To use this example:

  1. Replace the image_paths list with actual paths to your training images.
  2. Adjust hyperparameters like num_epochs, learning rate, etc., as needed.
  3. Run the script to train the LazyCIPSGenerator on the inpainting task.

This training setup showcases how the LazyCIPSGenerator can efficiently learn to perform context-aware inpainting, demonstrating its benefits for partial image generation tasks. The model learns to generate only the masked regions while considering the surrounding context, making it more efficient and flexible than traditional approaches that generate entire images.

celebhq-masks training

https://github.com/switchablenorms/CelebAMask-HQ?tab=readme-ov-file

About

mash up of CIPS generator and lazy diffusion

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Languages