In [None]:
from openai import OpenAI
import random
import json
import pandas as pd
import importlib

import constants # Required for the following line if kernel is restarted
importlib.reload(constants) # Else the old key value is retained
from constants import openai_api

client = OpenAI(api_key=openai_api)

In [None]:
from typing import List, Dict
import random
import pandas as pd
import torch
from transformers import BertTokenizer, BertModel
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

class DummyQueryGenerator:
    def __init__(self, dataset_name="google_trends"):
        self.categories = pd.read_csv(f"data/{dataset_name}.csv")
        self.categories['embedding'] = self.categories['embedding'].apply(lambda x: np.fromstring(x.strip('[]'), dtype=float, sep=' '))
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.model = BertModel.from_pretrained('bert-base-uncased')
        self.similarities = None
        
        self.style_prompt = ""

    def get_embedding(self, query: str):
        inputs = self.tokenizer(query, return_tensors='pt', truncation=True, padding=True)
        outputs = self.model(**inputs)
        return outputs.last_hidden_state.mean(dim=1).squeeze().tolist()

    def analyze_query_style(self, query: str) -> Dict:
        style_features = {
            'length': len(query.split()),
            'has_question_mark': '?' in query,
            'starts_with_question_word': any(query.lower().startswith(w) for w in ['how', 'what', 'where', 'when', 'why', 'who']),
            'capitalization': query[0].isupper() if query else False,
            'lowercase_ratio': sum(1 for c in query if c.islower()) / len(query) if query else 0
        }
        return style_features

    def identify_query_category(self, query: str) -> str:
        query_embedding = self.get_embedding(query)

        similarities = self.categories['embedding'].apply(lambda x: cosine_similarity(np.array(query_embedding).reshape(1, -1), x.reshape(1, -1))[0][0])
        self.similarities = similarities
        most_similar_category = self.categories.iloc[self.similarities.idxmax()]
        return most_similar_category['category']        

    def get_distant_categories(self, num_categories: int = 20) -> List[str]:
        embeddings = np.array(self.categories['embedding'].tolist())
        
        from sklearn.cluster import KMeans
        n_clusters = min(num_categories, len(self.categories))
        kmeans = KMeans(n_clusters=n_clusters, random_state=42)
        cluster_labels = kmeans.fit_predict(embeddings)
        
        query_cluster_distances = kmeans.transform(embeddings[self.similarities.idxmax()].reshape(1, -1))[0]
        
        selected_categories = []
        cluster_assignments = {i: [] for i in range(n_clusters)}
        
        for idx, cluster in enumerate(cluster_labels):
            cluster_assignments[cluster].append(idx)
        
        samples_per_cluster = num_categories // n_clusters
        remaining_samples = num_categories % n_clusters
        
        for cluster_idx in range(n_clusters):
            cluster_indices = cluster_assignments[cluster_idx]
            
            n_samples = samples_per_cluster + (1 if cluster_idx < remaining_samples else 0)
            
            if cluster_indices:
                sampled_indices = random.sample(cluster_indices, min(n_samples, len(cluster_indices)))
                selected_categories.extend(self.categories.iloc[sampled_indices]['category'].tolist())
        
        random.shuffle(selected_categories)
        
        return selected_categories
    
    def visualize_category_distribution(self, original_query: str, selected_categories: List[str]):
        selected_indices = self.categories[self.categories['category'].isin(selected_categories)].index
        all_embeddings = np.array(self.categories['embedding'].tolist())
        all_categories = self.categories['category'].tolist()
        
        query_embedding = np.array(self.get_embedding(original_query)).reshape(1, -1)
        
        combined_embeddings = np.vstack([all_embeddings, query_embedding])
        
        tsne = TSNE(n_components=2, random_state=42)
        embeddings_2d = tsne.fit_transform(combined_embeddings)
        
        categories_2d = embeddings_2d[:-1]
        query_2d = embeddings_2d[-1]
        
        plt.figure(figsize=(15, 10))
        
        plt.scatter(categories_2d[:, 0], categories_2d[:, 1], 
                alpha=0.1, color='gray', label='All Categories')
        
        plt.scatter(categories_2d[selected_indices, 0], categories_2d[selected_indices, 1], 
                color='blue', alpha=0.6, label='Selected Categories')
        
        for idx in selected_indices:
            plt.annotate(all_categories[idx], 
                        (categories_2d[idx, 0], categories_2d[idx, 1]),
                        xytext=(5, 5), textcoords='offset points',
                        fontsize=8, alpha=0.8)
        
        plt.scatter(query_2d[0], query_2d[1], 
                color='red', s=200, marker='*', label='Original Query')
        plt.annotate('Original Query', 
                    (query_2d[0], query_2d[1]),
                    xytext=(10, 10), textcoords='offset points',
                    fontsize=10, color='red', fontweight='bold')
        
        plt.legend(fontsize=10)
        plt.title('Distribution of Categories and Original Query in Embedding Space', 
                fontsize=12, pad=20)
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()


    def visualize_query_trajectories(self, original_query: str, dummy_queries: List[dict], follow_up_queries: List[dict]):
        from sklearn.manifold import TSNE
        import matplotlib.pyplot as plt
        
        all_embeddings = np.array(self.categories['embedding'].tolist())
        all_categories = self.categories['category'].tolist()
        
        original_embedding = np.array(self.get_embedding(original_query)).reshape(1, -1)
        dummy_embeddings = np.vstack([self.get_embedding(q['query']) for q in dummy_queries])
        followup_embeddings = np.vstack([self.get_embedding(q['query']) for q in follow_up_queries])
        
        combined_embeddings = np.vstack([
            all_embeddings,
            original_embedding,
            dummy_embeddings,
            followup_embeddings
        ])
        
        tsne = TSNE(n_components=2, random_state=42)
        embeddings_2d = tsne.fit_transform(combined_embeddings)
        
        categories_2d = embeddings_2d[:len(all_embeddings)]
        original_2d = embeddings_2d[len(all_embeddings)]
        dummy_2d = embeddings_2d[len(all_embeddings)+1:len(all_embeddings)+1+len(dummy_queries)]
        followup_2d = embeddings_2d[-len(follow_up_queries):]
        
        plt.figure(figsize=(15, 10))
        
        plt.scatter(categories_2d[:, 0], categories_2d[:, 1], 
                alpha=0.1, color='gray', label='Categories')
        
        plt.scatter(original_2d[0], original_2d[1], 
                color='red', s=200, marker='*', label='Original Query')
        
        for i in range(len(dummy_queries)):
            plt.scatter(dummy_2d[i, 0], dummy_2d[i, 1], 
                    color='blue', alpha=0.6, label='Dummy Query' if i == 0 else "")
            
            plt.scatter(followup_2d[i, 0], followup_2d[i, 1], 
                    color='green', alpha=0.6, label='Follow-up Query' if i == 0 else "")
            
            plt.arrow(dummy_2d[i, 0], dummy_2d[i, 1],
                    followup_2d[i, 0] - dummy_2d[i, 0],
                    followup_2d[i, 1] - dummy_2d[i, 1],
                    head_width=0.3, head_length=0.5, fc='k', ec='k', alpha=0.3)
            
            plt.annotate(f"D{i+1}: {dummy_queries[i]['query'][:30]}...",
                        (dummy_2d[i, 0], dummy_2d[i, 1]),
                        xytext=(5, 5), textcoords='offset points',
                        fontsize=8, alpha=0.8)
            plt.annotate(f"F{i+1}: {follow_up_queries[i]['query'][:30]}...",
                        (followup_2d[i, 0], followup_2d[i, 1]),
                        xytext=(5, 5), textcoords='offset points',
                        fontsize=8, alpha=0.8)
        
        plt.title('Query Trajectory Visualization\nShowing Original Query, Dummy Queries, and Follow-ups',
                fontsize=12, pad=20)
        plt.legend(fontsize=10)
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()


    def generate_dummy_queries(self, input_query: str, num_queries: int = 20, show_plots: bool = True, use_new_prompt: bool = False):
        query_style = self.analyze_query_style(input_query)
        
        input_category = self.identify_query_category(input_query)
        
        distant_categories = self.get_distant_categories(num_queries)  # Get enough categories for individual queries
        if show_plots:
            self.visualize_category_distribution(input_query, distant_categories)
        
        style_prompt = f"""
        Generate a query that matches these style characteristics:
        - Similar length (around {query_style['length']} words)
        - {'Use' if query_style['has_question_mark'] else 'Avoid'} question marks
        - {'Start with question words' if query_style['starts_with_question_word'] else 'Use declarative form'}
        - {'Capitalize first letter' if query_style['capitalization'] else 'Use lowercase'}
        """
    
        all_queries = []
        if use_new_prompt:
            print('Input category:' + input_category)
        print(distant_categories)
    
        for category in distant_categories:
            system_message = f"""You are a query generation assistant. Generate a single query that:
            1. Is specifically about {category}
            2. Matches the original query's style
            {style_prompt}
            """
            
            if use_new_prompt:
                system_message += f"""
                Here is an example query based on a category. But yours should not be exactly the same:
                Category: {input_category}
                Query: {input_query}
                """
            
            completion = client.chat.completions.create(
                model="gpt-4o-mini-2024-07-18",
                messages=[
                    {"role": "system", "content": system_message},
                    {"role": "user", "content": """Generate one search query.
                    Format your response as JSON with the following structure:
                    {
                        "queries": [
                            {"query": "query text", "category": "category_name"}
                        ]
                    }"""}
                ],
                response_format={ "type": "json_object" },
                temperature=0.7,
                seed=random.randint(0, 1000000)
            )
            
            query_result = parse_dummy_queries(completion.choices[0].message.content)
            all_queries.extend(query_result)
            
            if len(all_queries) >= num_queries:
                break
    
        return all_queries[:num_queries]


    def generate_consecutive_queries(self, input_query, dummy_queries, old_input_query=None, use_new_prompt: bool = False):
        query_style = self.analyze_query_style(input_query)
        
        style_prompt = f"""
        Generate a query that matches these style characteristics:
        - Similar length (around {query_style['length']} words)
        - {'Use' if query_style['has_question_mark'] else 'Avoid'} question marks
        - {'Start with question words' if query_style['starts_with_question_word'] else 'Use declarative form'}
        - {'Capitalize first letter' if query_style['capitalization'] else 'Use lowercase'}
        """
        all_queries = []
        for dummy_query in dummy_queries:
            system_message = f"""You are a query generation assistant. Generate a single query that:
            1. Will act as a follow-up to the query: {dummy_query['query']} 
            2. Matches the following style characteristics:
            {style_prompt}
            """
            
            if use_new_prompt and old_input_query is not None:
                system_message += f"""
                Here is an example follow-up query based on a query. But yours should not be exactly the same:
                Old Query: {old_input_query}
                Follow-up query: {input_query}
                """
            
            follow_up_query = client.chat.completions.create(
                model="gpt-4o-mini-2024-07-18",
                messages=[
                    {"role": "system", "content": system_message},
                    {"role": "user", "content": """Generate one search query.
                    Format your response as JSON with the following structure:
                    {
                        "queries": [
                            {"query": "query text"}
                        ]
                    }"""}
                ],
                temperature=0.7,
                seed=random.randint(0, 1000000)
            )

            query_result = parse_dummy_queries(follow_up_query.choices[0].message.content)
            all_queries.extend(query_result)
                
        return all_queries

import re
def parse_dummy_queries(response):
    cleaned_response = re.sub(r'```json|```', '', response)
    cleaned_response = cleaned_response.strip()

    return json.loads(cleaned_response)['queries']


In [None]:
# Usage
generator = DummyQueryGenerator()
dummy_queries = generator.generate_dummy_queries("cancer treatment")
print('Initial dummy queries: ')
for query in dummy_queries:
    print(query)

consecutive_queries = generator.generate_consecutive_queries("cancer symptoms", dummy_queries)
print('Consecutive queries: ')
for query in consecutive_queries:
    print(query)
    
generator.visualize_query_trajectories("cancer treatment", dummy_queries, consecutive_queries)

consecutive_queries_2 = generator.generate_consecutive_queries("chemotherapy side effects", consecutive_queries)
print('Consecutive queries 2: ')
for query in consecutive_queries_2:
    print(query)
    

In [None]:
import csv
import os

def create_output_file(num_queries, original_queries, dataset_name, prompt_changed):
    generator = DummyQueryGenerator(dataset_name=dataset_name)
    
    result = [[] for _ in range(num_queries)]
    for oq in original_queries:
        result[0].append(oq)
    
    dummy_queries = generator.generate_dummy_queries(original_queries[0], num_queries=num_queries - 1, show_plots=False, use_new_prompt=prompt_changed)
    for i, q in enumerate(dummy_queries):
        result[i + 1].append(q["query"])
    
    for i in range(len(original_queries)-1):
        if prompt_changed:
            dummy_queries = generator.generate_consecutive_queries(original_queries[i+1], dummy_queries, original_queries[i], use_new_prompt=prompt_changed)
        else:
            dummy_queries = generator.generate_consecutive_queries(original_queries[i+1], dummy_queries)
        for i, q in enumerate(dummy_queries):
            result[i + 1].append(q["query"])
    
    os.makedirs(os.path.dirname("outputs/"), exist_ok=True)
    prompt_str = "_p2" if prompt_changed else ""
    output_file = f"outputs/{dataset_name}_r{num_queries}_c{len(original_queries)}{prompt_str}.csv"
    with open(output_file, mode='w', newline='') as file:
        writer = csv.writer(file)
    
        writer.writerow([f"Query {i + 1}" for i in range(len(original_queries))])
    
        for row in result:
            writer.writerow(row)
    
    print(f"{output_file} generated.")

def create_all_output_files(original_queries, prompt_changed):
    datasets = ["google_trends", "wiki_categories_50k", "wiki_categories_100k"]
    num_queries = [10, 30, 50]
    num_consec = [1, 3, 5]
    for ds in datasets:
        for nq in num_queries:
            for nc in num_consec:
                try:
                    create_output_file(nq, original_queries[:nc], ds, prompt_changed)
                except Exception as e:
                    print(f"Error creating output file for dataset: {ds}, num_queries: {nq}, num_consec: {nc}. Error: {e}")


dataset_name = "wiki_categories_50k" #"google_trends"
prompt_changed = False
num_queries = 10
original_queries = [
    "best pizza",
    "margherita pizza",
    "which cheese for margherita pizza",
    "where to find swiss cheese",
    "why does swiss cheese has holes"
]

In [None]:
create_output_file(num_queries, original_queries, dataset_name, prompt_changed)

In [None]:
create_all_output_files(original_queries, prompt_changed)