# Interactive data‑profiling helper for CanonFodder
-------------------------------------------------
Run it as Jupyter-style cells
-------------------------------------------------
Per 2025‑05‑31
-------------------------------------------------
* **User‑country timeline stores ISO-2 country codes.
* **Artist‑country lookup first hits the local *ArtistInfo* table or
  `PQ/artist_info.parquet`**, then falls back to MusicBrainz if needed, and rewrites the
  cache + parquet on the fly.
* Adds a fast, vectorised join that assigns a ``UserCountry`` column to every
  scrobble by interval matching against the timeline.
* Keeps all guard‑rails (no overlaps, sensible dates) and rewrites
  ``PQ/uc.parquet`` automatically.
Performance Notes:
-----------------
* Use the `--no-interactive` flag to disable interactive visualizations for faster execution.
* When running in non-interactive mode, visualizations will be saved to files instead of
  being displayed in interactive windows.
* This can significantly improve performance, especially when running the script in a console
  or in environments where interactive display is not needed.


In [None]:
from __future__ import annotations
from datetime import datetime
from dotenv import load_dotenv
load_dotenv()
from DB import SessionLocal
from DB.models import (
    ArtistInfo,
    ArtistVariantsCanonized,
)
import argparse
from branca.colormap import LinearColormap, StepColormap
import calendar
from corefunc.data_cleaning import clean_artist_info_table
import folium
from folium import plugins as folium_plugins
from helpers import cli
from helpers import io
from helpers import stats
import importlib
from HTTP import mbAPI
mbAPI.init()
import json
import logging
import matplotlib
matplotlib.use("TkAgg")
from matplotlib.collections import PolyCollection
import matplotlib.colors as mcolors
import matplotlib.dates
from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
plt.ion()
import musicbrainzngs
musicbrainzngs.logging.getLogger().setLevel(logging.WARNING)
logging.getLogger("musicbrainzngs.mbxml").setLevel(logging.WARNING)
import numpy as np
import os
os.environ["MPLBACKEND"] = "TkAgg"
import pandas as pd
from pathlib import Path
import plotly.graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots
import pypopulation
import re
from scipy.stats import gaussian_kde
import seaborn as sns
from sqlalchemy import select
import sys
from statsmodels.tsa.seasonal import seasonal_decompose
import threading
import webbrowser


In [None]:
# Constants & basic setup
logging.basicConfig(
    level=logging.WARNING,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[logging.StreamHandler()],
)
log = logging.getLogger("dev_profile")
logging.getLogger("musicbrainzngs.mbxml").setLevel(logging.WARNING)
mbAPI.init()
# Default to TkAgg backend, but this can be overridden by command line args
matplotlib.use("TkAgg")
os.environ["MPLBACKEND"] = "TkAgg"
log.addFilter(lambda rec: not rec.name.startswith("musicbrainzngs.mbxml"))


In [None]:
# Parsing command-line arguments
def parse_args() -> argparse.Namespace:
    pee = argparse.ArgumentParser(description="CanonFodder dev profiling helper")
    pee.add_argument("--no-interactive", action="store_true",
                     help="Disable interactive visualizations (faster execution)")
    sub = pee.add_subparsers(dest="cmd", help="Sub‑commands")
    sub.add_parser("country", help="Edit user‑country timeline interactively")
    sub.add_parser("cleanup-artists", help="Clean up the ArtistInfo table by removing duplicates and orphaned entries")

    # Use parse_known_args to ignore unrecognized arguments (like PyCharm's --mode, --host, --port)
    return pee.parse_known_args()[0]


# Global flag to control interactive mode
INTERACTIVE_MODE = True

# Try to parse command-line arguments
args = None
try:
    args = parse_args()
    if args and args.no_interactive:
        print("Running in non-interactive mode (visualizations will be saved to files)")
        matplotlib.use('Agg')  # Use non-interactive backend
        plt.ioff()  # Turn off interactive mode
        INTERACTIVE_MODE = False
except Exception as e:
    print(f"Warning: Error parsing arguments: {e}")
    print("Continuing with default settings (interactive mode enabled)")


In [None]:
def show_or_save_plot(filename, dpi=100, description=None):
    """
    Either shows the matplotlib plot interactively or saves it to a file, depending on the mode.
    Args:
        filename: Name of the file to save the plot to (when in non-interactive mode)
        dpi: Resolution for the saved image
        description: Optional Markdown description to include with the plot when saving as HTML
    """
    pics_dir = PROJECT_ROOT / "pics"
    pics_dir.mkdir(exist_ok=True)
    filepath = pics_dir / filename

    # Check if running in a TTY (interactive terminal)
    is_tty = hasattr(sys.stdout, 'isatty') and sys.stdout.isatty()

    try:
        # If in interactive mode and in a TTY, show the plot
        # Otherwise, save it to a file
        if INTERACTIVE_MODE and is_tty:
            plt.show()
        else:
            # If not in a TTY, save the figure even if in interactive mode
            if INTERACTIVE_MODE and not is_tty:
                print(f"(no TTY – saving plot to file instead of displaying)")

            # Save the figure as an image
            plt.savefig(filepath, dpi=dpi)
            print(f"Plot saved to {filepath}")

            # If a description is provided, also save as HTML with the description
            if description:
                try:
                    from helpers.markdown import render_markdown

                    # Convert the plot to a base64 image for embedding in HTML
                    import io
                    import base64
                    buf = io.BytesIO()
                    plt.savefig(buf, format='png', dpi=dpi)
                    buf.seek(0)
                    img_str = base64.b64encode(buf.read()).decode('utf-8')

                    # Render the Markdown description to HTML
                    desc_html = render_markdown(description)

                    # Create HTML with both the description and the image
                    html_path = filepath.with_suffix('.html')
                    html_content = f"""<!DOCTYPE html>
<html>
<head>
    <title>{filename}</title>
    <style>
        body {{ font-family: Arial, sans-serif; margin: 20px; }}
        .description {{ margin-bottom: 20px; }}
        img {{ max-width: 100%; }}
    </style>
</head>
<body>
    <div class="description">
        {desc_html}
    </div>
    <div class="image">
        <img src="data:image/png;base64,{img_str}" alt="Plot">
    </div>
</body>
</html>"""

                    with open(html_path, 'w', encoding='utf-8') as f:
                        f.write(html_content)

                    print(f"Plot with description saved to {html_path}")
                except ImportError:
                    print("Showdown not available. HTML with description not generated.")
    except Exception as plot_error:
        print(f"Error showing/saving plot: {plot_error}")
        print(f"Attempting to save plot using alternative method...")
        try:
            plt.savefig(str(filepath), dpi=dpi, format='png')
            print(f"Plot saved to {filepath} using alternative method")
        except Exception as e2:
            print(f"Failed to save plot: {e2}")
    finally:
        # Close the figure if not in interactive mode or not in a TTY
        if not INTERACTIVE_MODE or not is_tty:
            plt.close()


In [None]:
def show_or_save_plotly(figura, filename, description=None, public=True):
    """
    Either shows the plotly figure interactively or saves it to a file, depending on the mode.
    Args:
        figura: The plotly figure to show or save
        filename: Name of the file to save the figure to (when in non-interactive mode)
        description: Optional Markdown description to include with the figure when saving as HTML
        public: If True, saves an additional copy in a public directory for teacher assessment
    """
    # Handle case where parameters might be passed in wrong order
    if hasattr(filename, 'layout') and isinstance(filename, go.Figure) and isinstance(figura, str):
        # Parameters are swapped, fix them
        figura, filename = filename, figura

    pics_dir = PROJECT_ROOT / "pics"
    pics_dir.mkdir(exist_ok=True)

    # Create a public directory for teacher assessment if it doesn't exist
    public_dir = PROJECT_ROOT / "public_visualizations"
    public_dir.mkdir(exist_ok=True)

    # Ensure filename is a string
    if not isinstance(filename, str):
        raise TypeError("filename must be a string, got {type(filename).__name__} instead")

    filepath = pics_dir / filename
    public_filepath = public_dir / filename
    if not INTERACTIVE_MODE:
        print(f"[PERFORMANCE OPTIMIZATION] Skipping actual figure rendering for {filename}")
        print(f"To view the figure, run the script with interactive mode enabled")
        with open(filepath.with_suffix('.txt'), 'w') as f:
            f.write(f"Figure would be saved here: {filepath}\n")
            f.write("Running in non-interactive mode with performance optimizations enabled.\n")
            f.write("To view the actual figure, run the script without the --no-interactive flag.\n")
        return

    # Check if running in a TTY (interactive terminal)
    is_tty = hasattr(sys.stdout, 'isatty') and sys.stdout.isatty()
    if not is_tty:
        print("(no TTY – assuming 'Y')")
        # When not in a TTY, just save to HTML without trying to open a browser
        html_path = filepath.with_suffix('.html')
        try:
            # Use a static configuration to reduce rendering complexity
            config = {
                'displayModeBar': False,  # Hide the modebar
                'responsive': True,  # Make the plot responsive
                'staticPlot': True,  # Make the plot static (no interactivity)
            }

            # If a description is provided, create HTML with both the description and the figure
            if description:
                try:
                    from helpers.markdown import render_markdown

                    # Render the Markdown description to HTML
                    desc_html = render_markdown(description)

                    # Get the figure HTML
                    figure_html = figura.to_html(
                        config=config,
                        include_plotlyjs='cdn',  # Use CDN for plotly.js (faster loading)
                        full_html=False  # Don't include HTML boilerplate
                    )

                    # Create HTML with both the description and the figure
                    html_content = f"""<!DOCTYPE html>
<html>
<head>
    <title>{filename}</title>
    <style>
        body {{ font-family: Arial, sans-serif; margin: 20px; }}
        .description {{ margin-bottom: 20px; }}
        .assessment-header {{ background-color: #f0f0f0; padding: 10px; margin-bottom: 20px; border-left: 5px solid #007bff; }}
    </style>
</head>
<body>
    <div class="description">
        {desc_html}
    </div>
    <div class="plotly-figure">
        {figure_html}
    </div>
</body>
</html>"""

                    # For public copies, add an assessment header
                    public_html_content = f"""<!DOCTYPE html>
<html>
<head>
    <title>{filename} - Teacher Assessment</title>
    <style>
        body {{ font-family: Arial, sans-serif; margin: 20px; }}
        .description {{ margin-bottom: 20px; }}
        .assessment-header {{ background-color: #f0f0f0; padding: 10px; margin-bottom: 20px; border-left: 5px solid #007bff; }}
    </style>
</head>
<body>
    <div class="assessment-header">
        <h2>CanonFodder Visualization - Teacher Assessment Copy</h2>
        <p>This visualization was generated for teacher assessment purposes.</p>
        <p>Filename: {filename}</p>
        <p>Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
    </div>
    <div class="description">
        {desc_html}
    </div>
    <div class="plotly-figure">
        {figure_html}
    </div>
</body>
</html>"""

                    with open(html_path, 'w', encoding='utf-8') as f:
                        f.write(html_content)

                    print(f"Visualization with description saved to {html_path}")

                    # If public flag is set, also save to public directory
                    if public:
                        public_html_path = public_filepath.with_suffix('.html')
                        with open(public_html_path, 'w', encoding='utf-8') as f:
                            f.write(public_html_content)
                        print(f"Public copy for teacher assessment saved to {public_html_path}")
                except ImportError:
                    print("Showdown not available. Using standard HTML output.")
                    # Fall back to standard HTML output
                    figura.write_html(
                        html_path,
                        config=config,
                        include_plotlyjs='cdn',
                        full_html=True
                    )
                    print(f"Visualization saved to {html_path}")

                    # If public flag is set, also save to public directory
                    if public:
                        public_html_path = public_filepath.with_suffix('.html')
                        figura.write_html(
                            public_html_path,
                            config=config,
                            include_plotlyjs='cdn',
                            full_html=True
                        )
                        print(f"Public copy for teacher assessment saved to {public_html_path}")
            else:
                # Save directly to HTML without opening browser
                figura.write_html(
                    html_path,
                    config=config,
                    include_plotlyjs='cdn',  # Use CDN for plotly.js (faster loading)
                    full_html=True
                )
                print(f"Visualization saved to {html_path}")

                # If public flag is set, also save to public directory
                if public:
                    public_html_path = public_filepath.with_suffix('.html')
                    figura.write_html(
                        public_html_path,
                        config=config,
                        include_plotlyjs='cdn',
                        full_html=True
                    )
                    print(f"Public copy for teacher assessment saved to {public_html_path}")
            return
        except Exception as html_error:
            print(f"Error saving HTML: {html_error}")
            # Fall through to other methods if HTML saving fails

    # Performance optimization for interactive mode
    try:
        # Configure a more efficient renderer
        pio.renderers.default = "browser"

        # Use a static configuration to reduce rendering complexity
        config = {
            'displayModeBar': False,  # Hide the modebar
            'responsive': True,  # Make the plot responsive
            'staticPlot': True,  # Make the plot static (no interactivity)
            'scrollZoom': False,  # Disable scroll zoom for better performance
            'showTips': False,  # Disable tips for better performance
        }

        # Progressive rendering approach:
        # 1. First create a simplified version of the figure for quick display
        print("Creating simplified version for quick display...")
        html_path = filepath.with_suffix('.html')

        # Create a simplified version of the figure
        simplified_fig = go.Figure()

        # Copy the layout from the original figure
        if hasattr(figura, 'layout'):
            for attr in dir(figura.layout):
                if not attr.startswith('_') and attr != 'template' and hasattr(simplified_fig.layout, attr):
                    try:
                        setattr(simplified_fig.layout, attr, getattr(figura.layout, attr))
                    except:
                        pass

        # Add a loading message
        simplified_fig.add_annotation(
            text="Loading visualization...",
            x=0.5, y=0.5,
            xref="paper", yref="paper",
            showarrow=False,
            font=dict(size=20)
        )

        # Write the simplified figure to HTML with auto-refresh
        with open(html_path, 'w') as f:
            f.write(f"""
            <!DOCTYPE html>
            <html>
            <head>
                <meta charset="UTF-8">
                <meta http-equiv="refresh" content="2">
                <title>Loading Visualization...</title>
                <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
            </head>
            <body>
                <div id="plot" style="width:100%;height:100vh;"></div>
                <script>
                    var data = {simplified_fig.to_json()};
                    Plotly.newPlot('plot', data.data, data.layout, {config});
                </script>
                <div style="text-align:center;margin-top:10px;">
                    <p>Preparing visualization, please wait...</p>
                </div>
            </body>
            </html>
            """)

        # Open the simplified version in a browser
        webbrowser.open(str(html_path))

        # 2. In a separate thread, create and save the full version
        def save_full_version():
            try:
                print("Generating full visualization...")
                # If a description is provided, include it in the HTML
                if description:
                    try:
                        from helpers.markdown import render_markdown
                        desc_html = render_markdown(description)
                        figure_html = figura.to_html(
                            config=config,
                            include_plotlyjs='cdn',
                            full_html=False
                        )
                        html_content = f"""<!DOCTYPE html>
                        <html>
                        <head>
                            <title>{filename}</title>
                            <style>
                                body {{ font-family: Arial, sans-serif; margin: 20px; }}
                                .description {{ margin-bottom: 20px; }}
                            </style>
                        </head>
                        <body>
                            <div class="description">
                                {desc_html}
                            </div>
                            <div class="plotly-figure">
                                {figure_html}
                            </div>
                        </body>
                        </html>"""
                        with open(html_path, 'w', encoding='utf-8') as f:
                            f.write(html_content)
                    except ImportError:
                        # Fall back to standard HTML if Showdown is not available
                        figura.write_html(
                            html_path,
                            config=config,
                            include_plotlyjs='cdn',
                            full_html=True
                        )
                else:
                    # Standard HTML without description
                    figura.write_html(
                        html_path,
                        config=config,
                        include_plotlyjs='cdn',
                        full_html=True
                    )

                # If public flag is set, also save to public directory
                if public:
                    public_html_path = public_filepath.with_suffix('.html')
                    if description:
                        try:
                            from helpers.markdown import render_markdown
                            desc_html = render_markdown(description)
                            figure_html = figura.to_html(
                                config=config,
                                include_plotlyjs='cdn',
                                full_html=False
                            )
                            public_html_content = f"""<!DOCTYPE html>
                            <html>
                            <head>
                                <title>{filename} - Teacher Assessment</title>
                                <style>
                                    body {{ font-family: Arial, sans-serif; margin: 20px; }}
                                    .description {{ margin-bottom: 20px; }}
                                    .assessment-header {{ background-color: #f0f0f0; padding: 10px; margin-bottom: 20px; border-left: 5px solid #007bff; }}
                                </style>
                            </head>
                            <body>
                                <div class="assessment-header">
                                    <h2>CanonFodder Visualization - Teacher Assessment Copy</h2>
                                    <p>This visualization was generated for teacher assessment purposes.</p>
                                    <p>Filename: {filename}</p>
                                    <p>Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
                                </div>
                                <div class="description">
                                    {desc_html}
                                </div>
                                <div class="plotly-figure">
                                    {figure_html}
                                </div>
                            </body>
                            </html>"""
                            with open(public_html_path, 'w', encoding='utf-8') as f:
                                f.write(public_html_content)
                        except ImportError:
                            # Fall back to standard HTML if Showdown is not available
                            figura.write_html(
                                public_html_path,
                                config=config,
                                include_plotlyjs='cdn',
                                full_html=True
                            )
                    else:
                        # Standard HTML without description
                        figura.write_html(
                            public_html_path,
                            config=config,
                            include_plotlyjs='cdn',
                            full_html=True
                        )

                print(f"Full visualization saved to {html_path}")

                # If public flag is set, also save to public directory
                if public:
                    public_html_path = public_filepath.with_suffix('.html')
                    figura.write_html(
                        public_html_path,
                        config=config,
                        include_plotlyjs='cdn',
                        full_html=True
                    )
                    print(f"Public copy for teacher assessment saved to {public_html_path}")

                # The browser will auto-refresh to show the full version

            except Exception as e:
                print(f"Error saving full visualization: {e}")

        # Start the thread to save the full version
        threading.Thread(target=save_full_version).start()

    except Exception as render_error:
        print(f"Error rendering visualization: {render_error}")
        print("Attempting to save as static image instead...")
        try:
            # Try to save as a static image as a fallback
            img_path = filepath.with_suffix('.png')
            figura.write_image(str(img_path), scale=0.5)  # Lower scale for better performance
            print(f"Saved as static image: {img_path}")

            # Open the image
            webbrowser.open(str(img_path))
        except Exception as e3:
            print(f"All rendering methods failed: {e3}")


In [None]:
def find_project_root():
    """Find the project root by looking for JSON and PQ directories."""
    if "__file__" in globals():
        # Try the standard approach first
        candidate = Path(__file__).resolve().parents[1]
        if (candidate / "JSON").exists() and (candidate / "PQ").exists():
            return candidate
    # If that fails, try the current directory and its parent
    current = Path.cwd()
    if (current / "JSON").exists() and (current / "PQ").exists():
        return current
    if (current.parent / "JSON").exists() and (current.parent / "PQ").exists():
        return current.parent
    # If all else fails, use an absolute path
    return Path(r"C:\Users\jurda\PycharmProjects\CanonFodder")


PROJECT_ROOT = find_project_root()
JSON_DIR = PROJECT_ROOT / "JSON"
PQ_DIR = PROJECT_ROOT / "PQ"
UC_PARQUET = PQ_DIR / "uc.parquet"
AC_PARQUET = PQ_DIR / "artist_info.parquet"
AC_COLS = ["artist_name", "country", "mbid", "disambiguation_comment"]
PALETTES_FILE = JSON_DIR / "palettes.json"
SEPARATOR = "{"
pd.options.display.max_columns = None
pd.options.display.max_rows = None
pd.set_option("display.width", 200)
pd.set_option("display.max_colwidth", 100)
pd.options.display.float_format = "{: .2f}".format
with PALETTES_FILE.open("r", encoding="utf-8") as fh:
    custom_palettes = json.load(fh)["palettes"]
custom_colors = io.register_custom_palette("colorpalette_5", custom_palettes)
sns.set_style("whitegrid")
sns.set_palette(sns.color_palette(custom_colors))


In [None]:
# =============================================================================
# Artist‑country helpers (parquet → DB → MusicBrainz)
# =============================================================================
def _append_to_parquet(path: Path, adatk: pd.DataFrame) -> None:
    if path.exists():
        old = pd.read_parquet(path)
        adatk = (pd.concat([old, adatk])
                 .drop_duplicates(subset="artist_name", keep="last"))
    adatk.to_parquet(path, compression="zstd", index=False)


def _country_for_series(series: pd.Series, cache: dict[str, str | None]) -> pd.Series:
    missing = [a for a in series.unique() if a and a not in cache]  # Skip empty artist names
    if missing:
        mb_results = {a: mbAPI.fetch_country(a) for a in missing}
        cache.update(mb_results)
        if mb_results:
            new_df = (pd.DataFrame
                      .from_dict(mb_results, orient="index", columns=["country"])
                      .assign(artist_name=lambda d: d.index))
            _append_to_parquet(AC_PARQUET, new_df)
    return series.map(cache)


def _df_from_db() -> pd.DataFrame:
    """Pull the entire ArtistInfo table into a DataFrame."""
    with SessionLocal() as sessio:
        rows = sessio.scalars(select(ArtistInfo)).all()
    if not rows:
        return pd.DataFrame(columns=AC_COLS)
    return pd.DataFrame(
        [
            {
                "artist_name": r.artist_name,
                "country": r.country,
                "mbid": r.mbid,
                "disambiguation_comment": r.disambiguation_comment,
            }
            for r in rows
        ],
        columns=AC_COLS
    )


In [None]:
# Syncing Parquet with DB
df = _df_from_db()
# Check if the dataframe is unusually large
row_count = len(df)
if row_count > 1000000:
    print(f"WARNING: Artist info contains {row_count} rows, which is unusually large.")
    print("This may indicate duplicate or unnecessary entries in the database.")
    print("Consider running the clean_artist_info_table() function to remove duplicates and orphaned entries.")
    print("Example usage: cleaned, remaining = clean_artist_info_table()")

# Save to parquet with row count in the log message
df.to_parquet(AC_PARQUET, index=False, compression="zstd")
print(f"Saved {row_count} artist records to {AC_PARQUET}")

# Dump scrobble data
io.dump_parquet()


In [None]:
# -------------------------------------------------------------------------------------
#   Step 1: Load scrobbles parquet & deduplicate
# -------------------------------------------------------------------------------------
print("=" * 90)
print("Welcome to the CanonFodder data profiling workflow!")
print("We'll load your scrobble data, apply any previously saved artist name unifications,")
print("then explore on forward.")
print("=" * 90)
data, latest_filename = io.latest_parquet(return_df=True)
if data is None or data.empty:
    sys.exit("🚫  No scrobble data found – aborting EDA.")


In [None]:
# Create a mapping from original column names to the ones we want to use for analysis
column_mapping = {
    "artist_name": "Artist",
    "track_title": "Song",
    "album_title": "Album",
    "play_time": "Datetime"
}

# Rename columns if they exist in the dataframe
for old_col, new_col in column_mapping.items():
    if old_col in data.columns:
        data = data.rename(columns={old_col: new_col})

# Ensure we have the expected columns
expected_columns = ["Artist", "Song", "Album", "Datetime"]
missing_columns = [col for col in expected_columns if col not in data.columns]
if missing_columns:
    print(f"Warning: Missing expected columns: {missing_columns}")
    print(f"Available columns: {data.columns.tolist()}")

# Ensure Datetime is in datetime format
if "Datetime" in data.columns and not pd.api.types.is_datetime64_any_dtype(data["Datetime"]):
    print("Converting Datetime column to datetime format...")
    data["Datetime"] = pd.to_datetime(data["Datetime"], errors="coerce")

# Drop duplicates
before_dedup = len(data)
data = data.drop_duplicates()
after_dedup = len(data)
if before_dedup > after_dedup:
    print(f"Dropped {before_dedup - after_dedup} duplicate rows")

# Basic info
print(f"Loaded {len(data):,} scrobbles from {latest_filename}")
print(f"Date range: {data['Datetime'].min()} to {data['Datetime'].max()}")
print(f"Unique artists: {data['Artist'].nunique():,}")
print(f"Unique songs: {data['Song'].nunique():,}")
print(f"Unique albums: {data['Album'].nunique():,}")


In [None]:
# -------------------------------------------------------------------------------------
#   Step 2: Apply artist name canonization
# -------------------------------------------------------------------------------------
# Check if we have any artist name canonization data
with SessionLocal() as session:
    canon_count = session.query(ArtistVariantsCanonized).count()

if canon_count > 0:
    print(f"\nFound {canon_count} artist name canonization records")
    # Load canonization data
    with SessionLocal() as session:
        canon_rows = session.query(ArtistVariantsCanonized).all()
        canon_dict = {}
        for row in canon_rows:
            variants = row.artist_variants_text.split(SEPARATOR)
            for variant in variants:
                if variant.strip():  # Skip empty strings
                    canon_dict[variant.strip()] = row.canonical_name
    
    # Apply canonization
    before_canon = data['Artist'].nunique()
    data['Artist'] = data['Artist'].map(lambda x: canon_dict.get(x, x))
    after_canon = data['Artist'].nunique()
    
    print(f"Applied artist name canonization: {before_canon} → {after_canon} unique artists")
    print(f"Unified {before_canon - after_canon} artist name variants")
else:
    print("\nNo artist name canonization data found")


In [None]:
# -------------------------------------------------------------------------------------
#   Step 3: Basic time series analysis
# -------------------------------------------------------------------------------------
# Set the Datetime column as the index for time series analysis
data_ts = data.set_index('Datetime')

# Resample by day and count scrobbles
daily_scrobbles = data_ts.resample('D').size()

# Plot daily scrobbles
plt.figure(figsize=(15, 6))
daily_scrobbles.plot()
plt.title('Daily Scrobble Count', fontsize=16)
plt.xlabel('Date', fontsize=14)
plt.ylabel('Scrobble Count', fontsize=14)
plt.grid(True, alpha=0.3)
plt.tight_layout()
show_or_save_plot('daily_scrobbles.png')

# Monthly scrobbles
monthly_scrobbles = data_ts.resample('M').size()
plt.figure(figsize=(15, 6))
monthly_scrobbles.plot(kind='bar')
plt.title('Monthly Scrobble Count', fontsize=16)
plt.xlabel('Month', fontsize=14)
plt.ylabel('Scrobble Count', fontsize=14)
plt.xticks(rotation=45)
plt.grid(True, axis='y', alpha=0.3)
plt.tight_layout()
show_or_save_plot('monthly_scrobbles.png')

# Scrobbles by day of week
data_ts['day_of_week'] = data_ts.index.day_name()
day_order = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']
day_of_week = data_ts['day_of_week'].value_counts().reindex(day_order)

plt.figure(figsize=(12, 6))
sns.barplot(x=day_of_week.index, y=day_of_week.values)
plt.title('Scrobbles by Day of Week', fontsize=16)
plt.xlabel('Day of Week', fontsize=14)
plt.ylabel('Scrobble Count', fontsize=14)
plt.grid(True, axis='y', alpha=0.3)
plt.tight_layout()
show_or_save_plot('day_of_week_scrobbles.png')

# Scrobbles by hour of day
data_ts['hour'] = data_ts.index.hour
hour_counts = data_ts['hour'].value_counts().sort_index()

plt.figure(figsize=(15, 6))
sns.barplot(x=hour_counts.index, y=hour_counts.values)
plt.title('Scrobbles by Hour of Day', fontsize=16)
plt.xlabel('Hour of Day', fontsize=14)
plt.ylabel('Scrobble Count', fontsize=14)
plt.xticks(range(0, 24))
plt.grid(True, axis='y', alpha=0.3)
plt.tight_layout()
show_or_save_plot('hour_of_day_scrobbles.png')


In [None]:
# -------------------------------------------------------------------------------------
#   Step 4: Top artists analysis
# -------------------------------------------------------------------------------------
# Get top artists
top_artists = data['Artist'].value_counts().head(20)

# Plot top artists
plt.figure(figsize=(12, 8))
sns.barplot(x=top_artists.values, y=top_artists.index)
plt.title('Top 20 Artists by Scrobble Count', fontsize=16)
plt.xlabel('Scrobble Count', fontsize=14)
plt.ylabel('Artist', fontsize=14)
plt.grid(True, axis='x', alpha=0.3)
plt.tight_layout()
show_or_save_plot('top_artists.png')

# Top artists by year
data_ts['year'] = data_ts.index.year
yearly_top_artists = data_ts.groupby(['year', 'Artist']).size().reset_index(name='count')

# Get top 5 artists for each year
yearly_top_5 = yearly_top_artists.groupby('year').apply(
    lambda x: x.nlargest(5, 'count')
).reset_index(drop=True)

# Pivot the data for plotting
pivot_data = yearly_top_5.pivot_table(
    index='year', 
    columns='Artist', 
    values='count', 
    aggfunc='sum'
).fillna(0)

# Plot yearly top artists
plt.figure(figsize=(15, 8))
pivot_data.plot(kind='bar', stacked=True, figsize=(15, 8))
plt.title('Top 5 Artists by Year', fontsize=16)
plt.xlabel('Year', fontsize=14)
plt.ylabel('Scrobble Count', fontsize=14)
plt.legend(title='Artist', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(True, axis='y', alpha=0.3)
plt.tight_layout()
show_or_save_plot('yearly_top_artists.png')

# Create a table of top artists by year
yearly_top_table = yearly_top_5.pivot_table(
    index='year',
    columns=yearly_top_5.groupby('year')['count'].rank(method='first', ascending=False).astype(int),
    values='Artist',
    aggfunc='first'
)
yearly_top_table.columns = [f'#{i}' for i in yearly_top_table.columns]

# Plot the table as a heatmap
plt.figure(figsize=(15, len(yearly_top_table) * 0.5))
sns.heatmap(yearly_top_table.notnull(), cmap='Blues', cbar=False, 
            linewidths=1, linecolor='white')

# Add text annotations
for i, year in enumerate(yearly_top_table.index):
    for j, rank in enumerate(yearly_top_table.columns):
        artist = yearly_top_table.loc[year, rank]
        if pd.notnull(artist):
            plt.text(j + 0.5, i + 0.5, artist, 
                     ha='center', va='center', fontsize=10,
                     color='black')

plt.title('Top Artists by Year', fontsize=16)
plt.xlabel('Rank', fontsize=14)
plt.ylabel('Year', fontsize=14)
plt.tight_layout()
show_or_save_plot('yearly_top_artists_table.png')


In [None]:
# -------------------------------------------------------------------------------------
#   Step 5: Top songs and albums analysis
# -------------------------------------------------------------------------------------
# Get top songs
top_songs = data.groupby(['Artist', 'Song']).size().reset_index(name='count')
top_songs = top_songs.sort_values('count', ascending=False).head(20)
top_songs['Artist_Song'] = top_songs['Artist'] + ' - ' + top_songs['Song']

# Plot top songs
plt.figure(figsize=(12, 8))
sns.barplot(x=top_songs['count'], y=top_songs['Artist_Song'])
plt.title('Top 20 Songs by Scrobble Count', fontsize=16)
plt.xlabel('Scrobble Count', fontsize=14)
plt.ylabel('Song', fontsize=14)
plt.grid(True, axis='x', alpha=0.3)
plt.tight_layout()
show_or_save_plot('top_songs.png')

# Get top albums
top_albums = data.groupby(['Artist', 'Album']).size().reset_index(name='count')
top_albums = top_albums.sort_values('count', ascending=False).head(20)
top_albums['Artist_Album'] = top_albums['Artist'] + ' - ' + top_albums['Album']

# Plot top albums
plt.figure(figsize=(12, 8))
sns.barplot(x=top_albums['count'], y=top_albums['Artist_Album'])
plt.title('Top 20 Albums by Scrobble Count', fontsize=16)
plt.xlabel('Scrobble Count', fontsize=14)
plt.ylabel('Album', fontsize=14)
plt.grid(True, axis='x', alpha=0.3)
plt.tight_layout()
show_or_save_plot('top_albums.png')


In [None]:
# -------------------------------------------------------------------------------------
#   Step 6: Artist country analysis
# -------------------------------------------------------------------------------------
# Load artist country data from the ArtistInfo table
artist_info = _df_from_db()

# Merge with scrobble data
artist_country_map = dict(zip(artist_info['artist_name'], artist_info['country']))
data['ArtistCountry'] = data['Artist'].map(artist_country_map)

# Count scrobbles by country
country_counts = data['ArtistCountry'].value_counts().reset_index()
country_counts.columns = ['country', 'count']
country_counts = country_counts[country_counts['country'].notna()]
top_countries = country_counts.head(15)

# Plot top countries
plt.figure(figsize=(12, 8))
sns.barplot(x=top_countries['count'], y=top_countries['country'])
plt.title('Top 15 Countries by Artist Scrobble Count', fontsize=16)
plt.xlabel('Scrobble Count', fontsize=14)
plt.ylabel('Country', fontsize=14)
plt.grid(True, axis='x', alpha=0.3)
plt.tight_layout()
show_or_save_plot('top_countries.png')

# Get top artists by country
top_artists_by_country = data.groupby(['ArtistCountry', 'Artist']).size().reset_index(name='count')
top_artists_by_country = top_artists_by_country[top_artists_by_country['ArtistCountry'].notna()]

# Get top 3 artists for each of the top 10 countries
top_10_countries = country_counts.head(10)['country'].tolist()
top_artists_by_top_countries = top_artists_by_country[
    top_artists_by_country['ArtistCountry'].isin(top_10_countries)
]

top_3_by_country = top_artists_by_top_countries.groupby('ArtistCountry').apply(
    lambda x: x.nlargest(3, 'count')
).reset_index(drop=True)

# Plot top artists by country
plt.figure(figsize=(15, 10))
g = sns.catplot(
    data=top_3_by_country,
    kind="bar",
    x="ArtistCountry", y="count", hue="Artist",
    height=8, aspect=1.5
)
g.set_xticklabels(rotation=45, ha="right")
g.fig.suptitle('Top 3 Artists by Country', fontsize=16, y=1.02)
g.set_axis_labels("Country", "Scrobble Count", fontsize=14)
plt.tight_layout()
plt.savefig(PROJECT_ROOT / "pics" / "top_artists_by_country.png", dpi=100, bbox_inches="tight")


In [None]:
# -------------------------------------------------------------------------------------
#   Step 7: Listening patterns analysis
# -------------------------------------------------------------------------------------
# Create a heatmap of listening patterns by hour and day of week
data_ts['day_of_week'] = data_ts.index.day_name()
data_ts['hour'] = data_ts.index.hour

# Create a pivot table for the heatmap
hour_day_pivot = data_ts.pivot_table(
    index='day_of_week',
    columns='hour',
    values='Artist',
    aggfunc='count'
)

# Reorder days of week
hour_day_pivot = hour_day_pivot.reindex(day_order)

# Plot heatmap
plt.figure(figsize=(15, 8))
sns.heatmap(hour_day_pivot, cmap='YlGnBu', annot=True, fmt='g')
plt.title('Listening Patterns by Hour and Day of Week', fontsize=16)
plt.xlabel('Hour of Day', fontsize=14)
plt.ylabel('Day of Week', fontsize=14)
plt.tight_layout()
show_or_save_plot('listening_patterns_heatmap.png')

# Seasonal patterns (by month)
data_ts['month'] = data_ts.index.month_name()
month_order = [
    'January', 'February', 'March', 'April', 'May', 'June',
    'July', 'August', 'September', 'October', 'November', 'December'
]
month_counts = data_ts['month'].value_counts().reindex(month_order)

plt.figure(figsize=(15, 6))
sns.barplot(x=month_counts.index, y=month_counts.values)
plt.title('Scrobbles by Month', fontsize=16)
plt.xlabel('Month', fontsize=14)
plt.ylabel('Scrobble Count', fontsize=14)
plt.xticks(rotation=45)
plt.grid(True, axis='y', alpha=0.3)
plt.tight_layout()
show_or_save_plot('monthly_listening_patterns.png')


In [None]:
# -------------------------------------------------------------------------------------
#   Step 8: Artist diversity analysis
# -------------------------------------------------------------------------------------
# Calculate artist diversity over time
data_ts['year_month'] = data_ts.index.to_period('M')
monthly_unique_artists = data_ts.groupby('year_month')['Artist'].nunique()
monthly_total_scrobbles = data_ts.groupby('year_month').size()
monthly_diversity = (monthly_unique_artists / monthly_total_scrobbles * 100).fillna(0)

# Plot artist diversity
fig, ax1 = plt.subplots(figsize=(15, 6))

# Plot unique artists
color = 'tab:blue'
ax1.set_xlabel('Month', fontsize=14)
ax1.set_ylabel('Unique Artists', color=color, fontsize=14)
ax1.plot(monthly_unique_artists.index.astype(str), monthly_unique_artists.values, color=color)
ax1.tick_params(axis='y', labelcolor=color)
ax1.grid(True, alpha=0.3)

# Create a second y-axis for diversity percentage
ax2 = ax1.twinx()
color = 'tab:red'
ax2.set_ylabel('Diversity (%)', color=color, fontsize=14)
ax2.plot(monthly_diversity.index.astype(str), monthly_diversity.values, color=color, linestyle='--')
ax2.tick_params(axis='y', labelcolor=color)

# Set x-axis ticks
plt.xticks(rotation=45)
plt.title('Monthly Artist Diversity', fontsize=16)
fig.tight_layout()
plt.savefig(PROJECT_ROOT / "pics" / "artist_diversity.png", dpi=100, bbox_inches="tight")

# Calculate cumulative unique artists over time
data_ts = data_ts.sort_index()  # Ensure chronological order
data_ts['cumulative_unique_artists'] = data_ts['Artist'].expanding().nunique()

# Resample to get monthly values
monthly_cumulative = data_ts.resample('M')['cumulative_unique_artists'].max()

# Plot cumulative unique artists
plt.figure(figsize=(15, 6))
plt.plot(monthly_cumulative.index, monthly_cumulative.values)
plt.title('Cumulative Unique Artists Over Time', fontsize=16)
plt.xlabel('Date', fontsize=14)
plt.ylabel('Cumulative Unique Artists', fontsize=14)
plt.grid(True, alpha=0.3)
plt.tight_layout()
show_or_save_plot('cumulative_unique_artists.png')


In [None]:
# -------------------------------------------------------------------------------------
#   Step 9: Country population vs. scrobble count analysis
# -------------------------------------------------------------------------------------
# Get country scrobble counts
country_scrobbles = data['ArtistCountry'].value_counts().reset_index()
country_scrobbles.columns = ['country_code', 'scrobble_count']
country_scrobbles = country_scrobbles[country_scrobbles['country_code'].notna()]

# Get country populations using pypopulation
country_pop = {}
country_names = {}
for country in country_scrobbles['country_code']:
    try:
        # Get population
        pop = pypopulation.get_population(country)
        if pop is not None:
            country_pop[country] = pop
            # Get country name
            country_names[country] = pypopulation.get_country_name(country)
    except Exception as e:
        print(f"Error getting population for {country}: {e}")

# Create a dataframe with country, population, and scrobble count
country_pop_df = pd.DataFrame({
    'country_code': list(country_pop.keys()),
    'population': list(country_pop.values()),
    'country_name': [country_names.get(c, c) for c in country_pop.keys()]
})

# Merge with scrobble counts
country_pop_df = country_pop_df.merge(country_scrobbles, on='country_code')

# Calculate scrobbles per million population
country_pop_df['scrobbles_per_million'] = (country_pop_df['scrobble_count'] / country_pop_df['population']) * 1000000

# Create a custom colormap based on scrobbles per million
min_per_million = country_pop_df['scrobbles_per_million'].min()
max_per_million = country_pop_df['scrobbles_per_million'].max()
norm = mcolors.Normalize(vmin=min_per_million, vmax=max_per_million)
custom_cmap = plt.cm.ScalarMappable(norm=norm, cmap='viridis')

# Plot population vs. scrobble count
plt.figure(figsize=(12, 8))
scatter = plt.scatter(
    country_pop_df['population'],
    country_pop_df['scrobble_count'],
    alpha=0.7,
    s=100,
    c=country_pop_df['scrobbles_per_million'],
    cmap='viridis'
)

# Add a colorbar
cbar = plt.colorbar(scatter)
cbar.set_label('Scrobbles per Million Population', fontsize=12)

# Add country labels for top countries by scrobble count and per capita
top_n = 10  # Top N countries by scrobble count
top_per_capita = 5  # Top N countries by scrobbles per million

# Add labels for top countries by scrobble count
for _, row in country_pop_df.sort_values('scrobble_count', ascending=False).head(top_n).iterrows():
    country_label = row['country_code']
    if 'country_name' in row and not pd.isna(row['country_name']):
        country_label = row['country_name']
    plt.annotate(
        country_label,
        (row['population'], row['scrobble_count']),
        xytext=(5, 5),  # Offset text by 5 points
        textcoords="offset points",
        fontsize=9,
        alpha=0.8
    )

# Add labels for top countries by scrobbles per million
for _, row in country_pop_df.sort_values('scrobbles_per_million', ascending=False).head(top_per_capita).iterrows():
    # Skip if already labeled in top N by count
    if row['country_code'] in country_pop_df.sort_values('scrobble_count', ascending=False).head(top_n)[
        'country_code'].values:
        continue

    country_label = row['country_code']
    if 'country_name' in row and not pd.isna(row['country_name']):
        country_label = row['country_name']
    plt.annotate(
        country_label,
        (row['population'], row['scrobble_count']),
        xytext=(5, -10),  # Offset downward to avoid overlap with other labels
        textcoords="offset points",
        fontsize=9,
        alpha=0.8,
        color="darkgreen"  # Different color to distinguish from top by count
    )
# Add labels for countries with large populations (which might be interesting data points)
top_population = 5
for _, row in country_pop_df.sort_values("population", ascending=False).head(top_population).iterrows():
    # Skip if already labeled in previous groups
    if (row["country_code"] in country_pop_df.sort_values("scrobble_count", ascending=False).head(top_n)[
        "country_code"].values or
            row["country_code"] in
            country_pop_df.sort_values("scrobbles_per_million", ascending=False).head(top_per_capita)[
                "country_code"].values):
        continue

    country_label = row["country_code"]
    if "country_name" in row and not pd.isna(row["country_name"]):
        country_label = row["country_name"]
    plt.annotate(
        country_label,
        (row["population"], row["scrobble_count"]),
        xytext=(-15, 10),  # Offset to the left and up to avoid overlap
        textcoords="offset points",
        fontsize=9,
        alpha=0.8,
        color="navy"  # Different color for population-based labels
    )

# Add a trend line
z = np.polyfit(country_pop_df["population"], country_pop_df["scrobble_count"], 1)
p = np.poly1d(z)
plt.plot(
    country_pop_df["population"],
    p(country_pop_df["population"]),
    "r--",
    alpha=0.7,
    label=f"Trend line (y = {z[0]:.2e}x + {z[1]:.2f})"
)
legend_elements = [Line2D([0], [0], marker='o', color='w', markerfacecolor='black', markersize=8,
                          label=f'Top {top_n} by scrobble count'),
                   Line2D([0], [0], marker='o', color='w', markerfacecolor='darkgreen', markersize=8,
                          label=f'Top {top_per_capita} by scrobbles per million'),
                   Line2D([0], [0], marker='o', color='w', markerfacecolor='navy', markersize=8,
                          label=f'Top {top_population} by population'),
                   Line2D([0], [0], color='r', linestyle='--', label=f"Trend line (y = {z[0]:.2e}x + {z[1]:.2f})")]

# Calculate correlation coefficient
corr = country_pop_df["population"].corr(country_pop_df["scrobble_count"])
plt.title(f"Country population vs. scrobble count (Correlation: {corr:.2f})", fontsize=16)
plt.xlabel("Population", fontsize=14)
plt.ylabel("Scrobble count", fontsize=14)
plt.grid(True, linestyle="--", alpha=0.7)
plt.legend(handles=legend_elements, loc='upper left', fontsize=9)
plt.tight_layout()
show_or_save_plot("population_vs_scrobbles.png")

# Create a second plot with log scales for better visualization
plt.figure(figsize=(12, 8))

# Use the same custom palette from JSON/palettes.json
scatter = plt.scatter(
    country_pop_df["population"],
    country_pop_df["scrobble_count"],
    alpha=0.7,
    s=100,
    c=country_pop_df["scrobbles_per_million"],
    cmap=custom_cmap  # Reuse the custom colormap from the first plot
)

# Add a colorbar
cbar = plt.colorbar(scatter)
cbar.set_label('Scrobbles per Million Population', fontsize=12)

# Add country labels for top countries by scrobble count
for _, row in country_pop_df.sort_values('scrobble_count', ascending=False).head(top_n).iterrows():
    country_label = row['country_code']
    if 'country_name' in row and not pd.isna(row['country_name']):
        country_label = row['country_name']
    plt.annotate(
        country_label,
        (row['population'], row['scrobble_count']),
        xytext=(5, 5),  # Offset text by 5 points
        textcoords="offset points",
        fontsize=9,
        alpha=0.8
    )

# Add labels for top countries by scrobbles per million
for _, row in country_pop_df.sort_values('scrobbles_per_million', ascending=False).head(top_per_capita).iterrows():
    # Skip if already labeled in top N by count
    if row['country_code'] in country_pop_df.sort_values('scrobble_count', ascending=False).head(top_n)[
        'country_code'].values:
        continue

    country_label = row['country_code']
    if 'country_name' in row and not pd.isna(row['country_name']):
        country_label = row['country_name']
    plt.annotate(
        country_label,
        (row['population'], row['scrobble_count']),
        xytext=(5, -10),  # Offset downward to avoid overlap with other labels
        textcoords="offset points",
        fontsize=9,
        alpha=0.8,
        color="darkgreen"  # Different color to distinguish from top by count
    )

# Add labels for countries with large populations (which might be interesting data points)
for _, row in country_pop_df.sort_values("population", ascending=False).head(top_population).iterrows():
    # Skip if already labeled in previous groups
    if (row["country_code"] in country_pop_df.sort_values("scrobble_count", ascending=False).head(top_n)[
        "country_code"].values or
            row["country_code"] in
            country_pop_df.sort_values("scrobbles_per_million", ascending=False).head(top_per_capita)[
                "country_code"].values):
        continue

    country_label = row["country_code"]
    if "country_name" in row and not pd.isna(row["country_name"]):
        country_label = row["country_name"]
    plt.annotate(
        country_label,
        (row["population"], row["scrobble_count"]),
        xytext=(-15, 10),  # Offset to the left and up to avoid overlap
        textcoords="offset points",
        fontsize=9,
        alpha=0.8,
        color="navy"  # Different color for population-based labels
    )

# Set log scales for both axes
plt.xscale("log")
plt.yscale("log")

# Add a legend for the different label categories with more space for readability
legend_elements = [
    Line2D([0], [0], marker='o', color='w', markerfacecolor='black', markersize=8,
           label=f'Top {top_n} by scrobble count'),
    Line2D([0], [0], marker='o', color='w', markerfacecolor='darkgreen', markersize=8,
           label=f'Top {top_per_capita} by scrobbles per million'),
    Line2D([0], [0], marker='o', color='w', markerfacecolor='navy', markersize=8,
           label=f'Top {top_population} by population')
]

plt.title(f"Iceland boasts an impressive scrobble count per population", fontsize=16)
plt.xlabel("Population (log scale)", fontsize=14)
plt.ylabel("Scrobble count (log scale)", fontsize=14)
plt.grid(True, linestyle="--", alpha=0.7)

# Add the legend with improved readability
# Use bbox_to_anchor to position the legend outside the plot area
# Increase fontsize and add padding with borderpad
# Legend removed as per requirements
plt.tight_layout()
show_or_save_plot("population_vs_scrobbles_log.png")


In [None]:
# -------------------------------------------------------------------------------------
#   Step 10: User-country analytics
# -------------------------------------------------------------------------------------
if UC_PARQUET.exists():
    use = cli.choose_timeline()
    if use == "e":
        uc_df = cli.edit_country_timeline()
    elif use == "n":
        UC_PARQUET.unlink(missing_ok=True)
        uc_df = cli.edit_country_timeline()
    else:  # "y"
        uc_df = pd.read_parquet(UC_PARQUET)
else:
    uc_df = cli.edit_country_timeline()
data["UserCountry"] = stats.assign_user_country(data, uc_df)
user_country_count = data.UserCountry.value_counts().sort_values(ascending=False).to_frame()[:10]
user_country_count = user_country_count.rename(columns={"UserCountry": "count"})

# Load custom palette for better visualization of skewed data
with open("JSON/palettes.json", 'r', encoding='utf-8') as f:
    custom_palettes = json.load(f)["palettes"]
# Use caribbean_current_shades_13_d2l palette which orders colors from dark to light (more scrobbles = darker color)
custom_colors = io.register_custom_palette("caribbean_current_shades_13_d2l", custom_palettes)

plt.figure(figsize=(12, 6))
ax = sns.barplot(
    x=user_country_count.index,
    y="count",
    data=user_country_count,
    palette=custom_colors,
    hue=user_country_count.index,
)
ax.set_xticks(range(len(user_country_count)))
ax.set_xticklabels(user_country_count.index, rotation=45, ha="right", fontsize=12)
ax.grid(True, axis="y", linestyle="--", alpha=0.7)
ax.set_yscale('log')
for p in ax.patches:
    if p.get_height() > 10:
        ax.annotate(
            f"{int(p.get_height())}",
            (p.get_x() + p.get_width() / 2.0, p.get_height()),
            ha="center",
            va="bottom",
            xytext=(0, 5),
            textcoords="offset points",
            fontsize=10,
            color="black",
            fontweight="bold",
            bbox=dict(facecolor='white', alpha=0.7, pad=2)
        )
ax.set_title("User countries per scrobble count (log scale)", fontsize=16)
ax.set_xlabel("UserCountry", fontsize=14)
ax.set_ylabel("Count (log scale)", fontsize=14)
plt.tight_layout()
show_or_save_plot("user_countries.png")


In [None]:
# --- entry point ----------------------------------------------------------------------
if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(message)s",
        handlers=[logging.StreamHandler()],
    )

    # Handle cleanup-artists command
    if args and hasattr(args, 'cmd') and args.cmd == "cleanup-artists":
        print("Running ArtistInfo table cleanup...")
        cleaned, remaining = clean_artist_info_table()
        if cleaned > 0:
            print(f"Cleanup successful! Removed {cleaned} records, {remaining} remain.")
            print("Updating artist_info.parquet file with cleaned data...")
            df = _df_from_db()
            df.to_parquet(AC_PARQUET, index=False, compression="zstd")
            print(f"Saved {len(df)} artist records to {AC_PARQUET}")
        else:
            print("No records were cleaned. The ArtistInfo table is already optimized.")
        sys.exit(0)