In [None]:
## Generate Data ##

import pandas as pd
import re
from collections import defaultdict
import ollama
import logging
from openai import OpenAI
from datetime import datetime

# Initialize logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Initialize OpenAI client
client = OpenAI(api_key="XXXX")  # Replace with your OpenAI API key

# Constants and variables
NUM_ITERATIONS = 100
NUM_ENTITIES = 5  # Number of entities to request in each query
NUM_ENTITIES_IN_CHART = 20  # Number of entities to display in the chart

# Model names
MODEL_NAME_LLAMA = 'llama3'
MODEL_NAME_GPT = 'GPT-4'
MODEL_NAME_GPT3_5 = 'GPT-3.5'
MODEL_NAME_GPT4O_MINI = 'gpt-4o-mini'

# Enable or disable models
USE_LLAMA = True
USE_GPT = True
USE_GPT3_5 = True
USE_GPT4O_MINI = True

# Query templates
def select_query_template(choice):
    templates = {
        1: '''
List the top {} companies in the world leading the advancement of generative AI. Respond only with the company name as a single lower case word. Return no other text before or after the words.
''',
        2: '''
List the top {} best electric car manufacturers in the world. Respond only with the manufacturer name as a single lower case word. If the manufacturer name is two or more words, combine them into a single word. Return no other text before or after the words.
''',
        3: '''
List the top {} qualities that make a good electric car. Respond only with each factor as a single noun in lower case letters. Return no other text before or after the noun.
'''
    }
    return templates.get(choice, templates[1]) 

# Choose the query template
CHOICE = 3  # Change this to 2 or 3 based on the desired template
query_template = select_query_template(CHOICE)

# Function to normalize and validate results
def normalize_and_validate_keywords(keywords):
    normalized_keywords = []
    pattern = re.compile(r'^[a-z]+$')
    for keyword in keywords:
        keyword = keyword.strip().lower()
        if pattern.match(keyword):
            normalized_keywords.append(keyword)
        else:
            logging.warning(f"Invalid format detected: {keyword}")
    return normalized_keywords

# Function to process query for Llama3
def process_query_llama(query, iterations, num_entities, max_retries=20):
    all_keywords = []
    i = 0
    while i < iterations:
        for attempt in range(max_retries):
            response = ollama.chat(model='llama3', messages=[{'role': 'user', 'content': query}])
            keywords = normalize_and_validate_keywords(re.findall(r'\b[a-z]+\b', response['message']['content']))
            if len(keywords) >= num_entities:
                all_keywords.append(keywords[:num_entities])
                i += 1
                break
            logging.warning(f"Llama3 attempt {attempt + 1} for iteration {i + 1}: Invalid result, retrying...")
            if attempt == max_retries - 1:
                logging.error(f"Llama3 failed to return a valid result after {max_retries} attempts for iteration {i + 1}. Retrying the iteration.")
    return all_keywords

# Function to query GPT models
def query_gpt(question, model, temperature=1.0):
    completion = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": question}
        ],
        temperature=temperature
    )
    response_content = completion.choices[0].message.content.strip()
    return response_content

# Function to process query for GPT models
def process_query_gpt(query, iterations, num_entities, model, max_retries=20):
    all_keywords = []
    i = 0
    while i < iterations:
        for attempt in range(max_retries):
            response_content = query_gpt(query, model)
            keywords = normalize_and_validate_keywords(re.findall(r'\b[a-z]+\b', response_content))
            if len(keywords) >= num_entities:
                all_keywords.append(keywords[:num_entities])
                i += 1
                break
            logging.warning(f"{model} attempt {attempt + 1} for iteration {i + 1}: Invalid result, retrying...")
            if attempt == max_retries - 1:
                logging.error(f"{model} failed to return a valid result after {max_retries} attempts for iteration {i + 1}. Retrying the iteration.")
    return all_keywords

# Process queries
if USE_LLAMA:
    all_results_llama = process_query_llama(query_template.format(NUM_ENTITIES), NUM_ITERATIONS, NUM_ENTITIES)
else:
    all_results_llama = []

if USE_GPT:
    all_results_gpt = process_query_gpt(query_template.format(NUM_ENTITIES), NUM_ITERATIONS, NUM_ENTITIES, "gpt-4")
else:
    all_results_gpt = []

if USE_GPT3_5:
    all_results_gpt3_5 = process_query_gpt(query_template.format(NUM_ENTITIES), NUM_ITERATIONS, NUM_ENTITIES, "gpt-3.5-turbo")
else:
    all_results_gpt3_5 = []

if USE_GPT4O_MINI:
    all_results_gpt4o_mini = process_query_gpt(query_template.format(NUM_ENTITIES), NUM_ITERATIONS, NUM_ENTITIES, "gpt-4o-mini")
else:
    all_results_gpt4o_mini = []

# Function to process results and prepare data for plotting
def process_results(all_results, model_name, num_iterations, num_entities_in_chart):
    combined_keywords_count = defaultdict(int)
    flattened_keywords = [word for sublist in all_results for word in sublist]

    company_mapping = {
        'google': 'google + deepmind',
        'deepmind': 'google + deepmind',
        'facebook': 'meta + facebook',
        'meta': 'meta + facebook',
        'microsoft': 'microsoft + openai',
        'openai': 'microsoft + openai'
    }
    combined_flattened_keywords = [company_mapping.get(word, word) for word in flattened_keywords]
    
    # Count occurrences of each keyword
    for word in combined_flattened_keywords:
        combined_keywords_count[word] += 1
    # Sort keywords by count in descending order
    sorted_words = sorted(combined_keywords_count.items(), key=lambda item: item[1], reverse=True)[:num_entities_in_chart]
    # Create DataFrame with the count and percentage of each keyword
    data = [{'Word': word, 'Count': count, 'Percentage': (count / num_iterations) * 100, 'Model': model_name} for word, count in sorted_words]
    df = pd.DataFrame(data)
    return df

# Process results for all models
df_llama = process_results(all_results_llama, MODEL_NAME_LLAMA, NUM_ITERATIONS, NUM_ENTITIES_IN_CHART) if USE_LLAMA else pd.DataFrame()
df_gpt = process_results(all_results_gpt, MODEL_NAME_GPT, NUM_ITERATIONS, NUM_ENTITIES_IN_CHART) if USE_GPT else pd.DataFrame()
df_gpt3_5 = process_results(all_results_gpt3_5, MODEL_NAME_GPT3_5, NUM_ITERATIONS, NUM_ENTITIES_IN_CHART) if USE_GPT3_5 else pd.DataFrame()
df_gpt4o_mini = process_results(all_results_gpt4o_mini, MODEL_NAME_GPT4O_MINI, NUM_ITERATIONS, NUM_ENTITIES_IN_CHART) if USE_GPT4O_MINI else pd.DataFrame()

# Save data to CSV files for later analysis
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
df_llama.to_csv(f"llama3_results_q{CHOICE}_{timestamp}.csv", index=False)
df_gpt.to_csv(f"gpt4_results_q{CHOICE}_{timestamp}.csv", index=False)
df_gpt3_5.to_csv(f"gpt3_5_results_q{CHOICE}_{timestamp}.csv", index=False)
df_gpt4o_mini.to_csv(f"gpt4o_mini_results_q{CHOICE}_{timestamp}.csv", index=False)

# Function to create the data for visualization
def create_visualization_data(df):
    combined_keywords_count = defaultdict(int)
    for index, row in df.iterrows():
        combined_keywords_count[row['Word']] += int(row['Count'])
    # Sort and include all the keywords
    sorted_keywords = sorted(combined_keywords_count.items(), key=lambda x: x[1], reverse=True)
    return sorted_keywords

# Create data for visualizations
visualization_data_llama = create_visualization_data(df_llama)
visualization_data_gpt = create_visualization_data(df_gpt)
visualization_data_gpt3_5 = create_visualization_data(df_gpt3_5)
visualization_data_gpt4o_mini = create_visualization_data(df_gpt4o_mini)

# Convert data to JavaScript format (if needed for visualization)
js_var_data_llama = "var data_llama = " + str([{'product': k, 'count': v} for k, v in visualization_data_llama]).replace("'", '"') + ";"
js_var_data_gpt = "var data_gpt = " + str([{'product': k, 'count': v} for k, v in visualization_data_gpt]).replace("'", '"') + ";"
js_var_data_gpt3_5 = "var data_gpt3_5 = " + str([{'product': k, 'count': v} for k, v in visualization_data_gpt3_5]).replace("'", '"') + ";"
js_var_data_gpt4o_mini = "var data_gpt4o_mini = " + str([{'product': k, 'count': v} for k, v in visualization_data_gpt4o_mini]).replace("'", '"') + ";"


In [9]:
## Visualise Results ##

import json
from datetime import datetime

html_content = f"""
<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8">
  <title>Waffle Charts</title>
  <style>
    body {{
      font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif;
      color: #333;
      background-color: #e0e0e0; /* Light background for the whole page */
      padding-left: 20px; /* Align the body content to the left */
    }}
    .chart-title {{
      font-size: 22px;
      font-weight: bold;
      margin-bottom: 5px;
    }}
    .chart-subtitle {{
      font-size: 16px;
      color: #666;
      margin-top: 0;
      margin-bottom: 20px;
    }}
    .charts-and-legend {{
      display: flex;
      align-items: flex-start;
      background-color: #fff;
      border: 1px solid #ccc;
      padding: 20px;
      border-radius: 8px;
      box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
      margin-bottom: 20px;
      max-width: 1200px; /* Increased width to allow more space */
    }}
    .charts {{
      flex: 0 0 850px; /* Increased width for the charts */
    }}
    .waffle-chart-container {{
      display: flex;
      align-items: center;
      margin-bottom: 15px;
    }}
    .waffle-chart {{
      shape-rendering: crispEdges;
      margin-left: 10px;
    }}
    .square {{
      stroke: #fff;
    }}
    .chart-label {{
      font-size: 16px;
      font-weight: bold;
      width: 80px;
      margin-left: 15px;
      margin-right: 5px;
    }}
    .legend {{
      width: 180px; /* Increased width for longer labels */
      margin-left: 40px; /* Increased margin to space out the legend */
      display: flex;
      flex-direction: column;
    }}
    .legend-item {{
      display: flex;
      align-items: center;
      margin-bottom: 8px;
      white-space: nowrap; /* Prevent labels from wrapping to multiple lines */
    }}
    .legend-color {{
      width: 15px;
      height: 15px;
      margin-right: 5px;
      border: 1px solid #ccc; /* Border around color squares */
    }}
    .chart-footer {{
      font-size: 12px;
      color: #999;
      margin-top: 10px;
      text-align: left; /* Align footer to the left */
    }}
    .axis-label {{
      text-align: center;
      font-size: 14px;
      margin-top: 5px;
      margin-left: 15px;
    }}
  </style>
</head>
<body>
  <main>
    <header>
      <h1 class="chart-title">
        "List top 5 factors that make a good electric car."
      </h1>
      <p class="chart-subtitle">
        Model Temperature: Defaults<br>
        Date Run: {datetime.now().strftime('%d.%m.%Y')}
      </p>
    </header>

    <section class="charts-and-legend">
      <div class="charts">
        <div class="waffle-chart-container">
          <div class="chart-label">llama3</div>
          <svg width="850" height="150" id="chart_llama" class="waffle-chart"></svg>
        </div>
        <div class="waffle-chart-container">
          <div class="chart-label">gpt-3.5</div>
          <svg width="850" height="150" id="chart_gpt3_5" class="waffle-chart"></svg>
        </div>
        <div class="waffle-chart-container">
          <div class="chart-label">gpt-4</div>
          <svg width="850" height="150" id="chart_gpt" class="waffle-chart"></svg>
        </div>
        <div class="waffle-chart-container">
          <div class="chart-label">gpt-4o-mini</div>
          <svg width="850" height="150" id="chart_gpt4o_mini" class="waffle-chart"></svg>
        </div>
        <!-- Axis Label -->
        <div class="axis-label">Number of iterations where factor was present</div>
      </div>
      <!-- Legend on the right side -->
      <div class="legend" id="chart_legend"></div>
    </section>

    <footer class="chart-footer">
      Data Source: Your Data Source Here<br>
      Note: Additional notes if necessary.
    </footer>
  </main>

  <script src="https://d3js.org/d3.v6.min.js"></script>
  <script>
    {js_var_data_llama}
    {js_var_data_gpt3_5}
    {js_var_data_gpt}
    {js_var_data_gpt4o_mini}

    // Function to process data and group items into "other"
    function processData(data) {{
      const filteredData = [];
      let otherCount = 0;
      data.forEach(d => {{
        if (d.count >= 15) {{
          filteredData.push(d);
        }} else {{
          otherCount += d.count;
        }}
      }});
      if (otherCount > 0) {{
        filteredData.push({{ product: "other", count: otherCount }});
      }}
      return filteredData;
    }}

    // Process data for each chart
    const processedData_llama = processData(data_llama);
    const processedData_gpt3_5 = processData(data_gpt3_5);
    const processedData_gpt = processData(data_gpt);
    const processedData_gpt4o_mini = processData(data_gpt4o_mini);

    // Collect products from filteredData arrays
    const allProductsSet = new Set();
    [processedData_llama, processedData_gpt3_5, processedData_gpt, processedData_gpt4o_mini].forEach(filteredData => {{
      filteredData.forEach(d => {{
        allProductsSet.add(d.product);
      }});
    }});
    const allProducts = Array.from(allProductsSet).sort();

    // Define the color palette without brown and ensure "other" is light grey
    const professionalColors = [
      "#1f77b4", // Blue
      "#ff7f0e", // Orange
      "#2ca02c", // Green
      "#d62728", // Red
      "#9467bd", // Purple
      "#e377c2", // Pink
      "#7f7f7f", // Gray
      "#bcbd22", // Olive
      "#17becf", // Teal
      "#ffbb78", // Light Orange
      "#98df8a"  // Light Green
    ];

    // Assign colors to products consistently across charts
    const assignedColors = {{}};
    const lightGreyColor = "#d3d3d3"; // Light grey for "other"

    allProducts.forEach((product, index) => {{
      if (product === "other") {{
        assignedColors[product] = lightGreyColor;
      }} else {{
        assignedColors[product] = professionalColors[index % professionalColors.length];
      }}
    }});

    // Function to assign a color to each product
    function assignColor(product) {{
      return assignedColors[product] || lightGreyColor; // Default to light grey if product not assigned
    }}

    // Function to draw a waffle chart with consistent item order
    const drawWaffleChart = (filteredData, svgId, topMargin) => {{
      const svg = d3.select(svgId);
      const width = +svg.attr("width");
      const height = +svg.attr("height");

      // Create a mapping from product to count
      const productCounts = new Map();
      filteredData.forEach(d => {{
        productCounts.set(d.product, d.count);
      }});

      // Build blocks in the same order for each chart
      const blocks = [];
      allProducts.forEach(product => {{
        const count = productCounts.get(product) || 0;
        const color = assignColor(product);
        for (let i = 0; i < count; i++) {{
          blocks.push({{ product, color }});
        }}
      }});

      const maxBlocksY = 8;
      const blockSize = 10;
      const blockPadding = 2;
      const margin = {{ left: 10, top: topMargin }};

      svg.selectAll(".square")
        .data(blocks)
        .enter()
        .append("rect")
        .attr("class", "square")
        .attr("width", blockSize)
        .attr("height", blockSize)
        .attr("x", (d, i) => margin.left + Math.floor(i / maxBlocksY) * (blockSize + blockPadding))
        .attr("y", (d, i) => margin.top + (i % maxBlocksY) * (blockSize + blockPadding))
        .attr("fill", d => d.color);
    }};

    // Function to create a single legend
    function createLegend() {{
      const legendContainer = d3.select("#chart_legend");
      const products = allProducts;

      products.forEach(product => {{
        const legendItem = legendContainer.append("div")
          .attr("class", "legend-item");

        legendItem.append("div")
          .attr("class", "legend-color")
          .style("background-color", assignColor(product));

        legendItem.append("div")
          .text(product);
      }});
    }}

    // Draw the charts
    drawWaffleChart(processedData_llama, "#chart_llama", 10);
    drawWaffleChart(processedData_gpt3_5, "#chart_gpt3_5", 10);
    drawWaffleChart(processedData_gpt, "#chart_gpt", 10);
    drawWaffleChart(processedData_gpt4o_mini, "#chart_gpt4o_mini", 10);

    // Create the single legend
    createLegend();
  </script>
</body>
</html>
"""

# Write the HTML content to a file
with open('waffle_charts.html', 'w') as file:
    file.write(html_content)
