In [None]:
import os
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from diffusers import LDMPipeline, UNet2DModel, VQModel, DDIMScheduler
import torch.nn.utils.prune as prune
from tqdm import tqdm
import torch.nn.functional as F
from IPython.display import display

# ============================
# 1. Configuration Parameters
# ============================

# Paths
DATASET_PATH = "dataset/celeba_hq_256"
OUTPUT_DIR = "generated_images"
os.makedirs(OUTPUT_DIR, exist_ok=True)


# Training Hyperparameters
BATCH_SIZE = 4
LEARNING_RATE = 1e-5
NUM_EPOCHS = 8
NUM_WORKERS = 4
IMAGE_SIZE = 256  # Assuming CelebA-HQ images are 256x256

# Device
DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

# ============================
# 2. Define the Custom Dataset
# ============================

class CelebaHQDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (str): Path to the dataset directory.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.image_paths = [
            os.path.join(root_dir, img) for img in os.listdir(root_dir)
            if img.lower().endswith(('.png', '.jpg', '.jpeg'))
        ]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a black image in case of error
            image = Image.new("RGB", (IMAGE_SIZE, IMAGE_SIZE), (0, 0, 0))
        if self.transform:
            image = self.transform(image)
        return image

# ============================
# 3. Prepare DataLoader
# ============================

# Define transformations (resize and normalize as needed)
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    # transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1]
])

# Initialize the dataset and dataloader
dataset = CelebaHQDataset(root_dir=DATASET_PATH, transform=transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

# ============================
# 4. Load and Prune the Models
# ============================

# Load the pretrained models and scheduler
unet = UNet2DModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="unet")
vqvae = VQModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="vqvae")
scheduler = DDIMScheduler.from_config("CompVis/ldm-celebahq-256", subfolder="scheduler")


# Initialize the pipeline with the pruned U-Net and VQ-VAE
print("Initializing the LDMPipeline with pruned models...")
pipeline = LDMPipeline(
    unet=unet,
    vqvae=vqvae,
    scheduler=scheduler,
).to(DEVICE)

# ============================
# 6. Define the Training Loop
# ============================

def kl_divergence_loss(noise_pred, noise):
    """
    Computes the KL Divergence between the predicted noise distribution and the actual noise.

    Parameters:
    - noise_pred (Tensor): Predicted noise by the model.
    - noise (Tensor): True Gaussian noise.

    Returns:
    - Tensor: KL Divergence loss value.
    """
    # Apply softmax to get probabilities (for demonstration, adjust if necessary)
    noise_pred_log_probs = F.log_softmax(noise_pred, dim=-1)
    noise_probs = F.softmax(noise, dim=-1)
    
    # Compute KL Divergence
    kl_loss = F.kl_div(noise_pred_log_probs, noise_probs, reduction='batchmean')
    return kl_loss

def generate_sample_images(pipeline, num_images=1, prompt="A high quality portrait"):
    pipeline.to(DEVICE)
    with torch.no_grad():
        generated_images = pipeline(batch_size=num_images, num_inference_steps=100).images
    for img in generated_images:
        display(img)

# Move models to device
unet.to(DEVICE)
vqvae.to(DEVICE)

# Set models to training mode
unet.train()
vqvae.train()

# Define optimizer (only parameters that require gradients)
optimizer = optim.Adam(
    list(unet.parameters()) + list(vqvae.parameters()),
    lr=LEARNING_RATE
)

# Define a loss function, e.g., Mean Squared Error
criterion = nn.MSELoss()

# Training Loop
for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    epoch_loss = 0.0
    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    
    for batch_idx, batch in enumerate(progress_bar):
        # Move batch to device
        images = batch.to(DEVICE)
        
        # Forward pass through VQ-VAE to get latent representations
        with torch.no_grad():
            vqvae_output = vqvae.encode(images)
            # Corrected line: Use 'latent_sample' instead of 'latent_dist.sample()'
            latents = vqvae_output.latents
            latents = latents * vqvae.config.scaling_factor
        
        # Add noise according to the scheduler
        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, scheduler.num_train_timesteps, (latents.shape[0],), device=DEVICE).long()
        noisy_latents = scheduler.add_noise(latents, noise, timesteps)
        
        # Forward pass through U-Net
        noise_pred = unet(noisy_latents, timesteps).sample
        
        # Compute loss
        loss = criterion(noise_pred, noise)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Mask gradients of pruned weights
        with torch.no_grad():
            for module in unet.modules():
                if isinstance(module, nn.Conv2d) and module.weight.grad is not None:
                    zero_mask = module.weight == 0
                    if zero_mask.any():
                        module.weight.grad[zero_mask] = 0
            for module in vqvae.modules():
                if isinstance(module, nn.Conv2d) and module.weight.grad is not None:
                    zero_mask = module.weight == 0
                    if zero_mask.any():
                        module.weight.grad[zero_mask] = 0
        
        # Optimizer step
        optimizer.step()
        
        # Re-mask the weights to ensure pruned weights remain zero
        with torch.no_grad():
            for module in unet.modules():
                if isinstance(module, nn.Conv2d):
                    module.weight[module.weight == 0] = 0
            for module in vqvae.modules():
                if isinstance(module, nn.Conv2d):
                    module.weight[module.weight == 0] = 0
        
        # Accumulate loss
        epoch_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())
    
    avg_epoch_loss = epoch_loss / len(dataloader)
    print(f"Average Loss: {avg_epoch_loss:.4f}")
    
    # Save the model at each epoch
    print(f"Saved fine-tuned models for epoch {epoch+1}")
    
    # Uncomment the following line to generate sample images after training
    generate_sample_images(pipeline, num_images=1, prompt="A high quality portrait")


print("\nTraining complete!")



2024-12-05 21:45:49.670342: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-05 21:45:49.702782: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-05 21:45:49.702819: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-05 21:45:49.703675: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-05 21:45:49.709210: I tensorflow/core/platform/cpu_feature_guar

Initializing the LDMPipeline with pruned models...

Epoch 1/8


  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
                                                                             

KeyboardInterrupt: 

In [None]:
import os
from PIL import Image

result_folder = "result_original_finetuned"
def generate_sample_images(pipeline, num_images=10, num_inference_steps_list=[10, 20, 100], num_samples=50000):
    # Ensure the pipeline is on the correct device

    # Iterate over each number of inference steps (10, 20, 100)
    for num_inference_steps in num_inference_steps_list:
        # Create a subfolder for each number of inference steps
        steps_folder = os.path.join(result_folder, f"{num_inference_steps}_steps")
        if not os.path.exists(steps_folder):
            os.makedirs(steps_folder)

        # Generate images in batches, aiming for the total number of images
        num_batches = num_samples // num_images
        for batch_idx in range(num_batches):
            with torch.no_grad():
                # Generate images using the pipeline
                generated_images = pipeline(batch_size=num_images, num_inference_steps=num_inference_steps).images

            # Save each generated image to the corresponding subfolder
            for i, img in enumerate(generated_images):
                # Create a filename for the image
                img_path = os.path.join(steps_folder, f"generated_image_{batch_idx * num_images + i + 1}.png")
                
                # Save the image
                img.save(img_path)
                print(f"Image {batch_idx * num_images + i + 1} saved at: {img_path}")

# Example usage:
generate_sample_images(pipeline, num_images=10, num_inference_steps_list=[10, 20, 100, 50], num_samples=50000)

