In [None]:
from datasets import load_dataset

# Load the dataset
dataset = load_dataset("json", data_files="data/dataset/captions.jsonl")
dataset = dataset.map(lambda x: {"file_path": f"data/dataset/{x['file_path']}"})


# Preview the dataset
print(dataset["train"][0])  # Example: {'file_path': 'images/image1.png', 'caption': 'A description'}


In [16]:
import torch

torch.mps.empty_cache()


In [None]:
from diffusers import UNet2DConditionModel, StableDiffusionPipeline, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from torch.utils.data import DataLoader
import torch
import os
from PIL import Image
from torchvision import transforms
from torch.cuda.amp import GradScaler, autocast

# Device setup
device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"

# Load pre-trained Stable Diffusion components
model_name = "CompVis/stable-diffusion-v1-1"

# UNet and scheduler
unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet").to(device)
scheduler = DDPMScheduler.from_pretrained(model_name, subfolder="scheduler")

# Load tokenizer and text encoder
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)


# Add a dummy alpha channel to each image
def add_alpha_channel(image_tensor):
    # Add an extra channel filled with zeros
    alpha_channel = torch.zeros_like(image_tensor[:, :1, :, :])  # Shape: [batch_size, 1, H, W]
    return torch.cat([image_tensor, alpha_channel], dim=1)  # Shape: [batch_size, 4, H, W]


# Correct definition of CustomDataset
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image_path = item["file_path"]  # Ensure this key exists in your dataset
        caption = item["caption"]      # Ensure this key exists in your dataset

        # Load and preprocess the image
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)

        return image, caption


transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])

# Prepare dataset and DataLoader
train_dataset = CustomDataset(dataset["train"], transform)
data_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)

# Optimizer
optimizer = torch.optim.AdamW(list(unet.parameters()) + list(text_encoder.parameters()), lr=5e-5)

scaler = GradScaler()

# Training loop
num_epochs = 1
for epoch in range(num_epochs):
    for images, captions in data_loader:
        with autocast():
            images = images.to(device)
            images = add_alpha_channel(images)

            # Tokenize captions
            inputs = tokenizer(captions, padding="max_length", truncation=True, max_length=77, return_tensors="pt")
            input_ids = inputs.input_ids.to(device)

            # Encode text
            text_embeddings = text_encoder(input_ids)[0]

            # Predict noise (forward pass)
            noise = torch.randn_like(images).to(device)
            noisy_images = scheduler.add_noise(images, noise, scheduler.timesteps[0])
            predicted_noise = unet(noisy_images, scheduler.timesteps[0], text_embeddings).sample

            # Compute loss (MSE)
            loss = torch.nn.functional.mse_loss(predicted_noise, noise)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

    print(f"Epoch {epoch + 1}/{num_epochs} completed. Loss: {loss.item()}")

# Save the fine-tuned model
unet.save_pretrained("fine_tuned_unet")
text_encoder.save_pretrained("fine_tuned_text_encoder")
print("Fine-tuned model saved.")


In [None]:
import wandb
wandb.init(project="stable-diffusion-fine-tuning")

# Log losses during training
wandb.log({"epoch": epoch, "loss": loss.item()})


In [None]:
from PIL import Image

# Load the fine-tuned model
fine_tuned_pipeline = StableDiffusionPipeline.from_pretrained("fine_tuned_stable_diffusion").to("cuda")

# Generate an image
prompt = "A beautiful painting of a futuristic cityscape"
image = fine_tuned_pipeline(prompt).images[0]

# Save and display the image
image.save("generated_image.png")
image.show()


In [None]:
from PIL import Image
from torchvision import transforms

# Load an image and ensure 3 channels (RGB)
image = Image.open("data/dataset/images/kvg:kanji_0f9a8.png").convert("RGB")

# Transform pipeline
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

# Apply the transform
input_tensor = transform(image).unsqueeze(0)  # Add batch dimension
print(input_tensor.shape)  # Output: [1, 3, 128, 128]
