### Load CLIP Model

In [None]:
from transformers import CLIPTextModel, CLIPTokenizer
import torch
import os

# Define CLIP model path
CLIP_MODEL_PATH = "./models/clip-vit-large-patch14"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load CLIP from local folder
print(f"Loading CLIP model from: {CLIP_MODEL_PATH}...")
if not os.path.exists(CLIP_MODEL_PATH):
    print("\n⚠️  Model not found locally. Downloading from Hugging Face...")
    print("This model is ~1.7GB and will take a few minutes.")
    print("Please be patient...\n")
    
    # Download and save to local folder
    clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
    clip_model = CLIPTextModel.from_pretrained(
        "openai/clip-vit-large-patch14",
        torch_dtype=torch.bfloat16  # Use bfloat16 to match FLUX
    )
    
    # Save to local folder
    print(f"Saving model to {CLIP_MODEL_PATH}...")
    clip_tokenizer.save_pretrained(CLIP_MODEL_PATH)
    clip_model.save_pretrained(CLIP_MODEL_PATH)
    print("✓ Model downloaded and saved locally!\n")
else:
    print("✓ Loading from local folder...\n")

# Load from local folder
clip_tokenizer = CLIPTokenizer.from_pretrained(CLIP_MODEL_PATH, local_files_only=True)
clip_model = CLIPTextModel.from_pretrained(
    CLIP_MODEL_PATH,
    torch_dtype=torch.bfloat16,  # Use bfloat16 to match FLUX
    local_files_only=True
).to(device)
clip_model.eval()  # Set to evaluation mode

print(f"✓ CLIP loaded successfully!")
print(f"  Embedding dimension: {clip_model.config.hidden_size}")
print(f"  Max sequence length: {clip_tokenizer.model_max_length}")
print(f"  Loaded from: {CLIP_MODEL_PATH}")
print(f"  Model dtype: {next(clip_model.parameters()).dtype}")

In [None]:
import ipywidgets as widgets
from IPython.display import display

# Create text input widget
clip_prompt_input = widgets.Textarea(
    value='an elephant',
    placeholder='Enter your prompt here',
    description='Prompt:',
    layout=widgets.Layout(width='80%', height='80px')
)
clip_generate_button = widgets.Button(
    description='Generate CLIP Embedding',
    button_style='success'
)
clip_output_area = widgets.Output()

# Global variable to store current embedding
current_clip_embedding = None
current_clip_tokens = None

def generate_clip_embedding(b):
    global current_clip_embedding, current_clip_tokens
    
    with clip_output_area:
        clip_output_area.clear_output()
        
        prompt = clip_prompt_input.value
        print(f"Generating CLIP embedding for: '{prompt}'\n")
        
        # Tokenize
        tokens = clip_tokenizer(
            prompt,
            padding="max_length",
            max_length=77,  # CLIP uses 77 tokens
            truncation=True,
            return_tensors="pt"
        )
        
        # Get token strings for display
        token_ids = tokens['input_ids'][0].tolist()
        token_strings = [clip_tokenizer.decode([tid]) for tid in token_ids]
        
        # Find how many real tokens (non-padding)
        num_real_tokens = (tokens['input_ids'][0] != clip_tokenizer.pad_token_id).sum().item()
        
        print(f"Tokenized into {num_real_tokens} real tokens (+ {77 - num_real_tokens} padding):")
        print("First 10 tokens:", token_strings[:10])
        print()
        
        # Generate embedding
        with torch.no_grad():
            tokens = {k: v.to(device) for k, v in tokens.items()}
            outputs = clip_model(**tokens)
            embedding = outputs.last_hidden_state  # Shape: [1, 77, embedding_dim]
        
        # Convert bfloat16 to float32 before converting to numpy
        current_clip_embedding = embedding.float().cpu().numpy()[0]  # Shape: [77, embedding_dim]
        current_clip_tokens = token_strings
        
        embedding_dim = current_clip_embedding.shape[1]
        total_numbers = current_clip_embedding.shape[0] * current_clip_embedding.shape[1]
        
        print(f"✓ CLIP embedding generated!")
        print(f"  Shape: {current_clip_embedding.shape}")
        print(f"  Total numbers: {total_numbers:,}")
        print(f"  Size: {current_clip_embedding.nbytes / 1024:.2f} KB")
        print()
        print(f"First token '{token_strings[0]}' embedding (first 10 values):")
        print(current_clip_embedding[0, :10])

clip_generate_button.on_click(generate_clip_embedding)
display(clip_prompt_input, clip_generate_button, clip_output_area)

In [None]:
import json

def save_clip_embedding(filename="clip_embedding.json"):
    if current_clip_embedding is None:
        print("❌ No CLIP embedding to save! Generate one first.")
        return
    
    data = {
        "embedding": current_clip_embedding.tolist(),
        "tokens": current_clip_tokens,
        "shape": list(current_clip_embedding.shape),
        "prompt": clip_prompt_input.value
    }
    
    with open(filename, 'w') as f:
        json.dump(data, f)
    
    file_size = os.path.getsize(filename) / (1024 * 1024)
    print(f"✓ CLIP embedding saved to '{filename}' ({file_size:.2f} MB)")

# Save button
clip_save_button = widgets.Button(description='Save CLIP Embedding', button_style='info')
clip_save_output = widgets.Output()

def on_clip_save_click(b):
    with clip_save_output:
        clip_save_output.clear_output()
        save_clip_embedding()

clip_save_button.on_click(on_clip_save_click)
display(clip_save_button, clip_save_output)