# Two-Tower Clothing Preference Model + TinyLLaMA Chat Interface

This notebook implements an end-to-end pipeline for clothing preference prediction with LLM reasoning and a Gradio web interface.

Environment: Optimized for RunPod GPU instances (≤16GB VRAM).

In [None]:
# =============================================================================
# DEPENDENCY INSTALLATION FOR RUNPOD (Including TinyLLaMA)
# =============================================================================
import subprocess, sys

def pip_install(pkg):
    try:
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', pkg, '--upgrade', '--no-cache-dir'])
        print(f'✅ Installed: {pkg}')
    except subprocess.CalledProcessError as e:
        print(f'❌ Failed: {pkg} -> {e}')

# IMPORTANT: Pin Pillow to a compatible version to avoid EXIF import errors
packages = [
    'torch', 'torchvision', 'numpy', 'pandas', 'matplotlib', 'seaborn',
    'pillow>=10.2.0,<11',
    'transformers>=4.30.0', 'accelerate>=0.20.0', 'tokenizers>=0.13.0', 'sentencepiece',
    'typing_extensions>=4.12.2', 'gradio>=4.0.0', 'fastapi>=0.100.0', 'uvicorn[standard]', 'websockets', 'httpx'
]
print('🚀 Installing required packages for RunPod (including TinyLLaMA & Gradio)...')
for p in packages:
    name = p.split('>=')[0].split('==')[0].replace('-', '_')
    # Always force-upgrade typing_extensions to ensure TypeIs is available
    if name == 'typing_extensions':
        print('📦 Forcing upgrade of typing_extensions ...')
        pip_install(p)
        continue
    try:
        __import__(name)
        print(f'✅ {name} already installed')
    except Exception:
        print(f'📦 Installing {p} ...')
        pip_install(p)
print('✅ Dependencies ready')

In [None]:
# =============================================================================
# MAIN IMPORTS AND DEVICE SETUP
# =============================================================================
import torch, torch.nn as nn, torch.optim as optim, torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision, torchvision.transforms as transforms
import numpy as np, random, warnings
import matplotlib; matplotlib.use('Agg')
from PIL import Image
warnings.filterwarnings('ignore')

def setup_device():
    if torch.cuda.is_available():
        print('🚀 Using CUDA')
        torch.backends.cudnn.benchmark = True
        return torch.device('cuda')
    print('❌ CUDA not available - CPU')
    return torch.device('cpu')

device = setup_device()
print('=' * 80)

## 📁 Data Preparation

In [None]:
# =============================================================================
# IMAGE/CHANNEL CONFIGURATION
# =============================================================================
IMAGE_CHANNELS = 3  # set to 1 to remain pure grayscale; 3 replicates grayscale to RGB
IMG_SIZE = 28       # keep 28x28 for Fashion-MNIST; can change if using a larger backbone later

def _norm_tuples(ch):
    mean = tuple([0.5]*ch)
    std = tuple([0.5]*ch)
    return mean, std

NORM_MEAN, NORM_STD = _norm_tuples(IMAGE_CHANNELS)
print(f"🧩 Image config -> channels: {IMAGE_CHANNELS}, size: {IMG_SIZE}x{IMG_SIZE}")

In [None]:
# =============================================================================
# CUSTOM DATA ROOT CONFIGURATION
# =============================================================================
import os
from pathlib import Path

# Default to a folder next to this notebook: projects/clothing-preference-model/data/custom
# Since the working directory is typically the notebook's folder, use a relative path
CUSTOM_DATA_ROOT = os.environ.get('CUSTOM_DATA_ROOT', str(Path('data') / 'custom'))
print(f"📂 CUSTOM_DATA_ROOT = {CUSTOM_DATA_ROOT}")

In [None]:
# =============================================================================
# DATA PREPARATION
# =============================================================================
FASHION_CLASSES = ['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle boot']
print('📥 Loading Fashion-MNIST...')
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=IMAGE_CHANNELS) if IMAGE_CHANNELS==3 else transforms.Lambda(lambda x: x),
    transforms.ToTensor(),
    transforms.Normalize(NORM_MEAN, NORM_STD)
])
train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset  = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
print(f'✅ Train: {len(train_dataset)} | Test: {len(test_dataset)} | Channels: {IMAGE_CHANNELS}')

## 🧩 Custom User Image Data (Optional)

> Folder layout:
```
projects/clothing-preference-model/data/custom/
  user_0/
    likes/        # images the user likes
    dislikes/     # images the user dislikes
    # optional: likes/Dress/*.jpg (class subfolders)
```
- User id is inferred from folder name: user_0 -> 0
- If class subfolders are used (matching entries in FASHION_CLASSES), we train the classifier too; else class loss is skipped.

In [None]:
# =============================================================================
# CUSTOM USER DATASET HELPERS
# =============================================================================
import os, glob
from PIL import Image as PILImage

def _class_from_path(path: str):
    # Try to detect a class name from immediate subfolder under likes/dislikes that matches FASHION_CLASSES
    parts = os.path.normpath(path).split(os.sep)
    for p in reversed(parts):
        if p in FASHION_CLASSES:
            return FASHION_CLASSES.index(p)
    return None

def _pil_to_tensor(pil_img: PILImage.Image):
    # Apply same preprocessing as UI/dataset
    if IMAGE_CHANNELS == 3 and pil_img.mode != 'RGB':
        pil_img = pil_img.convert('RGB')
    if IMAGE_CHANNELS == 1 and pil_img.mode != 'L':
        pil_img = pil_img.convert('L')
    pil_img = pil_img.resize((IMG_SIZE, IMG_SIZE))
    return _def_transform(pil_img) if '_def_transform' in globals() else transforms.Compose([
        transforms.Grayscale(num_output_channels=IMAGE_CHANNELS) if IMAGE_CHANNELS==3 else transforms.Lambda(lambda x: x),
        transforms.ToTensor(),
        transforms.Normalize(NORM_MEAN, NORM_STD)
    ])(pil_img)

def load_custom_user_samples(root_dir=None) -> list:
    """Return a list of dicts: {'user_id','image','preference','label'(optional)}"""
    if root_dir is None:
        root_dir = CUSTOM_DATA_ROOT if 'CUSTOM_DATA_ROOT' in globals() else 'projects/clothing-preference-model/data/custom'
    samples = []
    if not os.path.isdir(root_dir):
        print(f"(custom) root not found: {root_dir}")
        return samples
    users = sorted([d for d in os.listdir(root_dir) if d.startswith('user_') and os.path.isdir(os.path.join(root_dir, d))])
    for ud in users:
        try: uid = int(ud.split('_')[-1])
        except: continue
        like_dir = os.path.join(root_dir, ud, 'likes')
        dislike_dir = os.path.join(root_dir, ud, 'dislikes')
        for lbl, ddir in [(1, like_dir), (0, dislike_dir)]:
            if not os.path.isdir(ddir):
                continue
            for fp in glob.glob(os.path.join(ddir, '**', '*.*'), recursive=True):
                if not fp.lower().endswith(('.jpg','.jpeg','.png','.bmp','.webp')):
                    continue
                try:
                    pil = PILImage.open(fp).convert('RGB' if IMAGE_CHANNELS==3 else 'L')
                    tensor = _pil_to_tensor(pil)
                    cls_idx = _class_from_path(fp)
                    entry = {'user_id': uid, 'image': tensor, 'preference': lbl}
                    if cls_idx is not None:
                        entry['label'] = int(cls_idx)
                    samples.append(entry)
                except Exception as e:
                    print(f'⚠️ Skip {fp}: {e}')
    print(f'📦 Loaded custom samples: {len(samples)} from {root_dir}')
    return samples

## 👤 User Profiles (Likes/Dislikes/Style)

> We define rich user profiles with names, preferred styles, and per-class preferences. These profiles drive training labels.

In [None]:
# =============================================================================
# USER PROFILES CONFIG
# =============================================================================
from dataclasses import dataclass
from typing import List, Dict, Any
import math

@dataclass
class UserProfile:
    user_id: int
    name: str
    style: str  # e.g., 'casual', 'formal', 'sporty', 'minimal'
    likes: List[int]  # class indices
    dislikes: List[int]  # class indices
    neutral: List[int]  # class indices

USER_PROFILES: Dict[int, UserProfile] = {
    0: UserProfile(
        user_id=0,
        name='Joan Doe',
        style='casual, comfortable',
        likes=[0, 2, 3, 6],  # T-shirt/top, Pullover, Dress, Shirt
        dislikes=[4, 8],     # Coat, Bag (prefers wearable over accessories/formal layers)
        neutral=[1, 5, 7, 9],  # Trouser, Sandal, Sneaker, Ankle boot
    ),
    1: UserProfile(
        user_id=1,
        name='Alex Kim',
        style='sporty athleisure',
        likes=[1, 5, 7, 9],  # Trouser, Sandal, Sneaker, Ankle boot
        dislikes=[3, 8],
        neutral=[0, 2, 4, 6],
    ),
    2: UserProfile(
        user_id=2,
        name='Riya Singh',
        style='smart casual',
        likes=[3, 4, 6],
        dislikes=[5],
        neutral=[0, 1, 2, 7, 8, 9],
    ),
    3: UserProfile(
        user_id=3,
        name='Diego López',
        style='minimal formal',
        likes=[1, 4, 6],
        dislikes=[5, 9],
        neutral=[0, 2, 3, 7, 8],
    ),
}

def build_preference_lookup(user_profiles: Dict[int, UserProfile]) -> Dict[int, Dict[int, float]]:
    """Map user_id -> class_id -> base preference weight (0..1)."""
    lookup: Dict[int, Dict[int, float]] = {}
    for uid, prof in user_profiles.items():
        m = {c: 0.5 for c in range(len(FASHION_CLASSES))}
        for c in prof.likes: m[c] = 0.9
        for c in prof.dislikes: m[c] = 0.1
        for c in prof.neutral: m[c] = min(m.get(c, 0.5), 0.5)
        lookup[uid] = m
    return lookup

PROFILE_PREFS = build_preference_lookup(USER_PROFILES)
print('👤 Loaded user profiles:', {uid: p.name for uid, p in USER_PROFILES.items()})

## 🏗️ Two-Tower Model

In [None]:
# =============================================================================
# MODEL
# =============================================================================
class TwoTowerModel(nn.Module):
    def __init__(self, embedding_dim=128, num_users=1000, hidden_dim=256):
        super().__init__()
        in_ch = IMAGE_CHANNELS if 'IMAGE_CHANNELS' in globals() else 1
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.user_tower = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(hidden_dim, embedding_dim), nn.LayerNorm(embedding_dim)
        )
        self.item_cnn = nn.Sequential(
            nn.Conv2d(in_ch,32,3,padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32,64,3,padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64,128,3,padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d((4,4))
        )
        self._feat_dim = 128*4*4
        self.item_tower = nn.Sequential(
            nn.Linear(self._feat_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(hidden_dim, embedding_dim), nn.LayerNorm(embedding_dim)
        )
        self.classifier = nn.Sequential(
            nn.Linear(self._feat_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(hidden_dim, 10)
        )

    def _item_features(self, item_images):
        return self.item_cnn(item_images).view(item_images.size(0), -1)

    def forward(self, user_ids, item_images):
        u = self.user_tower(self.user_embedding(user_ids))
        feat = self._item_features(item_images)
        v = self.item_tower(feat)
        return u, v

    def predict_preference(self, user_ids, item_images):
        u, v = self.forward(user_ids, item_images)
        u, v = F.normalize(u, p=2, dim=1), F.normalize(v, p=2, dim=1)
        return torch.sigmoid(torch.sum(u*v, dim=1))

    def classify(self, item_images):
        feat = self._item_features(item_images)
        return self.classifier(feat)

# Determine num_users from profiles if available
_num_users = 1000
try:
    if 'USER_PROFILES' in globals() and len(USER_PROFILES) > 0:
        _num_users = max(1000, max(USER_PROFILES.keys())+1)
except Exception:
    pass

model = TwoTowerModel(num_users=_num_users).to(device)
print(f'✅ Model params: {sum(p.numel() for p in model.parameters()):,} | Users capacity: {_num_users} | InCh: {IMAGE_CHANNELS}')

## 🎯 Training

In [None]:
# =============================================================================
# TRAINING SETUP
# =============================================================================
def create_preference_data(dataset, num_users=1000, num_samples=2000):
    using_profiles = 'PROFILE_PREFS' in globals() and isinstance(PROFILE_PREFS, dict) and len(PROFILE_PREFS) > 0
    prefs = []
    for _ in range(num_samples):
        uid = random.choice(list(PROFILE_PREFS.keys())) if using_profiles else random.randint(0, num_users-1)
        idx = random.randint(0, len(dataset)-1)
        img, label = dataset[idx]
        if using_profiles:
            base = PROFILE_PREFS.get(uid, {}).get(int(label), 0.5)
            jitter = random.uniform(-0.1, 0.1)
            prob = min(max(base + jitter, 0.0), 1.0)
            pref = 1 if random.random() < prob else 0
        else:
            user_prefs = {uid % 3:[0,1,2], (uid+1)%3:[3,4,6], (uid+2)%3:[5,7,9]}.get(uid%3, list(range(10)))
            pref = 1 if (label in user_prefs and random.random()>0.2) or (label not in user_prefs and random.random()>0.7) else 0
        prefs.append({'user_id': uid, 'image': img, 'preference': pref, 'label': int(label)})
    return prefs

class PreferenceDataset(Dataset):
    def __init__(self, prefs): self.prefs = prefs
    def __len__(self): return len(self.prefs)
    def __getitem__(self, i):
        it = self.prefs[i]
        uid = torch.tensor(it['user_id'], dtype=torch.long)
        img = it['image']
        pref = torch.tensor(it['preference'], dtype=torch.float)
        label = torch.tensor(it['label'], dtype=torch.long) if 'label' in it else torch.tensor(-1, dtype=torch.long)
        return (uid, img, pref, label)

print('🎯 Preparing training data...')
# Try custom samples first
custom_samples = load_custom_user_samples()
if len(custom_samples) > 0:
    preference_data = custom_samples
    print('✅ Using custom user data for training')
else:
    print('ℹ️ No custom data found; generating profile-based synthetic samples')
    _num_users_create = _num_users if '_num_users' in globals() else 1000
    preference_data = create_preference_data(train_dataset, num_users=_num_users_create, num_samples=3000)

train_loader = DataLoader(PreferenceDataset(preference_data), batch_size=32, shuffle=True)

bce = nn.BCELoss()
ce = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
print(f"✅ Training setup complete | Samples: {len(preference_data)} | Using profiles: {'yes' if 'PROFILE_PREFS' in globals() else 'no'}")

In [None]:
# =============================================================================
# TRAIN LOOP
# =============================================================================
def train_model(model, loader, epochs=3, alpha_cls=0.3):
    model.train()
    for ep in range(epochs):
        ep_loss, correct_pref, total_pref = 0.0, 0, 0
        correct_cls, total_cls = 0, 0
        for b,(uids, imgs, prefs, labels) in enumerate(loader):
            uids, imgs, prefs, labels = uids.to(device), imgs.to(device), prefs.to(device), labels.to(device)
            optimizer.zero_grad()
            preds_pref = model.predict_preference(uids, imgs)
            logits = model.classify(imgs)
            loss_pref = bce(preds_pref, prefs)
            # Class loss if labels are provided (>=0)
            has_cls = (labels >= 0)
            if has_cls.any():
                loss_cls = ce(logits[has_cls], labels[has_cls])
                loss = loss_pref + alpha_cls * loss_cls
            else:
                loss_cls = torch.tensor(0.0, device=device)
                loss = loss_pref
            loss.backward(); optimizer.step()
            ep_loss += loss.item()
            pred_bin = (preds_pref>0.5).float(); total_pref += prefs.size(0); correct_pref += (pred_bin==prefs).sum().item()
            if has_cls.any():
                pred_cls = logits.argmax(dim=1)
                # Count only where class labels exist
                correct_cls += (pred_cls[has_cls]==labels[has_cls]).sum().item()
                total_cls += has_cls.sum().item()
            if b % 20 == 0: print(f'Ep {ep+1} B{b} Loss {loss.item():.4f} | Pref {loss_pref.item():.4f} | Cls {loss_cls.item():.4f}')
        msg_cls = f" | ClsAcc {(100.*correct_cls/max(1,total_cls)):.2f}%" if total_cls>0 else ""
        print(f'Epoch {ep+1} - Loss {ep_loss/len(loader):.4f} | PrefAcc {(100.*correct_pref/total_pref):.2f}%{msg_cls}')
        if device.type=='cuda': torch.cuda.empty_cache()

train_model(model, train_loader)
print('✅ Training complete')

## 🤖 LLM Integration (TinyLLaMA default; gpt-oss-20b optional)

In [None]:
# =============================================================================
# CONFIGURABLE LLM (TinyLLaMA default; gpt-oss-20b optional)
# =============================================================================
LLM_PROVIDER = globals().get('LLM_PROVIDER', 'tinyllama')  # 'tinyllama' | 'gpt-oss-20b'

try:
    from transformers import AutoTokenizer, AutoModelForCausalLM
    _llama_ok = False
    _llm_name = None
    print(f'🤖 Initializing LLM provider: {LLM_PROVIDER}')
    if LLM_PROVIDER.lower() in ['gpt-oss-20b', 'gpt_oss_20b', 'oss20b']:
        _llm_name = 'openai/gpt-oss-20b'
    else:
        _llm_name = 'TinyLlama/TinyLlama-1.1B-Chat-v1.0'

    _tok = AutoTokenizer.from_pretrained(_llm_name)
    _llm = AutoModelForCausalLM.from_pretrained(
        _llm_name,
        torch_dtype=(torch.float16 if device.type=='cuda' else torch.float32),
        device_map=('auto' if device.type=='cuda' else None)
    )
    if _tok.pad_token is None:
        _tok.pad_token = _tok.eos_token
    _llama_ok = True
    print(f'✅ LLM ready: {_llm_name}')

    def _format_prompt(system_msg: str, user_msg: str):
        # Support both chat-tuned TinyLLaMA and generic causal models
        if 'TinyLlama' in _llm_name:
            return f"<|system|>\n{system_msg}\n\n<|user|>\n{user_msg}\n\n<|assistant|>"
        else:
            # For gpt-oss-20b (causal), use a simple instruction format
            return f"System: {system_msg}\nUser: {user_msg}\nAssistant:"

    def generate_explanation(score, item_class, user_id):
        user_context = ''
        try:
            if 'USER_PROFILES' in globals() and int(user_id) in USER_PROFILES:
                p = USER_PROFILES[int(user_id)]
                user_context = f"User name: {p.name}. Style: {p.style}. Likes: {p.likes}. Dislikes: {p.dislikes}. "
        except Exception:
            pass
        sys_prompt = 'You are a fashion AI assistant explaining clothing preferences.'
        user_prompt = (
            f"{user_context}A user (ID: {user_id}) has a preference score of {score:.2f} for a {item_class}. "
            "Explain this preference in 1-2 sentences, considering that 1.0 is highest preference and 0.0 is lowest."
        )
        prompt = _format_prompt(sys_prompt, user_prompt)
        inputs = _tok(prompt, return_tensors='pt', truncation=True, max_length=512)
        if device.type=='cuda':
            inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            out = _llm.generate(
                **inputs,
                max_new_tokens=(120 if 'gpt-oss-20b' in _llm_name else 50),
                temperature=0.7,
                do_sample=True,
                pad_token_id=_tok.eos_token_id
            )
        text = _tok.decode(out[0], skip_special_tokens=True)
        if 'TinyLlama' in _llm_name and '<|assistant|>' in text:
            return text.split('<|assistant|>')[-1].strip()
        # For generic models, strip trailing artifacts
        return text.split('Assistant:')[-1].strip() if 'Assistant:' in text else text.strip()

except Exception as e:
    print(f'⚠️ LLM not available: {e}')
    def generate_explanation(score, item_class, user_id):
        return ('Very strong' if score>0.7 else 'Moderate' if score>0.4 else 'Low') + f' interest in {item_class} (score {score:.2f}).'
    _llama_ok = False

print('🧠 Explanation system:', _llm_name if '_llm_name' in globals() else 'Fallback')

In [None]:
# =============================================================================
# PILLOW SANITY CHECK (EXIF & TIFF support)
# =============================================================================
try:
    from PIL import Image, TiffImagePlugin
    from PIL import __version__ as PIL_VERSION
    print(f"🖼️ Pillow version: {PIL_VERSION}")
    # Create a simple grayscale image and call getexif safely
    im = Image.new('L', (8, 8), color=128)
    exif = getattr(im, 'getexif', lambda: {})()
    print("✅ Pillow EXIF access OK")
except Exception as e:
    print(f"❌ Pillow sanity check failed: {e}")
    print("🔧 Reinstalling Pillow to a compatible version...")
    import subprocess, sys
    subprocess.call([sys.executable, '-m', 'pip', 'install', 'pillow>=10.2.0,<11', '--upgrade', '--no-cache-dir'])
    try:
        from PIL import Image
        print("✅ Pillow reinstalled successfully")
    except Exception as e2:
        print(f"❌ Pillow import still failing: {e2}")

## 🧪 Inference & Tests

In [None]:
# =============================================================================
# TESTING
# =============================================================================
def _user_display(uid: int):
    if 'USER_PROFILES' in globals() and uid in USER_PROFILES:
        p = USER_PROFILES[uid]
        return f"{p.name} (style: {p.style})"
    return f"User {uid}"

def _fuse_with_profile(uid: int, cls_idx: int, score: float):
    base = PROFILE_PREFS.get(uid, {}).get(int(cls_idx), 0.5) if 'PROFILE_PREFS' in globals() else 0.5
    w = 0.6 if base < 0.3 else 0.5 if 0.3 <= base <= 0.7 else 0.6
    return (w*base + (1-w)*score)

def manual_predict(user_id=0, sample_idx=0, pil_image=None):
    model.eval()
    if pil_image is not None:
        x = _preprocess_image_for_model(pil_image)
        label = None
    else:
        img, label = test_dataset[sample_idx % len(test_dataset)]
        x = img.unsqueeze(0).to(device)
    with torch.no_grad():
        u = torch.tensor([int(user_id)]).to(device)
        score_raw = model.predict_preference(u, x).item()
        cls_logits = model.classify(x)
        pred_cls = int(cls_logits.argmax(dim=1).item())
    fused = _fuse_with_profile(int(user_id), pred_cls, score_raw)
    cls_name = FASHION_CLASSES[pred_cls]
    expl = generate_explanation(fused, cls_name, int(user_id))
    prior = PROFILE_PREFS.get(int(user_id), {}).get(pred_cls, 0.5) if 'PROFILE_PREFS' in globals() else 0.5
    verdict = 'dislike' if prior < 0.3 and fused < 0.5 else ('like' if fused > 0.6 else 'uncertain')
    res = {
        'user': _user_display(int(user_id)),
        'predicted_class': cls_name,
        'model_score_raw': round(score_raw,3),
        'profile_prior': round(prior,3),
        'fused_score': round(fused,3),
        'verdict': verdict,
        'explanation': expl
    }
    print(res); return res

def test_multiple_users(sample_idx=0, n=3):
    return [manual_predict(u, sample_idx) for u in range(n)]

def test_multiple_items(user_id=0, n=3):
    return [manual_predict(user_id, i) for i in range(n)]

print('✅ Testing helpers ready')
manual_predict(0,0)

## 🎨 Gradio Web Interface

In [None]:
# =============================================================================
# GRADIO (with port handling + TinyLLaMA chat)
# =============================================================================
import socket

def find_free_port(start=7860, attempts=20):
    for p in range(start, start+attempts):
        try:
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                s.bind(('', p)); return p
        except OSError:
            continue
    return 0

port = find_free_port()
print(f'🔌 Port selected: {port}')

# Common image preprocessing used by both prediction and chat
_def_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=IMAGE_CHANNELS) if IMAGE_CHANNELS==3 else transforms.Lambda(lambda x: x),
    transforms.ToTensor(),
    transforms.Normalize(NORM_MEAN, NORM_STD)
])

def _preprocess_image_for_model(image):
    from PIL import Image as PILImage
    if image is None:
        raise ValueError('No image provided')
    if isinstance(image, np.ndarray):
        image = PILImage.fromarray(image)
    if IMAGE_CHANNELS == 3 and image.mode != 'RGB':
        image = image.convert('RGB')
    if IMAGE_CHANNELS == 1 and image.mode != 'L':
        image = image.convert('L')
    image = image.resize((IMG_SIZE, IMG_SIZE))
    x = _def_transform(image).unsqueeze(0).to(device)
    return x

try:
    # Sanity check typing_extensions.TypeIs presence before importing gradio
    try:
        from typing_extensions import TypeIs  # type: ignore
        print('🔎 typing_extensions.TypeIs OK')
    except Exception as te:
        print(f'⚠️ typing_extensions.TypeIs missing: {te}. Attempting quick reinstall...')
        import subprocess, sys
        subprocess.call([sys.executable, '-m', 'pip', 'install', 'typing_extensions>=4.12.2', '--upgrade', '--no-cache-dir'])
        from typing_extensions import TypeIs  # retry
        print('✅ typing_extensions.TypeIs available after upgrade')

    import gradio as gr
    from PIL import Image as PILImage

    def _user_display(uid: int):
        if 'USER_PROFILES' in globals() and uid in USER_PROFILES:
            p = USER_PROFILES[uid]
            return f"{p.name} (style: {p.style})"
        return f"User {uid}"

    def _fuse_with_profile(uid: int, cls_idx: int, score: float):
        base = PROFILE_PREFS.get(uid, {}).get(int(cls_idx), 0.5) if 'PROFILE_PREFS' in globals() else 0.5
        w = 0.6 if base < 0.3 else 0.5 if 0.3 <= base <= 0.7 else 0.6
        return (w*base + (1-w)*score)

    def predict_ui(user_id, image):
        if image is None:
            return {'Error': 'Upload an image or use sample'}
        try:
            x = _preprocess_image_for_model(image)
            u = torch.tensor([int(user_id)]).to(device)
            with torch.no_grad():
                score_raw = model.predict_preference(u, x).item()
                logits = model.classify(x)
                pred_cls = int(logits.argmax(dim=1).item())
            fused = _fuse_with_profile(int(user_id), pred_cls, score_raw)
            user_info = _user_display(int(user_id))
            cls_name = FASHION_CLASSES[pred_cls]
            prior = PROFILE_PREFS.get(int(user_id), {}).get(pred_cls, 0.5) if 'PROFILE_PREFS' in globals() else 0.5
            verdict = 'dislike' if prior < 0.3 and fused < 0.5 else ('like' if fused > 0.6 else 'uncertain')
            return {
                'User': user_info,
                'Predicted Class': cls_name,
                'Model Score (raw)': f'{score_raw:.3f}',
                'Profile Prior': f'{prior:.3f}',
                'Fused Score': f'{fused:.3f}',
                'Verdict': verdict,
                'Explanation': generate_explanation(fused, cls_name, int(user_id))
            }
        except Exception as e:
            return {'Error': str(e)}

    def sample_img():
        i = random.randint(0, len(test_dataset)-1)
        img, _ = test_dataset[i]
        pil = transforms.ToPILImage()(img)
        if IMAGE_CHANNELS == 3 and pil.mode != 'RGB':
            pil = pil.convert('RGB')
        if IMAGE_CHANNELS == 1 and pil.mode != 'L':
            pil = pil.convert('L')
        return pil

    def chat_about_outfit(user_id, image, message):
        if image is None:
            return 'Please upload an image first.'
        try:
            x = _preprocess_image_for_model(image)
            u = torch.tensor([int(user_id)]).to(device)
            with torch.no_grad():
                score_raw = model.predict_preference(u, x).item()
                logits = model.classify(x)
                pred_cls = int(logits.argmax(dim=1).item())
            fused = _fuse_with_profile(int(user_id), pred_cls, score_raw)
            cls_name = FASHION_CLASSES[pred_cls]
            prior = PROFILE_PREFS.get(int(user_id), {}).get(pred_cls, 0.5) if 'PROFILE_PREFS' in globals() else 0.5
            profile_line = ''
            if 'USER_PROFILES' in globals() and int(user_id) in USER_PROFILES:
                p = USER_PROFILES[int(user_id)]
                profile_line = f"User: {p.name} (style: {p.style})\nLikes: {p.likes} Dislikes: {p.dislikes}\n"
            context = (
                f"{profile_line}Predicted item class: {cls_name}\n"
                f"Model score (raw): {score_raw:.3f} | Profile prior: {prior:.3f} | Fused: {fused:.3f}\n"
                f"Interpretation: {'high' if fused>0.7 else 'moderate' if fused>0.4 else 'low'} interest\n"
            )
            if '_llm' in globals() and '_tok' in globals() and globals().get('_llama_ok', False):
                sys_prompt = (
                    "You are a helpful fashion AI assistant. Use context and be decisive.\n"
                    "If profile strongly dislikes the predicted class and fused<0.5, clearly recommend against it and suggest alternatives."
                )
                full_prompt = _format_prompt(sys_prompt, f"CONTEXT:\n{context}\n\nQUESTION:\n{message}")
                inputs = _tok(full_prompt, return_tensors='pt', truncation=True, max_length=768)
                if device.type == 'cuda':
                    inputs = {k: v.to(device) for k, v in inputs.items()}
                with torch.no_grad():
                    out = _llm.generate(
                        **inputs,
                        max_new_tokens=(180 if 'gpt-oss-20b' in _llm_name else 120),
                        temperature=0.7,
                        do_sample=True,
                        pad_token_id=_tok.eos_token_id
                    )
                text = _tok.decode(out[0], skip_special_tokens=True)
                if 'TinyLlama' in _llm_name and '<|assistant|>' in text:
                    reply = text.split('<|assistant|>')[-1].strip()
                else:
                    reply = text.split('Assistant:')[-1].strip() if 'Assistant:' in text else text.strip()
            else:
                expl = generate_explanation(fused, cls_name, int(user_id))
                if prior < 0.3 and fused < 0.5:
                    reply = f"No—this looks too formal for the user. Try casual options like tees, pullovers, or sneakers. (Fused {fused:.3f})\n{expl}"
                else:
                    reply = (
                        f"Based on fused score {fused:.3f}, the user shows "
                        f"{'high' if fused>0.7 else 'moderate' if fused>0.4 else 'low'} interest. "
                        f"Explanation: {expl}"
                    )
            return reply
        except Exception as e:
            return f'Error: {e}'

    with gr.Blocks(title='Two-Tower Clothing Preference Model') as demo:
        gr.Markdown('# 👗 Two-Tower Clothing Preference Model')
        with gr.Tabs():
            with gr.Tab('Predict'):
                with gr.Row():
                    with gr.Column():
                        uid = gr.Number(label='User ID', value=0, minimum=0, maximum=999)
                        img = gr.Image(label='Upload Clothing Image', type='pil', height=300)
                        btn = gr.Button('🔮 Predict Preference')
                        sample = gr.Button('🎲 Random Sample')
                    with gr.Column():
                        out = gr.JSON(label='Results')
                btn.click(predict_ui, inputs=[uid, img], outputs=out)
                sample.click(sample_img, outputs=img)
            with gr.Tab('Chat'):
                with gr.Row():
                    with gr.Column():
                        chat_uid = gr.Number(label='User ID', value=0, minimum=0, maximum=999)
                        chat_img = gr.Image(label='Upload Outfit Image', type='pil', height=300)
                        chat_q = gr.Textbox(label='Ask LLM', placeholder='Would user like this clothing outfit?', lines=2)
                        chat_btn = gr.Button('💬 Ask')
                    with gr.Column():
                        chat_out = gr.Textbox(label='LLM Response')
                chat_btn.click(chat_about_outfit, inputs=[chat_uid, chat_img, chat_q], outputs=chat_out)

    print('🎨 Gradio ready; launching...')
    demo.launch(server_name='0.0.0.0', server_port=(port if port>0 else None), share=True, show_error=True)

except Exception as e:
    print(f'⚠️ Gradio not available or failed: {e}')
    print('💡 If this is a typing_extensions issue, rerun the first dependency cell to upgrade to >=4.12.2 and restart the kernel.')
    print('💡 Use manual_predict(user_id, sample_idx) and chat_about_outfit(user_id, image, message) for testing')

## 📋 Summary

- Standard Jupyter JSON format (RunPod compatible)
- Two-Tower model + training + testing
- TinyLLaMA AI explanations with fallback
- Gradio interface with port handling