In [1]:
import os
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from diffusers import StableDiffusionPipeline
from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPProcessor
from huggingface_hub import login
import pandas as pd
from PIL import Image
import numpy as np
from tqdm.auto import tqdm
import csv

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Set device to MPS if available, else CPU
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print(f"Using device: {device}")

Using device: mps


In [3]:
# Log in to Hugging Face
login(token="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")

In [4]:
# Load the pipeline
pipeline = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5"
    ).to(device)

Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00, 20.95it/s]


In [5]:
vae = pipeline.vae
unet = pipeline.unet
text_encoder = pipeline.text_encoder
tokenizer = pipeline.tokenizer
scheduler = pipeline.scheduler

In [6]:
# Load the CLIP model (both image and text encoders)
clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

In [7]:
# Paths to your data (update these paths)
images_dir = "./flickr30k_images/flickr30k_images"  # Update this path
captions_file = "./flickr30k_images/results.csv"    # Update this path

In [8]:
# Define the custom Dataset class
class Flickr30KDataset(Dataset):
    def __init__(self, images_dir, captions_file, tokenizer, clip_tokenizer, clip_transform, sd_transform):
        self.images_dir = images_dir
        self.captions_file = captions_file
        self.tokenizer = tokenizer
        self.clip_tokenizer = clip_tokenizer
        self.clip_transform = clip_transform
        self.sd_transform = sd_transform

        # Load captions from the CSV file
        self.captions_df = pd.read_csv(
            self.captions_file,
            sep="|",
            header=0,
            encoding='utf-8',
            on_bad_lines='skip'  # For pandas >=1.3.0
        )

        # Ensure the image names are strings
        self.captions_df['image_name'] = self.captions_df['image_name'].astype(str)

        # Group captions by image_name
        self.image_captions = self.captions_df.groupby('image_name')['comment'].apply(list).to_dict()
        self.image_names = list(self.image_captions.keys())

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        # Get image name
        image_name = self.image_names[idx]
        image_path = os.path.join(self.images_dir, image_name)

        # Load image
        image = Image.open(image_path).convert("RGB")

        # Apply CLIP transform to image
        image_clip = self.clip_transform(image)
        # Apply Stable Diffusion transform to image
        image_sd = self.sd_transform(image)

        # Get captions for the image
        captions = self.image_captions[image_name]
        # For simplicity, we'll use the first caption
        caption = captions[0]

        # Tokenize the caption using the Stable Diffusion tokenizer
        encoding = self.tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=77,
            return_tensors="pt"
        )

        # Tokenize the caption using the CLIP tokenizer (for the CLIP text encoder)
        clip_encoding = self.clip_tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=77,
            return_tensors="pt"
        )

        return {
            "image_clip": image_clip,  # Tensor
            "image_sd": image_sd,      # Tensor
            "clip_input_ids": clip_encoding["input_ids"].squeeze(),
            "clip_attention_mask": clip_encoding["attention_mask"].squeeze(),
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
        }

In [9]:
# Define image transformations
# For the CLIP image encoder (requires 224x224 images)
clip_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                         std=(0.26862954, 0.26130258, 0.27577711)),
])

# For the Stable Diffusion pipeline (expects 512x512 images)
sd_transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

In [10]:
# Initialize the tokenizer (from the pipeline)
sd_tokenizer = tokenizer

# CLIP tokenizer
clip_tokenizer = clip_processor.tokenizer

In [11]:

# Create the dataset
dataset = Flickr30KDataset(
    images_dir=images_dir,
    captions_file=captions_file,
    tokenizer=sd_tokenizer,
    clip_tokenizer=clip_tokenizer,
    clip_transform=clip_transform,
    sd_transform=sd_transform
)

# Create the DataLoader
batch_size = 1  # Adjust based on your hardware capabilities
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Fine-tuning the text_encoder
text_encoder.train()
clip_model.eval()  # We'll use the CLIP image encoder in evaluation mode

# Define the optimizer
optimizer = optim.AdamW(text_encoder.parameters(), lr=5e-5)

# Define the loss function
criterion = nn.CrossEntropyLoss()

In [12]:
# For the Stable Diffusion pipeline (expects 512x512 images)
sd_transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

In [13]:
# Create the dataset
dataset = Flickr30KDataset(
    images_dir=images_dir,
    captions_file=captions_file,
    tokenizer=tokenizer,
    transform=None  # We'll apply transforms separately
)

# Create the DataLoader
batch_size = 1  # Adjust based on your hardware capabilities
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Fine-tuning the text_encoder
text_encoder.train()
clip_model.eval()  # We'll use the CLIP image encoder in evaluation mode

# Define the optimizer
optimizer = optim.AdamW(text_encoder.parameters(), lr=5e-5)

# Define the loss function
criterion = nn.CrossEntropyLoss()

TypeError: __init__() got an unexpected keyword argument 'transform'

In [14]:
# Training loop
num_epochs = 1  # Adjust as needed
for epoch in range(num_epochs):
    running_loss = 0.0
    for batch in tqdm(dataloader, desc=f"Fine-tuning Text Encoder Epoch {epoch+1}"):
        optimizer.zero_grad()
        
        # Move data to device
        images_clip = batch["image_clip"].to(device)  # Already a tensor
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        clip_input_ids = batch["clip_input_ids"].to(device)
        clip_attention_mask = batch["clip_attention_mask"].to(device)
        
        # Get image embeddings using CLIP image encoder
        with torch.no_grad():
            image_embeddings = clip_model.get_image_features(images_clip)
        
        # Normalize image embeddings
        image_embeddings = image_embeddings / image_embeddings.norm(p=2, dim=-1, keepdim=True)
        
        # Get text embeddings from the text_encoder (Stable Diffusion's text encoder)
        text_outputs = text_encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        text_embeddings = text_outputs.last_hidden_state
        # Take the mean pooling over the sequence dimension
        text_embeddings = text_embeddings.mean(dim=1)
        # Normalize text embeddings
        text_embeddings = text_embeddings / text_embeddings.norm(p=2, dim=-1, keepdim=True)
        
        # Compute similarity scores
        logits_per_image = image_embeddings @ text_embeddings.t()
        logits_per_text = logits_per_image.t()
        
        # Labels
        batch_size = images_clip.size(0)
        labels = torch.arange(batch_size).to(device)
        
        # Compute loss
        loss_image = criterion(logits_per_image, labels)
        loss_text = criterion(logits_per_text, labels)
        loss = (loss_image + loss_text) / 2
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    avg_loss = running_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

Fine-tuning Text Encoder Epoch 1: 100%|██████████| 1/1 [00:00<00:00,  1.11it/s]

Epoch [1/1], Loss: 0.0000





In [None]:
# Save the fine-tuned text encoder
text_encoder.save_pretrained("fine_tuned_text_encoder")

# Update the pipeline's text encoder
text_encoder = CLIPTextModel.from_pretrained("fine_tuned_text_encoder").to(device)
pipeline.text_encoder = text_encoder

# Proceed to generate images using the fine-tuned pipeline
pipeline.to(device)
pipeline.enable_attention_slicing()

# Generate an image
prompt = "A beautiful landscape with mountains and a lake."

with torch.no_grad():
    image = pipeline(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]

# Display the image
display(image)

100%|██████████| 50/50 [00:50<00:00,  1.01s/it]
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
