In [1]:
import os
from tqdm import tqdm
import cv2
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from diffusers import DDPMScheduler

In [2]:
from transformers import CLIPProcessor, CLIPModel

# Initialize CLIP model and processor (just an example)
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")



In [3]:
from diffusers import StableDiffusionImg2ImgPipeline

pipeline = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipeline = pipeline.to("cuda")  # Move model to GPU
pipeline.enable_model_cpu_offload()
optimizer = torch.optim.AdamW(pipeline.unet.parameters(), lr=5e-6)
noise_scheduler = DDPMScheduler.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="scheduler")

# Transform pipeline
transform = transforms.Compose([
    #transforms.Resize((256, 256)),
    transforms.ToTensor(),
    #transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1] for Stable Diffusion
])

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [4]:
def reeadData(src_dir, target_dir, limit):
    src_list = []
    target_list = []
    i=0
    src = sorted(os.listdir(src_dir), key=lambda x: int(x.split('.')[0]))
    target = sorted(os.listdir(target_dir), key=lambda x: int(x.split('.')[0]))
    images = zip(src, target)
    for src, target in images:
        try:
            src_img, target_img = cv2.imread(os.path.join(src_dir, src)), cv2.imread(os.path.join(target_dir, target))   # read the image
            src_img, target_img= cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB), cv2.cvtColor(target_img, cv2.COLOR_BGR2RGB)
            src_list.append(src_img)
            target_list.append(target_img)
            i +=1
            if i == limit:
                break
        except Exception as e:
            pass
    return src_list, target_list

In [5]:
class PairedDataset(torch.utils.data.Dataset):
    def __init__(self, src_images, target_images, transform=None):
        """
        Args:
            src_images (list): List of source images.
            target_images (list): List of target images.
            transform (callable, optional): Transform to be applied to the images.
            image_limit (int, optional): Maximum number of images to include in the dataset.
        """

        self.src_images = src_images
        self.target_images = target_images
        self.transform = transform

    def __len__(self):
        return len(self.src_images)

    def __getitem__(self, idx):
        src = self.transform(self.src_images[idx]) if self.transform else self.src_images[idx]
        target = self.transform(self.target_images[idx]) if self.transform else self.target_images[idx]
        return src, target

In [6]:
src_images, target_images = reeadData("./dataset/face/", "./dataset/comics/", limit=10)
dataset = PairedDataset(src_images, target_images, transform=transform)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [7]:
import torch.nn as nn
class CLIPProjector(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(512, 768)  # Project from 512 to 768 dimensions

    def forward(self, x):
        return self.fc(x)

# Initialize the projector
clip_projector = CLIPProjector()

In [8]:
def encoder_hidden_states(src):
    """
    Processes the input `src` (image tensor) through the CLIP model to get the encoder hidden states.

    Args:
        src (torch.Tensor): A tensor of shape [batch_size, 3, height, width].

    Returns:
        torch.Tensor: Processed encoder hidden states of shape [batch_size, sequence_length, hidden_size].
    """
    # Preprocess the image tensor using the CLIP processor
    inputs = clip_processor(images=src, return_tensors="pt", padding=True, truncation=True)

    # Get image features from the CLIP model
    output = clip_model.get_image_features(**inputs)

    # Print the shape of the output (should be [batch_size, 512])
    print("Output shape from CLIP:", output.shape)

    # Project the output from 512 to 768 channels
    projected_output = clip_projector(output)

    # Print the new shape (should be [batch_size, 768])
    print("Projected output shape:", projected_output.shape)

    # Reshape to match the required format [batch_size, sequence_length, 768]
    # Here sequence_length is 1 (one token per image)
    batch_size = projected_output.shape[0]
    sequence_length = 1  # Sequence length is 1 for image-based input
    hidden_size = 768  # We projected to 768 channels

    # Reshaping output to [batch_size, sequence_length, hidden_size]
    hidden_states = projected_output.view(batch_size, sequence_length, hidden_size)

    # Now the output is 3D with shape [batch_size, 1, 768]
    print("Final hidden states shape:", hidden_states.shape)

    return hidden_states

In [None]:
for epoch in range(5):  # Number of epochs
    pipeline.unet.train()
    for src, target in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
        src, target = src.to("cuda"), target.to("cuda")

        # Add noise to the target image
        noise = torch.randn_like(target).to("cuda")
        timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (target.size(0),), device="cuda")
        noisy_target = noise_scheduler.add_noise(target, noise, timesteps)
        print(noisy_target.shape)

        # Add an extra channel to noisy_target to match UNet input requirements
        noisy_target = torch.cat([noisy_target, torch.randn_like(noisy_target[:, :1, :, :])], dim=1)  # Add a noise channel
        encoder_hidden_states(src)
        # Forward pass
        unet_output = pipeline.unet(noisy_target, timestep=timesteps, encoder_hidden_states=encoder_hidden_states(src))
        print(unet_output.sample.shape)
        noise_pred = unet_output.sample
        

        # Compute loss (mean squared error)
        loss = torch.nn.functional.mse_loss(noise_pred, noise)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1} completed with loss: {loss.item()}")

Epoch 1:   0%|                                                                                  | 0/10 [00:00<?, ?it/s]It looks like you are trying to rescale already rescaled images. If the input images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.


torch.Size([1, 3, 512, 512])
Output shape from CLIP: torch.Size([1, 512])
Projected output shape: torch.Size([1, 768])
Final hidden states shape: torch.Size([1, 1, 768])
Output shape from CLIP: torch.Size([1, 512])
Projected output shape: torch.Size([1, 768])
Final hidden states shape: torch.Size([1, 1, 768])
