# Pruebas con OBSERVATION_TABLE simulada

Aquí vamos a simular una pequeña base de datos con fechas de pacientes para ver si es más rápido trabajar con pyarrow o con pandas.

TO DO:
- Rehacer el método shift (v2) sin iterar sobre personas y sin reguardar datos. Hay que basarse en el método de VISIT_OCCURRENCE.
- Try with [numba](https://pandas.pydata.org/docs/user_guide/enhancingperf.html#enhancingperf)

Creamos la base de datos de mentira

In [None]:
import numpy as np
import pandas as pd
import pyarrow as pa
from pyarrow import parquet

def create_sample_df(n:int = 1000,n_dates:int=50,
                     first_date:str='2020-01-01',
                     last_date:str='2023-01-01',
                     mean_duration_days:int=60,
                     std_duration_days:int=180)->pd.DataFrame:

    # == Parameters ==    
    np.random.seed(42)
    pd.options.mode.string_storage = "pyarrow"
    # Start date from which to start the dates
    first_date = pd.to_datetime(first_date)
    last_date = pd.to_datetime(last_date)
    max_days = (last_date-first_date).days

    # == Generate IDs randomly ==
    # -- Generate the Ids
    people = np.random.randint(10000000, 99999999 + 1, size=n)
    person_id = np.random.choice(people,n*n_dates)

    # == Generate random dates ==
    # Generate random integers for days and convert to timedelta
    random_days = np.random.randint(0, max_days, size=n*n_dates)
    # Create the columns
    observation_start_date = first_date + pd.to_timedelta(random_days, unit='D')
    # Generate a gaussian sample of dates
    random_days = np.random.normal(mean_duration_days, std_duration_days, size=n*n_dates)
    random_days = np.int32(random_days)
    observation_end_date = observation_start_date + pd.to_timedelta(random_days, unit='D')
    # Correct end_dates
    # => If they are smaller than start_date, take start_date
    observation_end_date = np.where(observation_end_date<observation_start_date,
                                    observation_start_date,observation_end_date)
    

    # == Generate the code ==
    period_type_concept_id = np.random.randint(1, 11, size=n*n_dates)

    # == Generate the dataframe ==
    df_raw = {'person_id':person_id,'observation_period_start_date':observation_start_date,
            'observation_period_end_date':observation_end_date,'period_type_concept_id':period_type_concept_id}
    return pd.DataFrame(df_raw)

## Comparación entre métodos

**06/09/2024:** Por ahora va ganando la implementación de pyarrow. 
* **pyarrow** Default hasta ahora.
    * => 475 ms ± 5.34 ms per loop (mean ± std. dev. of 10 runs, 5 loops each)
* **pandas v1:** Copia pyarrow pero en pandas
    * => 793 ms ± 8.65 ms per loop (mean ± std. dev. of 5 runs, 10 loops each)
* **pandas v2** Shifting but iterating by person
    * => 3.82 s ± 19.6 ms per loop (mean ± std. dev. of 5 runs, 10 loops each)
* **pandas v3** Shifting pero sin iterar por persona, operando siempre sobre el mismo dataframe. 
    * => 4 ms ± 129 µs per loop (mean ± std. dev. of 5 runs, 10 loops each)
* **Dask** 
    * => No ha habido manera de que eso lance. Se basa en usar shift, que es mucho más lento, así que tiene pinta de que no va a ser posible hacerlo por ahí.

Parece que lo lento no son los cálculos, si no lo de encadenar los dataframes o las tablas. Haciendo los shifts podemos hacer el cálculo 2 órdenes de magnitud más rápido.

### Usando tablas pyarrow
#### v1 - Original, iterando sobre personas
Nos traemos las funciones que usaremos.


In [None]:
import pyarrow as pa
import pyarrow.compute as pc
import numpy as np
import pandas as pd
# import dask
# from dask.distributed import Client


def create_uniform_int_array(
        length: int,
        value: int = 0) -> pa.array:
    """Create an uniform array with a specific length

    By default is an array of zeroes, can be modified
    defining a integer value.

    Parameters
    ----------
    length : int
        length of the array.
    value : int, optional
        Value that fills the array, by default 0

    Returns
    -------
    pa.array
        pyarrow array with int32 datatype.
    """
    # creamos un array de zeros con numpy y
    # lo pasamos a pyarrow forzando int32
    zeros = pa.array(np.zeros(shape=length), pa.int32())
    if value == 0:
        return zeros
    else:
        # Sumamos la cantidad que sea
        return pc.add(zeros, value)  # pylint: disable=E1101


def group_observation_dates(
        start_dates: pa.Array,
        end_dates: pa.Array,
        n_days: int,
        verbose: bool = False) -> tuple[pa.Array, pa.Array, None | pa.Table]:
    """Given a pair of 'start_dates' and 'end_dates', it will
    compute the days between each 'end_date' and the next
    'start_date' and remove dates that are smaller that a
    given number of days ('n_days').

    The new dates will only contain start and end dates that have
    more than 'n_days' of difference between them.

    Parameters
    ----------
    start_dates : pa.Array
        Array of start dates
    end_dates : pa.Array
        Array of end dates
    n_days : int
        _description_
    verbose : bool, optional
        _description_, by default False

    Returns
    -------
    tuple[pa.Array, pa.Array, None | pa.Table]
        Always return a 3-item tuple.
        First item is reduced start dates.
        Second item is reduced end dates.
        Third item is None if verbose=True,
        if verbose=False, is table with start_dates,
        end_dates and days between them. Usefull when
        verifying dates.

    Raises
    ------
    AssertionError
        The resulting starting dates should always come before
        their corresponding end dates. Return an AssertionError
        otherwise.
    """
    # Get an array of end_dates, taking away the last one
    # (last date cannot be compared to the next start date)
    from_dates = end_dates[:-1]
    # Get an array of start_dates, taking away the first one
    # (first date cannot be compared to the previous end date)
    to_dates = start_dates[1:]

    # -- Compute days between
    intervals = pc.days_between(  # pylint: disable=E1101
        from_dates, to_dates).to_numpy()
    # Create an inner table for the calculations if verbose
    inner_table = None
    if verbose:
        inner_table = pa.Table.from_arrays(
            [start_dates, end_dates, pa.array(np.append(intervals, np.nan))],
            names=['start', 'end', 'intervals']).to_pandas()

    # Filter intervals under some assumption
    filt = intervals >= n_days
    # => When this filt is 'true', it means that for that index,
    # let's call it 'idx', between the end date of 'idx' and the start
    # date of 'idx+1' there more than 'n_days' days.
    # i.e.:
    # (start_date[idx+1] - end_date[idx]).days > n_days

    # if no interval is greater, take the first and last rows
    if np.nansum(filt) == 0:
        idx_end_dates = np.array([len(intervals)])
        # Sum 1 to get start dates
        idx_start_dates = np.array([0])

    # If some filters exist take those
    else:
        # Get indexes of end_dates
        idx_end_dates = filt.nonzero()[0]
        # Sum 1 to get corresponding start dates
        idx_start_dates = idx_end_dates+1
        # Append last entry as last end_date
        idx_end_dates = np.append(idx_end_dates, len(intervals))
        # Append first entry as first start_date
        idx_start_dates = np.append(0, idx_start_dates)

    # if verbose:
    #     print(f'{idx_start_dates=}')
    #     print(f'{idx_end_dates=}')

    # Make sure all end values are after start values
    new_start = start_dates.take(idx_start_dates)
    new_end = end_dates.take(idx_end_dates)
    if pc.any(pc.less(new_end, new_start)).as_py():  # pylint: disable=E1101
        if verbose:
            print(f"{start_dates=}", f"{end_dates=}")
            print(f"{new_start=}", f"{new_end=}")
        raise AssertionError(
            'Some end dates happen before start dates. Try sorting the original data.')

    return (new_start, new_end, inner_table)


def group_person_dates(
        table_rare: pa.Table,
        person: str | int,
        n_days: int) -> pa.Table:
    """Filters original table for a specific person and reduces
    the amount of date records grouping all records that are separated
    by n_days or less.

    Parameters
    ----------
    table_rare : pa.Table
        Table as prepared by 'prepare_table_raw_to_rare()'.
    person : str | int
        person id, can be an int (the usual) or a string.
    n_days : int
        number of maximum days between subsequent records.

    Returns
    -------
    pa.Table
        Table identical to table_rare but with less date records.
    """

    # Filter for the current person_id
    filt = pc.is_in(table_rare['person_id'],  # pylint: disable=E1101
                    pa.array([person]))
    table_person = table_rare.filter(filt)
    # Retrieve corresponding dates
    start_dates = table_person['observation_period_start_date']
    end_dates = table_person['observation_period_end_date']
    # Group dates closer
    start_dates, end_dates, _ = group_observation_dates(
        start_dates, end_dates, n_days, verbose=False)
    # Create person
    person_id = create_uniform_int_array(len(start_dates),
                                         value=person)
    # Retrieve most common period type
    period_type_concept_id = pc.mode(  # pylint: disable=E1101
        table_person['period_type_concept_id'])[0][0]
    period_type_concept_id = create_uniform_int_array(len(start_dates),
                                                      value=period_type_concept_id)
    # return table
    return pa.Table.from_arrays(
        [person_id, start_dates, end_dates, period_type_concept_id],
        names=['person_id', 'observation_period_start_date',
               'observation_period_end_date', 'period_type_concept_id'])

In [None]:
# Cargamos los datos
df_raw = create_sample_df()
df_raw = df_raw.sort_values(['person_id','observation_period_start_date'])
table_pa = pa.Table.from_pandas(df_raw)
n_days = 60
def serial_grouping():
    table_person = []
    for person in table_pa['person_id'].unique():
        tmp = group_person_dates(table_pa,person,n_days)
        table_person.append(tmp)
    return table_person

%timeit -n 5 -r 10 serial_grouping()

#### v2 - Sin iterar sobre personas
Probamos ahora a intentar simplemente buscar los índices que nos convengan, emulando la idea de v3 con pandas.

In [None]:
# == NOT FINISHED == #
def serial_grouping_v2(table_pa,n_days):
    lbl_0 = 'person_id'
    lbl_1 = 'observation_period_start_date'
    lbl_2 = 'observation_period_end_date'
    lbl_3 = 'period_type_concept_id'

    # Create necessary columns to do the calculations
    df_rare.loc[:,'previous_end_date'] = df_rare[lbl_2].shift(1)
    df_rare.loc[:,'previous_interval'] = (
        df_rare['previous_end_date']-df_rare[lbl_1]).fillna(pd.Timedelta(n_days*10, 'D'))
    # If previous_interval is under required n_days, it means the interval is small enough
    df_rare['idx_person'] = df_rare['person_id'] == df_rare['person_id'].shift(1)
    df_rare['idx_interval'] = df_rare['previous_interval'] < pd.Timedelta(n_days,unit='D')
    df_rare['to_remove'] = df_rare['idx_interval'] & df_rare['idx_person']
    return df_rare[df_rare['to_remove']]

In [None]:
# Cargamos los datos
n_days = 60
df_raw = create_sample_df()
df_raw = df_raw.sort_values(['person_id','observation_period_start_date'])
table_pa = pa.Table.from_pandas(df_raw)
%timeit -n 10 -r 5 serial_grouping_v2(table_pa,n_days)

### Intentamos paralelizar pyarrow con dask.delayed

**25/07/2024**: No hay cojones de que funcione con una tabla de pyarrow.

In [None]:
# client = Client(n_workers=4)


# @dask.delayed
# def delayed_filter(table, person):
#     filt = pc.is_in(table['person_id'], pa.array([person]))
#     table_person = table.filter(filt)
#     return table_person


# @dask.delayed
# def delayed_group_observation_dates(start_dates, end_dates, n_days):
#     return group_observation_dates(start_dates, end_dates, n_days)


# def table_reduced_explicit():
#     table_reduced = []
#     for person in table_pa['person_id'].unique():
#         # ============================================ #
#         # Filter for the current person_id
#         table_person = delayed_filter(table_pa, person)
#         # Retrieve corresponding dates
#         start_dates = table_person['observation_period_start_date']
#         end_dates = table_person['observation_period_end_date']
#         # Group dates closer
#         start_dates, end_dates, _ = group_observation_dates(
#             start_dates, end_dates, n_days)
#         # Create person
#         person_id = create_uniform_int_array(len(start_dates),
#                                              value=person)
#         # Retrieve most common period type
#         period_type_concept_id = pc.mode(  # pylint: disable=E1101
#             table_person['period_type_concept_id'])[0][0]
#         period_type_concept_id = create_uniform_int_array(len(start_dates),
#                                                           value=period_type_concept_id)
#         table_person = pa.Table.from_arrays(
#             [person_id, start_dates, end_dates, period_type_concept_id],
#             names=['person_id', 'observation_period_start_date',
#                    'observation_period_end_date', 'period_type_concept_id'])
#         # ============================================ #

#         table_reduced.append(table_person)
#     return table_reduced

# table_reduced_explicit = pa.concat_tables(table_reduced_explicit)
# %timeit serial_grouping_explicit()

In [None]:
# results = table_reduced_explicit() # NO FUNCIONA!

### Usando pandas.dataframes

#### v1
Vamos a probar usando simplemente pandas. Hay que corregir la función para intentar usar el método shift, como se ha hecho con dask, en lugar de hacer indexación explícita. Quizá así es más rápido.

Redefinimos la función:

In [None]:
def group_observation_dates_df(
        start_dates: pd.Series,
        end_dates: pd.Series,
        n_days: int,
        verbose: bool = False) -> tuple[pd.Series, pd.Series, None | pd.DataFrame]:
    """Given a pair of 'start_dates' and 'end_dates', it will
    compute the days between each 'end_date' and the next
    'start_date' and remove dates that are smaller that a
    given number of days ('n_days').

    The new dates will only contain start and end dates that have
    more than 'n_days' of difference between them.

    Parameters
    ----------
    start_dates : pd.Series
        Array of start dates
    end_dates : pd.Series
        Array of end dates
    n_days : int
        _description_
    verbose : bool, optional
        _description_, by default False

    Returns
    -------
    tuple[pd.Series, pd.Series, None | pd.DataFrame]
        Always return a 3-item tuple.
        First item is reduced start dates.
        Second item is reduced end dates.
        Third item is None if verbose=True,
        if verbose=False, is table with start_dates,
        end_dates and days between them. Usefull when
        verifying dates.

    Raises
    ------
    AssertionError
        The resulting starting dates should always come before
        their corresponding end dates. Return an AssertionError
        otherwise.
    """
    # Get an array of end_dates, taking away the last one
    # (last date cannot be compared to the next start date)
    from_dates = end_dates[:-1]
    # Get an array of start_dates, taking away the first one
    # (first date cannot be compared to the previous end date)
    to_dates = start_dates[1:]
    intervals = np.int64((to_dates-from_dates)/1e9/3600/24)
    # Filter intervals under some assumption
    filt = intervals >= n_days

    # => When this filt is 'true', it means that for that index,
    # let's call it 'idx', between the end date of 'idx' and the start
    # date of 'idx+1' there more than 'n_days' days.
    # i.e.:
    # (start_date[idx+1] - end_date[idx]).days > n_days

    # if no interval is greater, take the first and last rows
    if np.nansum(filt) == 0:
        idx_end_dates = np.array([len(intervals)])
        # Sum 1 to get start dates
        idx_start_dates = np.array([0])

    # If some filters exist take those
    else:
        # Get indexes of end_dates
        idx_end_dates = filt.nonzero()[0]
        # Sum 1 to get corresponding start dates
        idx_start_dates = idx_end_dates+1
        # Append last entry as last end_date
        idx_end_dates = np.append(idx_end_dates, len(intervals))
        # Append first entry as first start_date
        idx_start_dates = np.append(0, idx_start_dates)

    # Make sure all end values are after start values
    new_start = start_dates.take(idx_start_dates)
    new_end = end_dates.take(idx_end_dates)
    intervals = np.int64(new_end-new_start)
    if np.any(intervals < 0):
        if verbose:
            print(f"{start_dates=}", f"{end_dates=}")
            print(f"{new_start=}", f"{new_end=}")
        raise AssertionError(
            'Some end dates happen before start dates. Try sorting the original data.')

    return (new_start, new_end)

In [None]:
def serial_table_df_reduced(df_raw,n_days):
    table_reduced = []
    for person in df_raw['person_id'].unique():
        # Filter for the current person_id
        filt = df_raw['person_id'] == person
        table_person = df_raw[filt]
        # Group the dates
        start_dates = table_person['observation_period_start_date'].values
        end_dates = table_person['observation_period_end_date'].values
        start_dates, end_dates = group_observation_dates_df(
            start_dates, end_dates, n_days)

        # Create person
        person_id = np.array([person]*len(start_dates))
        # Retrieve most common period type
        period_type_concept_id = table_person['period_type_concept_id'].mode()[
            0]
        period_type_concept_id = np.array(
            [period_type_concept_id]*len(start_dates))
        table_person = pd.DataFrame({'person_id': person_id,
                                     'observation_period_start_date': start_dates,
                                     'observation_period_end_date': end_dates,
                                     'period_type_concept_id': period_type_concept_id})

        table_reduced.append(table_person)
    return table_reduced

In [None]:
# Cargamos los datos
n_days = 60
df_raw = create_sample_df()
df_raw = df_raw.sort_values(['person_id','observation_period_start_date'])
%timeit -n 10 -r 5 serial_table_df_reduced(df_raw,n_days)

#### v2

Vamos a probar haciendo los cálculos extendiendo los dataframes, como tuvimos que hacer con dask.

In [None]:
def serial_table_df_reduced_v2(df_raw,n_days):
    lbl_0 = 'person_id'
    lbl_1 = 'observation_period_start_date'
    lbl_2 = 'observation_period_end_date'
    lbl_4 = 'period_type_concept_id'
    table_reduced = []
    for person in df_raw[lbl_0].unique():
        # Filter for the current person_id
        filt = df_raw[lbl_0] == person
        table_person = df_raw[filt]

        # Create necessary columns to do the calculations
        table_person.loc[:,'previous_end_date'] = table_person[lbl_2].shift(1)
        table_person.loc[:,'previous_interval'] = (
            table_person[lbl_1]-table_person['previous_end_date']).fillna(pd.Timedelta(n_days*10, 'D'))
        table_person.loc[:,'next_start_date'] = table_person[lbl_1].shift(-1)
        table_person.loc[:,'next_interval'] = (
            table_person['next_start_date']-table_person[lbl_2]).fillna(pd.Timedelta(n_days*10, 'D'))

        # Filter out significant start_days
        new_start_idx = table_person['previous_interval'] > pd.Timedelta(10, 'D')
        new_start_dates = table_person.loc[new_start_idx, lbl_1]
        new_start_dates = new_start_dates.reset_index(drop=True)
        # Filter out significant end_days
        new_end_idx = table_person['next_interval'] > pd.Timedelta(10, 'D')
        new_end_dates = table_person.loc[new_end_idx, lbl_2]
        new_end_dates = new_end_dates.reset_index(drop=True)
        # Start new table
        new_table_person = pd.concat([new_start_dates, new_end_dates], axis=1)
        
        # Make sure all end values are after start values
        # check = (new_table_person.iloc[:,1]-new_table_person.iloc[:,0]) < pd.Timedelta(0,'D')
        # if check.all().compute():
        #     raise AssertionError(
        #         'Some end dates happen before start dates. Try sorting the original data.')

        # Create person
        new_table_person.loc[:,lbl_0] = person
        # Retrieve most common period type
        new_table_person.loc[:,lbl_4] = table_person[lbl_4].mode().values[0]
        new_table_person = new_table_person[[lbl_0, lbl_1, lbl_2, lbl_4]]
        table_reduced.append(new_table_person)
    return table_reduced

In [None]:
pd.options.mode.copy_on_write = True
# Cargamos los datos
n_days = 60
df_raw = create_sample_df()
df_raw = df_raw.sort_values(['person_id','observation_period_start_date'])
%timeit -n 10 -r 5 serial_table_df_reduced_v2(df_raw,n_days)

#### v3 (shift sin cribar por persona)


In [None]:
def serial_table_df_reduced_v3(df_raw,n_days):
    lbl_0 = 'person_id'
    lbl_1 = 'observation_period_start_date'
    lbl_2 = 'observation_period_end_date'
    lbl_3 = 'period_type_concept_id'
    n_personas = 100
    n_fechas = 10
    first_date = '2020-01-01'
    last_date = '2022-01-01'
    mean_duration_days = 60
    std_duration_days = 60*2

    df_rare = df_raw.copy()
    df_rare = df_rare.reset_index(drop=True)
    # Create necessary columns to do the calculations
    df_rare.loc[:,'previous_end_date'] = df_rare[lbl_2].shift(1)
    df_rare.loc[:,'previous_interval'] = (
        df_rare['previous_end_date']-df_rare[lbl_1]).fillna(pd.Timedelta(n_days*10, 'D'))
    # If previous_interval is under required n_days, it means the interval is small enough
    df_rare['idx_person'] = df_rare['person_id'] == df_rare['person_id'].shift(1)
    df_rare['idx_interval'] = df_rare['previous_interval'] < pd.Timedelta(n_days,unit='D')
    df_rare['to_remove'] = df_rare['idx_interval'] & df_rare['idx_person']
    return df_rare[df_rare['to_remove']]

In [None]:
# Cargamos los datos
n_days = 60
df_raw = create_sample_df()
df_raw = df_raw.sort_values(['person_id','observation_period_start_date'])
%timeit -n 10 -r 5 serial_table_df_reduced_v3(df_raw,n_days)

### Eliminando overlapping rows

In [None]:
lbl_0 = 'person_id'
lbl_1 = 'observation_period_start_date'
lbl_2 = 'observation_period_end_date'
lbl_3 = 'period_type_concept_id'
n = 50
n_days = 300
first_date = '2020-01-01'
last_date = '2022-01-01'
mean_duration_days = 60
std_duration_days = 365*3
df_raw = create_sample_df(n,n_days,first_date,last_date,mean_duration_days,std_duration_days)

df_rare = df_raw.copy()
df_rare = df_rare.sort_values([lbl_0,lbl_1,lbl_2,lbl_3],ascending=[True,True,False,True])

def find_overlap_index(df: pd.DataFrame,
                       person_lbl,start_lbl,end_lbl) -> pd.Series:
    # 1. Primero compruebo que el paciente anterior sea el mismo
    idx_person = df[person_lbl] == df[person_lbl].shift(1)
    # 2. Compruebo si la start_date actual es menor que la anterior
    idx_start = df[start_lbl] >= df[start_lbl].shift(1)
    # 3. Compruebo que la end_date actual es mayor que la anterior
    idx_end = df[end_lbl] <= df[end_lbl].shift(1)
    # 4. Si todo lo anterior es True, puedo tirarlo porque se cumplen todos los requisitos.
    return idx_start & idx_end & idx_person

def remove_all_overlap(df: pd.DataFrame,
                       counter_lim: int = 1000,
                       verbose: bool = False) -> pd.DataFrame:

    # Copy the dataframe
    df_tmp = df.copy()
    cols_to_show = [lbl_0,lbl_1,lbl_2]
    # Prepare the while loop
    idx_to_remove_sum = 1
    counter = 0
    if verbose:
        print('Cleaning...')
    # Start the loop
    while (idx_to_remove_sum > 0) and (counter <= counter_lim):
        # Get the rows
        idx_to_remove = find_overlap_index(df_tmp,lbl_0,lbl_1,lbl_2)
        # Prepare next loop
        idx_to_remove_sum = idx_to_remove.sum()
        counter += 1
        # Print the statements
        if verbose:
            print(f"{counter} => {idx_to_remove_sum} rows removed. Example:")
        # Show info of first case as an example
        if verbose & (idx_to_remove_sum > 0):
            idx_first_true = idx_to_remove.idxmax()
            print(df_tmp.loc[[idx_first_true-1, idx_first_true], cols_to_show])
        # Remove the overlapping rows
        df_tmp = df_tmp.loc[~idx_to_remove].reset_index(drop=True)

    return df_tmp.reset_index(drop=True)

%timeit -n 10 -r 5 remove_all_overlap(df_rare,verbose=False)

### Usando dask.dataframes
 
Vamos a intentar reproducir el código que hemos hecho con pyarrow usando dask.dataframes

In [None]:
# import dask.dataframe as dd
# n_days = 60

# # Primer leemos como dask.dataframe
# table_dd = dd.read_parquet('OBSERVATION_TABLE.parquet')
# table_dd = table_dd.sort_values('observation_period_start_date')
# table_dd.compute()

Nos encontramos varios problemas al intentar usar dask:
* Al cargar los dataframes de manera lazy, no podemos acceder fácilmente a las filas específicas.
    * Es decir, no podemos indexar por fila con iloc[], por ejemplo, por lo que calcular las diferencias entre fechas no se puede hacer al momento.
    * Vamos a tener que generar un dataframe previo modificado que tenga las fechas en el orden que queremos.

In [None]:
# # Sacamos los datos de una persona concreta para probar las nuevas columnas
# person = 62827729
# n_days = 10
# filt = table_dd['person_id'] == person
# table_person = table_dd[filt]
# table_person['previous_end_date'] = table_person['observation_period_end_date'].shift(
#     1)
# table_person['previous_interval'] = (table_person['observation_period_start_date'] -
#                                      table_person['previous_end_date']).fillna(pd.Timedelta(n_days*10, 'D'))
# table_person['next_start_date'] = table_person['observation_period_start_date'].shift(
#     -1)
# table_person['next_interval'] = (table_person['next_start_date'] -
#                                  table_person['observation_period_end_date']).fillna(pd.Timedelta(n_days*10, 'D'))
# table_person.compute()

En esta tabla podemos ver las relaciones de cada fila con las filas inmediatemente anteriores y posteriores. Se han calculado los intervalos de tiempos hasta las siguientes fechas, de modo que se pueda ver la distancia temporal entre ellas y compararla con el número de días mínimo requerido (*n_days*).

* Si *previous_interval* > *n_days* -> Entonces la distancia entre el final de la interacción anterior y el comienzo de la actual es significativa. 
    * La start_date de la actual fila debe guardarse como *new_start_date*.
* Si *next_interval* > *n_days* -> Entonces la distancia entre el final de la interacción actual y el comienzo de la siguiente es significativa.
    * La end_date de la actual fila debe guardarse como *new_end_date*.
* La primera y la última fecha siempre tendrán un nan, ya que no tienen con quien compararse. Hemos llenado los nan con n_days*10 para asegurar que siempre cumplen el criterio.

In [None]:
# # Hacemos el filtrado en función de del número de días
# new_start_idx = table_person['previous_interval'] > pd.Timedelta(10, 'D')
# print(new_start_idx.compute(), '\n\n== ==\n')
# new_start_dates = table_person.loc[new_start_idx,
#                                    'observation_period_start_date']
# print(new_start_dates.compute())

In [None]:
# # Hacemos el filtrado en función de del número de días
# new_end_idx = table_person['next_interval'] > pd.Timedelta(10, 'D')
# print(new_end_idx.compute(), '\n\n== ==\n')
# new_end_dates = table_person.loc[new_end_idx, 'observation_period_end_date']
# print(new_end_dates.compute())

In [None]:
# # Creamos un nuevo dataframe con estoy  comprobamos que las distancias son correctas
# new_table_person = dd.concat([new_start_dates.reset_index(
#     drop=True), new_end_dates.reset_index(drop=True)], axis=1)
# check = (new_table_person.iloc[:, 1] -
#          new_table_person.iloc[:, 0]) < pd.Timedelta(0, 'D')
# check.all().compute()

In [None]:
# # Create person
# new_table_person['person_id'] = person
# new_table_person['period_type_concept_id'] = table_person['period_type_concept_id'].mode(
# ).values.compute()[0]
# new_table_person = new_table_person[['person_id', 'observation_period_start_date',
#                                      'observation_period_end_date', 'period_type_concept_id']]
# new_table_person.compute()

In [None]:
# import dask
# import dask.dataframe as dd


# def serial_table_dd_reduced():
#     table_reduced = []
#     for person in table_dd['person_id'].unique():
#         # Filter for the current person_id
#         filt = table_dd['person_id'] == person
#         table_person = table_dd[filt]

#         # Create necessary columns to do the calculations
#         table_person['previous_end_date'] = table_person['observation_period_end_date'].shift(
#             1)
#         table_person['previous_interval'] = (
#             table_person['observation_period_start_date']-table_person['previous_end_date']).fillna(pd.Timedelta(n_days*10, 'D'))
#         table_person['next_start_date'] = table_person['observation_period_start_date'].shift(
#             -1)
#         table_person['next_interval'] = (
#             table_person['next_start_date']-table_person['observation_period_end_date']).fillna(pd.Timedelta(n_days*10, 'D'))

#         # Filter out significant start_days
#         new_start_idx = table_person['previous_interval'] > pd.Timedelta(
#             10, 'D')
#         new_start_dates = table_person.loc[new_start_idx,
#                                            'observation_period_start_date']
#         # Filter out significant end_days
#         new_end_idx = table_person['next_interval'] > pd.Timedelta(10, 'D')
#         new_end_dates = table_person.loc[new_end_idx,
#                                          'observation_period_end_date']

#         # Make sure all end values are after start values
#         new_table_person = dd.concat([new_start_dates.reset_index(
#             drop=True), new_end_dates.reset_index(drop=True)], axis=1)
#         # check = (new_table_person.iloc[:,1]-new_table_person.iloc[:,0]) < pd.Timedelta(0,'D')
#         # if check.all().compute():
#         #     raise AssertionError(
#         #         'Some end dates happen before start dates. Try sorting the original data.')

#         # Create person
#         new_table_person['person_id'] = person
#         # Retrieve most common period type
#         new_table_person['period_type_concept_id'] = table_person['period_type_concept_id'].mode(
#         ).values.compute()[0]
#         new_table_person = new_table_person[['person_id', 'observation_period_start_date',
#                                              'observation_period_end_date', 'period_type_concept_id']]
#         table_reduced.append(new_table_person)
#     table_reduced = dd.concat(table_reduced)
#     return table_reduced

In [None]:
# results = serial_table_dd_reduced()

In [None]:
# table_reduced = results.compute()

In [None]:
# table_reduced.compute()

Esto ""funciona"", pero es lentísimo. Probablemente porque no está en absoluto optimizado y dask tiene que estar pegando trozos de la misma person_id. Habría que ver cómo decirle que debe trabajar con todos los trozos de un person_id de un tirón, sin hacer particiones.

Prueba a lanzarlo sólo para una persona, a ver qué pasa.

#### Intentamos paralelizar
Esto funciona?

In [None]:
# import dask.delayed
# import pandas as pd


# @dask.delayed
# def pick_person(table_dd, person):
#     # Filter for the current person_id
#     filt = table_dd['person_id'] == person
#     return table_dd[filt]


# @dask.delayed
# def compute_intervals(table_person):
#     tmp = table_person
#     # Create necessary columns to do the calculations
#     tmp['previous_end_date'] = tmp['observation_period_end_date'].shift(1)
#     tmp['previous_interval'] = (tmp['observation_period_start_date'] -
#                                 tmp['previous_end_date']).fillna(pd.Timedelta(n_days*10, 'D'))
#     tmp['next_start_date'] = tmp['observation_period_start_date'].shift(-1)
#     tmp['next_interval'] = (
#         tmp['next_start_date']-tmp['observation_period_end_date']).fillna(pd.Timedelta(n_days*10, 'D'))
#     return tmp


# @dask.delayed
# def compute_new_dates(table_person, n_days):
#     tmp = table_person
#     new_start_idx = tmp['previous_interval'] > pd.Timedelta(n_days, 'D')
#     new_start_idx = np.array(new_start_idx)
#     new_start_dates = tmp.loc[new_start_idx, 'observation_period_start_date']
#     # Filter out significant end_days
#     new_end_idx = tmp['next_interval'] > pd.Timedelta(n_days, 'D')
#     new_end_idx = np.array(new_end_idx)
#     new_end_dates = tmp.loc[new_end_idx.compute(),
#                             'observation_period_end_date']
#     return (new_start_dates, new_end_dates)
# # new_start_dates, new_end_dates = ompute_new_dates(table_person, n_days)


# # == Main Loop ==
# n_days = 10
# table_reduced = []
# for person in table_dd['person_id'].unique():
#     # Filter for the current person_id
#     table_person = pick_person(table_dd, person)
#     table_person = table_person.persist()

#     # Create necessary columns to do the calculations
#     table_person = compute_intervals(table_person)

#     # Filter out significant start_days
#     new_start_dates, new_end_dates = compute_new_dates(table_person, n_days)

#     # new_start_dates = dd.Series(new_start_dates.reset_index(drop=True))
#     # new_end_dates = new_end_dates.reset_index(drop=True)
#     # Make sure all end values are after start values
#     # new_table_person = dd.concat([new_start_dates, new_end_dates], axis=1)
#     # check = (new_table_person.iloc[:,1]-new_table_person.iloc[:,0]) < pd.Timedelta(0,'D')
#     # if check.all().compute():
#     #     raise AssertionError(
#     #         'Some end dates happen before start dates. Try sorting the original data.')

#     # # Create person
#     # new_table_person['person_id'] = person
#     # # Retrieve most common period type
#     # new_table_person['period_type_concept_id'] = 1
#     # new_table_person = new_table_person[['person_id', 'observation_period_start_date',
#     #                                      'observation_period_end_date', 'period_type_concept_id']]
#     # table_reduced.append(new_table_person)
# # table_reduced = dd.concat(table_reduced)