## Having a look at the farfetch catalog

In particular, image taggings/descriptions!

## Package imports

In [None]:
from tqdm.notebook import tqdm
import itertools


In [None]:
import numpy as np
import pandas as pd

import ollama

import matplotlib.pyplot as plt
import plotly_express as px

In [None]:
import re
import json
import time
import base64
import io

from pathlib import Path
from PIL import Image

In [None]:
from typing import Dict, List

In [None]:
%load_ext autoreload
%autoreload 2

## Farfetch dataset

In [None]:
# Get images
image_paths = Path('../data/farfetch/images').glob('*.jpg')
image_list = list(image_paths)
print(f"Found {len(image_list)} images, ", image_list[0])

In [None]:
df_farfetch = pd.read_json("../data/farfetch/farfetch_fashion_dataset_images_crawlfeeds.json")
df_farfetch.head()

In [None]:
import matplotlib.image as mpimg

# Display first 25 images in 5x5 grid
fig, axes = plt.subplots(5, 5, figsize=(15, 15))
axes = axes.flatten()  # Convert 2D array to 1D for easier indexing

for i, ax in enumerate(axes):
    if i < len(image_list) and i < 25:  # Ensure we don't exceed available images
        img = mpimg.imread(str(image_list[i]))
        ax.imshow(img)
        ax.set_title(f"{image_list[i].name}", fontsize=8)  # Optional: show filename
        ax.axis('off')  # Remove axes
    else:
        ax.axis('off')  # Hide empty subplots

plt.tight_layout()
plt.show()

## Testing Ollama for image description

In [None]:
# Initialize client
client = ollama.Client()

# Check available models
print("Available models:", [m['model'] for m in client.list()['models']])

In [None]:


def encode_image_to_base64(image_path):
    """Encode image to base64 for Ollama"""
    with open(image_path, "rb") as f:
        return base64.b64encode(f.read()).decode("utf-8")
    
def resize_and_encode_image(image_path, max_width=256):
    """Resize, then encode image to base64 string"""
    image = Image.open(image_path)

    # Resize if needed
    if image.width > max_width:
        ratio = max_width / image.width
        new_height = int(image.height * ratio)
        image = image.resize((max_width, new_height), Image.Resampling.LANCZOS)

    # Display the resized image in notebook
    # display(image)

    # Convert to base64
    buffered = io.BytesIO()
    image.save(buffered, format="JPEG")
    img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')

    return img_str


def test_model_performance(
    model_name: str, encoded_image: str, system_prompt: str, user_prompt: str
) -> Dict:
    """Test a single model and return results with timing"""

    try:
        start_time = time.time()

        response = ollama.chat(
            model=model_name,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt, "images": [encoded_image]},
            ],
        )

        end_time = time.time()
        duration = end_time - start_time

        return {
            "model": model_name,
            "success": True,
            "duration": duration,
            "response": response.message.content,
            "response_length": len(response.message.content),
            "error": None,
        }

    except Exception as e:
        return {
            "model": model_name,
            "success": False,
            "duration": None,
            "response": None,
            "response_length": 0,
            "error": str(e),
        }


def run_model_comparison(image_path: str, models: List[str]) -> List[Dict]:
    """Run comparison across multiple models"""

    # Define prompts
    system_prompt = """You are a helpful AI assistant specialized in clothing and fashion analysis. 
    Follow instructions precisely, be concise and objective. 
    Only describe what is clearly visible in the image. 
    Provide structured responses exactly as requested."""

    user_prompt = """Analyze the clothing and fashion items in this image and provide:

    1. ONE-WORD TAGS for each category (use commas to separate multiple tags):
    - Colors: (e.g., red, blue, white)
    - Clothing types: (e.g., shirt, pants, jacket, dress)
    - Styles: (e.g., casual, formal, vintage, sporty)
    - Occasions: (e.g., party, work, outdoor, evening)
    - Materials: (e.g., cotton, leather, silk, denim)
    - Patterns: (e.g., striped, floral, solid, plaid)
    - Fits: (e.g., slim, loose, regular)
    - Seasons: (e.g., summer, winter, all-season)

    2. DESCRIPTIONS: One concise sentence per distinct clothing item, highlighting color, style, and key features.

    Only include tags you can clearly observe. Be precise and objective. Format your response clearly."""

    results = []

    print(f"Testing {len(models)} models on image: {image_path}")
    print("=" * 60)

    encoded_image = resize_and_encode_image(image_path)

    for model_name in models:
        # print(f"\n🔍 Testing: {model_name}")

        result = test_model_performance(
            model_name, encoded_image, system_prompt, user_prompt
        )
        result["image"] = image_path.name
        results.append(result)

        if result["success"]:
            print(f"⏱️  Time: {result['duration']:.2f}s")
            print(f"📝 Response length: {result['response_length']} characters")
            print(f"💬 Response preview: {result['response'][:100]}...")
        else:
            print(f"❌ Error: {result['error']}")

        print("-" * 40)

    return results


# Usage
def main(image_path):
    models = [
        "gemma3:4b",
        "granite3.2-vision:2b",  
        "qwen2.5vl:3b",
        "llava:7b",
        "qwen2.5vl:7b",
        "llama3.2-vision:latest",
    ]

    # Run tests
    results = run_model_comparison(image_path, models)

    # Summary analysis
    print("\n" + "=" * 60)
    print("📊 PERFORMANCE SUMMARY")
    print("=" * 60)

    successful_results = [r for r in results if r["success"]]

    if successful_results:
        # Sort by speed
        successful_results.sort(key=lambda x: x["duration"])

        print("\n🏃 Speed Ranking (fastest first):")
        for i, result in enumerate(successful_results, 1):
            print(f"{i}. {result['model']}: {result['duration']:.2f}s")

        # Response quality comparison
        print("\n📝 Response Length Comparison:")
        for result in successful_results:
            print(f"{result['model']}: {result['response_length']} chars")

    return results


In [None]:
results = main(image_list[0])

#### Comparing models

In [None]:
# all_results.to_csv("../results/model_comparison_results.csv", index=False)
all_results = pd.read_csv("../results/model_comparison_results.csv")
all_results.head()

In [None]:
fig = px.scatter(
    all_results,
    x=all_results.index,
    y="duration",
    color="model",
    title="Model Performance Over Test Cases",
    labels={
        "duration": "Response Time (seconds)",
        "index": "Test Case Index",
        "model": "Model"
    },
    hover_data=["image", "success", "response_length"],
    range_y=[0,60]
)
fig.update_layout(height=500)
fig.show()


In [None]:
tags = ["Colors", "Clothing types", "Styles", "Occasions", "Materials", "Patterns", "Fits", "Seasons"]

In [None]:
# def extract_tag_values(response, tag_keys):
#     """Safely extract tag values, handling None or empty responses"""
#     extracted = {}
    
#     if not response:
#         print(0)
#         return {tag: [] for tag in tag_keys}
        
#     for tag in tag_keys:
#         if tag + ":" in response:
#             tag_response = response.split(tag + ":")
#             if len(tag_response) > 1:
#                 value = tag_response[1].split(";")[0].split("\n", 1)[0].strip()
#                 values = [v.strip().lower() for v in value.split(",")]
#                 extracted[tag] = values
#             else:
#                 extracted[tag] = []
#         else:
#             extracted[tag] = []
#     return extracted

def extract_tag_values(response, tag_keys):
    """Safely extract tag values, handling None, NaN, or empty responses"""
    extracted = {tag: [] for tag in tag_keys}  # Initialize with empty lists
    
    # Handle None, NaN, or non-string values
    if not response or not isinstance(response, str):
        return extracted
        
    for tag in tag_keys:
        if tag + ":" in response:
            tag_response = response.split(tag + ":")
            if len(tag_response) > 1:
                value = tag_response[1].split(";")[0].split("\n", 1)[0].strip()
                values = [v.strip().lower() for v in value.split(",")]
                extracted[tag] = values
            else:
                extracted[tag] = []
        else:
            extracted[tag] = []
    return extracted


# Use the safe function
all_tags = all_results['response'].apply(lambda x: extract_tag_values(x, tags))
all_tags_df = pd.DataFrame(all_tags.tolist())


In [None]:
all_tags_df

In [None]:

all_results_with_tags = pd.concat([all_results.reset_index(drop=True), all_tags_df], axis=1)
all_results_with_tags

In [None]:
from IPython.display import display, HTML
import os
from PIL import Image as PILImage
import io
import base64
import pandas as pd

def pil_image_to_base64_str(img):
    """Convert PIL Image to base64 string for HTML embedding"""
    buffered = io.BytesIO()
    img.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode()
    return img_str

def create_simple_comparison_table(df, image_name, tag_columns):
    """Create comparison table with tags as rows and models as columns"""
    image_data = df[df['image'] == image_name]
    comparison_data = {}
    
    for _, row in image_data.iterrows():
        model = row['model']
        model_tags = {}
        for tag in tag_columns:
            if tag in row:
                model_tags[tag] = row[tag] 
            else:
                model_tags[tag] = "No data"
        comparison_data[model] = model_tags
    
    comparison_df = pd.DataFrame(comparison_data)
    return comparison_df

def display_image_and_table_side_by_side(image_path, comparison_df, width=300):
    """Display image and comparison table side by side using HTML"""
    
    if not os.path.isfile(image_path):
        print(f"Image not found: {image_path}")
        display(comparison_df)
        return

    # Load and resize image
    img = PILImage.open(image_path)
    img.thumbnail((width, width), PILImage.LANCZOS)
    img_data = pil_image_to_base64_str(img)

    # Create side-by-side HTML layout
    html = f'''
    <div style="display: flex; align-items: flex-start; gap: 20px; margin-bottom: 40px; 
                border: 1px solid #ddd; padding: 15px; border-radius: 8px;">
      <div style="flex-shrink: 0;">
        <h4 style="margin-top: 0; color: #333;">{os.path.basename(image_path)}</h4>
        <img src="data:image/png;base64,{img_data}" alt="Clothing Image" 
             style="max-width: {width}px; height: auto; border-radius: 8px; 
                    border: 2px solid #eee; box-shadow: 0 2px 8px rgba(0,0,0,0.1);"/>
      </div>
      <div style="flex-grow: 1; overflow-x: auto;">
        <h4 style="margin-top: 0; color: #333;">Model Comparison</h4>
        {comparison_df.to_html(border=0, table_id='comparison_table', 
                               classes='table table-striped', escape=False)}
      </div>
    </div>
    '''
    display(HTML(html))

# Your improved main loop
tags = ["Colors", "Clothing types", "Styles", "Occasions", "Materials", "Patterns", "Fits", "Seasons"]
unique_images = all_results_with_tags['image'].unique()

for img in unique_images[:50]:
    image_path = f"../data/farfetch/images/{img}"
    simple_table = create_simple_comparison_table(all_results_with_tags, img, tags)
    display_image_and_table_side_by_side(image_path, simple_table)


In [None]:
all_colors = [color for sublist in all_results_with_tags['Colors'] for color in sublist if sublist]
color_counts = pd.Series(all_colors).value_counts().reset_index()
color_counts.columns = ['color', 'count']
color_counts

# Create bar plot with plotly express
fig = px.bar(
    color_counts[color_counts["count"] > 5], 
    x='color', 
    y='count',
    title='Frequency of Clothing Colors Across All Responses',
    labels={'color': 'Color', 'count': 'Frequency'},
    color='color'  # Optional: color bars by the color names
)

# fig.update_layout(xaxis_tickangle=-45)  # Rotate x-axis labels if needed
fig.show()

In [None]:
all_types = [t for sublist in all_results_with_tags['Clothing types'] for t in sublist if sublist]
type_counts = pd.Series(all_types).value_counts().reset_index()
type_counts.columns = ['type', 'count']
type_counts

# Create bar plot with plotly express
fig = px.bar(
    type_counts[type_counts["count"] > 5], 
    x='type', 
    y='count',
    title='Frequency of Clothing Types Across All Responses',
    labels={'type': 'Type', 'count': 'Frequency'},
    color='type'  # Optional: color bars by the type names
)

# fig.update_layout(xaxis_tickangle=-45)  # Rotate x-axis labels if needed
fig.show()

In [None]:
all_seasons = [t for sublist in all_results_with_tags['Seasons'] for t in sublist if sublist]
season_counts = pd.Series(all_seasons).value_counts().reset_index()
season_counts.columns = ['season', 'count']
season_counts

# Create bar plot with plotly express
fig = px.bar(
    season_counts[season_counts["count"] > 5], 
    x='season', 
    y='count',
    title='Frequency of Clothing Seasons Across All Responses',
    labels={'season': 'Season', 'count': 'Frequency'},
    color='season'  # Optional: color bars by the season names
)

# fig.update_layout(xaxis_tickangle=-45)  # Rotate x-axis labels if needed
fig.show()

In [None]:
unique_images = all_results_with_tags['image'].unique()
for img in unique_images:
    print(f"Image: {img}")
    subset = all_results_with_tags[all_results_with_tags['image'] == img]
    for _, row in subset.iterrows():
        print(f"  {row['model']} -> {row['Colors']}")
    print("-"*20)


In [None]:
def clean_descriptions(desc_text):
    """Clean and normalize description text to consistent format"""
    # Remove empty lines and split by newlines
    lines = [line.strip() for line in desc_text.splitlines() if line.strip()]
    
    # Remove bullet points, asterisks, dashes, and numbered lists
    lines = [re.sub(r'^[-*•\d\.]+\s*', '', line) for line in lines]
    
    # Remove any remaining formatting markers like ** or ***
    lines = [re.sub(r'\*+', '', line).strip() for line in lines]
    
    # Join all sentences into a single paragraph
    cleaned_text = ' '.join(lines)
    
    # Add period at the end if missing
    if cleaned_text and cleaned_text[-1] not in '.!?':
        cleaned_text += '.'
        
    return cleaned_text

In [None]:
for i in [0,50,100,150]:
    parts = re.split(
        r"descriptions:", all_results.iloc[i]["response"], flags=re.IGNORECASE, maxsplit=1
    )
    desc = [part.strip() for part in parts][-1]
    print(len(desc))
    if len(desc)>1:
        print(desc)
        print(clean_descriptions(desc))
    print("---")

## More structured tags

### Run analysis on full catalogue

In [None]:
colours = [
    "Black",
    "White",
    "Off White",
    "Light Beige",
    "Beige",
    "Grey",
    "Light Blue",
    "Light Grey",
    "Dark Blue",
    "Dark Grey",
    "Pink",
    "Dark Red",
    "Greyish Beige",
    "Light Orange",
    "Silver",
    "Gold",
    "Light Pink",
    "Dark Pink",
    "Yellowish Brown",
    "Blue",
    "Light Turquoise",
    "Yellow",
    "Greenish Khaki",
    "Dark Yellow",
    "Other Pink",
    "Dark Purple",
    "Red",
    "Transparent",
    "Dark Green",
    "Other Red",
    "Turquoise",
    "Dark Orange",
    "Orange",
    "Dark Beige",
    "Other Yellow",
    "Light Green",
    "Other Orange",
    "Purple",
    "Light Red",
    "Light Yellow",
    "Green",
    "Light Purple",
    "Dark Turquoise",
    "Other Purple",
    "Bronze/Copper",
    "Other Turquoise",
    "Other Green",
    "Other Blue",
    # "Unknown",
]

In [None]:
genres = ["Men", "Women"]

In [None]:
types = [
    "Trousers",
    "Dress",
    "Sweater",
    "T-shirt",
    "Top",
    "Blouse",
    "Jacket",
    "Shorts",
    "Shirt",
    "Vest top",
    "Underwear bottom",
    "Skirt",
    "Hoodie",
    "Bra",
    "Socks",
    "Leggings/Tights",
    "Shoes",
    "Cardigan",
    "Hat/beanie",
    "Pyjama set",
    "Blazer",
    "Scarf",
    "Swimsuit",
    "Coat",
    "Belt",
    "Polo shirt",
    "Gloves",
    "Tie",
    "Robe",
    # "Other",
]


In [None]:

# flat_list = list(itertools.chain.from_iterable(all_results_with_tags["Styles"].to_list()))
# set(flat_list)
styles = {
    'artsy', 'athletic', 'beachwear', 'biker', 'bohemian', 'business',
    'casual', 'chic', 'classic', 'cocktail', 'edgy', 'elegant', 
    'evening', 'fitted', 'flowy', 'formal', 'gothic', 'graphic', 
    'grunge', 'knit', 'maxi', 'minimalist', 'modern', 'monogrammed', 
    'oversized', 'preppy', 'punk', 'romantic', 'retro', 'sneaker', 
    'sporty', 'streetwear', 'tropical', 'urban', 'vintage', 'workwear', 
    # 'unknown'
}



In [None]:
# flat_list = list(itertools.chain.from_iterable(all_results_with_tags["Occasions"].to_list()))
# set(flat_list)

occasions = {
    'athletic', 'brunch', 'casual', 'cocktail', 'date', 'daytime', 
    'dinner', 'evening', 'everyday', 'formal', 'funeral', 
    'graduation', 'holiday', 'interview', 'loungewear', 'maternity', 
    'networking', 'outdoor', 'party', 'religious', 'sleepwear', 
    'travel', 'vacation', 'wedding', 'work', 
    # 'unknown'
}

In [None]:
# flat_list = list(itertools.chain.from_iterable(all_results_with_tags["Fits"].to_list()))
# set(flat_list)

fits = {
    'athletic', 'boxy', 'compact', 'cropped', 'fitted', 'flowing', 
    'form-fitting', 'oversized', 'regular', 'relaxed', 'skinny', 
    'slim', 'straight', 'tailored', 'wide', 
    # 'unknown'
}


In [None]:
seasons = {'spring', 'summer', 'fall', 'winter'}

In [None]:
# flat_list = list(itertools.chain.from_iterable(all_results_with_tags["Materials"].to_list()))
# set(flat_list)

materials = {
    'bamboo', 'canvas', 'cashmere', 'chiffon', 'corduroy', 'cotton', 
    'denim', 'fleece', 'hemp', 'jersey', 'knit', 'leather', 'linen', 
    'lycra', 'mesh', 'metal', 'modal', 'neoprene', 'nylon', 'plaid', 
    'plastic', 'polyester', 'rayon', 'rubber', 'satin', 'silk', 
    'spandex', 'suede', 'synthetic', 'tencel', 'twill', 'velvet', 
    'viscose', 'wool', 'unknown'
}

In [None]:
# flat_list = list(itertools.chain.from_iterable(all_results_with_tags["Patterns"].to_list()))
# set(flat_list)

patterns = {
    'abstract', 'argyle', 'checkered', 'floral', 'graphic', 
    'logo', 'monogram', 'paisley', 'plaid', 'solid', 'striped'
}


In [None]:
# Define the cleaned categories from our analysis
CLOTHING_CATEGORIES = {
    "types": types,
    "genres": genres,
    "colours": colours,
    "styles": styles,
    "occasions": occasions,
    "fits": fits,
    "materials": materials,
    "patterns": patterns,
    "seasons": seasons,
}


In [None]:
def create_system_prompt(clothing_categories):
    """Create system prompt with proper category enforcement and description field."""
    
    prompt = """You are a clothing analysis AI. Return ONLY valid JSON, no other text.

STRICT RULES: You MUST only use tags from these exact lists and include a description field as a string.

"""

    # Add each category with its exact allowed values
    for category, items in clothing_categories.items():
        if isinstance(items, (set, list)):
            items_str = ', '.join(f'"{item}"' for item in sorted(items))
            prompt += f'{category.upper()}: [{items_str}]\n\n'

    prompt += f"""CRITICAL:
- Use ONLY the tags from the list: {list(clothing_categories.keys())}
- Each category must be a list of strings
- Add a "description" field with a short textual description of the clothing in the image
- If unsure about a category, use empty list []
- Return JSON only, no explanations

STRICT: Do not use any keys other than the exact category names provided.
Do not use combined or generic keys like "categories".
Every category must be present, even if empty like [].
"""

    return prompt


def create_user_prompt():
    """User prompt asking for JSON with categories and a short description."""
    return """Analyze this clothing image. Return JSON with the 9 categories as lists of strings and a "description" field with a concise summary of the clothing. Use only allowed tags. JSON only, no other text."""

In [None]:

def parse_json_simple(response_text: str) -> dict:
    """Just parse JSON as-is, no cleaning"""
    try:
        return {
            'success': True,
            'data': json.loads(response_text)
        }
    except:
        return {
            'success': False,
            'data': None
        }


In [None]:
def resize_and_encode_image(image_path, max_width=256):
    """Resize and encode image"""
    image = Image.open(image_path)
    if image.width > max_width:
        ratio = max_width / image.width
        new_height = int(image.height * ratio)
        image = image.resize((max_width, new_height), Image.Resampling.LANCZOS)
    
    buffered = io.BytesIO()
    image.save(buffered, format="JPEG")
    return base64.b64encode(buffered.getvalue()).decode('utf-8')

In [None]:
def test_model(model_name: str, encoded_image: str, system_prompt: str, user_prompt: str) -> dict:
    """Test one model"""
    try:
        start_time = time.time()
        
        response = ollama.chat(
            model=model_name,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt, "images": [encoded_image]},
            ],
        )
        
        duration = time.time() - start_time
        parsed = parse_json_simple(response.message.content)
        
        return {
            "model": model_name,
            "duration": duration,
            "response": response.message.content,
            "data": parsed['data'],
            "json_success": parsed['success']
        }
        
    except Exception as e:
        return {
            "model": model_name,
            "error": str(e),
            "data": None,
            "json_success": False
        }

In [None]:
def extract_json_block(response_text):
    if "categories" in response_text.lower():
        print(i)
    match = re.search(r'(\{.*\})', response_text, re.DOTALL)
    if match:
        json_str = match.group(1).replace('[""]', '[]').replace('["]', '[]').lower()
        try:
            data = json.loads(json_str)
        except json.JSONDecodeError as e:
            raise ValueError(f"Failed to parse JSON: {e}")
    else:
        raise ValueError("JSON block not found")
    return data

In [None]:
MAX_RETRIES = 10

def save_results(result, filename):
    with open(filename, "a", encoding="utf-8") as f:
        json_line = json.dumps(result, ensure_ascii=False)
        f.write(json_line + "\n")

def is_valid_structure(data, expected_keys):
    # Check keys exactly match expected, no extras or missing ones
    data_keys = set(data.keys())
    expected_keys = set(expected_keys)
    return data_keys == expected_keys

def call_ai_and_validate(model_name, encoded_image, system_prompt, user_prompt, expected_keys):
    for attempt in range(MAX_RETRIES):
        response_text = test_model(model_name, encoded_image, system_prompt, user_prompt)  
        json_data = extract_json_block(response_text["response"])
        
        if json_data and is_valid_structure(json_data, expected_keys):
            return response_text | json_data 
        else:
            print(f"Invalid response structure on attempt {attempt+1}, retrying...")
    raise ValueError("Failed to get valid JSON structure after retries")

def run_analysis_batch(image_paths: List[str], clothing_categories: dict, output_file: str) -> List[dict]:
    """Run analysis on multiple images with tqdm progress bar"""
    
    models = ["llava:7b", "qwen2.5vl:7b"]
    expected_keys = list(CLOTHING_CATEGORIES.keys()) + ["description"]
    
    system_prompt = create_system_prompt(clothing_categories)
    user_prompt = create_user_prompt()
        
    results = []
    
    for image_path in tqdm(image_paths, desc="Processing images"):
        try:
            encoded_image = resize_and_encode_image(image_path)
            
            for model_name in models:
                try:
                    result = call_ai_and_validate(model_name, encoded_image, system_prompt, user_prompt, expected_keys)
                    result["image_name"] = Path(image_path).name
                    result["image_path"] = str(image_path)
                    result["json_success"] = True
                    results.append(result)
                except ValueError as e:
                    # Max retries reached, log and continue
                    result = {
                        "model": model_name,
                        "image_name": Path(image_path).name,
                        "image_path": str(image_path),
                        "error": f"Max retries reached: {str(e)}",
                        "data": None,
                        "json_success": False
                    }
                    results.append(result)
                save_results(result, filename=output_file)
           
        except Exception as e:
            
            # Handle image processing errors
            for model_name in models:
                results.append({
                    "model": model_name,
                    "image_name": Path(image_path).name,
                    "image_path": str(image_path),
                    "error": f"Image processing error: {str(e)}",
                    "data": None,
                    "json_success": False
                })
    
    return results

In [None]:
# results = run_analysis_batch(image_list[:], CLOTHING_CATEGORIES, output_file="../results/farfetch.jsonl")

### Read results

In [None]:
df = pd.read_json("../results/farfetch.jsonl", lines=True)
df.shape

In [None]:
df.head()

In [None]:
def replace_nan_with_empty_list(x):
    if isinstance(x, float) and np.isnan(x):
        return []
    # Handle string "NaN" (case-insensitive, with whitespace handling)
    if isinstance(x, str) and x.strip().lower() == 'nan':
        return []
    return x

# Apply to all columns
for col in df.columns:
    df[col] = df[col].apply(replace_nan_with_empty_list)

In [None]:
df.sample()

In [None]:
def calculate_match_stats(df):
    perfect_matches = 0
    total_images = 0
    results = []  # Store detailed results
    
    for d in df["image_name"].unique():
        image_colours = df[df["image_name"] == d]["colours"]
        
        if len(image_colours) < 2:
            continue
            
        total_images += 1
        
        try:
            set1 = set(image_colours.iloc[0])
            set2 = set(image_colours.iloc[1])
            
            if set1 == set2:
                perfect_matches += 1
                results.append({"image_name": d, "match_type": "perfect", "similarity": 1.0})
            else:
                intersect = set1 & set2
                union = set1 | set2
                similarity = len(intersect) / len(union) if len(union) > 0 else 0
                results.append({"image_name": d, "match_type": "partial", "similarity": similarity})
                
        except (TypeError, AttributeError):
            results.append({"image_name": d, "match_type": "error", "similarity": 0})
    
    perfect_percentage = (perfect_matches / total_images * 100) if total_images > 0 else 0
    
    return {
        "perfect_matches": perfect_matches,
        "total_images": total_images,
        "perfect_percentage": perfect_percentage,
        "results": results
    }

# Usage
stats = calculate_match_stats(df)
print(f"Perfect match percentage: {stats['perfect_percentage']:.2f}%")
print(f"Perfect matches: {stats['perfect_matches']}")
print(f"Total valid images: {stats['total_images']}")


In [None]:
import pandas as pd


def create_image_match_database(df):
    records = []

    for d in df["image_name"].unique():
        record = {"image_name": d}  # Start with image name
        for tag in CLOTHING_CATEGORIES.keys():
            tag_lower = tag.lower()
            image_tags = df[df["image_name"] == d][tag_lower]

            # Handle cases with insufficient data
            if len(image_tags) < 2:
                record[f"{tag_lower}_intersection"] = []
                record[f"{tag_lower}_union"] = []
                record[f"{tag_lower}_match_percentage"] = 0.0
                continue

            try:
                set1 = set(image_tags.iloc[0])
                set2 = set(image_tags.iloc[1])

                intersect = list(set1 & set2)
                union = list(set1 | set2)
                match_percentage = (len(intersect) / len(union)) * 100 if union else 0

                # Add to the SAME record
                record[f"{tag_lower}_intersection"] = intersect
                record[f"{tag_lower}_union"] = union
                record[f"{tag_lower}_match_percentage"] = match_percentage

            except (TypeError, AttributeError):
                record[f"{tag_lower}_intersection"] = []
                record[f"{tag_lower}_union"] = []
                record[f"{tag_lower}_match_percentage"] = 0.0

        records.append(record)  # Append complete record ONCE per image

    return pd.DataFrame(records)


# Create the database
image_match_db = create_image_match_database(df)

In [None]:
image_match_db.sample()

In [None]:
categories = [col for col in image_match_db.columns if col.endswith('_match_percentage')]
categories

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

categories = [col for col in image_match_db.columns if col.endswith('_match_percentage')][:9]

fig = make_subplots(
    rows=3, cols=3,
    subplot_titles=[cat.replace('_match_percentage', '').replace('_', ' ').title() for cat in categories],
    specs=[[{"secondary_y": False}]*3 for _ in range(3)],
    horizontal_spacing=0.08,
    vertical_spacing=0.12
)

# Define a color palette with 9 distinct colors (Plotly default colors)
colors = ['#636EFA', '#EF553B', '#00CC96', '#AB63FA', '#FFA15A', 
          '#19D3F3', '#FF6692', '#B6E880', '#FF97FF']

for i, cat in enumerate(categories):
    row = (i // 3) + 1
    col = (i % 3) + 1
    fig.add_trace(
        go.Histogram(
            x=image_match_db[cat],
            nbinsx=15,
            histnorm='percent',  # Normalize to 100%
            name=cat.replace('_match_percentage', ''),
            showlegend=False,
            marker_color=colors[i],
            opacity=0.8,
            marker_line_width=1,
            marker_line_color='white'
        ),
        row=row, col=col
    )

fig.update_layout(
    title_text="Match Percentage Distributions (Normalized)",
    height=900,
    width=1200,
    showlegend=False
)

fig.show()


In [None]:
image_match_db.to_csv("../results/farfetch_image_match_database.csv", index=False)

### Further tests: unique colour extraction

In [None]:
colours

In [None]:
prompt=""
for items in colours:
        prompt += f' [{items}]\n\n'
prompt

In [None]:
def create_system_prompt():
    """Create system prompt with strict one-word enforcement."""
    prompt = f"""You are a clothing color detection AI. You MUST return EXACTLY one word.

STRICT RULES:
- Return ONLY one color from this list: {colours}
- NO punctuation, NO explanations, NO sentences
- NO multiple colors (e.g., NOT "red and blue")
- NO descriptors (e.g., NOT "dark blue", just "blue")
- NO phrases (e.g., NOT "reddish brown", just "red")
- If uncertain between colors, pick the most dominant one
- ONLY output the single color word

Example responses: "red", "blue", "green" (NOT "The main color is red")
"""
    return prompt

def create_user_prompt():
    """User prompt with additional constraints."""
    return """What is the main color? Answer with ONE WORD ONLY from the approved color list."""


def parse_json_simple(response_text: str) -> dict:
    """Just parse JSON as-is, no cleaning"""
    try:
        return {"success": True, "data": json.loads(response_text)}
    except:
        return {"success": False, "data": None}


def resize_and_encode_image(image_path, max_width=256):
    """Resize and encode image"""
    image = Image.open(image_path)
    if image.width > max_width:
        ratio = max_width / image.width
        new_height = int(image.height * ratio)
        image = image.resize((max_width, new_height), Image.Resampling.LANCZOS)

    buffered = io.BytesIO()
    image.save(buffered, format="JPEG")
    return base64.b64encode(buffered.getvalue()).decode("utf-8")


def test_model(
    model_name: str, encoded_image: str, system_prompt: str, user_prompt: str
) -> dict:
    """Test one model"""
    try:
        start_time = time.time()

        response = ollama.chat(
            model=model_name,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt, "images": [encoded_image]},
            ],
        )

        duration = time.time() - start_time
        parsed = parse_json_simple(response.message.content)

        return {
            "model": model_name,
            "duration": duration,
            "response": response.message.content,
            "data": parsed["data"],
            "json_success": parsed["success"],
        }

    except Exception as e:
        return {
            "model": model_name,
            "error": str(e),
            "data": None,
            "json_success": False,
        }


def call_ai_and_validate(model_name, encoded_image, system_prompt, user_prompt):
    response_text = test_model(model_name, encoded_image, system_prompt, user_prompt)
    return response_text


def run_analysis_batch(image_paths: List[str]) -> List[dict]:
    """Run analysis on multiple images with tqdm progress bar"""

    models = ["llava:7b", "qwen2.5vl:7b"]

    system_prompt = create_system_prompt()
    user_prompt = create_user_prompt()

    results = []

    for image_path in tqdm(image_paths, desc="Processing images"):
        try:
            encoded_image = resize_and_encode_image(image_path)

            for model_name in models:
                try:
                    result = call_ai_and_validate(
                        model_name, encoded_image, system_prompt, user_prompt
                    )
                    result["image_name"] = Path(image_path).name
                    result["image_path"] = str(image_path)
                    result["json_success"] = True
                    results.append(result)
                except ValueError as e:
                    # Max retries reached, log and continue
                    result = {
                        "model": model_name,
                        "image_name": Path(image_path).name,
                        "image_path": str(image_path),
                        "error": f"Max retries reached: {str(e)}",
                        "data": None,
                        "json_success": False,
                    }
                    results.append(result)
                # save_results(result, filename=output_file)

        except Exception as e:
            # Handle image processing errors
            for model_name in models:
                results.append(
                    {
                        "model": model_name,
                        "image_name": Path(image_path).name,
                        "image_path": str(image_path),
                        "error": f"Image processing error: {str(e)}",
                        "data": None,
                        "json_success": False,
                    }
                )

    return results

In [None]:
run_analysis_batch(image_list[:10])

## kNN search

In [None]:
df_catalogue = pd.read_csv("../results/farfetch_image_match_database.csv")
df_catalogue.head()

In [None]:
import ast # Convert all string list columns to actual lists upfront
def convert_string_lists_to_actual_lists(df, list_columns):
    """Convert string representations of lists to actual lists"""
    df_clean = df.copy()
    
    for col in list_columns:
        df_clean[col] = df_clean[col].apply(
            lambda x: ast.literal_eval(x) if isinstance(x, str) and x.startswith('[') else x
        )
    
    return df_clean

# Clean your data first
cols = df_catalogue.columns[df_catalogue.columns.str.contains("union") | df_catalogue.columns.str.contains("intersection")].to_list()

df_clean = convert_string_lists_to_actual_lists(df_catalogue, cols)
df_clean.head()


In [None]:
set(df["genres"].sum())

In [None]:
set(df["colours"].sum())

In [None]:
tags = set([c.split('_')[0] for c in df_clean.columns])
tags.remove('image')
tags

In [None]:
from sklearn.preprocessing import MultiLabelBinarizer

def encode_list_column(df, tag, suffix='_union'):
    """Properly encode a pandas series containing lists into individual binary columns"""
    
    # Use MultiLabelBinarizer - this is the correct tool for the job
    mlb = MultiLabelBinarizer()
    encoded = mlb.fit_transform(df[tag + suffix])
    
    # Create proper column names for individual elements
    feature_names = [f"{tag}_{class_}" for class_ in mlb.classes_]
    
    return pd.DataFrame(encoded, columns=feature_names, index=df.index)

def encode_all_list_columns(df, tags, suffix='_union'):
    """Encode multiple list columns and concatenate into single DataFrame"""
    encoded_dfs = []
    
    for tag in tags:
        encoded = encode_list_column(df, tag, suffix=suffix)
        encoded_dfs.append(encoded)
    
    # Combine all encoded features
    features_df = pd.concat(encoded_dfs, axis=1)
    print(f"\nTotal features created: {features_df.shape[1]}")
    return features_df

# Usage
df_encoded = encode_all_list_columns(df_clean, tags)
df_encoded.head()


In [None]:
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
# Fit kNN model
knn = NearestNeighbors(n_neighbors=10, metric='cosine')
knn.fit(df_encoded)

In [None]:
def encode_query_tags(user_tags, df=df_encoded):
    """Encode user query tags into the same format as the training data"""
    df_query = pd.DataFrame(0, index=[0], columns=df.columns)
    for k, l in user_tags.items():
        for v in l:
            df_query.loc[0, f"{k}_{v}"] = 1
    return df_query


In [None]:
user_tags = {'colours': ['black'], 'types': ['pants'], 'styles': ['casual'], 'seasons': ['summer']}

In [None]:
# Get recommendations for query tags
query_encoded = encode_query_tags(user_tags)
query_encoded

In [None]:
distances, indices = knn.kneighbors(query_encoded)
distances, indices

In [None]:
import matplotlib.image as mpimg

fig, axes = plt.subplots(2, 5, figsize=(10, 5))
axes = axes.flatten()  # Convert 2D array to 1D for easier indexing

for i, ax in enumerate(axes):
    if i < distances.shape[1]:  # Ensure we don't exceed available images
        img = mpimg.imread(str(image_list[indices[0,i]]))
        ax.imshow(img)
        ax.set_title(f"{image_list[i].name}", fontsize=8)  # Optional: show filename
        ax.axis('off')  # Remove axes
    else:
        ax.axis('off')  # Hide empty subplots

plt.tight_layout()
plt.show()

In [None]:
knn.kneighbors(query_encoded)