In [4]:
import torch
import numpy as np
!pip install diffusers
from diffusers import StableDiffusionPipeline
from scipy.stats import ttest_rel
from PIL import Image
from torchvision import transforms

import torch.nn as nn
from diffusers import StableDiffusionPipeline, DDPMScheduler



Collecting diffusers
  Downloading diffusers-0.30.3-py3-none-any.whl.metadata (18 kB)
Downloading diffusers-0.30.3-py3-none-any.whl (2.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m28.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: diffusers
Successfully installed diffusers-0.30.3


The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [1]:

def preprocess_image(image_path, pipeline, device):
    preprocess = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]) 
    ])
    image = Image.open(image_path).convert("RGB")
    image = preprocess(image).unsqueeze(0).to(device, dtype=torch.float16)  
    
    with torch.no_grad():
        latents = pipeline.vae.encode(image).latent_dist.mean
        latents = latents * pipeline.vae.config.scaling_factor 
    return latents

def add_noise(latents, timestep, noise_scheduler):
    noise = torch.randn_like(latents)
    noisy_latents = noise_scheduler.add_noise(latents, noise, torch.tensor([timestep], device=latents.device))
    return noisy_latents

def denoise_latents(noisy_latents, timestep, encoder_hidden_states, pipeline):
    timestep_tensor = torch.tensor([timestep], dtype=torch.long, device=noisy_latents.device)
    
    denoised_latents = pipeline.unet(
        noisy_latents, 
        timestep_tensor, 
        encoder_hidden_states=encoder_hidden_states
    ).sample
    return denoised_latents

def compute_cosine_similarity(original_latents, denoised_latents):
    original_latents = nn.functional.normalize(original_latents, dim=-1)
    denoised_latents = nn.functional.normalize(denoised_latents, dim=-1)
    cosine_sim = torch.sum(original_latents * denoised_latents, dim=-1).mean().item()
    return cosine_sim


In [2]:
def zero_shot_classification(image_path, categories, num_timesteps=10, weights=None):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    pipeline = StableDiffusionPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5",  
        torch_dtype=torch.float16
    ).to(device)
    
    noise_scheduler = pipeline.scheduler  

    original_latents = preprocess_image(image_path, pipeline, device)

    total_timesteps = noise_scheduler.config.num_train_timesteps
    selected_timesteps = torch.linspace(0, total_timesteps - 1, steps=num_timesteps, dtype=torch.long).tolist()

    if weights is None:
        weights = [1.0 for _ in selected_timesteps]
    else:
        assert len(weights) == len(selected_timesteps), "Weights length must match number of timesteps."

    scores = []

    for category in categories:
        text_inputs = pipeline.tokenizer(
            category, 
            padding="max_length", 
            max_length=pipeline.tokenizer.model_max_length, 
            truncation=True, 
            return_tensors="pt"
        ).to(device)
        
        with torch.no_grad():
            text_embeddings = pipeline.text_encoder(**text_inputs).last_hidden_state 

        category_scores = []

        for timestep, weight in zip(selected_timesteps, weights):
            noisy_latents = add_noise(original_latents, timestep, noise_scheduler)

            denoised_latents = denoise_latents(noisy_latents, timestep, text_embeddings, pipeline)

            score = compute_cosine_similarity(original_latents, denoised_latents)
            category_scores.append(score * weight)

        total_score = sum(category_scores)
        scores.append(total_score)

    scores_tensor = torch.tensor(scores)
    predicted_category_index = torch.argmin(scores_tensor).item()
    predicted_category = categories[predicted_category_index]

    return predicted_category

In [33]:


image_path = "/kaggle/input/cat-dataset/download (1).jpeg"  
categories = ["a photo of a house", "a photo of a horse", "a photo of a cat", "a photo of an umbrella"]

weights = None 

predicted_category = zero_shot_classification(image_path, categories, num_timesteps=10, weights=weights)
print(f"Predicted category: {predicted_category}")


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Predicted category: a photo of a cat
