In [None]:
import random
import pandas as pd
import matplotlib.pyplot as plt
from jupyterthemes import jtplot

jtplot.style(theme='onedork')

In [None]:
url = "https://opendata.arcgis.com/api/v3/datasets/dd4580c810204019a7b8eb3e0b329dd6_0/downloads/data?format=csv&spatialRefId=4326"

In [None]:
df_raw = pd.read_csv(url)
df_raw = df_raw.drop([c for c in df_raw.columns if "Id" in c], 1)
df_raw

In [None]:
region = df_raw.columns[0]
time = df_raw.columns[6]

In [None]:
df_raw.loc[:,time].min(), df_raw.loc[:,time].max()

In [None]:
df = df_raw.groupby(region).sum()
df

In [None]:
def random_split(n_elements, group_names:set):
    if isinstance(group_names, int):
        group_names = set(range(group_names))
    assert len(group_names)<n_elements
    while True:
        split = random.choices(population=list(group_names), k=n_elements)
        if len(set(split))==len(group_names):
            return split


def single_changes(split):
    group_names = set(split)
    for ix,group_name in enumerate(split):
        for new_group_name in group_names - {group_name}:
            split[ix] = new_group_name
            if len(set(split)) == len(group_names):
                yield split.copy()
        split[ix] = group_name


def split_approx(df2, group_names, metric):
    split = random_split(len(df2), group_names)
    while True:
        new_splits = list(single_changes(split))
        new_metrics = [metric(new_split, df2) for new_split in new_splits]
        if min(new_metrics)>=metric(split, df2):
            return split
        split = new_splits[new_metrics.index(min(new_metrics))]

In [None]:
def calc_split_metric(split, df):
    df3 = df.groupby(split).sum()
    df3 /= df3.sum()
    return (df3.max() - df3.min()).sum()

In [None]:
group_col = "group"
assert group_col not in df.columns
df[group_col] = split_approx(df, 3, calc_split_metric)
df.groupby(group_col).sum()

In [None]:
df

In [None]:
df_eval = df_raw.join(df.group, on=region).groupby([group_col, time]).sum()
for col in df_eval.columns:
    df_eval[col].unstack(group_col).plot(figsize=(16,6))
    plt.xticks(rotation=20)
    plt.title(col)
    plt.show()