In [2]:
# %% [markdown]
# # Fine-tuning all-MiniLM-L6-v2 for Android App Semantic Similarity

# %%
# Install required packages
!pip install -q google-play-scraper sentence-transformers transformers torch pandas matplotlib seaborn scikit-learn tqdm

# %%
import logging
import os
import random
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
from pathlib import Path
from typing import List, Dict, Tuple
from google_play_scraper import app, Sort, reviews_all, search
from sentence_transformers import SentenceTransformer, InputExample, losses, models, evaluation
from torch.utils.data import DataLoader
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import LabelEncoder

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Set seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Configuration
BASE_DIR = Path(".").resolve()
DATA_DIR = BASE_DIR / "data"
MODEL_DIR = BASE_DIR / "model"
PLOTS_DIR = BASE_DIR / "plots"

# Create directories
DATA_DIR.mkdir(exist_ok=True)
MODEL_DIR.mkdir(exist_ok=True)
PLOTS_DIR.mkdir(exist_ok=True)

# Constants
CATEGORIES = ['GAME_ACTION', 'TRAVEL_AND_LOCAL', 'SHOPPING', 'SOCIAL', 'EDUCATION']  # Example categories
NUM_APPS_PER_CATEGORY = 50  # Reduced for demonstration
TRIPLET_BATCH_SIZE = 16
NUM_EPOCHS = 3
MODEL_NAME = '../model/all-MiniLM-L6-v2'

# %% [markdown]
# ## Dataset Collection

# %%
def scrape_google_play_apps(categories: List[str], num_apps: int) -> pd.DataFrame:
    """Scrape app data from Google Play Store using search."""
    apps = []
    
    for category in tqdm(categories, desc="Scraping categories"):
        try:
            # Search for apps in category
            results = search(
                query=category,
                lang='en',
                country='us',
                n_hits=num_apps,
                # sort=Sort.MOST_RELEVANT
            )
            
            # Fetch details for each app
            for result in tqdm(results, desc=f"Processing {category}", leave=False):
                try:
                    app_details = app(
                        app_id=result['appId'],
                        lang='en',
                        country='us'
                    )
                    apps.append({
                        'app_name': app_details.get('title', ''),
                        'category': category,
                        'description': app_details.get('summary', ''),
                        'app_id': result['appId']
                    })
                except Exception as e:
                    logger.error(f"Error fetching app {result['appId']}: {str(e)}")
                    continue
        except Exception as e:
            logger.error(f"Error processing category {category}: {str(e)}")
            continue
    
    df = pd.DataFrame(apps)
    # Keep previous cleaning and filtering code
    
    return df

# %%
# Scrape or load cached data
dataset_path = DATA_DIR / "google_play_apps.csv"
if not dataset_path.exists():
    logger.info("Starting data scraping...")
    df_apps = scrape_google_play_apps(CATEGORIES, NUM_APPS_PER_CATEGORY)
    df_apps.to_csv(dataset_path, index=False)
else:
    logger.info("Loading cached dataset...")
    df_apps = pd.read_csv(dataset_path)

logger.info(f"Total apps collected: {len(df_apps)}")
logger.info(f"Categories distribution:\n{df_apps['category'].value_counts()}")

# %% [markdown]
# ## Triplet Dataset Creation

# %%
def generate_triplets(df: pd.DataFrame) -> List[InputExample]:
    """Generate triplets for contrastive learning."""
    triplets = []
    grouped = df.groupby('category')
    category_apps = {cat: grp['app_name'].tolist() for cat, grp in grouped}
    
    # Generate triplets for each category
    for category, apps in tqdm(category_apps.items(), desc="Generating triplets"):
        other_categories = [c for c in category_apps.keys() if c != category]
        
        for anchor in apps:
            # Find positive sample (same category)
            positives = [app for app in apps if app != anchor]
            if not positives:
                continue
            
            # Find negative sample (different category)
            negative_category = np.random.choice(other_categories)
            negatives = category_apps[negative_category]
            if not negatives:
                continue
            
            # Create triplet
            positive = np.random.choice(positives)
            negative = np.random.choice(negatives)
            triplets.append(InputExample(
                texts=[anchor, positive, negative]
            ))
    
    return triplets

# %%
# Generate or load triplets
triplet_path = DATA_DIR / "triplets.pt"
if not triplet_path.exists():
    logger.info("Generating triplets...")
    triplets = generate_triplets(df_apps)
    torch.save(triplets, triplet_path)
else:
    logger.info("Loading cached triplets...")
    triplets = torch.load(triplet_path)

logger.info(f"Total triplets generated: {len(triplets)}")

# Split train/test
train_size = int(0.9 * len(triplets))
train_triplets = triplets[:train_size]
test_triplets = triplets[train_size:]

# %% [markdown]
# ## Model Setup

# %%
# Initialize model
model = SentenceTransformer(MODEL_NAME)
original_model = SentenceTransformer(MODEL_NAME)  # Keep original for comparison

# Create dataloader
train_dataloader = DataLoader(train_triplets, shuffle=True, batch_size=TRIPLET_BATCH_SIZE)
train_loss = losses.TripletLoss(model=model)

# %% [markdown]
# ## Fine-tuning

# %%
# Training configuration
warmup_steps = len(train_dataloader) // 10  # 10% of total steps

logger.info("Starting training...")
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=NUM_EPOCHS,
    warmup_steps=warmup_steps,
    output_path=str(MODEL_DIR / "fine-tuned-model"),
    show_progress_bar=True
)
logger.info("Training completed")

# %% [markdown]
# ## Evaluation

# %%
def evaluate_model(model: SentenceTransformer, test_triplets: List[InputExample]) -> Dict:
    """Evaluate model performance on test triplets."""
    similarities = {'anchor_positive': [], 'anchor_negative': []}
    
    for example in tqdm(test_triplets, desc="Evaluating"):
        anchor = model.encode(example.texts[0], convert_to_tensor=True)
        positive = model.encode(example.texts[1], convert_to_tensor=True)
        negative = model.encode(example.texts[2], convert_to_tensor=True)
        
        similarities['anchor_positive'].append(
            cosine_similarity(anchor.unsqueeze(0), positive.unsqueeze(0)).item()
        )
        similarities['anchor_negative'].append(
            cosine_similarity(anchor.unsqueeze(0), negative.unsqueeze(0)).item()
        )
    
    return {
        'avg_positive_sim': np.mean(similarities['anchor_positive']),
        'avg_negative_sim': np.mean(similarities['anchor_negative']),
        'margin': np.mean(similarities['anchor_positive']) - np.mean(similarities['anchor_negative'])
    }

# %%
# Evaluate both models
logger.info("Evaluating original model...")
original_results = evaluate_model(original_model, test_triplets)
logger.info("Original model results:")
logger.info(f"Avg positive similarity: {original_results['avg_positive_sim']:.4f}")
logger.info(f"Avg negative similarity: {original_results['avg_negative_sim']:.4f}")
logger.info(f"Margin: {original_results['margin']:.4f}")

logger.info("Evaluating fine-tuned model...")
fine_tuned_results = evaluate_model(model, test_triplets)
logger.info("Fine-tuned model results:")
logger.info(f"Avg positive similarity: {fine_tuned_results['avg_positive_sim']:.4f}")
logger.info(f"Avg negative similarity: {fine_tuned_results['avg_negative_sim']:.4f}")
logger.info(f"Margin: {fine_tuned_results['margin']:.4f}")

# %% [markdown]
# ## Visualization

# %%
def plot_embeddings(model: SentenceTransformer, df: pd.DataFrame, title: str, filename: str):
    """Visualize embeddings using PCA."""
    sample_df = df.groupby('category').sample(n=10, random_state=SEED)  # Sample 10 per category
    texts = sample_df['app_name'].tolist()
    categories = sample_df['category'].tolist()
    
    # Encode texts
    embeddings = model.encode(texts, show_progress_bar=True)
    
    # Reduce dimensionality
    pca = PCA(n_components=2)
    embeddings_2d = pca.fit_transform(embeddings)
    
    # Create plot
    plt.figure(figsize=(10, 8))
    sns.scatterplot(
        x=embeddings_2d[:, 0],
        y=embeddings_2d[:, 1],
        hue=categories,
        palette="tab10",
        alpha=0.8,
        s=100
    )
    plt.title(title)
    plt.savefig(PLOTS_DIR / filename, bbox_inches='tight')
    plt.close()

# %%
# Generate visualizations
logger.info("Generating visualization for original model...")
plot_embeddings(
    original_model,
    df_apps,
    "Original Model Embeddings (PCA)",
    "original_embeddings.png"
)

logger.info("Generating visualization for fine-tuned model...")
plot_embeddings(
    model,
    df_apps,
    "Fine-tuned Model Embeddings (PCA)",
    "fine_tuned_embeddings.png"
)

# %% [markdown]
# ## Save Artifacts

# %%
# Save final model
model.save(str(MODEL_DIR / "final-model"))

# Save evaluation results
results_df = pd.DataFrame([original_results, fine_tuned_results], index=['original', 'fine-tuned'])
results_df.to_csv(DATA_DIR / "evaluation_results.csv")

logger.info("All artifacts saved successfully")

2025-04-13 00:57:28,083 - INFO - Starting data scraping...


Scraping categories:   0%|          | 0/5 [00:00<?, ?it/s]

Processing GAME_ACTION:   0%|          | 0/30 [00:00<?, ?it/s]

Processing TRAVEL_AND_LOCAL:   0%|          | 0/30 [00:00<?, ?it/s]

Processing SHOPPING:   0%|          | 0/24 [00:00<?, ?it/s]

Processing SOCIAL:   0%|          | 0/26 [00:00<?, ?it/s]

Processing EDUCATION:   0%|          | 0/28 [00:00<?, ?it/s]

2025-04-13 00:59:21,166 - INFO - Total apps collected: 138
2025-04-13 00:59:21,169 - INFO - Categories distribution:
category
GAME_ACTION         30
TRAVEL_AND_LOCAL    30
EDUCATION           28
SOCIAL              26
SHOPPING            24
Name: count, dtype: int64
2025-04-13 00:59:21,170 - INFO - Generating triplets...


Generating triplets:   0%|          | 0/5 [00:00<?, ?it/s]

2025-04-13 00:59:21,201 - INFO - Total triplets generated: 138
2025-04-13 00:59:21,209 - INFO - Use pytorch device_name: cpu
2025-04-13 00:59:21,210 - INFO - Load pretrained SentenceTransformer: ../model/all-MiniLM-L6-v2
2025-04-13 00:59:21,537 - INFO - Use pytorch device_name: cpu
2025-04-13 00:59:21,538 - INFO - Load pretrained SentenceTransformer: ../model/all-MiniLM-L6-v2
2025-04-13 00:59:21,617 - INFO - Starting training...


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss


2025-04-13 00:59:36,628 - INFO - Save model to /Users/hissain/git/github/AndroidSemanticSearch/python/exp/model/fine-tuned-model
2025-04-13 00:59:37,307 - INFO - Training completed
2025-04-13 00:59:37,309 - INFO - Evaluating original model...


Evaluating:   0%|          | 0/14 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2025-04-13 00:59:39,583 - INFO - Original model results:
2025-04-13 00:59:39,584 - INFO - Avg positive similarity: 0.3849
2025-04-13 00:59:39,585 - INFO - Avg negative similarity: 0.1158
2025-04-13 00:59:39,586 - INFO - Margin: 0.2692
2025-04-13 00:59:39,587 - INFO - Evaluating fine-tuned model...


Evaluating:   0%|          | 0/14 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2025-04-13 00:59:41,476 - INFO - Fine-tuned model results:
2025-04-13 00:59:41,477 - INFO - Avg positive similarity: 0.4639
2025-04-13 00:59:41,479 - INFO - Avg negative similarity: 0.0761
2025-04-13 00:59:41,480 - INFO - Margin: 0.3878
2025-04-13 00:59:41,481 - INFO - Generating visualization for original model...


Batches:   0%|          | 0/2 [00:00<?, ?it/s]

2025-04-13 00:59:43,243 - INFO - Generating visualization for fine-tuned model...


Batches:   0%|          | 0/2 [00:00<?, ?it/s]

2025-04-13 00:59:43,707 - INFO - Save model to /Users/hissain/git/github/AndroidSemanticSearch/python/exp/model/final-model
2025-04-13 00:59:43,981 - INFO - All artifacts saved successfully
