In [6]:
# %matplotlib
# %matplotlib inline
# %matplotlib notebook

import pandas as pd
import numpy as np
import random
import os
import warnings
from datetime import datetime, timedelta, timezone
warnings.simplefilter("ignore")
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.ticker import AutoMinorLocator
from matplotlib.ticker import FormatStrFormatter
import re
import math
from copy import deepcopy
from numba_stats import t
import scipy.stats as stats

# IMM
from gridmeter import IMM
from gridmeter import IMM_Settings

# Clustering
from gridmeter import Clustering
from gridmeter import Clustering_Settings

from IPython.display import Image, Markdown, display
plt.ion()
plt.rcParams['figure.figsize'] = [24, 16]
plt.rcParams['figure.dpi'] = 300

# pd.set_option('display.max_rows', None)
# pd.set_option('display.max_columns', None)
# pd.set_option('display.max_colwidth', None)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
df_ls_t = pd.read_csv("/app/.recurve_cache/clustering/example_dfs/df_ls_t.csv")
df_ls_cp = pd.read_csv("/app/.recurve_cache/clustering/example_dfs/df_ls_cp.csv")

In [26]:
# Test IMM

df_ls_t_mod = df_ls_t.set_index(["id", "hour"]).unstack()
df_ls_t_mod.columns = df_ls_t_mod.columns.droplevel(0)

df_ls_cp_mod = df_ls_cp.set_index(["id", "hour"]).unstack()
df_ls_cp_mod.columns = df_ls_cp_mod.columns.droplevel(0)

imm_settings = IMM_Settings()
df_cg, df_t_coeffs = IMM(imm_settings).get_comparison_group(df_ls_t_mod, df_ls_cp_mod)
df_cg

Unnamed: 0_level_0,treatment,distance,duplicated,cluster,weight
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
None-3234991905-3234991905,None-1094275585-1094275585,1.488844,False,0,1.0
None-1518448307-1518448307,None-1397301805-1397301805,4.454165,False,0,1.0
None-3517431310-3517431310,None-1432022910-1432022910,4.414845,False,0,1.0
None-1908503005-1908503005,None-1469355610-1469355610,1.487711,False,0,1.0
None-3529331605-3529331605,None-1504812305-1504812305,0.969864,False,0,1.0
...,...,...,...,...,...
None-1802814010-1802814010,None-5519977972-5519977972,80.664935,False,0,1.0
None-1646832205-1646832205,None-8098217928-8098217928,18.023014,False,0,1.0
None-0603795751-0603795751,None-8313277985-8313277985,9.958277,False,0,1.0
None-3831190605-3831190605,None-8860891437-8860891437,15.392707,False,0,1.0


In [None]:
from gridmeter._utils.calculate_distances import calculate_distances
from copy import deepcopy as copy

def TestDistanceMatching(
    df_ls_t,
    df_ls_c,
    n_matches_per_treatment=4,
    distance_metric="euclidean",
    allow_duplicate_match=True,
    replace_duplicate_method=None,  # currently unused [None, "closest_to_meter", "closest_global"]
    max_distance_threshold=None,
    n_match_multiplier=None,
    n_meters_per_chunk=10000,
):
    ls_t = df_ls_t.to_numpy()
    ls_cp = df_ls_c.to_numpy()

    n_matches_per_chunk = copy(n_matches_per_treatment)

    # Calculate closest distances
    if n_match_multiplier is None:
        n_matches_per_chunk = None

    if n_match_multiplier is not None:
        if (not allow_duplicate_match and replace_duplicate_method is not None) or max_distance_threshold is not None:
            n_matches_per_chunk *= n_match_multiplier

        if n_matches_per_chunk > ls_cp.shape[0]:
            n_matches_per_chunk = ls_cp.shape[0]

    cp_id_idx, dist = calculate_distances(
        ls_t, ls_cp, distance_metric, n_matches_per_chunk, n_meters_per_chunk
    )

    # create dataframes
    id_t = df_ls_t.index.values
    id_c = df_ls_c.index.values

    series_t = pd.Series(np.repeat(id_t, dist.shape[1]), name="treatment")
    series_cp = pd.Series(id_c[cp_id_idx.flatten()], name="id")
    clusters = pd.DataFrame(
        dist.flatten(), index=[series_t, series_cp], columns=["distance"]
    )
    clusters = clusters.reset_index()
    clusters["duplicated"] = clusters.duplicated(subset=["id"])
    clusters["cluster"] = 1
    clusters = clusters.set_index("id")
    
    if allow_duplicate_match:
        clusters = clusters.sort_values(by=["treatment", "distance"])

        # for each index, get 4 smallest distances
        clusters = clusters.groupby("treatment").head(n_matches_per_treatment)

    else:
        # get count of treatment and id pairs
        # t_id_counts = clusters.groupby(["treatment", "id"]).size()
        # print(t_id_counts)

        # drop duplicate index
        clusters = clusters[~clusters.index.duplicated(keep='first')]

        if replace_duplicate_method is not None:
            raise NotImplementedError(
                "'replace_duplicate_meters': True not implemented"
            )

    return clusters


def get_comparison_group(df_ls_t, df_ls_cp, weights=None, **kwargs):
    df_cg = TestDistanceMatching(df_ls_t, df_ls_cp, **kwargs)

    # Create df_t_coeffs
    t_ids = df_ls_t.index.unique()
    coeffs = np.ones(t_ids.values.size)

    df_t_coeffs = pd.DataFrame(coeffs, index=t_ids, columns=["pct_cluster_1"])
    df_t_coeffs.index.name = "id"

    return df_cg, df_t_coeffs

In [None]:
df_cg, df_t_coeffs = get_comparison_group(df_ls_t_mod, df_ls_cp_mod, allow_duplicate_match=True)
df_cg.reset_index().sort_values(by=["treatment", "id", "distance"])

In [None]:
df_cg.sort_values(["treatment", "distance"])

In [21]:
df_ls_t

Unnamed: 0,id,hour,ls
0,None-1094275585-1094275585,1,-0.004892
1,None-1094275585-1094275585,2,0.014424
2,None-1094275585-1094275585,3,0.024523
3,None-1094275585-1094275585,4,0.009783
4,None-1094275585-1094275585,5,0.002749
...,...,...,...
50395,None-9589493717-9589493717,500,0.029618
50396,None-9589493717-9589493717,501,0.001449
50397,None-9589493717-9589493717,502,0.007834
50398,None-9589493717-9589493717,503,0.007981


In [29]:
df_ls_t_mod.stack().reset_index().rename(columns={0: "ls"})

Unnamed: 0,id,hour,ls
0,None-1094275585-1094275585,1,-0.004892
1,None-1094275585-1094275585,2,0.014424
2,None-1094275585-1094275585,3,0.024523
3,None-1094275585-1094275585,4,0.009783
4,None-1094275585-1094275585,5,0.002749
...,...,...,...
50395,None-9589493717-9589493717,500,0.029618
50396,None-9589493717-9589493717,501,0.001449
50397,None-9589493717-9589493717,502,0.007834
50398,None-9589493717-9589493717,503,0.007981


In [39]:
# Test Clustering

clustering_settings = Clustering_Settings()
df_cg, df_t_coeffs = Clustering(clustering_settings).get_comparison_group(df_ls_t_mod, df_ls_cp_mod)
df_cg

Unnamed: 0_level_0,cluster
id,Unnamed: 1_level_1
None-6500609087-6500609087,0
None-4394605605-4394605605,0
None-1572272305-1572272305,0
None-2615141910-2615141910,0
None-2615294205-2615294205,0
...,...
None-3590533705-3590533705,8
None-1721108410-1721108410,8
None-1709262405-1709262405,8
None-1851355305-1851355305,8


In [19]:
df_t_coeffs

Unnamed: 0_level_0,pct_cluster_1,pct_cluster_2,pct_cluster_3,pct_cluster_4,pct_cluster_5,pct_cluster_6,pct_cluster_7,pct_cluster_8
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
None-1094275585-1094275585,2.163893e-12,5.029705e-01,4.288347e-01,0.000000e+00,4.847305e-13,0.000000e+00,4.097588e-02,2.721895e-02
None-1397301805-1397301805,1.574652e-11,1.887717e-01,6.672347e-01,1.439936e-01,7.034535e-12,0.000000e+00,1.078794e-11,0.000000e+00
None-1432022910-1432022910,1.296811e-12,6.902494e-04,4.342816e-03,7.915817e-05,1.381838e-12,3.685489e-06,3.904629e-01,6.044212e-01
None-1469355610-1469355610,4.351317e-12,3.474865e-01,5.179979e-01,1.345155e-01,3.327867e-12,0.000000e+00,1.160953e-14,1.171110e-12
None-1504812305-1504812305,7.987435e-12,4.239205e-01,4.190504e-01,1.570292e-01,1.727044e-12,0.000000e+00,1.716741e-11,0.000000e+00
...,...,...,...,...,...,...,...,...
None-5519977972-5519977972,3.001672e-05,1.697287e-07,5.576015e-02,1.722308e-03,3.799684e-02,2.551940e-01,3.890770e-01,2.602195e-01
None-8098217928-8098217928,7.050122e-08,3.575080e-01,4.697889e-01,9.410056e-03,1.201133e-01,6.750067e-03,3.872386e-03,3.255729e-02
None-8313277985-8313277985,2.632192e-12,1.684967e-02,1.464463e-11,8.571292e-11,1.276320e-11,2.251083e-01,3.401503e-11,7.580420e-01
None-8860891437-8860891437,1.365416e-02,5.706510e-01,4.043674e-01,1.084731e-03,6.107133e-12,1.060991e-11,1.024268e-02,6.747586e-12
