# Two-Tower Clothing Preference Model + TinyLLaMA Chat Interface

This notebook implements an end-to-end pipeline for clothing preference prediction with LLM reasoning and a Gradio web interface.

Environment: Optimized for RunPod GPU instances (≤16GB VRAM).

In [None]:
# =============================================================================
# DEPENDENCY INSTALLATION FOR RUNPOD (Including TinyLLaMA)
# =============================================================================
import subprocess, sys

def pip_install(pkg):
    try:
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', pkg, '--upgrade', '--no-cache-dir'])
        print(f'✅ Installed: {pkg}')
    except subprocess.CalledProcessError as e:
        print(f'❌ Failed: {pkg} -> {e}')

# IMPORTANT: Pin Pillow to a compatible version to avoid EXIF import errors
packages = [
    'torch', 'torchvision', 'numpy', 'pandas', 'matplotlib', 'seaborn',
    'pillow>=10.2.0,<11',
    'transformers>=4.30.0', 'accelerate>=0.20.0', 'tokenizers>=0.13.0', 'sentencepiece',
    'typing_extensions>=4.7.0', 'gradio>=4.0.0', 'fastapi>=0.100.0', 'uvicorn[standard]', 'websockets', 'httpx'
]
print('🚀 Installing required packages for RunPod (including TinyLLaMA & Gradio)...')
for p in packages:
    name = p.split('>=')[0].split('==')[0].replace('-', '_')
    try:
        __import__(name)
        print(f'✅ {name} already installed')
    except Exception:
        print(f'📦 Installing {p} ...')
        pip_install(p)
print('✅ Dependencies ready')

In [None]:
# =============================================================================
# MAIN IMPORTS AND DEVICE SETUP
# =============================================================================
import torch, torch.nn as nn, torch.optim as optim, torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision, torchvision.transforms as transforms
import numpy as np, random, warnings
import matplotlib; matplotlib.use('Agg')
from PIL import Image
warnings.filterwarnings('ignore')

def setup_device():
    if torch.cuda.is_available():
        print('🚀 Using CUDA')
        torch.backends.cudnn.benchmark = True
        return torch.device('cuda')
    print('❌ CUDA not available - CPU')
    return torch.device('cpu')

device = setup_device()
print('=' * 80)

## 📁 Data Preparation

In [None]:
# =============================================================================
# DATA PREPARATION
# =============================================================================
FASHION_CLASSES = ['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle boot']
print('📥 Loading Fashion-MNIST...')
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset  = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
print(f'✅ Train: {len(train_dataset)} | Test: {len(test_dataset)}')

## 🏗️ Two-Tower Model

In [None]:
# =============================================================================
# MODEL
# =============================================================================
class TwoTowerModel(nn.Module):
    def __init__(self, embedding_dim=128, num_users=1000, hidden_dim=256):
        super().__init__()
        self.user_embedding = nn.Embedding(num_users, embedding_dim)
        self.user_tower = nn.Sequential(nn.Linear(embedding_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden_dim, embedding_dim), nn.LayerNorm(embedding_dim))
        self.item_cnn = nn.Sequential(nn.Conv2d(1,32,3,padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(32,64,3,padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64,128,3,padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d((4,4)))
        self.item_tower = nn.Sequential(nn.Linear(128*4*4, hidden_dim), nn.ReLU(), nn.Dropout(0.2), nn.Linear(hidden_dim, embedding_dim), nn.LayerNorm(embedding_dim))
    def forward(self, user_ids, item_images):
        u = self.user_tower(self.user_embedding(user_ids))
        v = self.item_tower(self.item_cnn(item_images).view(item_images.size(0), -1))
        return u, v
    def predict_preference(self, user_ids, item_images):
        u, v = self.forward(user_ids, item_images)
        u, v = F.normalize(u, p=2, dim=1), F.normalize(v, p=2, dim=1)
        return torch.sigmoid(torch.sum(u*v, dim=1))

model = TwoTowerModel().to(device)
print(f'✅ Model params: {sum(p.numel() for p in model.parameters()):,}')

## 🎯 Training

In [None]:
# =============================================================================
# TRAINING SETUP
# =============================================================================
def create_preference_data(dataset, num_users=1000, num_samples=2000):
    prefs = []
    for _ in range(num_samples):
        uid = random.randint(0, num_users-1)
        idx = random.randint(0, len(dataset)-1)
        img, label = dataset[idx]
        user_prefs = {uid % 3:[0,1,2], (uid+1)%3:[3,4,6], (uid+2)%3:[5,7,9]}.get(uid%3, list(range(10)))
        pref = 1 if (label in user_prefs and random.random()>0.2) or (label not in user_prefs and random.random()>0.7) else 0
        prefs.append({'user_id': uid, 'image': img, 'preference': pref})
    return prefs

class PreferenceDataset(Dataset):
    def __init__(self, prefs): self.prefs = prefs
    def __len__(self): return len(self.prefs)
    def __getitem__(self, i):
        it = self.prefs[i]; return (torch.tensor(it['user_id'], dtype=torch.long), it['image'], torch.tensor(it['preference'], dtype=torch.float))

print('🎯 Creating preference dataset...')
preference_data = create_preference_data(train_dataset)
train_loader = DataLoader(PreferenceDataset(preference_data), batch_size=64, shuffle=True)
criterion = nn.BCELoss(); optimizer = optim.Adam(model.parameters(), lr=1e-3)
print('✅ Training setup complete')

In [None]:
# =============================================================================
# TRAIN LOOP
# =============================================================================
def train_model(model, loader, criterion, optimizer, epochs=3):
    model.train()
    for ep in range(epochs):
        ep_loss, correct, total = 0.0, 0, 0
        for b,(uids, imgs, prefs) in enumerate(loader):
            uids, imgs, prefs = uids.to(device), imgs.to(device), prefs.to(device)
            optimizer.zero_grad(); preds = model.predict_preference(uids, imgs)
            loss = criterion(preds, prefs); loss.backward(); optimizer.step()
            ep_loss += loss.item(); pred_bin = (preds>0.5).float(); total += prefs.size(0); correct += (pred_bin==prefs).sum().item()
            if b % 20 == 0: print(f'Epoch {ep+1} Batch {b} Loss {loss.item():.4f}')
        print(f'Epoch {ep+1} - Loss {ep_loss/len(loader):.4f} | Acc {(100.*correct/total):.2f}%')
        if device.type=='cuda': torch.cuda.empty_cache()

train_model(model, train_loader, criterion, optimizer)
print('✅ Training complete')

## 🤖 TinyLLaMA Integration

In [None]:
# =============================================================================
# TINYLLAMA AI EXPLANATIONS
# =============================================================================
try:
    from transformers import AutoTokenizer, AutoModelForCausalLM
    print('🤖 Loading TinyLLaMA...')
    _llm_name = 'TinyLlama/TinyLlama-1.1B-Chat-v1.0'
    _tok = AutoTokenizer.from_pretrained(_llm_name)
    _llm = AutoModelForCausalLM.from_pretrained(
        _llm_name,
        torch_dtype=(torch.float16 if device.type=='cuda' else torch.float32),
        device_map=('auto' if device.type=='cuda' else None)
    )
    if _tok.pad_token is None:
        _tok.pad_token = _tok.eos_token
    _llama_ok = True
    print('✅ TinyLLaMA ready')

    def generate_explanation(score, item_class, user_id):
        """Generate a natural-language explanation using TinyLLaMA."""
        prompt = f"""
<|system|>
You are a fashion AI assistant explaining clothing preferences.

<|user|>
A user (ID: {user_id}) has a preference score of {score:.2f} for a {item_class}.
Explain this preference in 1-2 sentences, considering that 1.0 is highest preference and 0.0 is lowest.

<|assistant|>
"""
        inputs = _tok(prompt, return_tensors='pt', truncation=True, max_length=256)
        if device.type=='cuda':
            inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            out = _llm.generate(
                **inputs,
                max_new_tokens=50,
                temperature=0.7,
                do_sample=True,
                pad_token_id=_tok.eos_token_id
            )
        text = _tok.decode(out[0], skip_special_tokens=True)
        return text.split('<|assistant|>')[-1].strip() if '<|assistant|>' in text else f'Preference score {score:.2f} indicates the user tends to like {item_class}.'

except Exception as e:
    print(f'⚠️ TinyLLaMA not available: {e}')
    def generate_explanation(score, item_class, user_id):
        return ('Very strong' if score>0.7 else 'Moderate' if score>0.4 else 'Low') + f' interest in {item_class} (score {score:.2f}).'
    _llama_ok = False

print('🧠 Explanation system:', 'TinyLLaMA' if _llama_ok else 'Fallback')

In [None]:
# =============================================================================
# PILLOW SANITY CHECK (EXIF & TIFF support)
# =============================================================================
try:
    from PIL import Image, TiffImagePlugin
    from PIL import __version__ as PIL_VERSION
    print(f"🖼️ Pillow version: {PIL_VERSION}")
    # Create a simple grayscale image and call getexif safely
    im = Image.new('L', (8, 8), color=128)
    exif = getattr(im, 'getexif', lambda: {})()
    print("✅ Pillow EXIF access OK")
except Exception as e:
    print(f"❌ Pillow sanity check failed: {e}")
    print("🔧 Reinstalling Pillow to a compatible version...")
    import subprocess, sys
    subprocess.call([sys.executable, '-m', 'pip', 'install', 'pillow>=10.2.0,<11', '--upgrade', '--no-cache-dir'])
    try:
        from PIL import Image
        print("✅ Pillow reinstalled successfully")
    except Exception as e2:
        print(f"❌ Pillow import still failing: {e2}")

## 🧪 Inference & Tests

In [None]:
# =============================================================================
# TESTING
# =============================================================================
def manual_predict(user_id=0, sample_idx=0):
    model.eval(); img, label = test_dataset[sample_idx % len(test_dataset)]
    with torch.no_grad():
        u = torch.tensor([user_id]).to(device); x = img.unsqueeze(0).to(device)
        score = model.predict_preference(u, x).item()
    cls = FASHION_CLASSES[label]; expl = generate_explanation(score, cls, user_id)
    res = {'user_id': user_id, 'item_class': cls, 'preference_score': round(score,3), 'explanation': expl}
    print(res); return res

def test_multiple_users(sample_idx=0, n=3):
    return [manual_predict(u, sample_idx) for u in range(n)]

def test_multiple_items(user_id=0, n=3):
    return [manual_predict(user_id, i) for i in range(n)]

print('✅ Testing helpers ready')
manual_predict(0,0)

## 🎨 Gradio Web Interface

In [None]:
# =============================================================================
# GRADIO (with port handling + TinyLLaMA chat)
# =============================================================================
import socket

def find_free_port(start=7860, attempts=20):
    for p in range(start, start+attempts):
        try:
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                s.bind(('', p)); return p
        except OSError:
            continue
    return 0

port = find_free_port()
print(f'🔌 Port selected: {port}')

# Common image preprocessing used by both prediction and chat
_def_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

def _preprocess_image_for_model(image):
    from PIL import Image as PILImage
    if image is None:
        raise ValueError('No image provided')
    if isinstance(image, np.ndarray):
        image = PILImage.fromarray(image)
    if image.mode != 'L':
        image = image.convert('L')
    image = image.resize((28, 28))
    x = _def_transform(image).unsqueeze(0).to(device)
    return x

try:
    import gradio as gr
    from PIL import Image as PILImage

    def predict_ui(user_id, image):
        if image is None:
            return {'Error': 'Upload an image or use sample'}
        try:
            x = _preprocess_image_for_model(image)
            u = torch.tensor([int(user_id)]).to(device)
            with torch.no_grad():
                score = model.predict_preference(u, x).item()
            return {
                'Preference Score': f'{score:.3f}',
                'Confidence': f'{score*100:.1f}%',
                'Explanation': generate_explanation(score, 'clothing item', int(user_id))
            }
        except Exception as e:
            return {'Error': str(e)}

    def sample_img():
        i = random.randint(0, len(test_dataset)-1)
        img, _ = test_dataset[i]
        return transforms.ToPILImage()(img)

    def chat_about_outfit(user_id, image, message):
        """Use TinyLLaMA to answer a user question about the uploaded outfit.
        Context includes the model's preference score for the given user and image.
        """
        if image is None:
            return 'Please upload an image first.'
        try:
            # Compute preference score as structured context
            x = _preprocess_image_for_model(image)
            u = torch.tensor([int(user_id)]).to(device)
            with torch.no_grad():
                score = model.predict_preference(u, x).item()
            context = (
                f"User ID: {int(user_id)}\n"
                f"Computed preference score (0-1): {score:.3f}\n"
                f"Interpretation: {'high' if score>0.7 else 'moderate' if score>0.4 else 'low'} interest\n"
            )
            # If TinyLLaMA is available, generate a response using it
            if '_llm' in globals() and '_tok' in globals() and globals().get('_llama_ok', False):
                sys_prompt = (
                    "You are a helpful fashion AI assistant.\n"
                    "Use the provided context about the user and outfit to answer the question concisely.\n"
                    "Be clear about confidence based on the score (0.0 to 1.0)."
                )
                full_prompt = (
                    f"<|system|>\n{sys_prompt}\n\n"
                    f"<|user|>\nCONTEXT:\n{context}\n\nQUESTION:\n{message}\n\n<|assistant|>"
                )
                inputs = _tok(full_prompt, return_tensors='pt', truncation=True, max_length=512)
                if device.type == 'cuda':
                    inputs = {k: v.to(device) for k, v in inputs.items()}
                with torch.no_grad():
                    out = _llm.generate(
                        **inputs,
                        max_new_tokens=120,
                        temperature=0.7,
                        do_sample=True,
                        pad_token_id=_tok.eos_token_id
                    )
                text = _tok.decode(out[0], skip_special_tokens=True)
                reply = text.split('<|assistant|>')[-1].strip() if '<|assistant|>' in text else text.strip()
            else:
                # Fallback: Use explanation generator and wrap into an answer
                expl = generate_explanation(score, 'clothing item', int(user_id))
                reply = (
                    f"Based on the model's score {score:.3f}, the user shows "
                    f"{'high' if score>0.7 else 'moderate' if score>0.4 else 'low'} interest. "
                    f"Explanation: {expl}"
                )
            return reply
        except Exception as e:
            return f'Error: {e}'

    # Build UI with tabs: Predict and Chat
    with gr.Blocks(title='Two-Tower Clothing Preference Model') as demo:
        gr.Markdown('# 👗 Two-Tower Clothing Preference Model')
        with gr.Tabs():
            with gr.Tab('Predict'):
                with gr.Row():
                    with gr.Column():
                        uid = gr.Number(label='User ID', value=0, minimum=0, maximum=999)
                        img = gr.Image(label='Upload Clothing Image', type='pil', height=300)
                        btn = gr.Button('🔮 Predict Preference')
                        sample = gr.Button('🎲 Random Sample')
                    with gr.Column():
                        out = gr.JSON(label='Results')
                btn.click(predict_ui, inputs=[uid, img], outputs=out)
                sample.click(sample_img, outputs=img)
            with gr.Tab('Chat'):
                with gr.Row():
                    with gr.Column():
                        chat_uid = gr.Number(label='User ID', value=0, minimum=0, maximum=999)
                        chat_img = gr.Image(label='Upload Outfit Image', type='pil', height=300)
                        chat_q = gr.Textbox(label='Ask TinyLLaMA', placeholder='Would user like this clothing outfit?', lines=2)
                        chat_btn = gr.Button('💬 Ask')
                    with gr.Column():
                        chat_out = gr.Textbox(label='TinyLLaMA Response')
                chat_btn.click(chat_about_outfit, inputs=[chat_uid, chat_img, chat_q], outputs=chat_out)

    print('🎨 Gradio ready; launching...')
    demo.launch(server_name='0.0.0.0', server_port=(port if port>0 else None), share=True, show_error=True)

except Exception as e:
    print(f'⚠️ Gradio not available or failed: {e}')
    print('💡 Use manual_predict(user_id, sample_idx) and chat_about_outfit(user_id, image, message) for testing')

## 📋 Summary

- Standard Jupyter JSON format (RunPod compatible)
- Two-Tower model + training + testing
- TinyLLaMA AI explanations with fallback
- Gradio interface with port handling