In [3]:
import plotly.express as px
from dash import Dash, html, dcc, Input, Output, callback
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import re
import pycountry
import geopandas as gpd
import json

In [4]:
filename = "World_Wide_Unicorn_Startups.csv"
basic_su_df = pd.read_csv(f"../data/{filename}")
basic_su_df.head()
basic_su_df.rename(columns={"year": "Year"}, inplace=True)
basic_su_df.head()

Unnamed: 0,Company,Valuation,Date,Country,City,Industry,Investors,Year,month,day
0,Bytedance,140.0,4/7/2017,China,Beijing,Artificial intelligence,"0 Sequoia Capital China, SIG Asia Investm...",2017,7,4
1,SpaceX,100.3,12/1/2012,United States,Hawthorne,Other,"0 Sequoia Capital China, SIG Asia Investm...",2012,1,12
2,Stripe,95.0,1/23/2014,United States,San Francisco,Fintech,"0 Sequoia Capital China, SIG Asia Investm...",2014,23,1
3,Klarna,45.6,12/12/2011,Sweden,Stockholm,Fintech,"0 Sequoia Capital China, SIG Asia Investm...",2011,12,12
4,Canva,40.0,1/8/2018,Australia,Surry Hills,Internet software & services,"0 Sequoia Capital China, SIG Asia Investm...",2018,8,1


In [8]:
filename = "Unicorn_Startups_Per_Country.csv"
preprocessed_su_df = pd.read_csv(f"../data/preprocessed/{filename}")
preprocessed_su_df.head()


Unnamed: 0,Country,Year,ISO3,N_Unicorns,Total_Val,Startups,Industries,N_Unicorns_Cumulative,Total_Val_Cumulative,Startups_Cumulative,Industries_Cumulative
0,Argentina,2015,ARG,0.0,0.0,{},[],0.0,0.0,{},[]
1,Argentina,2016,ARG,0.0,0.0,{},[],0.0,0.0,{},[]
2,Argentina,2017,ARG,0.0,0.0,{},[],0.0,0.0,{},[]
3,Argentina,2018,ARG,0.0,0.0,{},[],0.0,0.0,{},[]
4,Argentina,2019,ARG,0.0,0.0,{},[],0.0,0.0,{},[]


In [9]:
filename = "QOL.csv"
preprocessed_qol_df = pd.read_csv(f"../data/preprocessed/{filename}")
preprocessed_qol_df.head()

Unnamed: 0,Country,Quality_of_Life_Index,Purchasing_Power_Index,Cost_of_Living_Index,Property_Price_to_Income_Ratio,Year,GDP_Per_Capita,ISO3
0,Argentina,77.0,59.4,67.1,11.6,2015,14833.19968,ARG
1,Australia,180.8,110.4,99.3,7.1,2015,52009.802759,AUS
2,Austria,182.6,104.6,76.9,9.6,2015,43908.420277,AUT
3,Belgium,136.0,86.2,87.2,6.5,2015,40889.67357,BEL
4,Brazil,29.8,41.2,55.3,16.7,2015,8936.195589,BRA


In [10]:
merged_df = pd.merge(preprocessed_qol_df, preprocessed_su_df, on=["Country", "Year", "ISO3"])
merged_df['N_Unicorns'] = merged_df['N_Unicorns'].apply(lambda x: round(x)).astype(int)
# merged_df = merged_df[merged_df['n_unicorns'] > 0]
# merged_df = merged_df[merged_df['total_val'] > 0]
merged_df.head()

Unnamed: 0,Country,Quality_of_Life_Index,Purchasing_Power_Index,Cost_of_Living_Index,Property_Price_to_Income_Ratio,Year,GDP_Per_Capita,ISO3,N_Unicorns,Total_Val,Startups,Industries,N_Unicorns_Cumulative,Total_Val_Cumulative,Startups_Cumulative,Industries_Cumulative
0,Argentina,77.0,59.4,67.1,11.6,2015,14833.19968,ARG,0,0.0,{},[],0.0,0.0,{},[]
1,Australia,180.8,110.4,99.3,7.1,2015,52009.802759,AUS,0,0.0,{},[],0.0,0.0,{},[]
2,Austria,182.6,104.6,76.9,9.6,2015,43908.420277,AUT,0,0.0,{},[],0.0,0.0,{},[]
3,Belgium,136.0,86.2,87.2,6.5,2015,40889.67357,BEL,0,0.0,{},[],0.0,0.0,{},[]
4,Brazil,29.8,41.2,55.3,16.7,2015,8936.195589,BRA,0,0.0,{},[],0.0,0.0,{},[]


In [11]:
def country_to_iso3(name):
    try: return pycountry.countries.lookup(name).alpha_3
    except: return None

In [12]:
def normalize_country(name):
    name = str(name).strip()
    name = name.replace('U.S.', 'United States').replace('USA', 'United States')
    return name.title()

In [13]:
world = gpd.read_file("../data/ne_110m_admin_0_countries/ne_110m_admin_0_countries.shp").to_crs(4326)
print(world.columns)

Index(['featurecla', 'scalerank', 'LABELRANK', 'SOVEREIGNT', 'SOV_A3',
       'ADM0_DIF', 'LEVEL', 'TYPE', 'TLC', 'ADMIN',
       ...
       'FCLASS_TR', 'FCLASS_ID', 'FCLASS_PL', 'FCLASS_GR', 'FCLASS_IT',
       'FCLASS_NL', 'FCLASS_SE', 'FCLASS_BD', 'FCLASS_UA', 'geometry'],
      dtype='object', length=169)


In [14]:
import pyproj

countries_gdf = world[['ISO_A3','geometry']].rename(columns={'ISO_A3':'ISO3'})

unicorn_agg = (
    merged_df
      .groupby(["ISO3", "Country"], as_index=False)
      .agg(
          n_unicorns = ("ISO3", "size"),      # count
          total_val  = ("Total_Val", "sum")   # sum
      )
)

unicorn_agg_geo = pd.merge(unicorn_agg, countries_gdf, on='ISO3', how='left')
unicorn_agg_gdf = gpd.GeoDataFrame(unicorn_agg_geo, geometry='geometry').to_crs(4326)
unicorn_agg_gdf.info()

<class 'geopandas.geodataframe.GeoDataFrame'>
RangeIndex: 32 entries, 0 to 31
Data columns (total 5 columns):
 #   Column      Non-Null Count  Dtype   
---  ------      --------------  -----   
 0   ISO3        32 non-null     object  
 1   Country     32 non-null     object  
 2   n_unicorns  32 non-null     int64   
 3   total_val   32 non-null     float64 
 4   geometry    29 non-null     geometry
dtypes: float64(1), geometry(1), int64(1), object(2)
memory usage: 1.4+ KB


In [15]:
merged_df_new = pd.merge(preprocessed_qol_df, preprocessed_su_df, on=["Country", "Year", "ISO3"])
merged_df_new = pd.merge(merged_df_new, countries_gdf, on=[ "ISO3"])
merged_df_new = merged_df_new.rename(columns={"geometry": "Country_Geom"})
merged_df_new.info()





<class 'pandas.core.frame.DataFrame'>
RangeIndex: 203 entries, 0 to 202
Data columns (total 17 columns):
 #   Column                          Non-Null Count  Dtype   
---  ------                          --------------  -----   
 0   Country                         203 non-null    object  
 1   Quality_of_Life_Index           203 non-null    float64 
 2   Purchasing_Power_Index          203 non-null    float64 
 3   Cost_of_Living_Index            203 non-null    float64 
 4   Property_Price_to_Income_Ratio  203 non-null    float64 
 5   Year                            203 non-null    int64   
 6   GDP_Per_Capita                  203 non-null    float64 
 7   ISO3                            203 non-null    object  
 8   N_Unicorns                      203 non-null    float64 
 9   Total_Val                       203 non-null    float64 
 10  Startups                        203 non-null    object  
 11  Industries                      203 non-null    object  
 12  N_Unicorns_Cumulative 

In [None]:
import pandas as pd
import geopandas as gpd
from shapely import wkt
import plotly.express as px
from sklearn.preprocessing import MinMaxScaler

# Define metrics and year
metric_1 = "Purchasing_Power_Index"
metric_2 = "N_Unicorns_Cumulative"
year = 2021

# Filter data based on year
df = merged_df_new[merged_df_new["Year"] == year].copy()
df = df[df["Country_Geom"].notna()].reset_index(drop=True)

# Create GeoDataFrame - Handle the geometry properly
# Check if Country_Geom is already a geometry object
if hasattr(df["Country_Geom"].iloc[0], 'geom_type'):
    # Already a geometry object - use directly
    gdf = gpd.GeoDataFrame(df, geometry="Country_Geom", crs="EPSG:4326")
else:
    # It's a WKT string - convert to geometry
    df["geometry"] = df["Country_Geom"].apply(wkt.loads)
    gdf = gpd.GeoDataFrame(df, geometry="geometry", crs="EPSG:4326")

# Simplify geometries for better performance
gdf["geometry"] = gdf.geometry.simplify(tolerance=0.05, preserve_topology=True)

# Normalize the metric for color mapping
scaler = MinMaxScaler((0, 1))
gdf["scaled_metric"] = scaler.fit_transform(gdf[[metric_1]])

# Create choropleth map using built-in GeoJSON conversion
fig = px.choropleth(
    gdf,
    geojson=gdf.geometry.__geo_interface__,
    locations=gdf.index,
    color="scaled_metric",
    hover_name="Country",
    hover_data={metric_2: True, metric_1: True, "scaled_metric": False},
    color_continuous_scale="YlOrRd",
    labels={metric_1: metric_1, metric_2: metric_2, "scaled_metric": metric_1},
)

# Update layout for better appearance
fig.update_layout(
    margin=dict(l=0, r=0, t=0, b=0),
    geo=dict(
        showframe=False,
        showcoastlines=True,
        projection_type='equirectangular'
    )
)

# Display the map
fig.show()

In [95]:
import json
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go
from dash import Dash, dcc, html
from dash.dependencies import Input, Output, State
import geopandas as gpd
from shapely import wkt
from sklearn.preprocessing import MinMaxScaler

# metric_1 = "Purchasing_Power_Index"
# metric_2 = "Total_Val"

label_map = {
    "GDP_Per_Capita": "GDP per Capita",
    "Total_Val": "Total Valuation",
    "Purchasing_Power_Index": "Purchasing Power",
    "N_Unicorns_Cumulative": "Number of Unicorns Cumulative",
    "Cost_of_Living_Index": "Cost of Living",
    "Quality_of_Life_Index": "Quality of Life",
    "Property_Price_to_Income_Ratio": "Property Price to Income Ratio",
    "N_Unicorns": "Number of Unicorns",
    "Total_Val_Cumulative": "Total Val Cumulative"
}


# List of selectable metrics
metric_options = [
    "Quality_of_Life_Index", #
    "Purchasing_Power_Index", #
    "Cost_of_Living_Index", #
    "Property_Price_to_Income_Ratio", #
    "GDP_Per_Capita", #
    "N_Unicorns", #
    "Total_Val", #
    "N_Unicorns_Cumulative", #
    "Total_Val_Cumulative"	
]
country_options = [{"label": country, "value": country} for country in merged_df_new["Country"].unique()]
print(country_options)


#available_metrics = [ "GDP_Per_Capita", "Cost_of_Living_Index", "Quality_of_Life_Index"]
#available_metrics2 = ["N_Unicorns_Cumulative", "Total_Val"]
available_metrics = metric_options
available_metrics2 = metric_options

# Aggregate by country
agg_df = merged_df_new.groupby('Country').agg({
    'GDP_Per_Capita': 'max',
    'Quality_of_Life_Index': 'max'
}).reset_index()

# Precompute KPIs
total_unicorns = merged_df_new['N_Unicorns'].sum().max().astype(int)
highest_valuation_country = merged_df_new.groupby('Country')['Total_Val'].sum().idxmax()
highest_valuation_value = merged_df_new.groupby('Country')['Total_Val'].sum().max()
country_most_unicorns = merged_df_new.groupby('Country')['N_Unicorns'].sum().idxmax()
highest_gdp_country = agg_df.loc[agg_df['GDP_Per_Capita'].idxmax()]
highest_quality_country = agg_df.loc[agg_df['Quality_of_Life_Index'].idxmax()]

# App setup
app = Dash(__name__)

app.layout = html.Div([
    html.Div([
        html.H1("GLOBAL UNICORN STARTUPS DASHBOARD", style={
            "textAlign": "center",
            "marginBottom": "10px",
            "fontFamily": "Segoe UI",
            "color": "#333",
            "fontWeight": "bold",
            "fontSize": "24px"
        }),
    html.Div([
         

            # Total Companies KPI
            html.Div([
                html.H2(f"{total_unicorns:,}", style={
                    "color": "#333", "margin": "0", "fontSize": "18px"
                }),
                html.P("Global Unicorn Count", style={
                    "margin": "0", "fontSize": "11px", "color": "#333"
                })
            ], style={
                "flex": "1",
                "maxWidth": "220px",
                "padding": "6px",
                "textAlign": "center",
                "backgroundColor": "#C4B8F7", 
                "borderRadius": "10px",
                "boxShadow": "0 2px 4px rgba(0, 0, 0, 0.05)",
                "fontFamily": "Arial",
                "margin": "2px"
            }),

            # Highest Valuation KPI
            html.Div([
                html.H2(f"{highest_valuation_country}", style={
                    "color": "#333", "margin": "0", "fontSize": "18px"
                }),
                #html.P(f"Country with Highest Unicorn Valuation(${highest_valuation_value:,.2f}B)", style={
                html.P(f"Country with Highest Unicorn Valuation", style={
                    "margin": "0", "fontSize": "11px", "color": "#333"
                })
            ], style={
                "flex": "1",
                "maxWidth": "220px",
                "padding": "6px",
                "textAlign": "center",
                "backgroundColor": "#acf2bb",
                "borderRadius": "10px",
                "boxShadow": "0 2px 4px rgba(0, 0, 0, 0.05)",
                "fontFamily": "Arial",
                "margin": "2px"
            }),

            # Most Unicorns KPI
            html.Div([
                html.H2(f"{country_most_unicorns}", style={
                    "color": "#333", "margin": "0", "fontSize": "18px"
                }),
                html.P("Country with Highest Unicorn Count", style={
                    "margin": "0", "fontSize": "11px", "color": "#333"
                })
            ], style={
                "flex": "1",
                "maxWidth": "220px",
                "padding": "6px",
                "textAlign": "center",
                "backgroundColor": "#f4b6b6",  
                "borderRadius": "10px",
                "boxShadow": "0 2px 4px rgba(0, 0, 0, 0.05)",
                "fontFamily": "Arial",
                "margin": "2px"
            }),


            # Total QOL KPI
            html.Div([
                html.H2(highest_quality_country["Country"], style={
                    "color": "#333", "margin": "0", "fontSize": "18px"
                }),
                html.P(f"Country with Highest Quality of Life Index", style={
                    "margin": "0", "fontSize": "11px", "color": "#333"
                })
            ], style={
                "flex": "1",
                "maxWidth": "220px",
                "padding": "6px",
                "textAlign": "center",
                "backgroundColor": "#ffe082",
                "borderRadius": "10px",
                "boxShadow": "0 2px 4px rgba(0, 0, 0, 0.05)",
                "fontFamily": "Arial",
                "margin": "2px"
            }),


        # Total GDP KPI
            html.Div([
                html.H2(highest_gdp_country["Country"], style={
                    "color": "#333", "margin": "0", "fontSize": "18px"
                }),
                html.P(f"Country with Highest GDP per Capita", style={
                    "margin": "0", "fontSize": "11px", "color": "#333"
                })
            ], style={
                "flex": "1",
                "maxWidth": "220px",
                "padding": "6px",
                "textAlign": "center",
                "backgroundColor": "#f9ac79",  
                "borderRadius": "10px",
                "boxShadow": "0 2px 4px rgba(0, 0, 0, 0.05)",
                "fontFamily": "Arial",
                "margin": "2px"
            }),

        ], style={
            "display": "flex",
            "justifyContent": "center",  
            "flexWrap": "wrap",
            "width": "100%",
            "padding": "5px 10px",
            "overflowX": "hidden"
        })
    ], style={"width": "100%", "padding": "0 10px"}),


 html.Div([
    # Left panel - Map and line chart
    html.Div([
        html.Div([
            html.Label("Select Metric to Visualize on Map:", style={"fontSize": "14px" , "fontStyle": "italic"}),
            dcc.Dropdown(
                id='metric-1-dropdown',
                options=[{"label": label_map[m], "value": m} for m in metric_options],
                value='Total_Val',
                searchable=False,
                clearable=False 
            )
        ], style={"width": "35%", "display": "inline-block"}),

        html.Div([
            html.Label("Select Countries to Compare:", style={"fontSize": "14px", "fontStyle": "italic"}),
            dcc.Dropdown(
                id='country-dropdown',
                options=country_options,
                value=['United States'],
                multi=True
            )
        ], style={"width": "65%", "display": "inline-block"}),

        # World Map 
        dcc.Graph(
            id="world-map",
            config={"clickmode": "event+select",
                   "modeBarButtonsToRemove": ["select2d", "lasso2d", "pan"], 
                    "displaylogo": False, 
                    "watermark": False},
            style={
                "height": "45vh",
                "marginBottom": "10px",      
            }
        ),

        # Year Slider
        dcc.Slider(
            id='year-slider',
            min=merged_df_new['Year'].min(),
            max=merged_df_new['Year'].max(),
            step=1,
            value=merged_df_new['Year'].min(),
            marks={str(year): str(year) for year in sorted(merged_df_new['Year'].unique())},
            tooltip={"placement": "bottom", "always_visible": True}
        ),

        # Bar Chart 
        dcc.Graph(
            id="fig_bar",
            style={
                "height": "23vh",
                "boxShadow": "0 2px 8px rgba(0,0,0,0.1)"
            }
        ),
    ], style={
        "width": "57%",
        "padding": "1px",
        "backgroundColor": '#f9f9f9'  
    }),

        
    # Right panel
html.Div([
    # Container to center the dropdowns
    html.Div([
        html.Div([
            html.Label("Select Metric 1 (Bubble size):", style={"fontSize": "14px", "fontStyle": "italic"}),
            dcc.Dropdown(
                id="metric-dropdown1",
                options=[{"label": label_map[m], "value": m} for m in available_metrics2],
                value="N_Unicorns_Cumulative",
                 searchable=False,
                clearable=False   
            )
        ], style={"width": "200px", "margin": "0 10px","width": "35%"}),

        html.Div([
            html.Label("Select Metric 2 (Y-axis):", style={"fontSize": "14px", "fontStyle": "italic"}),
            dcc.Dropdown(
                id="metric-dropdown",
                 options=[{"label": label_map[m], "value": m} for m in available_metrics],
                value="GDP_Per_Capita",
                searchable=False,
                clearable=False 
            )
        ], style={"width": "200px", "margin": "0 10px","width": "35%"}),
    ], style={
        "display": "flex",
        "justifyContent": "center",
        "marginBottom": "10px",
        "flexWrap": "wrap"
    }),
    
    # Bubble chart
    html.Div([
        dcc.Graph(id="bubble-plot", config={
                    "scrollZoom": False,
                    "doubleClick": False,
                    "displaylogo": False,
                    "watermark": False,
                    "displayModeBar": True,
                    "modeBarButtonsToRemove": ["select2d", "lasso2d", 'pan', 'select', 'zoom', 'zoomIn',
                                               'zoomOut', 'resetScale', "AutoScale"],
                }
                          , style={
                        "height": "38vh", "marginBottom": "5px", "boxShadow": "0 2px 8px rgba(0,0,0,0.1)"
                    },
                          )
    ]),

    # Second graph row
    html.Div(
        style={
            "display": "flex",
            "alignItems": "flex-start",
        },
        children=[
            html.Div(
                dcc.Graph(id="company-bar-chart", style={"height": "35vh","boxShadow": "0 2px 8px rgba(0,0,0,0.1)"}, config = {"modeBarButtonsToRemove": ["select2d", "lasso2d", 'pan', 'select', 'zoom', 'zoomIn', 'zoomOut', 'resetScale', "AutoScale"], 
                    "displaylogo": False, 
                    "watermark": False}),
                style={"flex": "1", "marginRight": "2px"},
                
            ),
            html.Div(
                dcc.Graph(id="industry-bar-chart", style={"height": "35vh","boxShadow": "0 2px 8px rgba(0,0,0,0.1)"}, config = {"modeBarButtonsToRemove": ["select2d", "lasso2d", 'pan', 'select', 'zoom', 'zoomIn', 'zoomOut', 'resetScale', "AutoScale"], 
                    "displaylogo": False, 
                    "watermark": False}),
                style={"flex": "1", "marginLeft": "2px"},
            ),
        ],
    ),
], style={"width": "43%", "padding": "2px"})

    ], style={
        "display": "flex", 
        "gap": "1px", 
        "justifyContent": "space-between", 
        "margin": "0 auto", 
        "maxWidth": "98%"  
    }),
    
    # Hidden div to store clicked countries
    html.Div(id='clicked-countries-store', style={'display': 'none'}, children=json.dumps(['United States']))
],style={"backgroundColor": "#f9f9f9"})


# Callback to update map and bar chart by year
@app.callback(
    [Output("world-map", "figure"),
     Output("fig_bar", "figure")],
    [Input("year-slider", "value"),
     Input("metric-1-dropdown", "value"),
     Input("country-dropdown", "value")]
)
def update_dashboard(selected_year, selected_metric, selected_countries):
    # ── 1  FILTER DATA FOR THE SELECTED YEAR ───────────────────────
    df = merged_df_new[merged_df_new["Year"] == selected_year].copy()
    df = df[df["Country_Geom"].notna()].reset_index(drop=True)

    # GeoDataFrame handling (unchanged)
    if hasattr(df["Country_Geom"].iloc[0], "geom_type"):
        gdf = gpd.GeoDataFrame(df, geometry="Country_Geom", crs="EPSG:4326")
    else:
        df["geometry"] = df["Country_Geom"].apply(wkt.loads)
        gdf = gpd.GeoDataFrame(df, geometry="geometry", crs="EPSG:4326")

    gdf["geometry"] = gdf.geometry.simplify(tolerance=0.05, preserve_topology=True)
    gdf["color_metric"] = gdf[selected_metric]

    # Build GeoJSON
    geojson = {
        "type": "FeatureCollection",
        "features": [
            {
                "type": "Feature",
                "id": i,
                "properties": {
                    "Country": row["Country"],
                    "color_metric": row["color_metric"],
                    selected_metric: row[selected_metric],
                    "Total_Val": row["Total_Val"],
                },
                "geometry": row["geometry"].__geo_interface__,
            }
            for i, row in gdf.iterrows()
        ],
    }

    # ── 2  PICK ONE COLOUR SCALE + NUMERIC RANGE  (NEW) ────────────
    color_scale = [
        "#fffde7", "#fff59d", "#ffe082",
        "#ffca28", "#ffa000", "#f57c00"
    ]
    cmin = gdf[selected_metric].min()
    cmax = gdf[selected_metric].max()

    # ── 3  CHOROPLETH  (now receives scale + range) ────────────────
    fig_map = px.choropleth_map(
        gdf,
        geojson=geojson,
        locations=gdf.index,
        featureidkey="id",
        color="color_metric",
        hover_name="Country",
        hover_data={selected_metric: True, "Total_Val": True},
        color_continuous_scale=color_scale,      # same palette
        range_color=(cmin, cmax),                # same numeric window
        map_style="carto-positron",
        zoom=0.6,
        center={"lat": 20, "lon": 0},
        opacity=0.8,
        labels={"color_metric": selected_metric.replace("_", " ")},
    )

    hover_label = selected_metric.replace("_", " ")
    hover_label_data = None
    if selected_metric in ["Total_Val", "Total_Val_Cumulative"]:
        hover_label_data = "%{customdata[0]:,.2f}" + "B<br>"
    else:
        hover_label_data = "%{customdata[0]:,.2f}" + "<br>"
        
    fig_map.update_traces(
        hovertemplate="%{hovertext}<br>" +
                    f"{hover_label}: " + hover_label_data,
        marker=dict(opacity=0.85),
        unselected=dict(marker=dict(opacity=0.85)),
        selected=dict(marker=dict(opacity=1.0))
    )
    
    # Check if any countries are selected in the dropdown
    if selected_countries:
        # Make selected countries more visible on the map
        for country in selected_countries:
            indices = gdf[gdf["Country"] == country].index.tolist()
            for i in indices:
                # Add a thicker border to highlight selected countries
                fig_map.add_trace(
                    go.Choropleth(
                        geojson=geojson,
                        locations=[i],
                        z=[1],
                        featureidkey="id",
                        colorscale=[[0, 'rgba(0,0,0,0)'], [1, 'rgba(0,0,0,0)']],
                        showscale=False,
                        marker=dict(line=dict(color='rgba(0,0,0,1)', width=2)),
                        hoverinfo='skip'
                    )
                )

    fig_map.update_layout(
        geo=dict(showframe=False, showcoastlines=False),
        margin=dict(l=0, r=10, t=0, b=0),
        mapbox=dict(
            zoom=0.65, 
            center=dict(lat=20, lon=0), 
            style="carto-positron"
        ),
        dragmode='zoom',
        modebar_add=['pan', 'select', 'lasso2d', 'zoom', 'zoomIn', 'zoomOut', 'resetScale'],
        uirevision='constant',
        coloraxis_colorbar=dict(
            thickness=15,
            len=0.75,
            x=0.99,
            xanchor='left',
            y=0.5,
            yanchor='middle',
            title=None,
            title_side='top'
        ),
        plot_bgcolor='#f9f9f9',
        paper_bgcolor='#f9f9f9'
    )

    # ── 4  BUILD THE TOP-10 BAR WITH IDENTICAL COLOUR MAPPING 
    top_counts = (
        df.groupby("Country")[selected_metric]
          .sum().reset_index()
          .replace({"United Arab Emirates": "UAE"})
          .loc[lambda d: d[selected_metric] > 0]
          .sort_values(selected_metric, ascending=False)
          .head(10).reset_index(drop=True)
    )

    fig_bar = go.Figure(
        go.Bar(
            x=top_counts["Country"],
            y=top_counts[selected_metric],
            marker=dict(
                color=top_counts[selected_metric],   # numeric values
                colorscale=color_scale,              # ← same palette
                cmin=cmin,
                cmax=cmax,                           # ← same range
                showscale=False                      # don’t duplicate colourbar
            )
        )
    )


    fig_bar.update_layout(
        title=f"Top 10 Countries by {selected_metric.replace("_", " ")} in {selected_year}",
        margin=dict(t=30, l=20, r=20, b=20),
        plot_bgcolor="#f9f9f9",
        paper_bgcolor="#f9f9f9",
        font=dict(family="Arial", size=12),
        xaxis_title="Country",
        yaxis_title=label_map[selected_metric],
    )

    return fig_map, fig_bar

# Callback to update the country dropdown when map is clicked
@app.callback(
    Output("country-dropdown", "value"),
    [Input("world-map", "clickData")],
    [State("country-dropdown", "value"),
     State("clicked-countries-store", "children")]
)
def update_dropdown_from_map(clickData, current_countries, clicked_countries_json):
    # Get the current list of clicked countries
    selected_countries = current_countries if current_countries else []
    
    # If map was clicked
    if clickData and "points" in clickData:
        country = clickData["points"][0]["hovertext"]
        
        # Toggle the country in the selection list
        if country in selected_countries:
            selected_countries.remove(country)
        else:
            selected_countries.append(country)
    
    # Make sure we always have at least one country selected
    if not selected_countries:
        selected_countries = ["United States"]
        
    return selected_countries

# Store clicked countries for persistence
@app.callback(
    Output("clicked-countries-store", "children"),
    [Input("country-dropdown", "value")]
)
def store_countries(countries):
    # Make sure we always have at least one country
    if not countries:
        countries = ["United States"]
    return json.dumps(countries)

@app.callback(
    Output("bubble-plot", "figure"),
    [Input("metric-dropdown1", "value"),
     Input("metric-dropdown", "value"),
     Input("country-dropdown", "value")]
)
def update_bubble_plot(selected_metric1, selected_metric, selected_countries):
    # Fallback if no country selected
    if not selected_countries:
        selected_countries = ["United States"]

    df_selected = merged_df_new[merged_df_new["Country"].isin(selected_countries)].copy()
    df_selected["CountryLabel"] = df_selected["Country"]
    data = np.array(df_selected[selected_metric1]).reshape((-1,1))
    scaler = MinMaxScaler(feature_range=(0,1))
    scaled = scaler.fit_transform(data).flatten()
    df_selected["scaled_metric"] = scaled 
        

    selected_metric_label = None
    selected_metric1_label = None
    if selected_metric is not None:
        selected_metric_label = selected_metric.replace("_", " ")
    if selected_metric1 is not None:
        selected_metric_label1 = selected_metric1.replace("_", " ")    
    
    fig = px.scatter(
        df_selected,
        x="Year",
        y=selected_metric,
        size="scaled_metric",
        color="CountryLabel",
        hover_name="Country",
         hover_data={                 
        "Year": False,
        selected_metric1: True,
        selected_metric: False,
        "scaled_metric": False},
        size_max=40,
        title=f"Comparison of {selected_metric_label} Vs {selected_metric_label1}",
        labels={"CountryLabel": "Country"},
       color_discrete_sequence = [
        "#7A5DF0",  
        "#FFD700",  
        "#FF6B6B",  
        "#4D4DFF" ,
        "#18A645",  
    ]
    

    )

    fig.update_xaxes(fixedrange=True)
    fig.update_yaxes(fixedrange=True)

    fig.update_traces(
        marker=dict(opacity=0.7, line=dict(width=1, color='white')),
        selector=dict(mode='markers')
    )

    fig.update_layout(
        template="plotly_white",
        height=300,
        margin=dict(l=20, r=20, t=40, b=20),
        font=dict(family="Arial", size=11),
        xaxis=dict(tickmode='linear', dtick=1),
        plot_bgcolor='#f9f9f9',
        paper_bgcolor='#f9f9f9'
    )

    return fig

##company-bar-chart
@app.callback(
    Output("company-bar-chart", "figure"),
    [Input("country-dropdown", "value"),
     Input("year-slider", "value")]
)
def update_company_bar_chart(selected_countries, year):
    # Default to first selected country or United States
    selected_country = selected_countries[-1] if selected_countries else "United States"
    
    # Filter company data for the selected country
    company_df = basic_su_df[(basic_su_df["Country"] == selected_country) & (basic_su_df["Date"].map(lambda x: int(str(x)[-4:])) <= year)]
    # Projection for company and evaluation
    company_val = (
        company_df[["Company", "Valuation"]]
        .sort_values(by="Valuation",ascending=False)
        .reset_index(drop=True)
        .head(10)
    )
    company_val.Company = company_val.Company + "   "

    # Create horizontal bar chart
    fig_ind = px.bar(
        company_val,
        x="Valuation",
        y="Company",
        orientation="h",
        title=f"Top Companies in {selected_country} in {year}",
        labels={"Valuation": "", "Company": ""},   # no x‐axis title
        color="Valuation",
      color_continuous_scale = [
    "#fde4ec",  
    "#f8bbd0",  
    "#FC8C8C", 
    "#FC7272",  
    "#f572a0",  
    "#f66598"   
]



    )

    # Hide the x‐axis completely
    fig_ind.update_xaxes(visible=False, range=[0, company_val["Valuation"].max() * 1.05] if not company_val.empty else [0, 10])

    # Restore y‐axis labels in their normal position
    fig_ind.update_yaxes(ticklabelposition="outside")

    # Add the value labels on the outside of each bar
    fig_ind.update_traces(
        text=company_val["Valuation"].map(lambda v: f"{v:,.1f}B"),
        textposition="outside",
        cliponaxis=False,
        hovertemplate="<b>%{y}</b><br>Valuation: %{x:,.1f}B<extra></extra>"
    )

    fig_ind.update_layout(
        yaxis=dict(categoryorder="total ascending"),
        template="plotly_white",
        height=280,
        margin=dict(t=40, r=90, b=30),
        coloraxis_showscale=False,
        title_x=0.5,
        title_font_size=14,
        plot_bgcolor='#f9f9f9',
        paper_bgcolor='#f9f9f9'
    )

    return fig_ind
    
# Callback for industry bar chart
@app.callback(
    Output("industry-bar-chart", "figure"),
    [Input("country-dropdown", "value"),
    Input("year-slider", "value")]
)
def update_industry_bar_chart(selected_countries, year):
    # Default to first selected country or United States
    selected_country = selected_countries[-1] if selected_countries else "United States"

    # Filter industry data for the selected country
    industry_df = basic_su_df[(basic_su_df["Country"] == selected_country) & (basic_su_df["Date"].map(lambda x: int(str(x)[-4:])) == year)]

    # Group by industry and calculate total valuation
    short_names = {
                    "Other": "Other",
                    "Fintech": "Fintech",
                    "Finttech": "Fintech",
                    "Supply chain, logistics, & delivery": "Supply Chain",
                    "Data management & analytics": "Data & Analytics",
                    "E-commerce & direct-to-consumer": "E-commerce",
                    "Internet software & services": "Internet Services",
                    "Health": "Health",
                    "Artificial intelligence": "AI",
                    "Artificial Intelligence": "AI",
                    "Consumer & retail": "Retail",
                    "Cybersecurity": "Cybersecurity",
                    "Mobile & telecommunications": "Telecom",
                    "Auto & transportation": "Auto",
                    "Travel": "Travel",
                    "Hardware": "Hardware",
                    "Edtech": "Edtech",
                }

    # Group by industry and calculate total valuation
    industry_val = (
        industry_df.groupby("Industry")["Valuation"]
        .sum()
        .sort_values(ascending=False)
        .reset_index()
        .head(10)
    )
    
    industry_val["Industry"] = industry_val["Industry"].map(short_names).fillna("Other")
    industry_val.Industry = industry_val.Industry + "   "
    
    if not industry_val.empty:
        industry_val.Industry = industry_val.Industry + "   "

        # Create horizontal bar chart
        fig_ind = px.bar(
            industry_val,
            x="Valuation",
            y="Industry",
            orientation="h",
            title=f"Top Industries in {selected_country} in {year}",
            labels={"Valuation": "", "Industry": ""},   # no x‐axis title
            color="Valuation",
         color_continuous_scale = [
    "#e0fbe5", 
    "#b9f6c5",  
    "#93f0a7",  
    "#66ea89",  
    "#3fdd70",  
    "#25c961"   
]

        )

        # Hide the x‐axis completely
        fig_ind.update_xaxes(visible=False, range=[0, industry_val["Valuation"].max() * 1.05])

        # Restore y‐axis labels in their normal position
        fig_ind.update_yaxes(ticklabelposition="outside")

        # Add the value labels on the outside of each bar
        fig_ind.update_traces(
            text=industry_val["Valuation"].map(lambda v: f"{v:,.1f}B"),
            textposition="outside",
            cliponaxis=False,
            hovertemplate="<b>%{y}</b><br>Valuation: %{x:,.1f}B<extra></extra>"
        )
    else:
        # Create an empty figure if no data
        fig_ind = go.Figure()
        fig_ind.update_layout(
            title=f"No Industry Data for {selected_country}"
        )

    fig_ind.update_layout(
        yaxis=dict(categoryorder="total ascending"),
        template="plotly_white",
        height=280,
        margin=dict(t=40, r=90, b=30),
        coloraxis_showscale=False,
        title_x=0.5,
        title_font_size=14,
        plot_bgcolor='#f9f9f9',
        paper_bgcolor='#f9f9f9'
    )

    return fig_ind

# Run app
if __name__ == "__main__":
    app.run(port=8066, mode="inline")

[{'label': 'Argentina', 'value': 'Argentina'}, {'label': 'Australia', 'value': 'Australia'}, {'label': 'Austria', 'value': 'Austria'}, {'label': 'Belgium', 'value': 'Belgium'}, {'label': 'Brazil', 'value': 'Brazil'}, {'label': 'Canada', 'value': 'Canada'}, {'label': 'Chile', 'value': 'Chile'}, {'label': 'Colombia', 'value': 'Colombia'}, {'label': 'Croatia', 'value': 'Croatia'}, {'label': 'Denmark', 'value': 'Denmark'}, {'label': 'Finland', 'value': 'Finland'}, {'label': 'Germany', 'value': 'Germany'}, {'label': 'India', 'value': 'India'}, {'label': 'Indonesia', 'value': 'Indonesia'}, {'label': 'Ireland', 'value': 'Ireland'}, {'label': 'Israel', 'value': 'Israel'}, {'label': 'Japan', 'value': 'Japan'}, {'label': 'Lithuania', 'value': 'Lithuania'}, {'label': 'Malaysia', 'value': 'Malaysia'}, {'label': 'Mexico', 'value': 'Mexico'}, {'label': 'Netherlands', 'value': 'Netherlands'}, {'label': 'Philippines', 'value': 'Philippines'}, {'label': 'South Africa', 'value': 'South Africa'}, {'label