"INSTALLING DEPENDENCIES"
# %pip install

In [1]:
"IMPORTING LIBRARIES"
import pandas as pd
import torch
import torch.nn as nn
import json
from diffusers import StableDiffusionPipeline
from transformers import GPT2LMHeadModel, GPT2Tokenizer, CLIPProcessor, CLIPModel
import numpy as np
import os
import re
from tqdm import tqdm
import time
import gradio as gr
from accelerate import Accelerator
from pympler import asizeof as size
import sys
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
"INITIALIZING THE DATASET"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'device used is {device}')
try:
    os.environ["TRANSFORMERS_OFFLINE"] = "1"
    print("Downloading models...")
    for step in tqdm(["processor", "model"], desc="Downloading"):
        if step == "processor":
            if os.path.exists('./clip_processor'):
                print("Processor already exists, skipping download.")
                pass
            else:
                clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
                clip_processor.save_pretrained('./clip_processor')
        else:
            if os.path.exists('./clip_model'):
                print("Model already exists, skipping download.")
                pass
            else:
                clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
                clip_model.save_pretrained('./clip_model')                
        time.sleep(0.1)

    for step in tqdm(["tokenizer", "model"], desc="Downloading"):
        if step == "tokenizer":
            if os.path.exists('./gpt2_tokenizer'):
                print("gpt2 tokenizer already saved, skipping download.")
                pass
            else:
                gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
                gpt2_tokenizer.save_pretrained('./gpt2_tokenizer')
        else:
            if os.path.exists('./gpt2_model'):
                print(" gpt2 Model already saved, skipping download.")
                pass
            else:
                gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
                gpt2_model.save_pretrained('./gpt2_model')
        time.sleep(0.1)

    for step in tqdm(["diffuse"], desc="Downloading"):
        if step == "diffuse":
            if os.path.exists('./stable_diffuse'):
                print("stable diffuse already saved, skipping download.")
                pass
            else:
                stable_diffuse = StableDiffusionPipeline.from_pretrained(
                    "runwayml/stable-diffusion-v1-5",
                    torch_dtype=torch.float16
                ).to(device)
                stable_diffuse.save_pretrained('./stable_diffuse')
        time.sleep(0.1)
except Exception as e:
    print(f"An error occurred: {e}")
    print("Please check your internet connection or the model names.")
BASE_DIR = './data/'
sub_dirs = [d for d in os.listdir(BASE_DIR) if os.path.isdir(os.path.join(BASE_DIR, d))] 
TEXT_PATH = [os.path.join(BASE_DIR, sub_dir, 'text.json')
             for sub_dir in sub_dirs if os.path.isdir(os.path.join(BASE_DIR, sub_dir))]
IMG_PATH = [os.path.join(BASE_DIR, sub_dir, 'img', 'meta.json')
            for sub_dir in sub_dirs if os.path.isdir(os.path.join(BASE_DIR, sub_dir))]

device used is cpu
Downloading models...


Downloading:  50%|█████     | 1/2 [00:00<00:00,  9.80it/s]

Processor already exists, skipping download.
Model already exists, skipping download.


Downloading: 100%|██████████| 2/2 [00:00<00:00,  9.39it/s]
Downloading:  50%|█████     | 1/2 [00:00<00:00,  9.23it/s]

gpt2 tokenizer already saved, skipping download.
 gpt2 Model already saved, skipping download.


Downloading: 100%|██████████| 2/2 [00:00<00:00,  9.16it/s]
Downloading: 100%|██████████| 1/1 [00:00<00:00,  8.88it/s]


stable diffuse already saved, skipping download.


In [3]:
'READING THE DATA AND CREATING THE MMD'
def mm_data(txt_paths=TEXT_PATH, img_paths=IMG_PATH):
    """Create a multimodal dataset (MMD) from text and image files."""
    for i, (txt_path, img_path) in enumerate(zip(txt_paths, img_paths)):
        print(f"n--- Processing pair {i+1} ---")
        print(f"Text file: {txt_path}")
        print(f"Image file: {img_path}")
        
        if not txt_path.endswith(".json") or not img_path.endswith(".json"):
            continue
         
        combined_entries = {}
        
        # Process text file
        if txt_path:
            try:
                with open(txt_path, 'r', encoding='utf-8') as f:
                    text_data = json.load(f)
                    print(f"✓ Text data loaded: {len(text_data) if hasattr(text_data, '__len__') else 'N/A'} items")
                    combined_entries['text_data'] = text_data
                    print(f'size of text data: {size.asizeof(combined_entries['text_data'])/ (1024):.2f}kb')
            except Exception as e:
                print(f"✗ Error reading text file: {e}")

        # Process image file
        if img_path:
            try:
                with open(img_path, 'r', encoding='utf-8') as f:
                    img_data = json.load(f)
                    print(f"✓ Image data loaded: {len(img_data) if hasattr(img_data, '__len__') else 'N/A'} items")
                    combined_entries['image_data'] = img_data
                    print(f'size of image data: {size.asizeof(combined_entries['image_data'])/ (1024):.2f}kb')
            except Exception as e:
                print(f"✗ Error reading image file: {e}")

        if combined_entries:
            yield combined_entries
        else:
            print(f"✗ No data found for pair {i+1}, skipping.")

In [4]:
'CLEANING THE DATA'
def clean_data(data):
    """Comprehensive text cleaning function"""
    
    # Remove HTML tags
    data = re.sub(r'<[^>]+>', '', data)
    
    # Remove URLs
    data = re.sub(r'httpS+|wwwS+|httpsS+', '', data, flags=re.MULTILINE)

    # Remove excessive punctuation but keep basic ones
    data = re.sub(r'[^a-zA-Z0-9s.,;:!?-]', '', data)
    
    # Remove email addresses
    data = re.sub(r'S+@S+', '', data)
    
    # Convert to lowercase
    data = data.lower()
    
    # Remove extra whitespace
    data = re.sub(r's+', ' ', data).strip()
    
    return data

In [5]:
def infer_feature_dim(df):
    """Infers the dimension of the image features from the first valid entry."""
    if df.empty:
        print("Warning: DataFrame is empty. using default of 512.")
        return 512
    for _, entry in df.iterrows():
        if 'image_data' in entry and pd.notna(entry['image_data']):
            image_data = entry['image_data']
            if isinstance(image_data, dict) and 'img_meta' in image_data:
                for img_item in image_data['img_meta']:
                    features = img_item.get('features', [])
                    if features and len(features) > 0:
                        print(f'feature_dim generated, length is {len(features)}')
                        # Return the length of the first valid feature list
                        return len(features)
    raise ValueError("Could not infer feature dimension using default of 512.")

In [6]:
'DATA TOKENIZATION AND EMMBEDDING USING TRANSFORMERS'
class TFMultimodalProcessor:
    """A simple multimodal processor that extracts text and image features."""
    def __init__(self, result_df, feature_dim):
        self.df = result_df
        if os.path.isdir("./clip_processor") and os.path.isdir("./clip_model"):
            self.clip_processor = CLIPProcessor.from_pretrained("./clip_processor", local_files_only=True)
            self.clip_model = CLIPModel.from_pretrained("./clip_model", local_files_only=True)
        else:
            self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
            self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.feature_dim = feature_dim
        self.text_embedding_dim = self.clip_model.config.text_config.hidden_size
    
    def extract_text(self, entry):
        """Extract all text from both text_data and image_data"""
        all_text = []
        
        # From text_data
        if 'text_data' in entry and pd.notna(entry['text_data']):
            text_data = entry['text_data']
            if isinstance(text_data, str):
                cleaned_text = clean_data(text_data)
                all_text.append(cleaned_text)
            elif isinstance(text_data, dict):
                # Extract common text fields
                for key in ['title', 'content', 'description', 'text', 'summary', 'wikitext']:
                    if key in text_data and text_data[key]:
                        cleaned_text = clean_data(str(text_data[key]))[:1500]
                        all_text.append(cleaned_text)
        
        # From image_data (descriptions, captions, etc.)
        if 'image_data' in entry and pd.notna(entry['image_data']):
            image_data = entry['image_data']
            if isinstance(image_data, dict) and 'img_meta' in image_data:
                for img_item in image_data['img_meta']:
                    for key in ['description', 'caption', 'title', 'parsed_title']:
                        if key in img_item and img_item[key]:
                            value = img_item[key]
                            if isinstance(value, list):
                                cleaned_text = [clean_data(str(v)) for v in value if v]
                                all_text.extend(cleaned_text)
                            else:
                                cleaned_text = clean_data(str(value))
                                all_text.append(cleaned_text)
        
        return " ".join(all_text)
    
    def process_and_embed(self, df, BATCH_SIZE=64):
        """Process all entry to create unified embedding"""        
        for i in range(0, len(df), BATCH_SIZE):
            # GET BATCH DATA
            print(f"--- Processing batch {i // BATCH_SIZE + 1} ---")
            batch_df = df.iloc[i: i + BATCH_SIZE]

            extracted_text = [self.extract_text(entry) for _, entry in batch_df.iterrows()]
            text_embedding = None
            if extracted_text:
                inputs = self.clip_processor(text=extracted_text, return_tensors="pt", padding=True, truncation=True)
                with torch.no_grad():
                    text_embedding = self.clip_model.get_text_features(**inputs).cpu()

            # Get image features if available
            image_features = []
            for _, entry in batch_df.iterrows():
                image_data = entry.get('image_data', {})
                features = []
                if isinstance(image_data, dict) and 'img_meta' in image_data: 
                    for img_item in image_data['img_meta']:
                        features.extend(img_item.get('features', []))

                # tokenize the features
                if features:
                    float_features = [float(f) for f in features]
                    image_features.append(torch.tensor(float_features, dtype=torch.float32))
                else:
                    image_features.append(torch.zeros(self.feature_dim, dtype=torch.float32))

            if not image_features:
                image_features = [torch.zeros(self.feature_dim, dtype=torch.float32) for _ in range(len(batch_df))]  
            max_dim = max([f.shape[0] for f in image_features]) if image_features else self.feature_dim
            max_dim = max(max_dim, self.feature_dim)
            # Pad all image features to the max dimension of the batch
            padded_image_features = []
            for f in image_features:
                if f.shape[0] < max_dim:
                    padded = torch.cat([f, torch.zeros(max_dim - f.shape[0], dtype=torch.float32)])
                else:
                    padded = f[:max_dim]
                padded_image_features.append(padded)
            image_embedding = torch.stack(padded_image_features).cpu()

            # 3. Concatenate and yield the combined embedding
            # The embeddings are concatenated as-is. This is the correct multimodal approach.
            # The `MultimodalBot` we designed will have an `input_projection` layer that learns
            # how to handle this combined, heterogeneous embedding.
            combined_embeddings = torch.cat((text_embedding, image_embedding), dim=1).numpy()

            # Ensure shapes match before yielding
            if combined_embeddings.shape[0] == len(batch_df):
                yield batch_df, combined_embeddings
            else:
                print("Warning: Skipping this batch due to mismatched embedding and dataframe sizes.")

            del text_embedding, image_embedding, combined_embeddings
            if torch.cuda.is_available():
                torch.cuda.empty_cache()


In [7]:
# 'MULTIMODAL GENERATION PIPELINE'
class MultimodalGenerator:
    def __init__(self, bot_model, processor, device=device):
        # The bot_model is the new, trainable multimodal model
        self.bot_model = bot_model
        self.processor = processor
        self.device = device
        self.accelerator = Accelerator()
        print(f"Using device for clip: {self.accelerator.device}")
        
        # We get the LLM tokenizer from the bot_model, which is our new bot.
        self.llm_tokenizer = self.bot_model.tokenizer

        # Load Stable Diffusion
        if os.path.isdir("./stable_diffuse"):
            self.pipe = StableDiffusionPipeline.from_pretrained(
                "./stable_diffuse", local_files_only=True, torch_dtype=torch.float16,
                device_map="balanced"
            )
        else:
            self.pipe = StableDiffusionPipeline.from_pretrained(
                "runwayml/stable-diffusion-v1-5",
                torch_dtype=torch.float16,
                device_map="balanced"
            )
        self.pipe = self.accelerator.prepare(self.pipe)

        # Load CLIP for evaluation from the processor
        self.clip_processor = self.processor.clip_processor
        self.clip_model = self.processor.clip_model

        # Create output folder
        os.makedirs("generated_outputs", exist_ok=True)

    def create_combined_embedding(self, user_text, user_image=None):
        # 1. Get text embedding
        text_inputs = self.clip_processor(text=[user_text], return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            text_embedding = self.clip_model.get_text_features(**text_inputs).cpu()
        
        # 2. Get image embedding
        if user_image:
            image_inputs = self.clip_processor(images=user_image, return_tensors="pt", padding=True, truncation=True)
            with torch.no_grad():
                image_embedding = self.clip_model.get_image_features(**image_inputs).cpu()
        else:
            # If no image is provided, use a zero vector of the correct size
            clip_image_dim = self.clip_model.config.vision_config.hidden_size
            image_embedding = torch.zeros(1, clip_image_dim)
        
        # 3. Concatenate the embeddings
        combined_embedding = torch.cat((text_embedding, image_embedding), dim=1)
        return combined_embedding.to(self.device)
    
    def refine_prompt_multimodal(self, combined_embedding):
        with torch.no_grad():
            refined_prompts = self.bot_model(combined_embedding.to(self.accelerator.device))
        return refined_prompts[0]

    def generate_image(self, prompt):
        image = self.pipe(prompt, guidance_scale=7.5).images[0]
        return image

    def evaluate_clip_similarity(self, image, text):
        inputs = self.clip_processor(text=[text], images=image, return_tensors="pt", padding=True).to(self.accelerator.device)
        outputs = self.clip_model(**inputs)
        similarity = torch.cosine_similarity(outputs.image_embeds, outputs.text_embeds).item()
        return similarity

    def run(self, data, limit=5):
        processed_batches = self.processor.process_and_embed(data)
        limit_count = 0
        
        # Loop through the generator to get batch-level data and embeddings
        for batch_df, batch_embeddings in processed_batches:
            if limit_count >= limit:
                print("Limit reached, stopping.")
                return

            # Feed the embeddings into our new bot to get prompts
            refined_prompts = self.bot_model(torch.tensor(batch_embeddings).float().unsqueeze(1))
            
            # Now, iterate through the individual entries and use the generated prompts
            batch_index = 0
            for i, row in batch_df.iterrows():
                if limit_count >= limit:
                    break

                raw_text = self.processor.extract_text(row)
                if not raw_text.strip():
                    print(f"⚠️ Skipping empty text at index {limit_count}")
                    limit_count += 1
                    continue

                refined_prompt = refined_prompts[batch_index % len(refined_prompts)] # Get the correct prompt for the row
                image = self.generate_image(refined_prompt)
                similarity = self.evaluate_clip_similarity(image, raw_text)

                # Save image and log
                image_path = f"generated_outputs/image_{limit_count+1}.png"
                image.save(image_path)

                print(f"n📌 Entry {limit_count+1}")
                print(f"🔤 Original Text: {raw_text[:100]}...")
                print(f"🧠 Refined Prompt: {refined_prompt}")
                print(f"🎯 CLIP Similarity: {similarity:.4f}")
                print(f"🖼️ Saved to: {image_path}")
                
                limit_count += 1


In [8]:
'TRAINING THE MODEL'
class MultimodalBotSoftPrompt(nn.Module):
    """
    Multimodal bot that uses the combined embedding as a soft prompt.
    The combined embedding is prepended to the input token embeddings,
    allowing the LLM's attention mechanism to process it directly.
    """
    def __init__(self, embedding_dim):
        super().__init__()
        # Load the pre-trained LLM and tokenizer
        if os.path.isdir("./gpt2_tokenizer") and os.path.isdir("./gpt2_model"):
            self.tokenizer = GPT2Tokenizer.from_pretrained("./gpt2_tokenizer", local_files_only=True)
            self.llm = GPT2LMHeadModel.from_pretrained("./gpt2_model", local_files_only=True)
        else:
            self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
            self.llm = GPT2LMHeadModel.from_pretrained("gpt2")
        
        # We need to project the combined embedding to the same dimension as the
        # LLM's token embeddings (GPT-2's hidden size is 768)
        self.input_projection = nn.Linear(embedding_dim, self.llm.config.hidden_size)

        # Set the tokenizer's padding token
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.llm.config.pad_token_id = self.llm.config.eos_token_id
        
        # Freeze the main LLM weights to focus training on the projection layer
        for param in self.llm.parameters():
            param.requires_grad = False
        
        # The projection layer is trainable
        self.input_projection.requires_grad = True

    def forward(self, combined_embedding):
        # Check if the input is 2D and unsqueeze it to make it 3D
        # Expected shape: (batch_size, 1, embedding_dim)
        if combined_embedding.dim() == 2:
            combined_embedding = combined_embedding.unsqueeze(1)
        # 1. Project the combined embedding to the LLM's embedding space
        # Shape: (batch_size, embedding_dim) -> (batch_size, 1, hidden_size)
        projected_embedding = self.input_projection(combined_embedding)
        
        # 2. Start with a simple "starter" prompt. This is not strictly necessary but
        # can help guide the generation.
        starter_prompt_text = "Generate a vivid prompt for an image based on the input features:"
        
        # 3. Tokenize the text prompt
        input_tokens = self.tokenizer(starter_prompt_text, return_tensors="pt")
        
        # 4. Get the original token embeddings from the LLM
        # This requires manually accessing the LLM's embedding layer
        text_embeddings = self.llm.get_input_embeddings()(input_tokens.input_ids.to(projected_embedding.device))
        
        # 5. Concatenate the projected embedding with the text embeddings
        # This is the "soft prompt" fusion!
        # The multimodal information is now at the very beginning of the sequence.
        batch_size = projected_embedding.size(0)
        if batch_size > 1:
            text_embeddings_expanded = text_embeddings.expand(batch_size, -1, -1)
        else:
            text_embeddings_expanded = text_embeddings
        fused_input_embeddings = torch.cat([projected_embedding, text_embeddings_expanded], dim=1)

        # 6. Create an attention mask for the fused input
        # The mask should be all 1s since all tokens are relevant
        # The length is the combined length of the soft prompt and the starter prompt
        attention_mask = torch.ones(
        projected_embedding.size(0), 
        fused_input_embeddings.size(1), 
        dtype=torch.long,
        device=projected_embedding.device
        )

        # 7. Generate text using the fused embeddings
        # The `inputs_embeds` parameter allows us to bypass the standard token lookup
        output = self.llm.generate(
            inputs_embeds=fused_input_embeddings,
            attention_mask=attention_mask,
            max_new_tokens=50,
            do_sample=True,
            top_k=50,
            top_p=0.9,
            temperature=0.8,
            pad_token_id=self.llm.config.pad_token_id,
            eos_token_id=self.llm.config.eos_token_id
        )

        # Decode and return the generated prompts
        refined_prompts = [self.tokenizer.decode(ids, skip_special_tokens=True) for ids in output]
        return refined_prompts

In [None]:
# MEMORY-EFFICIENT SOLUTION: Create sample dataset instead of loading all data
# IMPROVED DATASET CREATION WITH SMART SAMPLING
# Add this function to replace your current create_limited_dataframe

def create_smart_limited_dataframe(max_files=5, max_items_per_file=100, sampling_strategy='diverse'):
    """Create limited DataFrame with smart sampling to preserve feature diversity"""
    import random
    sample_data = []
    file_count = 0
    
    print(f"Creating smart limited dataset: {max_files} files, {max_items_per_file} items each")
    print(f"Sampling strategy: {sampling_strategy}")
    
    for data_entry in mm_data(TEXT_PATH, IMG_PATH):
        if file_count >= max_files:
            print(f"Reached limit of {max_files} files, stopping...")
            break
            
        limited_entry = {}
        
        # Handle text data
        if 'text_data' in data_entry:
            text_data = data_entry['text_data']
            if isinstance(text_data, list) and len(text_data) > max_items_per_file:
                if sampling_strategy == 'diverse':
                    # Sample evenly across the dataset
                    step = len(text_data) // max_items_per_file
                    sampled_text = [text_data[i] for i in range(0, len(text_data), step)][:max_items_per_file]
                    limited_entry['text_data'] = sampled_text
                    print(f"  → Text: sampled {len(sampled_text)} from {len(text_data)} (diverse)")
                else:
                    limited_entry['text_data'] = text_data[:max_items_per_file]
            else:
                limited_entry['text_data'] = text_data
        
        # Handle image data with SMART SAMPLING - KEY IMPROVEMENT!
        if 'image_data' in data_entry:
            img_data = data_entry['image_data']
            if isinstance(img_data, dict) and 'img_meta' in img_data:
                img_meta = img_data['img_meta']
                
                if len(img_meta) > max_items_per_file:
                    if sampling_strategy == 'diverse':
                        # DIVERSE SAMPLING: Get features from across the entire dataset
                        step = len(img_meta) // max_items_per_file
                        systematic_indices = list(range(0, len(img_meta), step))[:max_items_per_file * 3 // 4]
                        
                        # Add random samples for extra diversity
                        remaining_count = max_items_per_file - len(systematic_indices)
                        if remaining_count > 0:
                            remaining_indices = [i for i in range(len(img_meta)) if i not in systematic_indices]
                            if remaining_indices:
                                random_indices = random.sample(remaining_indices, 
                                                             min(remaining_count, len(remaining_indices)))
                                all_indices = systematic_indices + random_indices
                            else:
                                all_indices = systematic_indices
                        else:
                            all_indices = systematic_indices
                        
                        # Sort and sample
                        final_indices = sorted(list(set(all_indices)))[:max_items_per_file]
                        sampled_img_meta = [img_meta[i] for i in final_indices]
                        
                        limited_entry['image_data'] = {'img_meta': sampled_img_meta}
                        print(f"  → Images: sampled {len(sampled_img_meta)} from {len(img_meta)} (DIVERSE - preserves visual diversity!)")
                    else:
                        # Simple truncation (your current method)
                        limited_entry['image_data'] = {'img_meta': img_meta[:max_items_per_file]}
                        print(f"  → Images: truncated to {max_items_per_file} from {len(img_meta)} (WARNING: may lose diversity)")
                else:
                    limited_entry['image_data'] = {'img_meta': img_meta}
                    print(f"  → Images: kept all {len(img_meta)} features")
            else:
                limited_entry['image_data'] = img_data
        
        sample_data.append(limited_entry)
        file_count += 1
        print(f"✓ Added file {file_count}/{max_files}")
    
    return pd.DataFrame(sample_data)

# REPLACE YOUR CURRENT DATASET CREATION WITH THIS:
print("=== CREATING DATASET WITH SMART SAMPLING ===")
combined_dataframe = create_smart_limited_dataframe(
    max_files=3, 
    max_items_per_file=50, 
    sampling_strategy='diverse'  # This preserves image feature diversity!
)

print(f"\n=== SMART DATASET CREATED ===")
print(f"DataFrame shape: {combined_dataframe.shape}")
print(f"Memory usage: {combined_dataframe.memory_usage(deep=True).sum() / (1024 * 1024):.2f} MB")

# PROGRESSIVE SCALING STRATEGY
# Start small for testing, then scale up for training

def get_recommended_dataset_size(purpose='development'):
    """Get recommended dataset parameters based on purpose"""
    configs = {
        'development': {'files': 3, 'items': 50, 'total': 150},
        'testing': {'files': 10, 'items': 100, 'total': 1000},
        'training_small': {'files': 25, 'items': 200, 'total': 5000},
        'training_medium': {'files': 50, 'items': 400, 'total': 20000},
        'training_large': {'files': 100, 'items': 500, 'total': 50000},
        'production': {'files': 200, 'items': 500, 'total': 100000}
    }
    return configs.get(purpose, configs['development'])

def create_scalable_dataset(purpose='development'):
    """Create dataset based on purpose with memory monitoring"""
    config = get_recommended_dataset_size(purpose)
    
    print(f"\n=== CREATING {purpose.upper()} DATASET ===")
    print(f"Target: {config['files']} files × {config['items']} items = ~{config['total']} samples")
    
    try:
        df = create_smart_limited_dataframe(
            max_files=config['files'], 
            max_items_per_file=config['items'],
            sampling_strategy='diverse'
        )
        memory_in_mb = df.memory_usage(deep=True).sum() / 1024 / 1024
        print(f"\n✅ SUCCESS: {purpose} dataset created")
        print(f"Shape: {df.shape}")
        print(f"Memory: {memory_in_mb:.2f} MB")

        if memory_in_mb > 1000:  # > 1GB
            print("⚠️  WARNING: High memory usage detected")
        
        return df
    except MemoryError:
        print(f"❌ MEMORY ERROR: {purpose} dataset too large")
        print("Falling back to smaller dataset...")
        return create_scalable_dataset('development')

# STEP 1: Start with development dataset for initial testing
print("STEP 1: Creating development dataset for initial testing...")
combined_dataframe = create_scalable_dataset('development')

# STEP 2: Function to upgrade dataset size when ready
def upgrade_dataset(current_df, target_purpose='testing'):
    """Upgrade to larger dataset when ready"""
    print(f"\n🔄 UPGRADING DATASET TO: {target_purpose}")
    return create_scalable_dataset(target_purpose)

print(f"\n📊 CURRENT DATASET INFO:")
print(f"Purpose: Development (for initial testing)")
print(f"Samples: ~{combined_dataframe.shape[0] * 50} (estimated)")
print(f"\n🚀 TO SCALE UP FOR TRAINING:")
print(f"   combined_dataframe = upgrade_dataset(combined_dataframe, 'training_small')")
print(f"   # This will give you ~5,000 samples for basic training")

=== CREATING DATASET WITH SMART SAMPLING ===
Creating smart limited dataset: 3 files, 50 items each
Sampling strategy: diverse
n--- Processing pair 1 ---
Text file: ./data/...And_Justice_for_All_(album)\text.json
Image file: ./data/...And_Justice_for_All_(album)\img\meta.json
✓ Text data loaded: 5 items
size of text data: 667.54kb
✓ Image data loaded: 1 items
size of image data: 130.54kb
  → Images: kept all 1 features
✓ Added file 1/3
n--- Processing pair 2 ---
Text file: ./data/0.999\text.json
Image file: ./data/0.999\img\meta.json
✓ Text data loaded: 420314 items
size of text data: 410.51kb
✓ Image data loaded: 169008 items
size of image data: 165.09kb
✓ Added file 2/3
n--- Processing pair 3 ---
Text file: ./data/1080┬░_Snowboarding\text.json
Image file: ./data/1080┬░_Snowboarding\img\meta.json
✓ Text data loaded: 145973 items
size of text data: 142.59kb
✓ Image data loaded: 16 items
size of image data: 0.06kb
✓ Added file 3/3
n--- Processing pair 4 ---
Text file: ./data/1257_Samala

: 

In [None]:
# 'DATA TOKENIZATION AND EMMBEDDING USING TRANSFORMERS' and 'MultimodalBotSoftPrompt' classes would be placed here.

# 0. Initialize feature_dim
feature_dim = infer_feature_dim(combined_dataframe)

# 1. Initialize the data processor
tf_processor = TFMultimodalProcessor(combined_dataframe, feature_dim=feature_dim)

# 2. Initialize the multimodal bot (this is the new, trainable model)
embedding_dim = tf_processor.text_embedding_dim + tf_processor.feature_dim
multimodal_bot = MultimodalBotSoftPrompt(embedding_dim=embedding_dim)  

# 3. Create the generator with the trained bot
generator = MultimodalGenerator(bot_model=multimodal_bot, processor=tf_processor)
  
# 4. Run the pipeline with the generator
generator.run(combined_dataframe, limit=5)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


feature_dim generated, length is 2048
Using device for clip: cpu


Loading pipeline components...:  57%|█████▋    | 4/7 [00:01<00:01,  1.56it/s]

In [None]:

# 4. Set up the optimizer to train the bot
optimizer = torch.optim.AdamW(multimodal_bot.parameters(), lr=1e-5)
num_epochs = 3
BATCH_SIZE = 32
device = "cuda" if torch.cuda.is_available() else "cpu"
multimodal_bot.to(device)

for epoch in range(num_epochs):
    print(f"--- Epoch {epoch + 1} of True Multimodal Training ---")
    epoch_loss = 0.0
    batch_count = 0
    try:
       # The generator yields batches of data and embeddings
        data_stream = tf_processor.process_and_embed(combined_dataframe, BATCH_SIZE=BATCH_SIZE)
        for i, (batch_of_entries, batch_of_embeddings) in enumerate(data_stream):
            optimizer.zero_grad()  # Clear gradients at the start
            # Move embeddings to the correct device
            combined_embeddings_tensor = torch.from_numpy(batch_of_embeddings).float().to(device)
            # The bot generates refined prompts from the embeddings
            refined_prompts = multimodal_bot(combined_embeddings_tensor)
   
            batch_losses = []
            batch_idx = 0

            for j, row in batch_of_entries.iterrows():
                try:
                    raw_text = tf_processor.extract_text(row)
                    if not raw_text.strip():
                        continue
                        
                    refined_prompt = refined_prompts[batch_idx % len(refined_prompts)]
                    batch_idx += 1

                    # Generate image and evaluate similarity using the generator's methods
                    image = generator.generate_image(refined_prompt)
                    similarity_score = generator.evaluate_clip_similarity(image, raw_text)
                    
                    # Calculate loss based on CLIP similarity (maximize similarity)
                    loss = -torch.tensor(similarity_score, device=device, requires_grad=True)
                    batch_losses.append(loss)
                    
                except Exception as e:
                    print(f"Error processing entry {j}: {e}")
                    continue
            
            if batch_losses:
                # Average the losses and backpropagate
                batch_loss = torch.stack(batch_losses).mean()
                batch_loss.backward()
                
                # Gradient clipping for stability
                torch.nn.utils.clip_grad_norm_(multimodal_bot.parameters(), max_norm=1.0)
                
                optimizer.step()
                
                epoch_loss += batch_loss.item()
                batch_count += 1
                
                print(f"Batch {i+1} Loss: {batch_loss.item():.4f}")
                
                # Memory cleanup
                del batch_loss, batch_losses
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
    
    except Exception as e:
        print(f"Error in epoch {epoch + 1}: {e}")
        continue
    
    avg_epoch_loss = epoch_loss / max(batch_count, 1)
    print(f"Epoch {epoch + 1} Average Loss: {avg_epoch_loss:.4f}")
    print("True multimodal bot training complete!")
    # Save the new bot model's state dictionary
torch.save(multimodal_bot.state_dict(), "multimodal_bot.pth")

In [None]:
# This function pads the embedding history to a fixed length
def pad_embeddings(embeddings_list, max_history_length=10):
    """Pads a list of embeddings to a fixed sequence length."""
    # Stack the embeddings to create a single tensor
    stacked_embeddings = torch.stack(embeddings_list, dim=1).squeeze(0)
    
    # Get the current sequence length and embedding dimension
    current_length, embedding_dim = stacked_embeddings.shape
    
    if current_length >= max_history_length:
        # Truncate if the history is too long
        return stacked_embeddings[:, -max_history_length:]
    else:
        # Pad with zeros if the history is too short
        padding_tensor = torch.zeros(max_history_length - current_length, embedding_dim)
        padded = torch.cat([stacked_embeddings, padding_tensor], dim=0)
        return padded.unsqueeze(0) # Reshape for batch size of 1

# A global variable to store the conversation history (embeddings)
conversation_history_embeddings = []

def multimodal_app(user_text, uploaded_image=None):
    global conversation_history_embeddings

    if not user_text.strip():
        return None, "⚠️ Please enter a valid prompt.", ""
    
    # Add a reset command
    if user_text.lower() == "/reset":
        conversation_history_embeddings.clear()
        return None, "Conversation history has been reset.", ""

    # 1. Create a new combined embedding for the current turn
    # This requires the `create_combined_embedding_single_turn` method in your generator
    new_embedding = generator.create_combined_embedding(user_text, uploaded_image)
    
    # 2. Append the new embedding to the conversation history
    conversation_history_embeddings.append(new_embedding)
    
    # 3. Pad the history to a fixed sequence length
    padded_embeddings = pad_embeddings(conversation_history_embeddings)
    
    # 4. Feed the entire history to the conversational bot
    # This requires the `refine_prompt_conversational` method in your generator
    refined_prompt = generator.refine_prompt_multimodal(padded_embeddings)
    
    # 5. Generate and evaluate the image
    generated_image = generator.generate_image(refined_prompt)
    similarity_generated = generator.evaluate_clip_similarity(generated_image, user_text)
    
    # Evaluate uploaded image if available
    similarity_uploaded = None
    if uploaded_image is not None:
        similarity_uploaded = generator.evaluate_clip_similarity(uploaded_image, user_text)

    # Prepare the output message
    similarity_msg = f"🧠 Generated Image Similarity: {similarity_generated:.3f}"
    if similarity_uploaded is not None:
        similarity_msg += f"n📷 Uploaded Image Similarity: {similarity_uploaded:.3f}"
        
    return generated_image, similarity_msg, refined_prompt

In [None]:
'RUNNING WITH GRADIO'
demo = gr.Interface(
    fn=multimodal_app,
    inputs=[
        gr.Textbox(lines=3, label="Enter a prompt"),
        gr.Image(type="pil", label="Upload an image (optional)")
    ],
    outputs=[
        gr.Image(type="pil", label="Generated Image"),
        gr.Textbox(label="Similarity Scores"),
        gr.Textbox(label="Refined Prompt")
    ],
    title="🖼️ Multimodal Generator with CLIP Evaluation",
    description="Enter a prompt to generate an image and compare it with an uploaded image using CLIP similarity."
)

demo.launch()