# Rhomberg Jewelry Multi-Label Classification

Train a multi-label classifier on real Rhomberg jewelry data.

## Features:
- ‚úÖ **Configurable download**: Choose how many images to process
- ‚úÖ **Smart access**: Auto-detects local vs remote execution  
- ‚úÖ **HTTP download**: Fast image downloads (vs SSH)
- ‚úÖ **Caching**: Skips already downloaded images
- ‚úÖ **Multi-label**: Category, Gender, Material, Price Range

## 1. Configuration - ADJUST THESE SETTINGS

In [1]:
# ============================================================
# DATASET CONFIGURATION
# ============================================================

# How many images per category?
MAX_IMAGES_PER_CATEGORY = 200  # Options:
#   200  = ~1,400 images (fast testing, ~5-10 min download)
#   500  = ~3,500 images (medium, ~15-25 min download)
#   None = ALL 12,026 images (full dataset, ~1-2 hours download)

# ============================================================
# TRAINING CONFIGURATION
# ============================================================
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 20
LEARNING_RATE = 0.001

# ============================================================
# PATHS
# ============================================================
CSV_PATH = "/project/data/jewlery.csv"
OUTPUT_DIR = "/project/data/rhomberg_final"
MODEL_DIR = "/project/models"

print("‚úì Configuration loaded")
print(f"  Images per category: {MAX_IMAGES_PER_CATEGORY if MAX_IMAGES_PER_CATEGORY else 'ALL'}")
print(f"  Training epochs: {EPOCHS}")

‚úì Configuration loaded
  Images per category: 200
  Training epochs: 20


## 2. Setup and Imports

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import transforms, models
from PIL import Image
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import requests
import shutil
import warnings
warnings.filterwarnings('ignore')

# Create directories
Path(OUTPUT_DIR).mkdir(exist_ok=True, parents=True)
Path(MODEL_DIR).mkdir(exist_ok=True, parents=True)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"PyTorch: {torch.__version__}")
print(f"Device: {DEVICE}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
print("‚úì Imports complete!")

## 3. Download/Access Images

**This will:**
- Auto-detect if on spark (local files) or remote (HTTP download)
- Process CSV and extract metadata
- Download/copy images with caching
- Create training-ready dataset

In [None]:
# Helper functions
def is_on_spark():
    return Path('/mnt/img/jpeg/detailbilder').exists()

def clean_category(product_type):
    if pd.isna(product_type): return 'unknown'
    parts = str(product_type).split('>')
    if len(parts) > 0:
        main = parts[0].strip().lower()
        mapping = {'fingerringe': 'rings', 'ohrschmuck': 'earrings', 'halsschmuck': 'necklaces',
                   'armschmuck': 'bracelets', 'anh√§nger': 'pendants', 'piercing': 'piercing', 'fu√üketten': 'anklets'}
        for de, en in mapping.items():
            if de in main: return en
        return main.split()[0] if main else 'unknown'
    return 'unknown'

def extract_material(material_str):
    if pd.isna(material_str): return 'unknown'
    material = str(material_str).lower()
    if 'platin' in material: return 'platinum'
    elif 'gold' in material: return 'gold'
    elif 'silber' in material or 'silver' in material: return 'silver'
    elif 'edelstahl' in material or 'stainless' in material: return 'stainless_steel'
    elif 'titan' in material: return 'titan'
    else: return material.split()[0] if material.split() else 'unknown'

def download_image_http(url, dest_path):
    try:
        response = requests.get(url, timeout=10)
        response.raise_for_status()
        with open(dest_path, 'wb') as f:
            f.write(response.content)
        return True
    except: return False
\n",
def copy_local_image(product_id, dest_path):\n",
    try:\n",
        # Use 360px images (good quality, reasonable size)\n",
        source_path = Path('/mnt/img/jpeg/detailbilder') / '360' / f\"{product_id}.jpg\"\n",
        if source_path.exists():\n",
            shutil.copy2(source_path, dest_path)\n",
            return True\n",
        return False\n",
    except: return False

# Process data
print("="*70)
print("PROCESSING RHOMBERG DATA")
print("="*70)

on_spark = is_on_spark()
print(f"\n{'üéØ ON spark' if on_spark else 'üåê REMOTE'} ‚Üí {'Local files' if on_spark else 'HTTP download'}")

df_raw = pd.read_csv(CSV_PATH, sep='\t', on_bad_lines='skip')
print(f"Total products: {len(df_raw):,}")

df_raw['category'] = df_raw['product_type'].apply(clean_category)
df_raw['material_clean'] = df_raw['material'].apply(extract_material)
df_raw['gender_clean'] = df_raw['gender'].fillna('unisex')
df_raw['price_clean'] = df_raw['price'].str.replace(' EUR', '').str.replace(',', '.').astype(float, errors='ignore')
df_raw['price_range'] = df_raw['price_clean'].apply(lambda p: 'budget' if p < 50 else ('mid_range' if p < 100 else ('premium' if p < 300 else 'luxury')) if not pd.isna(p) else 'unknown')

df_with_images = df_raw[df_raw['image_link'].notna()].copy()

if MAX_IMAGES_PER_CATEGORY:
    print(f"\nSampling {MAX_IMAGES_PER_CATEGORY} per category...")
    sampled = []
    for cat in df_with_images['category'].unique():
        cat_df = df_with_images[df_with_images['category'] == cat]
        n = min(MAX_IMAGES_PER_CATEGORY, len(cat_df))
        sampled.append(cat_df.sample(n=n, random_state=42))
        print(f"  {cat:15s}: {n:4d}")
    df_to_process = pd.concat(sampled, ignore_index=True)
else:
    df_to_process = df_with_images

print(f"\nProcessing {len(df_to_process):,} images...")

images_dir = Path(OUTPUT_DIR) / 'images'
images_dir.mkdir(exist_ok=True, parents=True)

results = []
found, not_found, skipped = 0, 0, 0

for _, row in tqdm(df_to_process.iterrows(), total=len(df_to_process)):
    cat, pid = row['category'], row['id']
    cat_dir = images_dir / cat
    cat_dir.mkdir(exist_ok=True)
    dest = cat_dir / f"{cat}_{pid}.jpg"
    
    if dest.exists():
        skipped += 1
        found += 1
    elif (copy_local_image(pid, dest) if on_spark else download_image_http(row['image_link'], dest)):
        found += 1
    else:
        not_found += 1
        continue
    
    results.append({
        'filename': f"{cat}_{pid}.jpg",
        'filepath': str(dest),
        'product_id': pid,
        'title': row['title'],
        'category': cat,
        'gender': row['gender_clean'],
        'material': row['material_clean'],
        'price': row['price_clean'],
        'price_range': row['price_range']
    })

df = pd.DataFrame(results)
metadata_path = Path(OUTPUT_DIR) / 'jewelry_metadata.csv'
df.to_csv(metadata_path, index=False)

print(f"\n‚úì Processed: {found:,} | ‚ö° Cached: {skipped:,} | ‚úó Failed: {not_found:,}")
print(f"‚úì Saved: {metadata_path}")

## 4. Analyze Dataset

Check the distribution of categories, materials, gender, and price ranges.