In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline


In [None]:
from typing import Optional
import os
import numpy as np
import scipy as sp
from scipy.spatial import ConvexHull
import torch
import torch.nn as nn
from timeit import default_timer
from torch import Tensor
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.path import Path

from lib.problems import ProblemDataset
#from lib.utils import visualize_clustering
from lib.utils.visualization import group_by_label_mean, calculate_inertia


In [None]:
ITALIA = False

if ITALIA:
    gb21_pth = "plot_data/italia_tel_2020/ccp_mh/full_test_n2020_k25_cap1_1/eval_results_full_1.pkl"
    rpack_pth = "plot_data/italia_tel_2020/rpack/full_test_n2020_k25_cap1_1/eval_results_full_1.pkl"
    ncc_pth = "plot_data/italia_tel_2020/ncc_samp_re2/full_test_n2020_k25_cap1_1/eval_results_full_1.pkl"
    ckm_pth = "plot_data/italia_tel_2020/cap_kmeans/full_test_n2020_k25_cap1_1/eval_results_full_1.pkl"
    SAVE_DIR = "plot_data/italia_tel_2020/"
    ncc_name = "NCC_(s-25-64)"
else:
    gb21_pth = "plot_data/shanghai_tel_2373/ccp_mh/full_test_n2373_k40_cap1_1/eval_results_full_1.pkl"
    rpack_pth = "plot_data/shanghai_tel_2373/rpack/full_test_n2373_k40_cap1_1/eval_results_full_1.pkl"
    ncc_pth = "plot_data/shanghai_tel_2373/ncc_samp_re2/full_test_n2373_k40_cap1_1/eval_results_full_1.pkl"
    ckm_pth = "plot_data/shanghai_tel_2373/cap_kmeans/full_test_n2373_k40_cap1_1/eval_results_full_1.pkl"
    SAVE_DIR = "plot_data/shanghai_tel_2373/"
    ncc_name = "NCC_(s-25-128)"

In [None]:
PTHS = [gb21_pth, rpack_pth, ckm_pth, ncc_pth]
NAMES = ["GB21", "PACK", "CapKMeans", ncc_name]

In [None]:
results = [torch.load(pth)[0] for pth in PTHS]
results

In [None]:
def plot_cluster(
    ax, x, y,
    plot_centers: bool = True,
    plot_medoids: bool = False,
    plot_convex_hull: bool = True,
    plot_hull_points: bool = False,
    compute_inertia: bool = False,
    alpha: float = 0.5,
    rt_str: str = "",
    name: str = None,
    legend=False,
    fontsize: int = 15,
):
    x_data = pd.DataFrame(np.concatenate((x[:, :3], y[:, None]), axis=-1),
                          columns=["x_coord", "y_coord", "weight", "pred_label"])
    # convert dtype
    x_data.pred_label = x_data.pred_label.astype(int)

    sns.scatterplot(x="x_coord", y="y_coord", hue="pred_label",
                    size="weight", sizes=(10, 100),
                    alpha=alpha, palette="gist_rainbow", legend=legend,
                    data=x_data, ax=ax)
    if plot_centers or plot_convex_hull or compute_inertia:
        coords = x[:, :2]
        lbls = np.unique(y)
        nc = len(lbls)
        coords_ = torch.from_numpy(coords)
        y_ = torch.from_numpy(y)
        centers_ = group_by_label_mean(
            coords_[None, :, :],
            y_[None, None, :],
            torch.arange(nc).unsqueeze(0)
        ).squeeze(0).squeeze(0)
        centers = centers_.cpu().numpy()
        if plot_centers:
            ax.scatter(x=centers[:, 0], y=centers[:, 1], marker="X", c="black", alpha=1.0)

    if plot_medoids:
        raise NotImplementedError()

    if plot_convex_hull:
        hulls, coord_set = [], []
        colors = sns.color_palette("gist_rainbow", n_colors=nc)
        for lb in lbls:
            cl_points = coords[y==lb]
            hulls.append(ConvexHull(cl_points))
            coord_set.append(cl_points)
        for hull, points, c in zip(hulls, coord_set, colors):
            for simplex in hull.simplices:
                    ax.plot(points[simplex, 0], points[simplex, 1], c=c)
            if plot_hull_points:
                ax.plot(points[hull.vertices, 0], points[hull.vertices, 1], 'o', mec='r', color='none', lw=1, markersize=10)

    if compute_inertia:
        #inertia = calculate_inertia(coords_, centers_, y_).item()
        sets = []
        for lbl in lbls:
            sets.append((y==lbl).nonzero())
        tot_dist = 0.0
        for s in sets:
            if s is not None and len(s) > 0:    # not empty
                center = coords[s].mean(0)
                cum_dist = np.linalg.norm(coords[s] - center, ord=2) ** 2   # squared euclidean distance
                tot_dist += cum_dist
        inertia = tot_dist

    name_str = name if name is not None and len(name) > 0 else ""
    inertia_str = f"\ninertia: {round(inertia, 4): .4f}  " if compute_inertia else ""
    rt_str = f"time: {rt_str}s" if rt_str is not None and len(rt_str) > 0 else ""
    title = f"{name_str}{inertia_str}{rt_str}"
    ax.set_title(title, fontdict={'fontsize': fontsize, 'fontweight': 'bold'})


In [None]:
def plot_full(
    results: list,
    names: list,
    alpha: float = 0.5,
    plot_legend: bool = False,
    plot_centers: bool = True,
    plot_medoids: bool = False,
    plot_convex_hull: bool = True,
    compute_inertia: bool = False,
    remove_axis_ticks_and_labels: bool = True,
    save_dir: str = None,
    fontsize: int = 25,
    **kwargs
):
    assert len(results) == len(names)
    n_plots = len(results)

    #plt.rcParams.update({'axes.titlesize': 'x-large'})
    sns.set_theme(style="white", palette=None)
    if plot_legend:
        legend = "auto"
    else:
        legend = False

    fig, axs = plt.subplots(nrows=1, ncols=n_plots, figsize=(5.9*n_plots, 6))
    for ax, res, nm in zip(axs, results, names):
        inst = res['instance']
        x = np.concatenate((inst['coords'][0], inst['demands'][0][:, None]), axis=-1)
        y = res['assignment']
        rt = str(round(res['run_time'], 1))
        plot_cluster(ax, x, y, plot_centers, plot_medoids, plot_convex_hull,
                     compute_inertia=compute_inertia, name=nm, alpha=alpha,
                     legend=legend, rt_str=rt, fontsize=fontsize, **kwargs)

    if remove_axis_ticks_and_labels:
        for ax in axs:
            ax.set(xticklabels=[], yticklabels=[], xlabel=None, ylabel=None)

    if plot_legend:
        for ax in axs:
            sns.move_legend(ax, loc='upper center', bbox_to_anchor=(0.5, -0.05), ncol=n_plots)

    fig.tight_layout()
    if save_dir is None:
        plt.show()
    else:   # save figure
        nm_str = "_".join(names)
        save_pth = os.path.join(save_dir, f"{nm_str}_plot.pdf")
        plt.savefig(fname=save_pth, format="pdf", bbox_inches='tight')

In [None]:
plot_full(
    results=results,
    names=NAMES,
    plot_legend=False,
    plot_centers=True,
    plot_convex_hull=True,
    compute_inertia=True,
    remove_axis_ticks_and_labels=True,
    fontsize=80,
    save_dir=SAVE_DIR,
)