# Rhomberg Jewelry Multi-Label Classification

Train a multi-label classifier on real Rhomberg jewelry data.

## Features:
- ‚úÖ **Configurable download**: Choose how many images to process
- ‚úÖ **Smart access**: Auto-detects local vs remote execution  
- ‚úÖ **HTTP download**: Fast image downloads (vs SSH)
- ‚úÖ **Caching**: Skips already downloaded images
- ‚úÖ **Multi-label**: Category, Gender, Material, Price Range

## 1. Configuration - ADJUST THESE SETTINGS

In [4]:
# ============================================================
# DATASET CONFIGURATION
# ============================================================

# How many images per category?
MAX_IMAGES_PER_CATEGORY = None  # Options:
#   200  = ~1,400 images (fast testing, ~5-10 min download)
#   500  = ~3,500 images (medium, ~15-25 min download)
#   None = ALL 12,026 images (full dataset, ~1-2 hours download)

# ============================================================
# TRAINING CONFIGURATION
# ============================================================
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 20
LEARNING_RATE = 0.001

# ============================================================
# PATHS
# ============================================================
CSV_PATH = "/project/data/jewlery.csv"
OUTPUT_DIR = "/project/data/rhomberg_final"  # Persists on host
MODEL_DIR = "/project/models"                 # Persists on host

# NOTE: Images ARE tracked by git in AI Studio
# To avoid this, either:
# 1. Add 'data/rhomberg_final/' to .gitignore (done)
# 2. Don't commit these files to git
# 3. Use git-lfs for large files

print("‚úì Configuration loaded")
print(f"  Images per category: {MAX_IMAGES_PER_CATEGORY if MAX_IMAGES_PER_CATEGORY else 'ALL'}")
print(f"  Training epochs: {EPOCHS}")
print(f"‚ö†Ô∏è  Images in git-tracked folder - add to .gitignore!")

‚úì Configuration loaded
  Images per category: ALL
  Training epochs: 20
‚ö†Ô∏è  Images in git-tracked folder - add to .gitignore!


## 2. Setup and Imports

In [5]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import transforms, models
from PIL import Image
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import requests
import shutil
import warnings
warnings.filterwarnings('ignore')

# Create directories
Path(OUTPUT_DIR).mkdir(exist_ok=True, parents=True)
Path(MODEL_DIR).mkdir(exist_ok=True, parents=True)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"PyTorch: {torch.__version__}")
print(f"Device: {DEVICE}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
print("‚úì Imports complete!")

PyTorch: 2.6.0a0+df5bbc09d1.nv24.11
Device: cuda
CUDA: True
GPU: NVIDIA GB10
‚úì Imports complete!


## 3. Download/Access Images

**This will:**
- Auto-detect if on spark (local files) or remote (HTTP download)
- Process CSV and extract metadata
- Download/copy images with caching
- Create training-ready dataset

In [None]:
# Helper functions
def is_on_spark():
    return Path('/mnt/img/jpeg/detailbilder').exists()

def clean_category(product_type):
    if pd.isna(product_type): return 'unknown'
    parts = str(product_type).split('>')
    if len(parts) > 0:
        main = parts[0].strip().lower()
        mapping = {'fingerringe': 'rings', 'ohrschmuck': 'earrings', 'halsschmuck': 'necklaces',
                   'armschmuck': 'bracelets', 'anh√§nger': 'pendants', 'piercing': 'piercing', 'fu√üketten': 'anklets'}
        for de, en in mapping.items():
            if de in main: return en
        return main.split()[0] if main else 'unknown'
    return 'unknown'

def extract_material(material_str):
    """
    Extract and normalize material names.
    Simply takes the first word and lowercases it.
    This automatically handles all materials: Silber, Gold, Edelstahl, Titan, 
    Platin, Wolfram, Tantal, Keramik, Palladium, etc.
    """
    if pd.isna(material_str): 
        return 'unknown'
    
    material = str(material_str).strip()
    if not material:
        return 'unknown'
    
    # Get first word and normalize to lowercase
    first_word = material.split()[0].lower()
    return first_word

def download_image_http(url, dest_path):
    try:
        response = requests.get(url, timeout=10)
        response.raise_for_status()
        with open(dest_path, 'wb') as f:
            f.write(response.content)
        return True
    except:
        return False

def copy_local_image(product_id, dest_path):
    try:
        source_path = Path('/mnt/img/jpeg/detailbilder') / '360' / f"{product_id}.jpg"
        if source_path.exists():
            shutil.copy2(source_path, dest_path)
            return True
        return False
    except:
        return False

# Process data
print("="*70)
print("PROCESSING RHOMBERG DATA")
print("="*70)

on_spark = is_on_spark()
print(f"\n{'üéØ ON spark' if on_spark else 'üåê REMOTE'} ‚Üí {'Local files' if on_spark else 'HTTP download'}")

df_raw = pd.read_csv(CSV_PATH, sep='\t', on_bad_lines='skip')
print(f"Total products: {len(df_raw):,}")

df_raw['category'] = df_raw['product_type'].apply(clean_category)
df_raw['material_clean'] = df_raw['material'].apply(extract_material)
df_raw['gender_clean'] = df_raw['gender'].fillna('unisex')
df_raw['price_clean'] = df_raw['price'].str.replace(' EUR', '').str.replace(',', '.').astype(float, errors='ignore')
df_raw['price_range'] = df_raw['price_clean'].apply(lambda p: 'budget' if p < 50 else ('mid_range' if p < 100 else ('premium' if p < 300 else 'luxury')) if not pd.isna(p) else 'unknown')

df_with_images = df_raw[df_raw['image_link'].notna()].copy()

if MAX_IMAGES_PER_CATEGORY:
    print(f"\nSampling {MAX_IMAGES_PER_CATEGORY} per category...")
    sampled = []
    for cat in df_with_images['category'].unique():
        cat_df = df_with_images[df_with_images['category'] == cat]
        n = min(MAX_IMAGES_PER_CATEGORY, len(cat_df))
        sampled.append(cat_df.sample(n=n, random_state=42))
        print(f"  {cat:15s}: {n:4d}")
    df_to_process = pd.concat(sampled, ignore_index=True)
else:
    df_to_process = df_with_images

print(f"\nProcessing {len(df_to_process):,} images...")

images_dir = Path(OUTPUT_DIR) / 'images'
images_dir.mkdir(exist_ok=True, parents=True)

results = []
found, not_found, skipped = 0, 0, 0

for _, row in tqdm(df_to_process.iterrows(), total=len(df_to_process)):
    cat, pid = row['category'], row['id']
    cat_dir = images_dir / cat
    cat_dir.mkdir(exist_ok=True)
    dest = cat_dir / f"{cat}_{pid}.jpg"
    
    if dest.exists():
        skipped += 1
        found += 1
    elif (copy_local_image(pid, dest) if on_spark else download_image_http(row['image_link'], dest)):
        found += 1
    else:
        not_found += 1
        continue
    
    results.append({
        'filename': f"{cat}_{pid}.jpg",
        'filepath': str(dest),
        'product_id': pid,
        'title': row['title'],
        'category': cat,
        'gender': row['gender_clean'],
        'material': row['material_clean'],
        'price': row['price_clean'],
        'price_range': row['price_range']
    })

df = pd.DataFrame(results)
metadata_path = Path(OUTPUT_DIR) / 'jewelry_metadata.csv'
df.to_csv(metadata_path, index=False)

print(f"\n‚úì Processed: {found:,} | ‚ö° Cached: {skipped:,} | ‚úó Failed: {not_found:,}")
print(f"‚úì Saved: {metadata_path}")

PROCESSING RHOMBERG DATA

üéØ ON spark ‚Üí Local files
Total products: 12,026

Processing 12,026 images...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12026/12026 [00:07<00:00, 1677.83it/s]


‚úì Processed: 46 | ‚ö° Cached: 46 | ‚úó Failed: 11,980
‚úì Saved: /project/data/rhomberg_final/jewelry_metadata.csv





## 4. Analyze Dataset

Check the distribution of categories, materials, gender, and price ranges.

In [7]:
print(f"\n{'='*70}")
print("DATASET STATISTICS")
print(f"{'='*70}\n")

print(f"Total images: {len(df):,}\n")

print("CATEGORY DISTRIBUTION:")
for cat, count in df['category'].value_counts().items():
    print(f"  {cat:15s}: {count:4d} ({count/len(df)*100:.1f}%)")

print("\nGENDER DISTRIBUTION:")
for gender, count in df['gender'].value_counts().items():
    print(f"  {gender:15s}: {count:4d} ({count/len(df)*100:.1f}%)")

print("\nMATERIAL DISTRIBUTION:")
for mat, count in df['material'].value_counts().items():
    print(f"  {mat:15s}: {count:4d} ({count/len(df)*100:.1f}%)")

print("\nPRICE RANGE DISTRIBUTION:")
for pr, count in df['price_range'].value_counts().items():
    print(f"  {pr:15s}: {count:4d} ({count/len(df)*100:.1f}%)")

# Create label encodings
category_labels = {cat: idx for idx, cat in enumerate(sorted(df['category'].unique()))}
gender_labels = {g: idx for idx, g in enumerate(sorted(df['gender'].unique()))}
material_labels = {m: idx for idx, m in enumerate(sorted(df['material'].unique()))}
price_labels = {p: idx for idx, p in enumerate(sorted(df['price_range'].unique()))}

print(f"\n‚úì Label encodings created")
print(f"  Categories: {len(category_labels)}")
print(f"  Genders: {len(gender_labels)}")
print(f"  Materials: {len(material_labels)}")
print(f"  Price ranges: {len(price_labels)}")


DATASET STATISTICS

Total images: 46

CATEGORY DISTRIBUTION:
  bracelets      :   10 (21.7%)
  pendants       :   10 (21.7%)
  rings          :   10 (21.7%)
  necklaces      :    9 (19.6%)
  earrings       :    6 (13.0%)
  fussketten     :    1 (2.2%)

GENDER DISTRIBUTION:
  female         :   33 (71.7%)
  male           :    8 (17.4%)
  unisex         :    5 (10.9%)

MATERIAL DISTRIBUTION:
  silver         :   27 (58.7%)
  stainless_steel:    9 (19.6%)
  gold           :    7 (15.2%)
  unknown        :    2 (4.3%)
  titan          :    1 (2.2%)

PRICE RANGE DISTRIBUTION:
  budget         :   17 (37.0%)
  mid_range      :   16 (34.8%)
  premium        :    9 (19.6%)
  luxury         :    4 (8.7%)

‚úì Label encodings created
  Categories: 6
  Genders: 3
  Materials: 5
  Price ranges: 4


## 5. Dataset Class

Custom PyTorch Dataset with multi-label support.

In [None]:
class MultiLabelJewelryDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.df = dataframe.reset_index(drop=True)
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Load image
        img_path = row['filepath']
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        # Create labels
        labels = {
            'category': category_labels[row['category']],
            'gender': gender_labels[row['gender']],
            'material': material_labels[row['material']],
            'price_range': price_labels[row['price_range']]
        }
        
        return image, labels

# Create transforms
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Filter out categories with less than 2 samples (required for stratified split)
category_counts = df['category'].value_counts()
categories_to_keep = category_counts[category_counts >= 2].index
df_filtered = df[df['category'].isin(categories_to_keep)].copy()

removed_categories = set(category_counts.index) - set(categories_to_keep)
if removed_categories:
    removed_count = len(df) - len(df_filtered)
    print(f"‚ö†Ô∏è  Removed {removed_count} image(s) from {len(removed_categories)} category/categories with <2 samples: {removed_categories}")

# Split data with stratification
train_df, val_df = train_test_split(df_filtered, test_size=0.2, random_state=42, stratify=df_filtered['category'])

train_dataset = MultiLabelJewelryDataset(train_df, transform=train_transform)
val_dataset = MultiLabelJewelryDataset(val_df, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

print(f"‚úì Datasets created")
print(f"  Train: {len(train_dataset):,} images")
print(f"  Val: {len(val_dataset):,} images")
print(f"  Batches per epoch: {len(train_loader)}")

ValueError: The least populated class in y has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2.

## 6. Model Definition

Multi-head classifier based on MobileNetV2.

In [None]:
class MultiLabelJewelryModel(nn.Module):
    def __init__(self, num_categories, num_genders, num_materials, num_price_ranges):
        super().__init__()
        
        # Load pretrained MobileNetV2
        mobilenet = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
        
        # Extract features (remove classifier)
        self.features = mobilenet.features
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        
        # Shared layer
        self.shared = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(1280, 512),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        
        # Task-specific heads
        self.head_category = nn.Linear(512, num_categories)
        self.head_gender = nn.Linear(512, num_genders)
        self.head_material = nn.Linear(512, num_materials)
        self.head_price_range = nn.Linear(512, num_price_ranges)
        
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.shared(x)
        
        return {
            'category': self.head_category(x),
            'gender': self.head_gender(x),
            'material': self.head_material(x),
            'price_range': self.head_price_range(x)
        }

# Create model
model = MultiLabelJewelryModel(
    num_categories=len(category_labels),
    num_genders=len(gender_labels),
    num_materials=len(material_labels),
    num_price_ranges=len(price_labels)
).to(DEVICE)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"‚úì Model created")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Device: {DEVICE}")

## 7. Training Loop

Train the model with multi-task learning.

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = {task: 0 for task in ['category', 'gender', 'material', 'price_range']}
    total = 0
    
    pbar = tqdm(loader, desc='Training')
    for images, labels in pbar:
        images = images.to(device)
        labels = {k: v.to(device) for k, v in labels.items()}
        
        optimizer.zero_grad()
        outputs = model(images)
        
        # Calculate loss for each task
        loss = sum(criterion(outputs[task], labels[task]) for task in outputs.keys())
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        total += images.size(0)
        
        # Calculate accuracy for each task
        for task in outputs.keys():
            _, predicted = outputs[task].max(1)
            correct[task] += predicted.eq(labels[task]).sum().item()
        
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    avg_loss = total_loss / len(loader)
    accuracies = {task: 100. * correct[task] / total for task in correct.keys()}
    
    return avg_loss, accuracies

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = {task: 0 for task in ['category', 'gender', 'material', 'price_range']}
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Validation'):
            images = images.to(device)
            labels = {k: v.to(device) for k, v in labels.items()}
            
            outputs = model(images)
            loss = sum(criterion(outputs[task], labels[task]) for task in outputs.keys())
            
            total_loss += loss.item()
            total += images.size(0)
            
            for task in outputs.keys():
                _, predicted = outputs[task].max(1)
                correct[task] += predicted.eq(labels[task]).sum().item()
    
    avg_loss = total_loss / len(loader)
    accuracies = {task: 100. * correct[task] / total for task in correct.keys()}
    
    return avg_loss, accuracies

# Training loop
print(f"\n{'='*70}")
print("TRAINING")
print(f"{'='*70}\n")

history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
best_val_loss = float('inf')

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
    val_loss, val_acc = validate(model, val_loader, criterion, DEVICE)
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_acc'].append(train_acc)
    history['val_acc'].append(val_acc)
    
    print(f"\nTrain Loss: {train_loss:.4f}")
    print(f"Train Acc - Category: {train_acc['category']:.2f}%, Gender: {train_acc['gender']:.2f}%, "
          f"Material: {train_acc['material']:.2f}%, Price: {train_acc['price_range']:.2f}%")
    print(f"Val Loss: {val_loss:.4f}")
    print(f"Val Acc - Category: {val_acc['category']:.2f}%, Gender: {val_acc['gender']:.2f}%, "
          f"Material: {val_acc['material']:.2f}%, Price: {val_acc['price_range']:.2f}%")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), Path(MODEL_DIR) / 'best_multilabel_model.pth')
        print("‚úì Saved best model")

print(f"\n{'='*70}")
print("TRAINING COMPLETE!")
print(f"{'='*70}")

## 8. Visualize Results

Plot training history.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot loss
axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
axes[0].plot(history['val_loss'], label='Val Loss', marker='o')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True)

# Plot accuracy for each task
tasks = ['category', 'gender', 'material', 'price_range']
for task in tasks:
    train_accs = [epoch_acc[task] for epoch_acc in history['train_acc']]
    val_accs = [epoch_acc[task] for epoch_acc in history['val_acc']]
    axes[1].plot(val_accs, label=task.replace('_', ' ').title(), marker='o')

axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Validation Accuracy by Task')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.savefig(Path(MODEL_DIR) / 'training_history.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úì Training curves saved")

## 9. Test Predictions

Visualize predictions on sample images.

In [None]:
# Create reverse mappings
category_names = {v: k for k, v in category_labels.items()}
gender_names = {v: k for k, v in gender_labels.items()}
material_names = {v: k for k, v in material_labels.items()}
price_names = {v: k for k, v in price_labels.items()}

# Load best model
model.load_state_dict(torch.load(Path(MODEL_DIR) / 'best_multilabel_model.pth'))
model.eval()

# Get some samples
sample_indices = np.random.choice(len(val_dataset), 8, replace=False)

fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

for idx, sample_idx in enumerate(sample_indices):
    image, labels = val_dataset[sample_idx]
    
    # Predict
    with torch.no_grad():
        image_batch = image.unsqueeze(0).to(DEVICE)
        outputs = model(image_batch)
        
        predictions = {
            'category': category_names[outputs['category'].argmax(1).item()],
            'gender': gender_names[outputs['gender'].argmax(1).item()],
            'material': material_names[outputs['material'].argmax(1).item()],
            'price_range': price_names[outputs['price_range'].argmax(1).item()]
        }
        
        true_labels = {
            'category': category_names[labels['category']],
            'gender': gender_names[labels['gender']],
            'material': material_names[labels['material']],
            'price_range': price_names[labels['price_range']]
        }
    
    # Display
    img_path = val_df.iloc[sample_idx]['filepath']
    img = Image.open(img_path)
    axes[idx].imshow(img)
    axes[idx].axis('off')
    
    title = f"True: {true_labels['category']}\n"
    title += f"Pred: {predictions['category']}\n"
    title += f"Mat: {predictions['material']} | {predictions['gender']}"
    
    color = 'green' if predictions['category'] == true_labels['category'] else 'red'
    axes[idx].set_title(title, fontsize=9, color=color)

plt.tight_layout()
plt.savefig(Path(MODEL_DIR) / 'predictions_sample.png', dpi=150, bbox_inches='tight')
plt.show()

print("‚úì Sample predictions visualized")