In [1]:
from dataset.VaseDataset import VaseDataset
from torchvision import transforms
from torch.utils.data import DataLoader

In [2]:
# Not sure if we need transformations?
transform = transforms.Compose([
    transforms.Resize((256, 256)), 
    transforms.ToTensor(), # Convert images to PyTorch tensors
])

# Make datasets
train_dataset = VaseDataset(root_dir="dataset/train", captions_file="captions.csv", transform=transform)

val_dataset = VaseDataset(root_dir="dataset/val", captions_file="captions.csv", transform=transform)


# Make dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False)

In [3]:
for val in val_dataset:
    print(val)
    break

{'masked_images': tensor([[[0.9020, 0.9020, 0.9020,  ..., 0.8667, 0.8667, 0.8667],
         [0.9020, 0.9020, 0.9020,  ..., 0.8667, 0.8667, 0.8667],
         [0.9020, 0.9020, 0.9020,  ..., 0.8667, 0.8667, 0.8667],
         ...,
         [0.8667, 0.8667, 0.8667,  ..., 0.8157, 0.8157, 0.8157],
         [0.8667, 0.8667, 0.8667,  ..., 0.8157, 0.8157, 0.8157],
         [0.8667, 0.8667, 0.8667,  ..., 0.8118, 0.8118, 0.8157]],

        [[0.9020, 0.9020, 0.9020,  ..., 0.8667, 0.8667, 0.8667],
         [0.9020, 0.9020, 0.9020,  ..., 0.8667, 0.8667, 0.8667],
         [0.9020, 0.9020, 0.9020,  ..., 0.8667, 0.8667, 0.8667],
         ...,
         [0.8667, 0.8667, 0.8667,  ..., 0.8157, 0.8157, 0.8157],
         [0.8667, 0.8667, 0.8667,  ..., 0.8157, 0.8157, 0.8157],
         [0.8667, 0.8667, 0.8667,  ..., 0.8118, 0.8118, 0.8157]],

        [[0.9020, 0.9020, 0.9020,  ..., 0.8667, 0.8667, 0.8667],
         [0.9020, 0.9020, 0.9020,  ..., 0.8667, 0.8667, 0.8667],
         [0.9020, 0.9020, 0.9020,  ..., 

  mask = torch.load(mask_path)  # Load the mask tensor (H, W)


In [None]:
from diffusers import StableDiffusionInpaintPipeline
from accelerate import Accelerator
import torch
from torch.utils.data import DataLoader
from transformers import AdamW

# Load pipeline
pipe = StableDiffusionInpaintPipeline.from_pretrained("stabilityai/stable-diffusion-2-inpainting")
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

# Freeze VAE parameters
for param in pipe.vae.parameters():
    param.requires_grad = False

# Fine-tune only the U-Net and text encoder
for param in pipe.unet.parameters():
    param.requires_grad = True
for param in pipe.text_encoder.parameters():
    param.requires_grad = True

pipe.to(device)

# Optimizer
optimizer = AdamW(
    [{"params": pipe.unet.parameters()}, {"params": pipe.text_encoder.parameters()}],
    lr=5e-5
)

# DataLoader placeholder (replace `train_dataloader` with your actual DataLoader)
# train_dataloader = DataLoader(...)

# Use Accelerator for distributed training
accelerator = Accelerator()
pipe, optimizer, train_dataloader = accelerator.prepare(pipe, optimizer, train_dataloader)

NUM_EPOCHS = 1

# Training loop
for epoch in range(NUM_EPOCHS):
    pipe.unet.train()
    pipe.text_encoder.train()

    for batch in train_dataloader:
        # Get inputs
        masked_images = batch["masked_images"].to(device)
        full_images = batch["full_images"].to(device)
        masks = batch["masks"].to(device)  # Binary masks
        prompts = batch["text"]

        # Tokenize text prompts
        tokenized_prompts = pipe.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(device)
        text_embeddings = pipe.text_encoder(**tokenized_prompts).last_hidden_state

        # Encode masked images into latent space
        latents = pipe.vae.encode(masked_images).latent_dist.sample()
        latents = latents * pipe.vae.config.scaling_factor

        # Assert latent dimensions
        assert latents.shape[1] == 4, f"Latent channels should be 4, got {latents.shape[1]}"
        assert latents.shape[2] % 8 == 0 and latents.shape[3] % 8 == 0, \
            "Latent dimensions should be divisible by 8 for the UNet"

        # Add noise to the latents
        batch_size = latents.size(0)
        timesteps = torch.randint(0, pipe.scheduler.num_train_timesteps, (batch_size,), device=device).long()
        noise = torch.randn_like(latents)
        noisy_latents = pipe.scheduler.add_noise(latents, noise, timesteps)

        # Resize mask to match latent spatial dimensions
        latent_masks = torch.nn.functional.interpolate(masks, size=noisy_latents.shape[-2:])  # Resize mask
        latent_masks = latent_masks[:, None, :, :]  # Ensure shape is (B, 1, H, W)
        
        # Fix latent_masks shape by squeezing the extra singleton dimension
        if latent_masks.ndim == 5:  # Check if there's an extra dimension
            latent_masks = latent_masks.squeeze(2)  # Remove the extra dimension
        # Assert mask shape matches expected dimensions
        print(f"Noisy latents shape: {noisy_latents.shape}")
        print(f"Latent masks shape: {latent_masks.shape}")
        
        assert latent_masks.ndim == 4, f"Mask should have 4 dimensions, got {latent_masks.ndim}"
        assert latent_masks.shape[1] == 1, f"Mask must have 1 channel, got {latent_masks.shape[1]}"
        assert latent_masks.shape[2:] == noisy_latents.shape[2:], \
            f"Mask spatial dimensions {latent_masks.shape[2:]} must match latents {noisy_latents.shape[2:]}"

        # Generate spatial encodings
        batch_size, _, height, width = noisy_latents.shape
        x = torch.linspace(-1, 1, steps=width, device=device).view(1, 1, 1, -1).expand(batch_size, 1, height, width)
        y = torch.linspace(-1, 1, steps=height, device=device).view(1, 1, -1, 1).expand(batch_size, 1, height, width)
        spatial_encodings = torch.cat([x, y], dim=1)  # Shape: (B, 2, H, W)

        # Concatenate noisy latents, mask, and spatial encodings
        unet_input = torch.cat([noisy_latents, latent_masks, spatial_encodings], dim=1)

        # Add extra dummy channels (if required)
        extra_channels = torch.zeros(unet_input.shape[0], 2, unet_input.shape[2], unet_input.shape[3], device=device)
        unet_input = torch.cat([unet_input, extra_channels], dim=1)

        # Assert the input shape
        assert unet_input.shape[1] == 9, f"UNet input must have 9 channels, got {unet_input.shape[1]}"

        # Forward pass through UNet
        unet_output = pipe.unet(
            sample=unet_input,
            timestep=timesteps,
            encoder_hidden_states=text_embeddings
        ).sample

        # Assert UNet output shape matches latent input
        assert unet_output.shape == latents.shape, \
            f"UNet output shape mismatch: {unet_output.shape} != {latents.shape}"

        # Decode the output latents back to image space
        reconstructed_images = pipe.vae.decode(unet_output / pipe.vae.config.scaling_factor).sample

        # Assert decoded images match the size of full images
        assert reconstructed_images.shape == full_images.shape, \
            f"Decoded images shape mismatch: {reconstructed_images.shape} != {full_images.shape}"

        # Compute pixel-wise loss
        loss = torch.nn.functional.mse_loss(reconstructed_images, full_images)

        # Backpropagation
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()

    print(f"Epoch {epoch + 1}/{NUM_EPOCHS} completed. Loss: {loss.item():.4f}")


  from .autonotebook import tqdm as notebook_tqdm
Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00, 17.47it/s]


ValueError: fp16 mixed precision requires a GPU (not 'mps').