In [1]:
# Vector Embedding Visualizer - Interactive ML Token Explorer
# A modern, interactive tool for exploring transformer model embeddings

import numpy as np
import pandas as pd
import umap.umap_ as umap
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import warnings
from transformers import GPT2Tokenizer, GPT2Model, AutoTokenizer, AutoModel
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
from sklearn.decomposition import PCA
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output
import json
from datetime import datetime
import re
import os
import tempfile

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

# Modern CSS styling with glassmorphism and contemporary design
display(HTML("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&family=JetBrains+Mono:wght@400;500&display=swap');

body {
    font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif;
    background: linear-gradient(135deg, #0f0f23 0%, #1a1a2e 50%, #16213e 100%);
    color: #e4e4e7;
    line-height: 1.6;
}

.visualizer-header {
    background: linear-gradient(135deg, 
        rgba(139, 92, 246, 0.3) 0%, 
        rgba(59, 130, 246, 0.25) 35%,
        rgba(16, 185, 129, 0.2) 100%);
    backdrop-filter: blur(20px);
    border: 1px solid rgba(255, 255, 255, 0.1);
    padding: 40px 30px;
    border-radius: 24px;
    margin-bottom: 30px;
    color: white;
    text-align: center;
    box-shadow: 
        0 20px 25px -5px rgba(0, 0, 0, 0.1),
        0 10px 10px -5px rgba(0, 0, 0, 0.04),
        inset 0 1px 0 rgba(255, 255, 255, 0.1);
    position: relative;
    overflow: hidden;
}

.visualizer-header::before {
    content: '';
    position: absolute;
    top: 0;
    left: 0;
    right: 0;
    height: 1px;
    background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.4), transparent);
}

.visualizer-header h1 {
    font-size: 2.5em;
    font-weight: 700;
    margin: 0 0 16px 0;
    background: linear-gradient(135deg, #ffffff 0%, #e0e7ff 100%);
    -webkit-background-clip: text;
    -webkit-text-fill-color: transparent;
    background-clip: text;
    text-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
}

.visualizer-header p {
    font-size: 1.2em;
    font-weight: 500;
    margin: 0 0 8px 0;
    opacity: 0.9;
}

.visualizer-header small {
    font-size: 0.95em;
    opacity: 0.7;
    font-weight: 400;
}

.control-panel {
    background: rgba(255, 255, 255, 0.03);
    backdrop-filter: blur(16px);
    border: 1px solid rgba(255, 255, 255, 0.08);
    border-radius: 20px;
    padding: 32px;
    margin: 24px 0;
    box-shadow: 
        0 4px 6px -1px rgba(0, 0, 0, 0.1),
        0 2px 4px -1px rgba(0, 0, 0, 0.06),
        inset 0 1px 0 rgba(255, 255, 255, 0.05);
}

.control-panel h2 {
    font-size: 1.5em;
    font-weight: 600;
    margin: 0 0 24px 0;
    color: #f8fafc;
    display: flex;
    align-items: center;
    gap: 12px;
}

.control-panel h3 {
    font-size: 1.2em;
    font-weight: 600;
    margin: 24px 0 16px 0;
    color: #e2e8f0;
    display: flex;
    align-items: center;
    gap: 8px;
}

.metric-card {
    background: linear-gradient(145deg, 
        rgba(255, 255, 255, 0.1) 0%, 
        rgba(255, 255, 255, 0.05) 100%);
    backdrop-filter: blur(10px);
    border: 1px solid rgba(255, 255, 255, 0.1);
    border-radius: 16px;
    padding: 24px;
    margin: 20px 0;
    box-shadow: 
        0 4px 6px -1px rgba(0, 0, 0, 0.1),
        0 2px 4px -1px rgba(0, 0, 0, 0.06);
    color: #f1f5f9;
}

.metric-card h3 {
    font-size: 1.3em;
    font-weight: 600;
    margin: 0 0 16px 0;
    color: #ffffff;
    display: flex;
    align-items: center;
    gap: 10px;
}

.metric-card p {
    margin: 8px 0;
    color: #cbd5e1;
}

.metric-card strong {
    color: #e2e8f0;
    font-weight: 600;
}

.metric-card ul {
    margin: 12px 0;
    padding-left: 20px;
}

.metric-card li {
    margin: 6px 0;
    color: #94a3b8;
}

.token-info {
    background: rgba(15, 23, 42, 0.8);
    backdrop-filter: blur(12px);
    border: 1px solid rgba(100, 116, 139, 0.2);
    color: #f1f5f9;
    padding: 24px;
    border-radius: 16px;
    margin: 20px 0;
    font-family: 'JetBrains Mono', 'Monaco', 'Consolas', monospace;
    box-shadow: 
        0 4px 6px -1px rgba(0, 0, 0, 0.1),
        0 2px 4px -1px rgba(0, 0, 0, 0.06);
}

.token-info h3 {
    font-size: 1.2em;
    font-weight: 600;
    margin: 0 0 20px 0;
    color: #ffffff;
    display: flex;
    align-items: center;
    gap: 10px;
}

.search-result {
    margin: 16px 0;
    padding: 16px;
    border-left: 3px solid #3b82f6;
    background: rgba(59, 130, 246, 0.1);
    border-radius: 0 12px 12px 0;
    transition: all 0.2s ease;
}

.search-result:hover {
    background: rgba(59, 130, 246, 0.15);
    transform: translateX(4px);
}

.search-result strong {
    color: #60a5fa;
    font-weight: 600;
}

.token-detail {
    background: linear-gradient(145deg, 
        rgba(16, 185, 129, 0.15) 0%, 
        rgba(16, 185, 129, 0.05) 100%);
    border: 1px solid rgba(16, 185, 129, 0.3);
    border-radius: 12px;
    padding: 20px;
    margin: 16px 0;
    color: #f1f5f9;
}

.vector-mode {
    background: rgba(139, 92, 246, 0.1);
    border: 1px solid rgba(139, 92, 246, 0.3);
    border-radius: 8px;
    padding: 12px;
    margin: 12px 0;
    color: #c4b5fd;
}

/* Widget styling improvements */
.widget-label {
    color: #e2e8f0 !important;
    font-weight: 500 !important;
}

.widget-dropdown select {
    background: rgba(30, 41, 59, 0.8) !important;
    border: 1px solid rgba(100, 116, 139, 0.3) !important;
    color: #f1f5f9 !important;
    border-radius: 8px !important;
}

.widget-text input {
    background: rgba(30, 41, 59, 0.8) !important;
    border: 1px solid rgba(100, 116, 139, 0.3) !important;
    color: #f1f5f9 !important;
    border-radius: 8px !important;
    padding: 8px 12px !important;
}

.widget-text input::placeholder {
    color: #64748b !important;
}

.widget-button {
    background: linear-gradient(135deg, #3b82f6 0%, #1d4ed8 100%) !important;
    border: none !important;
    border-radius: 10px !important;
    padding: 10px 20px !important;
    font-weight: 600 !important;
    transition: all 0.2s ease !important;
    box-shadow: 0 2px 4px rgba(59, 130, 246, 0.2) !important;
}

.widget-button:hover {
    transform: translateY(-1px) !important;
    box-shadow: 0 4px 8px rgba(59, 130, 246, 0.3) !important;
}

.widget-button.btn-success {
    background: linear-gradient(135deg, #10b981 0%, #059669 100%) !important;
    box-shadow: 0 2px 4px rgba(16, 185, 129, 0.2) !important;
}

.widget-button.btn-success:hover {
    box-shadow: 0 4px 8px rgba(16, 185, 129, 0.3) !important;
}

.widget-button.btn-warning {
    background: linear-gradient(135deg, #f59e0b 0%, #d97706 100%) !important;
    box-shadow: 0 2px 4px rgba(245, 158, 11, 0.2) !important;
}

.widget-button.btn-info {
    background: linear-gradient(135deg, #06b6d4 0%, #0891b2 100%) !important;
    box-shadow: 0 2px 4px rgba(6, 182, 212, 0.2) !important;
}

/* Toggle button styling */
.widget-toggle-button {
    background: rgba(100, 116, 139, 0.3) !important;
    border: 1px solid rgba(100, 116, 139, 0.5) !important;
    color: #94a3b8 !important;
    border-radius: 8px !important;
    transition: all 0.2s ease !important;
}

.widget-toggle-button.selected {
    background: linear-gradient(135deg, #8b5cf6 0%, #7c3aed 100%) !important;
    border-color: #8b5cf6 !important;
    color: white !important;
}

/* Slider improvements */
.widget-hslider .ui-slider {
    background: rgba(100, 116, 139, 0.3) !important;
    border-radius: 6px !important;
}

.widget-hslider .ui-slider .ui-slider-handle {
    background: linear-gradient(135deg, #3b82f6 0%, #1d4ed8 100%) !important;
    border: none !important;
    border-radius: 50% !important;
    box-shadow: 0 2px 4px rgba(59, 130, 246, 0.3) !important;
}

/* Add some subtle animations */
.control-panel, .metric-card, .token-info {
    animation: fadeInUp 0.5s ease-out;
}

@keyframes fadeInUp {
    from {
        opacity: 0;
        transform: translateY(20px);
    }
    to {
        opacity: 1;
        transform: translateY(0);
    }
}

/* Status indicators */
.status-success {
    color: #10b981;
    font-weight: 600;
}

.status-error {
    color: #ef4444;
    font-weight: 600;
}

.status-loading {
    color: #f59e0b;
    font-weight: 600;
}

/* Improve code blocks */
code {
    background: rgba(30, 41, 59, 0.8);
    color: #60a5fa;
    padding: 2px 6px;
    border-radius: 6px;
    font-family: 'JetBrains Mono', monospace;
    font-size: 0.9em;
    font-weight: 500;
}
</style>
"""))

# Header with enhanced styling
display(HTML("""
<div class="visualizer-header">
    <h1>Vector Embedding Visualizer</h1>
    <p>Explore transformer model embeddings in interactive 3D space</p>
    <small>Interactive tool for understanding how models encode language</small>
</div>
"""))

# =====================================================================================
# CORE VISUALIZER CLASS
# =====================================================================================

class EmbeddingVisualizer:
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.embeddings = None
        self.tokens = None
        self.reduced_embeddings = None
        self.current_model_name = None
        self.token_metadata = {}
        self.selected_token_idx = None
        self.vector_mode = False
        
    def load_custom_embeddings(self, embedding_data, vocab_data=None):
        """Load custom embedding matrix and vocabulary"""
        print("Loading custom embeddings...")
        
        try:
            # Handle different embedding formats
            if isinstance(embedding_data, str):
                # File path
                if embedding_data.endswith('.npy'):
                    embeddings = np.load(embedding_data)
                elif embedding_data.endswith('.npz'):
                    data = np.load(embedding_data)
                    embeddings = data['embeddings'] if 'embeddings' in data else data[data.files[0]]
                elif embedding_data.endswith('.csv'):
                    df = pd.read_csv(embedding_data)
                    embeddings = df.values
                else:
                    print(f"Unsupported file format: {embedding_data}")
                    return False
            else:
                # Direct numpy array
                embeddings = embedding_data
            
            # Validate embeddings
            if len(embeddings.shape) != 2:
                print(f"Error: Embeddings must be 2D matrix, got shape {embeddings.shape}")
                return False
            
            self.embeddings = embeddings.astype(np.float32)
            self.current_model_name = "custom_embeddings"
            self.model = None  # No underlying transformer model
            
            # Handle vocabulary
            if vocab_data is not None:
                if isinstance(vocab_data, str):
                    # File path
                    if vocab_data.endswith('.json'):
                        with open(vocab_data, 'r') as f:
                            vocab = json.load(f)
                        if isinstance(vocab, dict):
                            # Assume it's {token: id} mapping
                            self.custom_vocab = {int(v): k for k, v in vocab.items()}
                        else:
                            # Assume it's a list of tokens
                            self.custom_vocab = {i: token for i, token in enumerate(vocab)}
                    elif vocab_data.endswith('.txt'):
                        with open(vocab_data, 'r') as f:
                            tokens = [line.strip() for line in f.readlines()]
                        self.custom_vocab = {i: token for i, token in enumerate(tokens)}
                elif isinstance(vocab_data, (list, dict)):
                    if isinstance(vocab_data, list):
                        self.custom_vocab = {i: token for i, token in enumerate(vocab_data)}
                    else:
                        self.custom_vocab = vocab_data
                else:
                    print("Unsupported vocabulary format")
                    self.custom_vocab = None
            else:
                # Generate default token names
                self.custom_vocab = {i: f"token_{i}" for i in range(len(embeddings))}
            
            self.tokenizer = None  # Custom tokenizer object
            
            print(f"Custom embeddings loaded successfully!")
            print(f"   Vocabulary size: {self.embeddings.shape[0]:,}")
            print(f"   Embedding dimension: {self.embeddings.shape[1]}")
            
            return True
            
        except Exception as e:
            print(f"Error loading custom embeddings: {str(e)}")
            return False
    
    def load_model(self, model_name='gpt2', custom_path=None, custom_embeddings=None, custom_vocab=None):
        """Load transformer model and tokenizer or custom embeddings"""
        
        # Handle custom embeddings
        if custom_embeddings is not None:
            return self.load_custom_embeddings(custom_embeddings, custom_vocab)
        
        # Handle custom model path
        if model_name == 'custom' and custom_path:
            model_name = custom_path
            print(f"Loading custom model from: {custom_path}")
        else:
            print(f"Loading model: {model_name}")
        
        try:
            if model_name == 'custom' and not custom_path:
                print("Please specify a custom model path!")
                return False
                
            if 'bert' in model_name.lower():
                # For BERT models
                self.tokenizer = AutoTokenizer.from_pretrained(model_name)
                self.model = AutoModel.from_pretrained(model_name)
            else:
                # For GPT-2 models and other custom models
                self.tokenizer = AutoTokenizer.from_pretrained(model_name)
                self.model = AutoModel.from_pretrained(model_name)
                
                # Fallback to GPT2 tokenizer for unknown models
                if self.tokenizer is None:
                    print("Falling back to GPT2Tokenizer...")
                    self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
                    self.model = AutoModel.from_pretrained(model_name)
            
            # Handle padding token for GPT-2
            if hasattr(self.tokenizer, 'pad_token') and self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
                
            self.current_model_name = model_name
            
            # Extract embeddings
            self.embeddings = self.model.get_input_embeddings().weight.data.cpu().numpy()
            print(f"Model loaded successfully!")
            print(f"   Vocabulary size: {self.embeddings.shape[0]:,}")
            print(f"   Embedding dimension: {self.embeddings.shape[1]}")
            
            return True
            
        except Exception as e:
            print(f"Error loading model: {str(e)}")
            print("For custom models, ensure they are compatible with Hugging Face transformers or that embeddings can be extracted")
            return False
    
    def prepare_tokens(self, top_n=3000):
        """Prepare token data with metadata"""
        print(f"Preparing top {top_n:,} tokens...")
        
        # Select top N embeddings (roughly by frequency for most models)
        self.selected_embeddings = self.embeddings[:top_n]
        
        # Decode tokens and create metadata
        self.tokens = []
        self.token_metadata = {
            'lengths': [],
            'types': [],
            'has_special_chars': [],
            'is_uppercase': [],
            'is_digit': [],
            'frequency_rank': [],
            'embedding_norm': [],
            'cosine_distance_from_origin': []
        }
        
        for i in range(top_n):
            try:
                # Decode token - handle custom vocabularies
                if hasattr(self, 'custom_vocab') and self.custom_vocab:
                    token = self.custom_vocab.get(i, f"<UNK_{i}>")
                    clean_token = str(token).strip() if str(token).strip() else f"<TOKEN_{i}>"
                elif self.tokenizer is not None:
                    token = self.tokenizer.decode([i])
                    clean_token = token.strip() if token.strip() else f"<TOKEN_{i}>"
                else:
                    clean_token = f"<TOKEN_{i}>"
                
                self.tokens.append(clean_token)
                
                # Compute metadata
                embedding = self.selected_embeddings[i]
                
                self.token_metadata['lengths'].append(len(clean_token))
                self.token_metadata['frequency_rank'].append(i)
                self.token_metadata['has_special_chars'].append(bool(re.search(r'[^a-zA-Z0-9\s]', clean_token)))
                self.token_metadata['is_uppercase'].append(clean_token.isupper())
                self.token_metadata['is_digit'].append(clean_token.isdigit())
                self.token_metadata['embedding_norm'].append(float(np.linalg.norm(embedding)))
                self.token_metadata['cosine_distance_from_origin'].append(float(1 - np.dot(embedding, np.zeros_like(embedding))))
                
                # Classify token type
                if clean_token.startswith('<') and clean_token.endswith('>'):
                    token_type = 'special'
                elif clean_token.isdigit():
                    token_type = 'number'
                elif clean_token.isalpha():
                    token_type = 'word'
                elif clean_token.startswith('token_'):
                    token_type = 'custom'
                else:
                    token_type = 'mixed'
                
                self.token_metadata['types'].append(token_type)
                
            except Exception as e:
                # Handle problematic tokens
                self.tokens.append(f"<UNK_{i}>")
                self.token_metadata['lengths'].append(0)
                self.token_metadata['frequency_rank'].append(i)
                self.token_metadata['has_special_chars'].append(True)
                self.token_metadata['is_uppercase'].append(False)
                self.token_metadata['is_digit'].append(False)
                self.token_metadata['embedding_norm'].append(0.0)
                self.token_metadata['cosine_distance_from_origin'].append(1.0)
                self.token_metadata['types'].append('unknown')
        
        print(f"Tokens prepared!")
        print(f"   Sample tokens: {self.tokens[:10]}")
    
    def reduce_dimensions(self, method='umap', n_components=3, **kwargs):
        """Reduce embedding dimensions"""
        print(f"Reducing dimensions using {method.upper()}...")
        
        if method == 'umap':
            # UMAP parameters
            n_neighbors = kwargs.get('n_neighbors', 15)
            min_dist = kwargs.get('min_dist', 0.1)
            metric = kwargs.get('metric', 'cosine')
            
            reducer = umap.UMAP(
                n_neighbors=n_neighbors,
                min_dist=min_dist,
                metric=metric,
                n_components=n_components,
                random_state=42
            )
            
        elif method == 'pca':
            reducer = PCA(n_components=n_components, random_state=42)
        
        else:
            raise ValueError(f"Unknown reduction method: {method}")
        
        self.reduced_embeddings = reducer.fit_transform(self.selected_embeddings)
        print(f"Dimensions reduced to {n_components}D")
        
        return reducer
    
    def find_neighbors(self, token_idx, n_neighbors=10, metric='cosine'):
        """Find nearest neighbors for a token"""
        if metric == 'cosine':
            similarities = cosine_similarity([self.selected_embeddings[token_idx]], 
                                           self.selected_embeddings)[0]
            # Convert to distances (1 - similarity)
            distances = 1 - similarities
        elif metric == 'euclidean':
            distances = euclidean_distances([self.selected_embeddings[token_idx]], 
                                          self.selected_embeddings)[0]
        else:
            raise ValueError(f"Unsupported metric: {metric}")
        
        # Get indices of nearest neighbors (excluding self)
        neighbor_indices = np.argsort(distances)[1:n_neighbors+1]
        neighbor_distances = distances[neighbor_indices]
        
        neighbors = []
        for idx, dist in zip(neighbor_indices, neighbor_distances):
            neighbors.append({
                'token': self.tokens[idx],
                'distance': dist,
                'index': idx
            })
        
        return neighbors
    
    def search_tokens(self, query, max_results=50):
        """Search for tokens matching a query"""
        query_lower = query.lower()
        results = []
        
        for i, token in enumerate(self.tokens):
            token_lower = token.lower()
            if query_lower in token_lower:
                results.append({
                    'token': token,
                    'index': i,
                    'match_type': 'contains'
                })
                
                if len(results) >= max_results:
                    break
        
        return results
    
    def get_token_details(self, token_idx):
        """Get detailed information about a token"""
        if token_idx >= len(self.tokens):
            return None
            
        token = self.tokens[token_idx]
        embedding = self.selected_embeddings[token_idx]
        
        details = {
            'token': token,
            'index': token_idx,
            'length': self.token_metadata['lengths'][token_idx],
            'type': self.token_metadata['types'][token_idx],
            'frequency_rank': self.token_metadata['frequency_rank'][token_idx],
            'embedding_norm': self.token_metadata['embedding_norm'][token_idx],
            'has_special_chars': self.token_metadata['has_special_chars'][token_idx],
            'is_uppercase': self.token_metadata['is_uppercase'][token_idx],
            'is_digit': self.token_metadata['is_digit'][token_idx],
        }
        
        if self.reduced_embeddings is not None:
            details.update({
                'x': float(self.reduced_embeddings[token_idx, 0]),
                'y': float(self.reduced_embeddings[token_idx, 1]),
                'z': float(self.reduced_embeddings[token_idx, 2]) if self.reduced_embeddings.shape[1] > 2 else 0.0
            })
        
        return details
    
    def export_metadata(self):
        """Export comprehensive token metadata"""
        if self.tokens is None:
            return None
            
        data = []
        for i in range(len(self.tokens)):
            entry = self.get_token_details(i)
            if self.reduced_embeddings is not None and 'x' not in entry:
                entry.update({
                    'x': float(self.reduced_embeddings[i, 0]),
                    'y': float(self.reduced_embeddings[i, 1]),
                    'z': float(self.reduced_embeddings[i, 2]) if self.reduced_embeddings.shape[1] > 2 else 0.0
                })
            data.append(entry)
        
        df = pd.DataFrame(data)
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f"token_metadata_{self.current_model_name.replace('/', '_')}_{timestamp}.csv"
        df.to_csv(filename, index=False)
        
        return df, filename

# Initialize visualizer
viz = EmbeddingVisualizer()

# =====================================================================================
# INTERACTIVE CONTROLS
# =====================================================================================

# Model selection
model_dropdown = widgets.Dropdown(
    options=[
        ('GPT-2 (124M)', 'gpt2'),
        ('GPT-2 Medium (355M)', 'gpt2-medium'),
        ('GPT-2 Large (774M)', 'gpt2-large'),
        ('DistilGPT-2', 'distilgpt2'),
        ('BERT Base Uncased', 'bert-base-uncased'),
        ('DistilBERT', 'distilbert-base-uncased'),
        ('Custom Model', 'custom'),
    ],
    value='gpt2',
    description='Model:',
    style={'description_width': 'initial'}
)

# Custom model inputs (initially hidden)
custom_model_path = widgets.Text(
    placeholder='Enter Hugging Face model name or local path...',
    description='Model Path:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='400px')
)

custom_embeddings_upload = widgets.FileUpload(
    accept='.npy,.npz,.pt,.pkl,.csv',
    multiple=False,
    description='Or upload embeddings:',
    style={'description_width': 'initial'}
)

custom_tokenizer_upload = widgets.FileUpload(
    accept='.json,.txt,.pkl',
    multiple=False,
    description='Upload vocab/tokenizer:',
    style={'description_width': 'initial'}
)

# Container for custom model controls (initially hidden)
custom_model_container = widgets.VBox([
    widgets.HTML("<p style='color: #94a3b8; font-size: 0.9em; margin: 10px 0;'>Load custom model from Hugging Face or local files:</p>"),
    custom_model_path,
    widgets.HTML("<p style='color: #64748b; font-size: 0.8em; margin: 5px 0;'>OR upload embedding matrix directly:</p>"),
    custom_embeddings_upload,
    widgets.HTML("<p style='color: #64748b; font-size: 0.8em; margin: 5px 0;'>Optional: Upload custom vocabulary/tokenizer:</p>"),
    custom_tokenizer_upload,
], layout=widgets.Layout(display='none'))  # Initially hidden

# Visualization mode toggle
viz_mode_toggle = widgets.ToggleButtons(
    options=[('Scatter Plot', False), ('Vector Plot', True)],
    value=False,
    description='Visualization:',
    button_style='',
    tooltips=['Standard scatter plot view', 'Vector field visualization']
)

# Parameters
top_n_slider = widgets.IntSlider(value=3000, min=500, max=10000, step=500, description='Top N tokens:')
n_neighbors_slider = widgets.IntSlider(value=15, min=5, max=50, step=5, description='UMAP neighbors:')
min_dist_slider = widgets.FloatSlider(value=0.1, min=0.01, max=1.0, step=0.01, description='UMAP min dist:')

# Distance metric for neighbors
distance_metric = widgets.Dropdown(
    options=['cosine', 'euclidean'],
    value='cosine',
    description='Distance metric:'
)

# Color scheme
color_scheme = widgets.Dropdown(
    options=['Viridis', 'Plasma', 'Inferno', 'Magma', 'Rainbow', 'Turbo', 'Electric'],
    value='Viridis',
    description='Color scheme:'
)

# Search box
search_box = widgets.Text(
    placeholder='Search tokens...',
    description='Search:',
    style={'description_width': 'initial'}
)

# Buttons
load_button = widgets.Button(description='Load Model', button_style='primary')
visualize_button = widgets.Button(description='Create Visualization', button_style='success')
export_button = widgets.Button(description='Download All Metadata', button_style='warning')
export_selected_button = widgets.Button(description='Download Selected Token', button_style='info')

# Output areas
model_output = widgets.Output()
viz_output = widgets.Output()
search_output = widgets.Output()
token_detail_output = widgets.Output()

# Selected token info widget
selected_token_info = widgets.HTML()

# Global variable to store the current figure for click handling
current_fig = None

def model_dropdown_change(change):
    """Show/hide custom model inputs based on selection"""
    if change['new'] == 'custom':
        custom_model_container.layout.display = 'block'
    else:
        custom_model_container.layout.display = 'none'

def load_model_callback(b):
    with model_output:
        clear_output(wait=True)
        
        # Handle different loading scenarios
        if model_dropdown.value == 'custom':
            # Check if custom embeddings were uploaded
            if custom_embeddings_upload.value:
                try:
                    # Handle uploaded embedding file
                    uploaded_file = list(custom_embeddings_upload.value.values())[0]
                    
                    # Save uploaded file temporarily
                    with tempfile.NamedTemporaryFile(delete=False, suffix='.npy') as tmp_file:
                        tmp_file.write(uploaded_file['content'])
                        tmp_path = tmp_file.name
                    
                    # Handle vocabulary if uploaded
                    vocab_path = None
                    if custom_tokenizer_upload.value:
                        vocab_file = list(custom_tokenizer_upload.value.values())[0]
                        with tempfile.NamedTemporaryFile(delete=False, suffix='.json', mode='w') as vocab_tmp:
                            if vocab_file['name'].endswith('.json'):
                                vocab_tmp.write(vocab_file['content'].decode('utf-8'))
                            else:
                                vocab_tmp.write(vocab_file['content'].decode('utf-8'))
                            vocab_path = vocab_tmp.name
                    
                    # Load custom embeddings
                    success = viz.load_custom_embeddings(tmp_path, vocab_path)
                    
                    # Cleanup temporary files
                    os.unlink(tmp_path)
                    if vocab_path:
                        os.unlink(vocab_path)
                        
                except Exception as e:
                    print(f"Error processing uploaded files: {str(e)}")
                    success = False
                    
            elif custom_model_path.value.strip():
                # Load from Hugging Face or local path
                success = viz.load_model('custom', custom_path=custom_model_path.value.strip())
            else:
                print("Please either upload embedding files or specify a model path!")
                return
        else:
            # Load preset model
            success = viz.load_model(model_dropdown.value)
        
        if success:
            viz.prepare_tokens(top_n_slider.value)

def create_visualization(b):
    """Fixed version of create_visualization with proper click handling"""
    global current_fig
    
    with viz_output:
        clear_output(wait=True)
        
        if viz.embeddings is None:
            print("Please load a model first!")
            return
        
        # Update vector mode
        viz.vector_mode = viz_mode_toggle.value
        
        # Reduce dimensions
        viz.reduce_dimensions(
            method='umap',
            n_components=3,
            n_neighbors=n_neighbors_slider.value,
            min_dist=min_dist_slider.value,
            metric='cosine'
        )
        
        print("Creating 3D visualization...")
        
        # Color by token length or embedding norm based on mode
        if viz.vector_mode:
            colors = viz.token_metadata['embedding_norm']
            color_title = "Embedding Norm"
        else:
            colors = viz.token_metadata['lengths']
            color_title = "Token Length"
        
        # Create hover text with rich information
        hover_text = []
        for i, token in enumerate(viz.tokens):
            hover_info = f"""
            <b>Token:</b> {token}<br>
            <b>Index:</b> {i}<br>
            <b>Length:</b> {viz.token_metadata['lengths'][i]}<br>
            <b>Type:</b> {viz.token_metadata['types'][i]}<br>
            <b>Frequency Rank:</b> {viz.token_metadata['frequency_rank'][i]}<br>
            <b>Embedding Norm:</b> {viz.token_metadata['embedding_norm'][i]:.3f}
            """
            hover_text.append(hover_info)
        
        # Create the figure with improved click handling
        if viz.vector_mode:
            # Vector visualization
            fig = go.Figure()
            
            # Add origin point
            fig.add_trace(go.Scatter3d(
                x=[0], y=[0], z=[0],
                mode='markers',
                marker=dict(size=8, color='white', symbol='diamond'),
                name='Origin',
                showlegend=True
            ))
            
            # Add vectors
            x_coords = []
            y_coords = []
            z_coords = []
            
            for i in range(len(viz.tokens)):
                x_coords.extend([0, viz.reduced_embeddings[i, 0], None])
                y_coords.extend([0, viz.reduced_embeddings[i, 1], None])
                z_coords.extend([0, viz.reduced_embeddings[i, 2], None])
            
            fig.add_trace(go.Scatter3d(
                x=x_coords,
                y=y_coords,
                z=z_coords,
                mode='lines',
                line=dict(color='rgba(100, 150, 200, 0.3)', width=1),
                showlegend=False,
                hoverinfo='skip',
                name='Vectors'
            ))
            
            # Add clickable points
            fig.add_trace(go.Scatter3d(
                x=viz.reduced_embeddings[:, 0],
                y=viz.reduced_embeddings[:, 1],
                z=viz.reduced_embeddings[:, 2],
                mode='markers',
                marker=dict(
                    size=4,
                    opacity=0.9,
                    color=colors,
                    colorscale=color_scheme.value,
                    showscale=True,
                    colorbar=dict(
                        title=dict(text=color_title, font=dict(color='white')),
                        tickfont=dict(color='white')
                    )
                ),
                text=[f"{i}:{token}" for i, token in enumerate(viz.tokens)],  # Include index
                hovertext=hover_text,
                hovertemplate='<b>%{text}</b><br>%{hovertext}<extra></extra>',
                name='Token Vectors',
                customdata=list(range(len(viz.tokens)))  # Store indices as custom data
            ))
            
        else:
            # Standard scatter plot
            fig = go.Figure(data=go.Scatter3d(
                x=viz.reduced_embeddings[:, 0],
                y=viz.reduced_embeddings[:, 1],
                z=viz.reduced_embeddings[:, 2],
                mode='markers',
                marker=dict(
                    size=3,
                    opacity=0.8,
                    color=colors,
                    colorscale=color_scheme.value,
                    showscale=True,
                    colorbar=dict(
                        title=dict(text=color_title, font=dict(color='white')),
                        tickfont=dict(color='white')
                    )
                ),
                text=[f"{i}:{token}" for i, token in enumerate(viz.tokens)],  # Include index
                hovertext=hover_text,
                hovertemplate='<b>%{text}</b><br>%{hovertext}<extra></extra>',
                name='Tokens',
                customdata=list(range(len(viz.tokens)))  # Store indices as custom data
            ))
        
        # Modern dark theme styling
        fig.update_layout(
            title=dict(
                text=f'{viz.current_model_name} Token Embeddings ({len(viz.tokens):,} tokens) - {"Vector" if viz.vector_mode else "Scatter"} Mode',
                font=dict(size=20, color='white'),
                x=0.5
            ),
            scene=dict(
                xaxis=dict(
                    showbackground=False,
                    showgrid=True,
                    gridcolor='rgba(255, 255, 255, 0.1)',
                    title=dict(text='UMAP 1', font=dict(color='white')),
                    tickfont=dict(color='white')
                ),
                yaxis=dict(
                    showbackground=False,
                    showgrid=True,
                    gridcolor='rgba(255, 255, 255, 0.1)',
                    title=dict(text='UMAP 2', font=dict(color='white')),
                    tickfont=dict(color='white')
                ),
                zaxis=dict(
                    showbackground=False,
                    showgrid=True,
                    gridcolor='rgba(255, 255, 255, 0.1)',
                    title=dict(text='UMAP 3', font=dict(color='white')),
                    tickfont=dict(color='white')
                ),
                bgcolor='rgba(0, 0, 0, 0)',
                camera=dict(
                    eye=dict(x=1.5, y=1.5, z=1.5)
                )
            ),
            paper_bgcolor='rgba(0, 0, 0, 0.9)',
            plot_bgcolor='rgba(0, 0, 0, 0)',
            font=dict(color='white'),
            height=700
        )
        
        # Store figure globally
        current_fig = fig
        
        # Create the FigureWidget for interactive callbacks
        fig_widget = go.FigureWidget(fig)
        
        # Add click handler to the appropriate trace
        if viz.vector_mode:
            fig_widget.data[-1].on_click(handle_point_click)  # Points trace
        else:
            fig_widget.data[0].on_click(handle_point_click)   # Scatter trace
        
        display(fig_widget)
        
        # Show statistics
        display(HTML(f"""
        <div class="metric-card">
            <h3>Embedding Statistics</h3>
            <p><strong>Model:</strong> <span style="color: #60a5fa;">{viz.current_model_name}</span></p>
            <p><strong>Tokens visualized:</strong> <span style="color: #34d399;">{len(viz.tokens):,}</span></p>
            <p><strong>Original dimension:</strong> <span style="color: #fbbf24;">{viz.embeddings.shape[1]}</span></p>
            <p><strong>Reduced to:</strong> <span style="color: #a78bfa;">3D</span></p>
            <p><strong>Visualization mode:</strong> <span style="color: #f97316;">{"Vector Field" if viz.vector_mode else "Scatter Plot"}</span></p>
        </div>
        """))
        
        # Count token types
        type_counts = {}
        for t in viz.token_metadata['types']:
            type_counts[t] = type_counts.get(t, 0) + 1
        
        display(HTML("<p><strong>Token distribution:</strong></p><ul style='margin-left: 20px;'>"))
        
        for token_type, count in sorted(type_counts.items()):
            percentage = (count / len(viz.tokens)) * 100
            display(HTML(f"""
                <li style="margin: 8px 0;">
                    <span style="color: #e2e8f0; font-weight: 500;">{token_type}:</span> 
                    <span style="color: #94a3b8;">{count:,} tokens</span>
                    <span style="color: #64748b; font-size: 0.9em;">({percentage:.1f}%)</span>
                </li>
            """))
        
        display(HTML("</ul></div>"))

def handle_point_click(trace, points, selector):
    """Fixed point click handler"""
    if not points.point_inds:
        return
        
    point_idx = points.point_inds[0]
    viz.selected_token_idx = point_idx
    
    # Update the selected token display
    with token_detail_output:
        clear_output(wait=True)
        show_token_details(point_idx)
    
    # Update the selection info widget
    token_name = viz.tokens[point_idx]
    selected_token_info.value = f"""
    <div style="background: rgba(16, 185, 129, 0.2); border: 1px solid rgba(16, 185, 129, 0.5); 
                border-radius: 8px; padding: 12px; margin: 10px 0; text-align: center;">
        <strong>Selected Token:</strong> 
        <span style="color: #10b981; font-size: 1.2em;">"{token_name}"</span> 
        <span style="color: #64748b;">(Index: {point_idx})</span>
    </div>
    """

def show_token_details(token_idx):
    """Fixed token details display"""
    details = viz.get_token_details(token_idx)
    
    if details is None:
        return
    
    # Find neighbors
    neighbors = viz.find_neighbors(token_idx, n_neighbors=10, metric=distance_metric.value)
    
    display(HTML(f"""
    <div class="token-detail">
        <h3>üéØ Selected Token Details</h3>
        <p><strong>Token:</strong> <span style="color: #fbbf24; font-size: 1.3em; font-weight: bold;">"{details['token']}"</span></p>
        <p><strong>Index:</strong> {details['index']}</p>
        <p><strong>Character Length:</strong> {details['length']}</p>
        <p><strong>Token Type:</strong> <span style="color: #a78bfa;">{details['type']}</span></p>
        <p><strong>Frequency Rank:</strong> {details['frequency_rank']:,}</p>
        <p><strong>Embedding Norm:</strong> {details['embedding_norm']:.4f}</p>
        <p><strong>Special Characters:</strong> {"‚úÖ Yes" if details['has_special_chars'] else "‚ùå No"}</p>
        <p><strong>All Uppercase:</strong> {"‚úÖ Yes" if details['is_uppercase'] else "‚ùå No"}</p>
        <p><strong>Is Digit:</strong> {"‚úÖ Yes" if details['is_digit'] else "‚ùå No"}</p>
    </div>
    """))
    
    if 'x' in details:
        display(HTML(f"""
        <div class="vector-mode">
            <h4>üìç 3D Coordinates</h4>
            <p><strong>X:</strong> {details['x']:.4f}</p>
            <p><strong>Y:</strong> {details['y']:.4f}</p>
            <p><strong>Z:</strong> {details['z']:.4f}</p>
        </div>
        """))
    
    # Show nearest neighbors
    display(HTML("<div class='token-info'><h3>üîç Nearest Neighbors (Most Similar)</h3>"))
    
    for i, neighbor in enumerate(neighbors[:8]):
        similarity_score = 1 - neighbor['distance']  # Convert distance to similarity
        similarity_bar = "‚ñà" * int(similarity_score * 10) + "‚ñë" * (10 - int(similarity_score * 10))
        
        display(HTML(f"""
        <div class='search-result'>
            <strong>#{i+1}:</strong> "<span style="color: #fbbf24;">{neighbor['token']}</span>" 
            <br><span style="color: #10b981;">Similarity: {similarity_score:.3f}</span>
            <span style="color: #64748b;">[{similarity_bar}]</span>
            <span style="color: #94a3b8; font-size: 0.85em;">(distance: {neighbor['distance']:.4f})</span>
        </div>
        """))
    
    display(HTML("</div>"))

def export_selected_token_callback(b):
    """Fixed export function for selected token"""
    if viz.selected_token_idx is None:
        with token_detail_output:
            display(HTML("""
                <div class="metric-card" style="text-align: center; border-color: #ef4444;">
                    <h3>‚ùå No Token Selected</h3>
                    <p style="color: #ef4444;">Please select a token first by clicking on a point in the visualization above!</p>
                    <p style="color: #94a3b8; font-size: 0.9em;">The point will highlight and details will appear below.</p>
                </div>
            """))
        return
        
    try:
        # Get token details and neighbors
        token_idx = viz.selected_token_idx
        token_name = viz.tokens[token_idx]
        details = viz.get_token_details(token_idx)
        neighbors = viz.find_neighbors(token_idx, n_neighbors=15, metric=distance_metric.value)
        
        # Create comprehensive export data
        export_data = {
            'selected_token': details,
            'neighbors': neighbors,
            'embedding_vector': viz.selected_embeddings[token_idx].tolist() if hasattr(viz, 'selected_embeddings') else None,
            'export_timestamp': datetime.now().isoformat(),
            'model_name': viz.current_model_name,
            'visualization_parameters': {
                'umap_neighbors': n_neighbors_slider.value,
                'umap_min_dist': min_dist_slider.value,
                'distance_metric': distance_metric.value,
                'color_scheme': color_scheme.value,
                'vector_mode': viz.vector_mode
            }
        }
        
        # Create DataFrame with token and neighbors
        rows = []
        
        # Add selected token
        token_row = dict(details)
        token_row['relationship'] = 'selected'
        token_row['distance_from_selected'] = 0.0
        token_row['similarity_to_selected'] = 1.0
        rows.append(token_row)
        
        # Add neighbors
        for i, neighbor in enumerate(neighbors):
            neighbor_details = viz.get_token_details(neighbor['index'])
            neighbor_details['relationship'] = f'neighbor_rank_{i+1}'
            neighbor_details['distance_from_selected'] = neighbor['distance']
            neighbor_details['similarity_to_selected'] = 1 - neighbor['distance']
            rows.append(neighbor_details)
        
        df = pd.DataFrame(rows)
        
        # Generate clean filenames
        clean_token = re.sub(r'[^\w\-_\.]', '_', token_name)
        if len(clean_token) > 20:
            clean_token = clean_token[:20]
        
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        base_name = f"token_{clean_token}_idx{token_idx}_{timestamp}"
        
        csv_filename = f"{base_name}.csv"
        json_filename = f"{base_name}.json"
        
        # Save files
        df.to_csv(csv_filename, index=False)
        
        with open(json_filename, 'w', encoding='utf-8') as f:
            json.dump(export_data, f, indent=2, ensure_ascii=False)
        
        # Show success message
        with token_detail_output:
            clear_output(wait=True)
            
            display(HTML(f"""
            <div class="metric-card" style="border-color: #10b981;">
                <h3>‚úÖ Export Successful!</h3>
                <div style="text-align: center; margin: 20px 0;">
                    <h4 style="color: #10b981;">Selected Token: "{token_name}"</h4>
                    <p style="color: #64748b;">Index: {token_idx} | Type: {details.get('type', 'unknown')}</p>
                </div>
                
                <div style="background: rgba(59, 130, 246, 0.1); border: 1px solid rgba(59, 130, 246, 0.3); 
                           border-radius: 8px; padding: 16px; margin: 16px 0;">
                    <h4>üìÅ Files Created:</h4>
                    <p><strong>CSV:</strong> <code style="color: #60a5fa;">{csv_filename}</code></p>
                    <p><strong>JSON:</strong> <code style="color: #60a5fa;">{json_filename}</code></p>
                </div>
                
                <div style="background: rgba(16, 185, 129, 0.1); border: 1px solid rgba(16, 185, 129, 0.3); 
                           border-radius: 8px; padding: 16px; margin: 16px 0;">
                    <h4>üìä Export Contents:</h4>
                    <ul style="margin-left: 20px; color: #cbd5e1;">
                        <li><strong>Records:</strong> {len(df)} (1 selected token + {len(neighbors)} neighbors)</li>
                        <li><strong>Token metadata:</strong> Position, type, frequency, characteristics</li>
                        <li><strong>Neighbor analysis:</strong> Distance and similarity rankings</li>
                        <li><strong>3D coordinates:</strong> UMAP reduced dimensions</li>
                        <li><strong>Original embedding:</strong> Full {len(export_data.get('embedding_vector', []))}D vector (JSON only)</li>
                        <li><strong>Visualization settings:</strong> Parameters used for analysis</li>
                    </ul>
                </div>
                
                <div style="text-align: center; margin-top: 20px;">
                    <p style="color: #94a3b8; font-size: 0.9em;">
                        Files saved to your current working directory<br>
                        Ready for further analysis or sharing!
                    </p>
                </div>
            </div>
            """))
            
            # Re-display token details below export confirmation
            show_token_details(token_idx)
            
    except Exception as e:
        with token_detail_output:
            display(HTML(f"""
                <div class="metric-card" style="border-color: #ef4444;">
                    <h3>‚ùå Export Error</h3>
                    <p style="color: #ef4444;">Failed to export token data:</p>
                    <pre style="color: #fbbf24; background: rgba(239, 68, 68, 0.1); padding: 12px; 
                              border-radius: 6px; margin: 12px 0;">{str(e)}</pre>
                    <p style="color: #94a3b8; font-size: 0.9em;">
                        Please try again or check that you have write permissions in the current directory.
                    </p>
                </div>
            """))

def search_callback(change):
    with search_output:
        clear_output(wait=True)
        
        if not search_box.value or viz.tokens is None:
            return
        
        # Search for tokens
        results = viz.search_tokens(search_box.value)
        
        if results:
            display(HTML(f"<div class='token-info'><h3>Search Results ({len(results)} found)</h3>"))
            
            for i, result in enumerate(results[:15]):  # Show top 15
                token_idx = result['index']
                
                # Find neighbors
                neighbors = viz.find_neighbors(token_idx, n_neighbors=3, metric=distance_metric.value)
                neighbor_tokens = [n['token'] for n in neighbors]
                
                display(HTML(f"""
                <div class='search-result' style="cursor: pointer;">
                    <strong>Token:</strong> "<span style="color: #fbbf24;">{result['token']}</span>" 
                    <span style="color: #64748b;">(index: {token_idx})</span><br>
                    <strong>Nearest neighbors:</strong> 
                    <span style="color: #a78bfa;">{', '.join(neighbor_tokens)}</span>
                    <br><small style="color: #64748b;">Click on visualization to select and export</small>
                </div>
                """))
                
                if i >= 14:  # Limit display
                    break
                    
            display(HTML("</div>"))
        else:
            display(HTML("""
                <div class='token-info' style='text-align: center; opacity: 0.8;'>
                    <p>No tokens found matching your search.</p>
                    <p style='color: #64748b; font-size: 0.9em;'>Try a different search term or check your spelling.</p>
                </div>
            """))

def export_metadata_callback(b):
    """Export all token metadata to CSV"""
    with model_output:
        if viz.tokens is None:
            print("Please load a model first!")
            return
            
        try:
            df, filename = viz.export_metadata()
            print(f"All metadata exported successfully!")
            print(f"File: {filename}")
            print(f"Records: {len(df):,}")
            print(f"Columns: {', '.join(df.columns.tolist())}")
            
            # Show sample of data
            display(HTML(f"""
            <div class="metric-card">
                <h3>Export Complete</h3>
                <p><strong>Filename:</strong> <code>{filename}</code></p>
                <p><strong>Records exported:</strong> {len(df):,}</p>
                <p><strong>Columns:</strong> {len(df.columns)}</p>
                <div style="margin-top: 16px;">
                    <strong>Sample data:</strong>
                    <div style="font-family: monospace; font-size: 0.9em; margin-top: 8px; max-height: 200px; overflow-y: auto;">
                        {df.head().to_html(classes='table table-dark', escape=False)}
                    </div>
                </div>
            </div>
            """))
            
        except Exception as e:
            print(f"Error exporting metadata: {str(e)}")

# Connect callbacks
model_dropdown.observe(model_dropdown_change, names='value')
load_button.on_click(load_model_callback)
visualize_button.on_click(create_visualization)
export_button.on_click(export_metadata_callback)
export_selected_button.on_click(export_selected_token_callback)
search_box.observe(search_callback, names='value')

# =====================================================================================
# DISPLAY INTERFACE
# =====================================================================================

display(HTML("<div class='control-panel'>"))
display(HTML("<h2>üéõÔ∏è Control Panel</h2>"))

# Model selection section
display(HTML("<h3>üì¶ Model Selection</h3>"))
display(model_dropdown)
display(custom_model_container)
display(widgets.HBox([load_button]))
display(model_output)

# Visualization mode section
display(HTML("<h3>üé® Visualization Mode</h3>"))
display(viz_mode_toggle)

# Parameters section
display(HTML("<h3>‚öôÔ∏è Parameters</h3>"))
display(widgets.HBox([top_n_slider, n_neighbors_slider, min_dist_slider]))
display(widgets.HBox([distance_metric, color_scheme]))
display(widgets.HBox([visualize_button, export_button, export_selected_button]))

# Search section
display(HTML("<h3>üîç Token Search & Inspection</h3>"))
display(search_box)
display(search_output)

# Selected token info
display(selected_token_info)

# Token details section
display(HTML("<h3>üéØ Selected Token Details</h3>"))
display(HTML("<p style='color: #94a3b8; font-size: 0.9em;'>Click on any point in the visualization above to inspect a token</p>"))
display(token_detail_output)

display(HTML("</div>"))

# Main visualization area
display(viz_output)

# =====================================================================================
# UTILITY FUNCTIONS
# =====================================================================================

def create_neighborhood_plot(token, n_neighbors=20):
    """Create a focused plot showing a token and its neighbors"""
    if viz.tokens is None:
        print("Please load a model first!")
        return
    
    # Find token index
    token_indices = [i for i, t in enumerate(viz.tokens) if t.lower() == token.lower()]
    
    if not token_indices:
        print(f"Token '{token}' not found!")
        return
    
    token_idx = token_indices[0]
    neighbors = viz.find_neighbors(token_idx, n_neighbors=n_neighbors, metric='cosine')
    
    # Get coordinates for token and neighbors
    indices = [token_idx] + [n['index'] for n in neighbors]
    coords = viz.reduced_embeddings[indices]
    labels = [f"TARGET: {viz.tokens[token_idx]}"] + [n['token'] for n in neighbors]
    colors = ['red'] + ['blue'] * len(neighbors)
    sizes = [12] + [8] * len(neighbors)
    
    fig = go.Figure(data=go.Scatter3d(
        x=coords[:, 0],
        y=coords[:, 1],
        z=coords[:, 2],
        mode='markers+text',
        marker=dict(size=sizes, color=colors, opacity=0.8),
        text=labels,
        textposition="middle right",
        textfont=dict(color='white', size=10)
    ))
    
    fig.update_layout(
        title=f"Neighborhood of '{token}' ({n_neighbors} nearest neighbors)",
        scene=dict(
            xaxis=dict(showbackground=False, title='UMAP 1', titlefont=dict(color='white')),
            yaxis=dict(showbackground=False, title='UMAP 2', titlefont=dict(color='white')),
            zaxis=dict(showbackground=False, title='UMAP 3', titlefont=dict(color='white'))
        ),
        paper_bgcolor='rgba(0, 0, 0, 0.9)',
        font=dict(color='white'),
        height=600
    )
    
    fig.show()

def compare_tokens(token1, token2):
    """Compare two tokens and their embeddings"""
    if viz.tokens is None:
        print("Please load a model first!")
        return
    
    # Find token indices
    idx1 = None
    idx2 = None
    
    for i, token in enumerate(viz.tokens):
        if token.lower() == token1.lower():
            idx1 = i
        if token.lower() == token2.lower():
            idx2 = i
    
    if idx1 is None:
        print(f"Token '{token1}' not found!")
        return
    if idx2 is None:
        print(f"Token '{token2}' not found!")
        return
    
    # Calculate similarity
    emb1 = viz.selected_embeddings[idx1]
    emb2 = viz.selected_embeddings[idx2]
    
    cosine_sim = cosine_similarity([emb1], [emb2])[0][0]
    euclidean_dist = euclidean_distances([emb1], [emb2])[0][0]
    
    display(HTML(f"""
    <div class="token-detail">
        <h3>üîÑ Token Comparison</h3>
        <p><strong>Token 1:</strong> "{token1}" (index: {idx1})</p>
        <p><strong>Token 2:</strong> "{token2}" (index: {idx2})</p>
        <p><strong>Cosine Similarity:</strong> <span style="color: #10b981;">{cosine_sim:.4f}</span></p>
        <p><strong>Euclidean Distance:</strong> <span style="color: #f59e0b;">{euclidean_dist:.4f}</span></p>
        <p><strong>Semantic Relationship:</strong> 
        <span style="color: {'#10b981' if cosine_sim > 0.8 else '#34d399' if cosine_sim > 0.6 else '#fbbf24' if cosine_sim > 0.4 else '#ef4444'};">
        {"Very Similar" if cosine_sim > 0.8 else "Similar" if cosine_sim > 0.6 else "Somewhat Related" if cosine_sim > 0.4 else "Different"}
        </span>
        </p>
    </div>
    """))

def load_embedding_file(file_path):
    """Helper function to load embeddings from various file formats"""
    if not os.path.exists(file_path):
        print(f"File not found: {file_path}")
        return None
        
    try:
        if file_path.endswith('.npy'):
            return np.load(file_path)
        elif file_path.endswith('.npz'):
            data = np.load(file_path)
            # Try common key names
            for key in ['embeddings', 'embedding', 'vectors', 'data']:
                if key in data:
                    return data[key]
            # Return first array if no standard key found
            return data[data.files[0]]
        elif file_path.endswith('.csv'):
            df = pd.read_csv(file_path)
            return df.values
        elif file_path.endswith('.pt'):
            try:
                import torch
                data = torch.load(file_path, map_location='cpu')
                if isinstance(data, torch.Tensor):
                    return data.numpy()
                elif isinstance(data, dict):
                    for key in ['embeddings', 'embedding', 'vectors', 'weight']:
                        if key in data:
                            return data[key].numpy() if hasattr(data[key], 'numpy') else data[key]
            except ImportError:
                print("PyTorch not available for loading .pt files")
                return None
        else:
            print(f"Unsupported file format: {file_path}")
            return None
    except Exception as e:
        print(f"Error loading file {file_path}: {str(e)}")
        return None

def batch_analyze_tokens(token_list):
    """Analyze multiple tokens at once and return comparison matrix"""
    if viz.tokens is None:
        print("Please load a model first!")
        return None
    
    # Find indices for all tokens
    token_indices = {}
    for token in token_list:
        for i, t in enumerate(viz.tokens):
            if t.lower() == token.lower():
                token_indices[token] = i
                break
    
    # Filter out tokens not found
    valid_tokens = list(token_indices.keys())
    if len(valid_tokens) < 2:
        print("Need at least 2 valid tokens for comparison")
        return None
    
    # Get embeddings
    embeddings = [viz.selected_embeddings[token_indices[token]] for token in valid_tokens]
    
    # Calculate similarity matrix
    similarity_matrix = cosine_similarity(embeddings)
    
    # Create DataFrame
    df = pd.DataFrame(similarity_matrix, index=valid_tokens, columns=valid_tokens)
    
    display(HTML("""
    <div class="metric-card">
        <h3>üîç Batch Token Analysis</h3>
        <p>Cosine similarity matrix for selected tokens:</p>
    </div>
    """))
    
    # Style the dataframe for better visualization
    styled_df = df.style.background_gradient(cmap='viridis', vmin=0, vmax=1).format('{:.3f}')
    display(styled_df)
    
    return df

def export_embedding_cluster(center_token, radius=0.3, max_tokens=100):
    """Export all tokens within a semantic cluster around a center token"""
    if viz.tokens is None:
        print("Please load a model first!")
        return None
    
    # Find center token
    center_idx = None
    for i, token in enumerate(viz.tokens):
        if token.lower() == center_token.lower():
            center_idx = i
            break
    
    if center_idx is None:
        print(f"Token '{center_token}' not found!")
        return None
    
    # Calculate distances from center token
    center_embedding = viz.selected_embeddings[center_idx]
    similarities = cosine_similarity([center_embedding], viz.selected_embeddings)[0]
    distances = 1 - similarities
    
    # Find tokens within radius
    cluster_indices = np.where(distances <= radius)[0]
    cluster_indices = cluster_indices[:max_tokens]  # Limit results
    
    # Create export data
    cluster_data = []
    for idx in cluster_indices:
        details = viz.get_token_details(idx)
        details['distance_from_center'] = distances[idx]
        details['similarity_to_center'] = similarities[idx]
        cluster_data.append(details)
    
    # Sort by similarity
    cluster_data.sort(key=lambda x: x['similarity_to_center'], reverse=True)
    
    df = pd.DataFrame(cluster_data)
    
    # Generate filename
    clean_token = re.sub(r'[^\w\-_\.]', '_', center_token)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"cluster_{clean_token}_r{radius}_{timestamp}.csv"
    
    df.to_csv(filename, index=False)
    
    display(HTML(f"""
    <div class="metric-card" style="border-color: #10b981;">
        <h3>‚úÖ Cluster Export Complete</h3>
        <p><strong>Center Token:</strong> "{center_token}"</p>
        <p><strong>Radius:</strong> {radius}</p>
        <p><strong>Tokens in cluster:</strong> {len(cluster_data)}</p>
        <p><strong>File:</strong> <code>{filename}</code></p>
    </div>
    """))
    
    return df, filename

# Enhanced utility functions display
display(HTML("""
<div class="control-panel">
    <h2>üõ†Ô∏è Advanced Analysis Tools</h2>
    <div style="background: rgba(59, 130, 246, 0.1); border: 1px solid rgba(59, 130, 246, 0.2); border-radius: 12px; padding: 20px;">
        <p style="margin: 0 0 16px 0; color: #e2e8f0; font-weight: 500;">Use these functions in code cells for advanced analysis:</p>
        
        <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 20px; margin: 20px 0;">
            <div style="background: rgba(16, 185, 129, 0.1); border: 1px solid rgba(16, 185, 129, 0.2); border-radius: 8px; padding: 16px;">
                <h4 style="color: #10b981; margin: 0 0 12px 0;">üìä Single Token Analysis</h4>
                <ul style="margin: 0; color: #cbd5e1; font-size: 0.9em;">
                    <li><code>create_neighborhood_plot("token")</code></li>
                    <li><code>show_token_details(index)</code></li>
                    <li><code>viz.find_neighbors(index, n=20)</code></li>
                </ul>
            </div>
            
            <div style="background: rgba(139, 92, 246, 0.1); border: 1px solid rgba(139, 92, 246, 0.2); border-radius: 8px; padding: 16px;">
                <h4 style="color: #8b5cf6; margin: 0 0 12px 0;">üîÑ Multi-Token Comparison</h4>
                <ul style="margin: 0; color: #cbd5e1; font-size: 0.9em;">
                    <li><code>compare_tokens("token1", "token2")</code></li>
                    <li><code>batch_analyze_tokens(["a", "b", "c"])</code></li>
                </ul>
            </div>
        </div>
        
        <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 20px; margin: 20px 0;">
            <div style="background: rgba(245, 158, 11, 0.1); border: 1px solid rgba(245, 158, 11, 0.2); border-radius: 8px; padding: 16px;">
                <h4 style="color: #f59e0b; margin: 0 0 12px 0;">üìÅ Data Import/Export</h4>
                <ul style="margin: 0; color: #cbd5e1; font-size: 0.9em;">
                    <li><code>load_embedding_file("path.npy")</code></li>
                    <li><code>export_embedding_cluster("center")</code></li>
                </ul>
            </div>
            
            <div style="background: rgba(6, 182, 212, 0.1); border: 1px solid rgba(6, 182, 212, 0.2); border-radius: 8px; padding: 16px;">
                <h4 style="color: #06b6d4; margin: 0 0 12px 0;">üéØ Direct Access</h4>
                <ul style="margin: 0; color: #cbd5e1; font-size: 0.9em;">
                    <li><code>viz.embeddings</code> - Full embedding matrix</li>
                    <li><code>viz.tokens</code> - Token list</li>
                    <li><code>viz.reduced_embeddings</code> - 3D coordinates</li>
                </ul>
            </div>
        </div>
        
        <div style="margin-top: 20px; padding: 16px; background: rgba(16, 185, 129, 0.1); border: 1px solid rgba(16, 185, 129, 0.2); border-radius: 8px;">
            <h4 style="margin: 0 0 12px 0; color: #10b981;">üìã Custom Model Examples</h4>
            <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 16px;">
                <div>
                    <p style="margin: 8px 0; color: #cbd5e1; font-size: 0.9em;">
                        <strong>ü§ó Hugging Face models:</strong><br>
                        <code>microsoft/DialoGPT-medium</code><br>
                        <code>EleutherAI/gpt-neo-125M</code><br>
                        <code>openai-gpt</code>
                    </p>
                </div>
                <div>
                    <p style="margin: 8px 0; color: #cbd5e1; font-size: 0.9em;">
                        <strong>üìÅ File formats supported:</strong><br>
                        <code>.npy</code> - NumPy arrays<br>
                        <code>.npz</code> - Compressed NumPy<br>
                        <code>.pt</code> - PyTorch tensors<br>
                        <code>.csv</code> - CSV matrices
                    </p>
                </div>
            </div>
        </div>
    </div>
</div>
"""))

print("üöÄ Vector Embedding Visualizer loaded successfully!")
print("\nüìã Quick Start Instructions:")
print("1. üì¶ Select a model and click 'Load Model'")
print("2. üé® Choose visualization mode (Scatter or Vector)")
print("3. ‚öôÔ∏è  Adjust parameters as needed")
print("4. üéØ Click 'Create Visualization' to generate the plot")
print("5. üñ±Ô∏è  Click on any point to inspect token details")
print("6. üîç Use search to find specific tokens")
print("7. üíæ Export metadata for further analysis")
print("\n‚ú® Advanced features available via code cells above!")
print("üìö Documentation and examples included in the interface.")

Dropdown(description='Model:', options=(('GPT-2 (124M)', 'gpt2'), ('GPT-2 Medium (355M)', 'gpt2-medium'), ('GP‚Ä¶

VBox(children=(HTML(value="<p style='color: #94a3b8; font-size: 0.9em; margin: 10px 0;'>Load custom model from‚Ä¶

HBox(children=(Button(button_style='primary', description='Load Model', style=ButtonStyle()),))

Output()

ToggleButtons(description='Visualization:', options=(('Scatter Plot', False), ('Vector Plot', True)), tooltips‚Ä¶

HBox(children=(IntSlider(value=3000, description='Top N tokens:', max=10000, min=500, step=500), IntSlider(val‚Ä¶

HBox(children=(Dropdown(description='Distance metric:', options=('cosine', 'euclidean'), value='cosine'), Drop‚Ä¶

HBox(children=(Button(button_style='success', description='Create Visualization', style=ButtonStyle()), Button‚Ä¶

Text(value='', description='Search:', placeholder='Search tokens...', style=TextStyle(description_width='initi‚Ä¶

Output()

HTML(value='')

Output()

Output()

üöÄ Vector Embedding Visualizer loaded successfully!

üìã Quick Start Instructions:
1. üì¶ Select a model and click 'Load Model'
2. üé® Choose visualization mode (Scatter or Vector)
3. ‚öôÔ∏è  Adjust parameters as needed
4. üéØ Click 'Create Visualization' to generate the plot
5. üñ±Ô∏è  Click on any point to inspect token details
6. üîç Use search to find specific tokens
7. üíæ Export metadata for further analysis

‚ú® Advanced features available via code cells above!
üìö Documentation and examples included in the interface.
