# New observation_period_grouping

Resulta que el código que agrupa las fechas en OBSERVATION_PERIOD no es del todo correcto. No tiene en cuenta las citas contenidas una dentro de otra. Esto ya lo hemos hecho para VISIT_OCCURRENCE. Ya que además hemos descubierto que sin depender de concatenar dataframes es todo mucho (**MUCHO**) más rápido, vamos a aprovechar para reescribirlo.

Primero hay que eliminar las citas que vengan contenidas en otra previa. Esto ya se hizo para VISIT_OCCURRENCE, pero vamos a intentar reescribirlo como una función recursiva.

Luego, con el dataframe limpio de citas que se superponen, haremos otra función que calcule las distancias entre citas y elimine las que estén cerca.

**12/09/2024** - Las funciones generadas y descritas en este documento se han movido a `ETL1_transform.general`. Sustituyen a las creadas para las tablas OBSERVATION_PERIOD y VISIT_OCCURRENCE, así que estas se han borrado de sus respectivos archivos.

# 1. Eliminar overlap

El problema consiste en identificar si, para un paciente dado, el sistema tiene citas que se superponen. Para comprobarlo habría que ordenar las filas por cada persona. 

Vamos a ordenar primero por person_id de manera ascendente, para que todas las interacciones de una persona estén juntas. Segundo por start_date de manera ascendente, para que las fechas iniciales estén ordenadas en el tiempo. **Tercero, vamos a ordenar por end_date pero de manera descendente**. Esto me asegura que, en una serie de citas que empiezan el mismo día, la primera fila sea la más duradera y, por tanto, tendrá más posibilidades de englobar a las demás.

Si un start_date de una fila está dentro del intervalo definido en el start_date y end_date anterior, significa que podemos eliminar la fila anterior.

La siguiente clave está en la jerarquía de las citas, si dos citas se superponen y una tiene código de visita a hospital (cód. 8756) y otra de receta farmacia (cód. 581458), el evento más general es la visita al hospital, por lo que es el que debe prevalecer si hay que elegir cual quitar.

## 1.1 Creación dataset de prueba
Vamos a crear un dataframe a mano que tenga todos los problemas que nos podamos encontrar:
- Fechas posteriores completamente contenidas en fechas anteriores
    * (2020-01-01, 2020-02-01) contiene a (2020-01-02, 2020-02-02) y a (2020-01-04, 2020-02-04) 
        - Aquí quiero borrar la segunda.
- Fechas posteriores que se superpongan parcialmente a fechas anteriores
    * (2020-03-01, 2020-04-01) contiene parcialmente a (2020-03-15, 2020-04-15) 
        - Aquí quiero combinar la más antigua con la más nueva => (2020-03-01, 2020-04-15).
- Asegurarse de que no se mezclen datos de dos personas distintas.
    * Comprobar que nunca se combinan datos con person_id distintos.

In [None]:
import pandas as pd

nombre_columnas = [
    "person_id",
    "start_date",
    "end_date",
    "type_concept",
    "should_remain",
    "visit_concept_id",
    "provider_id",
]
filas = [
    # == Problema de fechas ==
    # -- Visita principal 1 --
    (
        1,  # Person_id
        "2020-01-01",  # start_date
        "2020-02-01",  # end_date
        1,  # type_concept
        True,  # should_remain
        9202,  # visit_concept_id
        0,  # provider_id
    ),
    # Completamente solapada, mismo type_concept
    (
        1,  # Person_id
        "2020-01-02",  # start_date
        "2020-01-02",  # end_date
        1,  # type_concept
        False,  # should_remain
        9202,  # visit_concept_id
        0,  # provider_id
    ),
    # Completamente solapada, distinto type_concept
    (
        1,  # Person_id
        "2020-01-04",  # start_date
        "2020-01-04",  # end_date
        2,  # type_concept
        False,  # should_remain
        9202,  # visit_concept_id
        0,  # provider_id
    ),
    # Parcialmente solapada, mismo type_concept
    (
        1,  # Person_id
        "2020-01-06",  # start_date
        "2020-02-06",  # end_date
        1,  # type_concept
        False,  # should_remain
        9202,  # visit_concept_id
        0,  # provider_id
    ),
    # Parcialmente solapada, distinto type_concept
    (
        1,  # Person_id
        "2020-01-08",  # start_date
        "2020-02-08",  # end_date
        2,  # type_concept
        False,  # should_remain
        9202,  # visit_concept_id
        0,  # provider_id
    ),
    # -- Visita principal 2 --
    (
        1,  # Person_id
        "2020-03-01",  # start_date
        "2020-04-01",  # end_date
        1,  # type_concept
        True,  # should_remain
        9202,  # visit_concept_id
        0,  # provider_id
    ),
    # Completamente solapada, mismo type_concept
    (
        1,  # Person_id
        "2020-03-02",  # start_date
        "2020-03-02",  # end_date
        1,  # type_concept
        False,  # should_remain
        9202,  # visit_concept_id
        0,  # provider_id
    ),
    # Completamente solapada, distinto type_concept
    (
        1,  # Person_id
        "2020-03-04",  # start_date
        "2020-03-04",  # end_date
        2,  # type_concept
        False,  # should_remain
        9202,  # visit_concept_id
        0,  # provider_id
    ),
    # Parcialmente solapada, mismo type_concept
    (
        1,  # Person_id
        "2020-03-06",  # start_date
        "2020-04-06",  # end_date
        1,  # type_concept
        False,  # should_remain
        9202,  # visit_concept_id
        0,  # provider_id
    ),
    # Parcialmente solapada, distinto type_concept
    (
        1,  # Person_id
        "2020-03-08",  # start_date
        "2020-04-08",  # end_date
        2,  # type_concept
        False,  # should_remain
        9202,  # visit_concept_id
        0,  # provider_id
    ),
    # == Problema de person_id ==
    # Tres personas distintas
    (
        2,  # Person_id
        "2021-01-01",  # start_date
        "2021-01-01",  # end_date
        1,  # type_concept
        True,  # should_remain
        9202,  # visit_concept_id
        0,  # provider_id
    ),
    (
        2,  # Person_id
        "2021-02-01",  # start_date
        "2021-02-01",  # end_date
        1,  # type_concept
        True,  # should_remain
        9202,  # visit_concept_id
        0,  # provider_id
    ),
    # Comparte fecha con 1
    (
        3,  # Person_id
        "2021-02-01",  # start_date
        "2021-02-01",  # end_date
        2,  # type_concept
        True,  # should_remain
        9202,  # visit_concept_id
        0,  # provider_id
    ),
    # No omparte fecha con 1
    (
        3,  # Person_id
        "2021-03-01",  # start_date
        "2021-03-01",  # end_date
        2,  # type_concept
        True,  # should_remain
        9202,  # visit_concept_id
        0,  # provider_id
    ),
    # == Problema de type_concept ==
    (
        4,  # Person_id
        "2022-03-01",  # start_date
        "2022-04-01",  # end_date
        1,  # type_concept
        True,  # should_remain
        9202,  # visit_concept_id
        0,  # provider_id
    ),
    # Misma persona y fecha, sólo debe quedar type_concept == 1
    (
        4,  # Person_id
        "2022-03-01",  # start_date
        "2022-04-01",  # end_date
        2,  # type_concept
        False,  # should_remain
        9202,  # visit_concept_id
        0,  # provider_id
    ),
]
df_raw = pd.DataFrame.from_records(filas, columns=nombre_columnas)
df_raw["start_date"] = pd.to_datetime(df_raw["start_date"])
df_raw["end_date"] = pd.to_datetime(df_raw["end_date"])
(print(df_raw.info()))
df_raw

In [None]:
df = df_raw.copy()
df.groupby("person_id")["type_concept"].first()

## 1.2 Eliminar overlap
La idea de este código está clara. Se parte de un dataframe que tiene las columnas `person_id`, `start_date`, `end_date`, `type_concept`. Se ordena en el siguiente orden:
1. person_id, ascendente
2. start_date, ascendente
3. end_date, descendente
4. type_concept, ascendente

* De este modo nos aseguramos de que para cada `start_date`, la primera fila tiene la `end_date` más alejada, que es la que puede contener a las otras filas que tengan la misma `start_date`.
* La columna type_concept tiene que transformarse previamente en tipo categoría con un orden que predefinamos, para que así podamos efectuar el orden. Este orden representará la prioridad del código. Cuando todo lo anterior sea igual, la que permanecerá será aquella fila que esté más arriba.

### Polars approach

Es posible que se pueda hacer todo el proceso jugando con columnas. Si además lo hacemos con polars, será bastante más rápido.

La idea es construir en un único dataframe las tablas VISIT_DETAIL y VISIT_OCCURRENCE. 

Imagine this as a visit in a timeline

```python
-(- - -)- - - - - - - -
 |     |
 |     |-> End of the visit
 |
 |-> Start of the visit
```
We want to find the **main visits**:

- A main visit contains other visits. 
- Main visit can be next to each other, ie: The end of the 1st can be the start of the 2nd
- Main visits can not overlap each other, ie: The end of the 1st can not be after the start of the 2nd
- Only main visits can populate the VISIT_OCCURRENCE table

If we sort the visits by person_id (asc), start_date (asc), end_date (desc) y type_concet (asc), we will have something like this for each person:
```python
-(- - -)- - - - - - - - 
-(- -)- - - - - - - - -
-(-)- - - - - - - - - -
- -(- - -)- - - - - - -
- -(- -)- - - - - - - -
- -(-)- - - - - - - - - 
- - -(- - -)- - - - - -
- - -(- -)- - - - - - -
- - -(-)- - - - - - - -
- - - - - -(- - -)- - - 
- - - - - -(- -)- - - -
- - - - - -(-)- - - - -
```

If we compare each visit with the **FIRST ONE**, there are different cases here:
1. **COMPLETELY CONTAINED VISITS**: Contained visits are completely contained in the first visit we are considering (This includes consecutive single day visits)
2. **PARTIALLY CONTAINED VISITS**:The start of the visit happens after the start of the 1st, but end of the happens after the end of the 1st
    - Starts afterwards, but extends further into the future.
3. **NOT CONTAINED VISITS**: The start of the visit is after the 1st. It is a "new" main visit.

```python
-(- - -)- - - - - - - - # This is a main visit
-(- -)- - - - - - - - - 1. contained
-(-)- - - - - - - - - - 1. contained
- -(- - -)- - - - - - - 2. partial
- -(- -)- - - - - - - - 1. contained
- -(-)- - - - - - - - - 1. contained
- - -(- - -)- - - - - - 2. partial
- - -(- -)- - - - - - - 2. partial
- - -(-)- - - - - - - - 1. contained
- - - - - -(- - -)- - - # This is the next main visit
- - - - - -(- -)- - - - 1. contained
- - - - - -(-)- - - - - 1. contained
```

Completely contained visits are the easy ones. We can link them to the main visit and remove them.

Partially contained visits are problematic. These will force us to extend our initial visit further into the future, essentially creating a new record.




Empezamos por juntar todos los archivos. Al hacerlo tenemos esencialmente la tabla VISIT_DETAIL. 

Con esta tabla podemos:

1. Construir el esqueleto de VISIT_DETAIL
   1. Ordenar según el orden person_id (asc), start_date (asc), end_date (desc) y type_concet (asc).
   2. Generar el campo visit_detail_id con todas las visitas realizadas.
   3. Renombrar columnas para adaptarlas al formato VISIT_DETAIL

Con esto, tendríamos el esqueleto de la tabla VISIT_DETAIL. A partir de aquí podemos empezar a construir VISIT_OCCURRENCE sobre VISIT_DETAIL.

2. Extendemos la tabla para prepararla para la generación de VISIT_OCCURRENCE.
   1. Generamos placeholders para las futuras columnas de VISIT_OCCURRENCE y las columnas temporales que usaremos durante su generación
      1. Agrupamos por cada paciente y generamos:
        - **visit_start_datetime**: Inicialmente el primer registro de **visit_detail_start_datetime**. Será la futura columna **visit_start_datetime** de VISIT_OCCCURRENCE.
        - **visit_end_datetime**: Inicialmente el primer registro de **visit_detail_end_datetime**. Será la futura columna **visit_start_datetime** de VISIT_OCCCURRENCE. Si hay que unir visitas esta columna tendrá la end_date de la última visita unida.
        - **visit_detail_id_original**: Inicialmente el primer registro de **visit_detail_id**. Será la futura columna **visit_detail_id** de VISIT_OCCCURRENCE.
        - **main_visit**: Inicialmente tendrá un valor constante *"Unknown"* para todas las filas. La iremos modificando para identificar las que son main_visit (Cambiará a *"Yes"*) de las que no (cambiará a *"No"*).
      2. Hacemos un join con VISIT_DETAIL para poder comparar fila a fila estos campos con cada visita de cada paciente.
   2. Generar una columna *main_visit*.
      - Las *main_visit* son aquellas que engloban otro conjunto de visitas. i.e.: son las filas que estarán en la tabla VISIT_OCCURRENCE
      - Este campo tiene 3 posibilidades: 
         - *Unknown:* No se ha evaluado si es **main_visit** o no
         - *Yes:* Se ha evaluado y SÍ es **main_visit**
         - *No:* Se ha evaluado y NO es **main_visit**
3. Identificamos las primeras **main_visit**
   1. Serán las más recientes que veriquen:
      1. aún estén como *"Unknown"*
      2. **visit_detail_id** == **visit_detail_id_original**
4. Comprobamos si el resto de visitas están:
   1. Completamente contenidas.
      - Se cambia **main_visit** a *"No"*. 
      - Se asigna una columna **parent_visit_detail** al valor correspondiente de **visit_detail_id**.
   2. Parcialmente contenidas.
      - Se cambia **main_visit** a *"No"*. 
      - Se extiende cambia la columna **visit_end_datetime** al valor de la columna **visit_detail_end_datetime** más tardía.
   3. No contenidas
      - Estas filas hay que volver a analizarlas. Algunas de ellas pueden ser una **main_visit**.
      1. Para las filas no contenidas, se agrupa por paciente y se extraen las columnas **visit_detail_start_datetime**, **visit_detail_end_datetime** y **visit_detail_id** del registro más reciente.
      2. Se hace un join con la tabla anterior, regenerando las columnas **visit_start_datetime**, **visit_end_datetime** y **visit_detail_id_original**.
5. Volvemos al paso 3 y repetimos hasta que no haya más columnas **main_visit** == *"Unknown"* o se supere un número predeterminado de iteraciones.


It is important to note that every additional column that previously existed in VISIT_DETAIL will be carried to VISIT_OCCURRENCE. Due to the initial sorting, only the values of the first record at every grouping step (steps 3 and 4) will remain.

In [None]:
import polars as pl
import numpy as np

pl.Config(
    tbl_formatting="MARKDOWN",
    tbl_rows=20,
    set_tbl_width_chars=400,
    set_tbl_cols=-1,
)

df_raw_pl = pl.DataFrame(df_raw)
print(df_raw_pl)

In [None]:
def build_visit_detail(df):
    visit_detail = df
    # First we do the sorting
    visit_detail = visit_detail.sort(
        ["person_id", "start_date", "end_date", "type_concept"],
        descending=[False, False, True, False],
    )
    # Assign the visit_detail_id
    visit_detail = visit_detail.with_columns(visit_detail_id=np.arange(df.shape[0]))
    # Rename columns
    visit_detail = visit_detail.rename(
        {
            "start_date": "visit_detail_start_datetime",
            "end_date": "visit_detail_end_datetime",
            "type_concept": "visit_detail_type_concept_id",
        }
    )
    return visit_detail


visit_detail = build_visit_detail(df_raw_pl)
print(visit_detail)

In [None]:
def build_visit_detail_extended(visit_detail):

    # Get the first date of every person and
    # create a new dataframe with person_id and the first_end_date
    visit_occurrence_dates = (
        visit_detail.group_by("person_id")
        .first()
        .select(
            "person_id",
            "visit_detail_start_datetime",
            "visit_detail_end_datetime",
            "visit_detail_id",
        )
    )
    visit_occurrence_dates = visit_occurrence_dates.rename(
        {
            "visit_detail_start_datetime": "visit_start_datetime",
            "visit_detail_end_datetime": "visit_end_datetime",
            "visit_detail_id": "visit_detail_id_original",
        }
    )
    # Join with the first_end_date dataframe
    # This is to compare each date of each person_id with the first_date of that person_id
    df = visit_detail.join(visit_occurrence_dates, on="person_id", how="left")
    # Create a column to mark if we the row is a main_visit
    df = df.with_columns(main_visit=pl.lit("Unknown").cast(pl.Enum(["Yes", "No", "Unknown"])))

    return df


df = build_visit_detail_extended(visit_detail)
print(df)

This let us identify the first batch of main visits, those that verify `visit_detail_id_original == visit_detail_id`

In [None]:
# Identify main_visits
def identify_next_main_visits(df):
    df = df.with_columns(
        main_visit=(
            pl.when(
                (pl.col("visit_detail_id_original") == pl.col("visit_detail_id"))
                & (pl.col("main_visit") == "Unknown")
            )
            .then(pl.lit("Yes"))
            .otherwise(pl.col("main_visit"))
        )
    )
    return df


df = identify_next_main_visits(df)
print(df)

In [None]:
def identify_contained_rows(df):
    if "is_contained" not in df.columns:
        df = df.with_columns(is_contained=pl.lit(False))

    df = df.with_columns(
        is_contained=pl.when(
            (pl.col("main_visit") == "Unknown")
            & (pl.col("visit_start_datetime") <= pl.col("visit_detail_start_datetime"))
            & (pl.col("visit_end_datetime") >= pl.col("visit_detail_end_datetime"))
        )
        .then(True)
        .otherwise(pl.col("is_contained"))
    )
    return df


def identify_partial_rows(df):
    if "is_partial" not in df.columns:
        df = df.with_columns(is_partial=pl.lit(False))

    df = df.with_columns(
        is_partial=pl.when(
            (pl.col("main_visit") == "Unknown")
            & (pl.col("visit_detail_start_datetime") <= pl.col("visit_end_datetime"))
            & (pl.col("visit_detail_end_datetime") > pl.col("visit_end_datetime"))
        )
        .then(True)
        .otherwise(pl.col("is_partial"))
    )
    return df


def identify_not_contained_rows(df):
    if "not_contained" not in df.columns:
        df = df.with_columns(not_contained=pl.lit(False))

    df = df.with_columns(
        not_contained=pl.when(
            (pl.col("main_visit") == "Unknown")
            & (pl.col("visit_detail_start_datetime") >= pl.col("visit_end_datetime"))
        )
        .then(True)
        .otherwise(pl.col("not_contained"))
    )
    return df


df = identify_contained_rows(df)
df = identify_partial_rows(df)
df = identify_not_contained_rows(df)
print(df)

- Si tenemos `is_contained == True`, la fecha está contenida en la que estamos considerando ahora como la principal.
  - Podríamos asignar `parent_visit_detail_id = visit_detail_id_original`
  - Estas filas no afectan a VISIT_OCCURRENCE. Mantenemos el `visit_occurrence_id = False`


In [None]:
def update_contained_rows(df):

    if "parent_visit_detail_id" not in df.columns:
        df = df.with_columns(parent_visit_detail_id=pl.lit(None))

    df = df.with_columns(
        # Update main_visit to mark contained visits
        main_visit=(
            pl.when((pl.col("is_contained") == True))
            .then(pl.lit("No"))
            .otherwise(pl.col("main_visit"))
        ),
        # Build the parent_visit_detail_id, since we are here
        parent_visit_detail_id=(
            pl.when(
                (pl.col("is_contained") == True)
                & (pl.col("visit_detail_id") != pl.col("visit_detail_id_original"))
            )
            .then(pl.col("visit_detail_id_original"))
            .otherwise(pl.col("parent_visit_detail_id"))
        ),
    )
    return df


df = update_contained_rows(df)
print(df)

- Si tenemos `is_partial == True`, la fecha está parcialmente contenida en la que estamos considerando ahora como la principal.
  - Habría que reemplazar la actual `visit_end_datetime` por su `visit_detail_end_datetime`.
  - Esto puede ocurrir en varias filas, habría que coger la última.

In [None]:
# Retrieve the latest visit_detail_end_datetime from the partially contained visits
def retrieve_latest_date(df):
    latest_date = (
        df.filter(pl.col("is_partial") == True)
        .group_by("person_id")
        .agg(pl.col("visit_detail_end_datetime").max().alias("latest_end_datetime"))
    )
    return latest_date


latest_date = retrieve_latest_date(df)
print(latest_date)

In [None]:
# Join back to the main dataframe and update visit_end_datetime
def update_partial_rows(df):

    latest_date = retrieve_latest_date(df)

    df = (
        df.join(
            latest_date,
            on="person_id",
            how="left",
        )
        .with_columns(
            main_visit=(
                pl.when((pl.col("is_partial") == True))
                .then(pl.lit("No"))
                .otherwise(pl.col("main_visit"))
            ),
            visit_end_datetime=pl.when(pl.col("main_visit") == "Yes")
            .then(
                pl.coalesce(
                    [pl.col("latest_end_datetime"), pl.col("visit_detail_end_datetime")]
                )
            )
            .otherwise(pl.col("visit_end_datetime")),
        )
        .drop("latest_end_datetime")
    )
    return df


df = update_partial_rows(df)
print(df)

- Si tenemos `not_contained == True`, no están contenidas en la visita que estábamos considerando.
  - Habría que comenzar de nuevo usando la primera de estas visitas `not_contained`
  - Esto puede ocurrir en varias filas, habría que coger la primera.

In [None]:
# Retrieve the latest visit_detail_end_datetime from the partially contained visits
def retrieve_newest_not_contained(df):
    newest_not_contained = (
        df.filter(pl.col("not_contained") == True)
        .group_by("person_id")
        .agg(
            pl.col("visit_detail_id").first().alias("visit_detail_id_newest"),
            pl.col("visit_detail_start_datetime")
            .first()
            .alias("visit_detail_start_datetime_newest"),
            pl.col("visit_detail_end_datetime")
            .first()
            .alias("visit_detail_end_datetime_newest"),
        )
    )
    return newest_not_contained


newest_not_contained = retrieve_newest_not_contained(df)
print(newest_not_contained)

In [None]:
def update_not_contained_rows(df):

    newest_not_contained = retrieve_newest_not_contained(df)

    # Join back to the main dataframe and update visit_end_datetime
    df = df.join(
        newest_not_contained,
        on="person_id",
        how="left",
    )
    # Update the values on visit_detail_id_original, visit_start_datetime and visit_end_datetime
    df = df.with_columns(
        visit_detail_id_original=pl.when((pl.col("main_visit") == "Unknown"))
        .then(pl.col("visit_detail_id_newest"))
        .otherwise(pl.col("visit_detail_id_original")),
        visit_start_datetime=pl.when((pl.col("main_visit") == "Unknown"))
        .then(pl.col("visit_detail_start_datetime_newest"))
        .otherwise(pl.col("visit_start_datetime")),
        visit_end_datetime=pl.when((pl.col("main_visit") == "Unknown"))
        .then(pl.col("visit_detail_end_datetime_newest"))
        .otherwise(pl.col("visit_end_datetime")),
    ).drop(
        pl.col(
            "visit_detail_id_newest",
            "visit_detail_start_datetime_newest",
            "visit_detail_end_datetime_newest",
        )
    )
    return df


df = update_not_contained_rows(df)
print(df)

In [None]:
# Identify main_visits
df = identify_next_main_visits(df)
# Now, we can identify all the completely contained cases
df = identify_contained_rows(df)
df = identify_partial_rows(df)
df = identify_not_contained_rows(df)
# Update completely contained rows
df = update_contained_rows(df)
# Update partially contained rows
df = update_partial_rows(df)
# Update not contained rows
df = update_not_contained_rows(df)
print(df)

We need to assign a unique visit_occurrence_id only for main visits.

This functions does that:

In [None]:
def assign_visit_occurrence_id(visit_occurrence):
    # Get the number of main visits. ie: Unique entries in VISIT_OCCURRENCE table
    n_main_visits = visit_occurrence.filter(pl.col("main_visit") == "Yes").height

    # Create a helper column to track main visit sequence
    visit_occurrence = visit_occurrence.with_columns(
        is_main_visit=(pl.col("main_visit") == "Yes")
    )

    # Assign a unique identifier only to main visits using row_number and clean the helper
    visit_occurrence = visit_occurrence.with_columns(
        visit_occurrence_id=pl.when(pl.col("is_main_visit"))
        .then(pl.col("is_main_visit").cast(pl.Int32).cum_sum() - 1)
        .otherwise(None)
    ).drop("is_main_visit")

    # Fill the rest using forward fill (ffill)
    visit_occurrence = visit_occurrence.with_columns(
        visit_occurrence_id=pl.col("visit_occurrence_id").forward_fill()
    )

    return visit_occurrence

visit_occurrence = assign_visit_occurrence_id(df)
visit_occurrence

We need to build the VISIT_DETAIL and VISIT_OCCURRENCE table by removing all the helper columns we have created

In [None]:
def build_visit_occurrence(df, verbose = 0):
    # -- Initialization --
    # Get the core of the visit_detail table
    df = build_visit_detail(df)
    # Extend the table for processing
    df = build_visit_detail_extended(df)
    # Initialize counters for the while loop
    n_unknown = df.filter(pl.col("main_visit") == "Unknown").height
    n_iter = 0

    # -- Loop --
    while n_unknown > 0 and n_iter < 10:
        if verbose > 0:
            print(f"Iter {n_iter:>2}: {n_unknown} unknown rows left.")
        if verbose > 1:
            print(df.filter(pl.col("main_visit") == "Unknown").head(10))

        # Look for next batch of main_visits
        df = identify_next_main_visits(df)

        # Identify and mark completely contained visits
        df = identify_contained_rows(df)
        df = update_contained_rows(df)

        # Identify and mark partially contained visits
        df = identify_partial_rows(df)
        df = update_partial_rows(df)

        # Identify and mark not contained visits
        df = identify_not_contained_rows(df)
        df = update_not_contained_rows(df)

        # Update conditions
        n_unknown = df.filter(pl.col("main_visit") == "Unknown").height
        n_iter += 1

    # Assign an unique visit_occurrence_id only to main_visits
    df = assign_visit_occurrence_id(df)    
    
    # Drop the extra helper columns
    df = df.drop(
        # Drop helpers
        pl.col("visit_detail_id_original"),
        pl.col("is_contained"),
        pl.col("is_partial"),
        pl.col("not_contained"),
    )
    
    # -- Build the core of the visit_detal table --
    visit_detail = df
    
    # Drop visit_occurrence columns
    visit_detail = visit_detail.drop(
        pl.col("visit_start_datetime"),
        pl.col("visit_end_datetime"),
        pl.col("main_visit"), # This one is dropped here so it can be used for visit_occurrence
    )
    
    # -- Build the core of the visit_occurrence table --
    visit_occurrence = df
    
    # Get only main visits
    visit_occurrence = df.filter(pl.col("main_visit") == "Yes").drop(pl.col("main_visit"))

    # Rename columns 
    visit_occurrence = visit_occurrence.rename(
        {
            "visit_detail_type_concept_id": "visit_type_concept_id",
        }
    )

    # Drop columns from visit_detail
    visit_occurrence = visit_occurrence.drop(
        pl.col("visit_detail_start_datetime"),
        pl.col("visit_detail_end_datetime"),
        pl.col("visit_detail_id"),
        pl.col("parent_visit_detail_id"),
    )

    return visit_detail, visit_occurrence

visit_detail, visit_occurrence = build_visit_occurrence(df_raw_pl,verbose=2)

In [None]:
visit_detail


In [None]:
visit_occurrence

In [None]:
%timeit -n 10 -r 5 build_visit_occurrence(df_raw_pl)

TODO:

- Verificar tests.
- Hacer benchmarks
- Limpiar este notebook para que sólo tenga la explicación del código actual
  - Traducir a inglés.

#### Fixing this

This approach is not faster really, let's try to fix it.

Main problem seems to be that we are creating lots of dataframes along the way. We should be creating filters instead, and then apply the filters.

Maybe we can just update a single series with the main visits, as we were, but leaving the original dataframe as is.

In [None]:
# -- Initialization --
# Get the core of the visit_detail table
df = build_visit_detail(df_raw_pl)
# Extend the table for processing
df = build_visit_detail_extended(df)
# Initialize is_contained, is_partial, not_contained and parent_visit_detail_id
df = df.with_columns(
    parent_visit_detail_id=pl.lit(None),
    is_contained=pl.lit(False),
    is_partial=pl.lit(False),
    not_contained=pl.lit(False),
)
# Initialize counters for the while loop
n_unknown = df.filter(pl.col("main_visit") == "Unknown").height
n_iter = 0
df

In [None]:
def single_pass(df):
    # Let's rewrite the inner loop to be done in one pass with polars
    result = (
        df.lazy()
        # Find next batch of main_visits
        .with_columns(
            main_visit=pl.when(
                (pl.col("main_visit") == "Unknown")
                & (pl.col("visit_detail_id_original") == pl.col("visit_detail_id"))
            )
            .then(pl.lit("Yes"))
            .otherwise(pl.col("main_visit"))
        )
        # Identify contained, partial and not contained rows
        .with_columns(
            is_contained=pl.when(
                (pl.col("main_visit") == "Unknown")
                & (pl.col("visit_start_datetime") <= pl.col("visit_detail_start_datetime"))
                & (pl.col("visit_end_datetime") >= pl.col("visit_detail_end_datetime"))
            )
            .then(True)
            .otherwise(pl.col("is_contained")),
            is_partial=pl.when(
                (pl.col("main_visit") == "Unknown")
                & (pl.col("visit_detail_start_datetime") <= pl.col("visit_end_datetime"))
                & (pl.col("visit_detail_end_datetime") > pl.col("visit_end_datetime"))
            )
            .then(True)
            .otherwise(pl.col("is_partial")),
            not_contained=pl.when(
                (pl.col("main_visit") == "Unknown")
                & (pl.col("visit_detail_start_datetime") >= pl.col("visit_end_datetime"))
            )
            .then(True)
            .otherwise(pl.col("not_contained")),
        )
        # -- Contained cases --
        # Assign parent_visit_detail_id for contained visits, update the main_visit on contained cases
        .with_columns(
            parent_visit_detail_id=(
                pl.when(
                    (pl.col("is_contained") == True)
                    & (pl.col("visit_detail_id") != pl.col("visit_detail_id_original"))
                )
                .then(pl.col("visit_detail_id_original"))
                .otherwise(pl.col("parent_visit_detail_id"))
            ),
            main_visit=(
                pl.when((pl.col("is_contained") == True))
                .then(pl.lit("No"))
                .otherwise(pl.col("main_visit"))
            ),
        )
        # -- Partial Cases --
        # Update the main_visit on partial cases and get the latest partial
        .with_columns(
            main_visit=(
                pl.when((pl.col("is_partial") == True))
                .then(pl.lit("No"))
                .otherwise(pl.col("main_visit"))
            ),
            latest=pl.col("visit_detail_end_datetime")
            .filter(pl.col("is_partial") == True)
            .max()
            .over("person_id"),
        )
        .with_columns(
            visit_end_datetime=pl.when((pl.col("visit_detail_id") == pl.col("visit_detail_id_original")))
            .then(pl.coalesce([pl.col("latest"), pl.col("visit_detail_end_datetime")]))
            .otherwise(pl.col("visit_end_datetime")),
        )
        # -- Not contained cases --
        # We need to refresh visit_detail_id_original, visit_start_datetime, visit_end_datetime for Unknown cases
        .with_columns(
            newest_id=(
                pl.col("visit_detail_id")
                .filter(pl.col("not_contained") == True)
                .first()
                .over("person_id")
            ),
            newest_start=(
                pl.col("visit_detail_start_datetime")
                .filter(pl.col("not_contained") == True)
                .first()
                .over("person_id")
            ),
            newest_end=(
                pl.col("visit_detail_end_datetime")
                .filter(pl.col("not_contained") == True)
                .first()
                .over("person_id")
            ),
        )
        .with_columns(
            visit_detail_id_original=pl.when(pl.col("main_visit") == "Unknown")
            .then(pl.col("newest_id").fill_null(pl.col("visit_detail_id_original")))
            .otherwise(pl.col("visit_detail_id_original")),
            visit_start_datetime=pl.when(pl.col("main_visit") == "Unknown")
            .then(pl.col("newest_start").fill_null(pl.col("visit_start_datetime")))
            .otherwise(pl.col("visit_start_datetime")),
            visit_end_datetime=pl.when(pl.col("main_visit") == "Unknown")
            .then(pl.col("newest_end").fill_null(pl.col("visit_end_datetime")))
            .otherwise(pl.col("visit_end_datetime")),
        )
        .select(df.columns)
        .collect()
    )
    return result

result = single_pass(df)
removed_count = df.height - result.height
print(f"Removed {removed_count} overlapping rows. Final rows: {result.height}")

print(
    result.select(
        "person_id",
        "visit_detail_start_datetime",
        "visit_detail_end_datetime",
        "visit_start_datetime",
        "visit_end_datetime",
        "main_visit",
        "is_contained",
        "is_partial",
        "not_contained",
        "visit_detail_id",
        "visit_detail_id_original",
        # "newest_id",
        # "newest_start",
        # "newest_end",
    )
)

In [None]:
result = single_pass(df)
print(
    result.select(
        "person_id",
        "visit_detail_id",
        "visit_detail_id_original",
        "visit_detail_start_datetime",
        "visit_detail_end_datetime",
        "visit_start_datetime",
        "visit_end_datetime",
        "main_visit",
    )
)

In [None]:
result = single_pass(result)
print(
    result.select(
        "person_id",
        "visit_detail_id",
        "visit_detail_id_original",
        "visit_detail_start_datetime",
        "visit_detail_end_datetime",
        "visit_start_datetime",
        "visit_end_datetime",
        "main_visit",
    )
)

In [None]:
def build_visit_occurrence_v2(df_raw_pl, verbose=0):
    # -- Initialization --
    # Get the core of the visit_detail table
    df = build_visit_detail(df_raw_pl)
    # Extend the table for processing
    df = build_visit_detail_extended(df)
    # Initialize is_contained, is_partial, not_contained and parent_visit_detail_id
    df = df.with_columns(
        parent_visit_detail_id=pl.lit(None),
        is_contained=pl.lit(False),
        is_partial=pl.lit(False),
        not_contained=pl.lit(False),
    )
    # Initialize counters for the while loop
    n_unknown = df.filter(pl.col("main_visit") == "Unknown").height
    n_iter = 0
    # -- Loop --
    while n_unknown > 0 and n_iter < 1000:
        if verbose > 0:
            print(f"Iter {n_iter:>2}: {n_unknown} unknown rows left.")

        df = single_pass(df)

        # Update conditions
        n_unknown = df.filter(pl.col("main_visit") == "Unknown").height
        n_iter += 1

    # Assign an unique visit_occurrence_id only to main_visits
    df = assign_visit_occurrence_id(df)    
    
    # Drop the extra helper columns
    df = df.drop(
        # Drop helpers
        pl.col("visit_detail_id_original"),
        pl.col("is_contained"),
        pl.col("is_partial"),
        pl.col("not_contained"),
    )
    
    # -- Build the core of the visit_detal table --
    visit_detail = df
    
    # Drop visit_occurrence columns
    visit_detail = visit_detail.drop(
        pl.col("visit_start_datetime"),
        pl.col("visit_end_datetime"),
        pl.col("main_visit"), # This one is dropped here so it can be used for visit_occurrence
    )
    
    # -- Build the core of the visit_occurrence table --
    visit_occurrence = df
    
    # Get only main visits
    visit_occurrence = df.filter(pl.col("main_visit") == "Yes").drop(pl.col("main_visit"))

    # Rename columns 
    visit_occurrence = visit_occurrence.rename(
        {
            "visit_detail_type_concept_id": "visit_type_concept_id",
        }
    )

    # Drop columns from visit_detail
    visit_occurrence = visit_occurrence.drop(
        pl.col("visit_detail_start_datetime"),
        pl.col("visit_detail_end_datetime"),
        pl.col("visit_detail_id"),
        pl.col("parent_visit_detail_id"),
    )

    return visit_detail, visit_occurrence

%timeit -n 10 -r 5 build_visit_occurrence_v2(df_raw_pl)

In [None]:
import polars as pl
import warnings

def remove_overlap_optimized(
    df: pl.DataFrame,
    person_col: str = "person_id", 
    start_col: str = "start_date",
    end_col: str = "end_date",
    type_col: str = "type_concept",
    verbose: int = 0
) -> pl.DataFrame:
    """
    Optimized version that attempts to remove overlaps in fewer iterations
    by using window functions and more sophisticated logic.
    """
    if verbose > 0:
        print("Removing overlapping rows (optimized)...")
        print(f"Initial rows: {df.height}")
    
    # Sort by person, start_date, end_date (desc)
    df_sorted = df.sort([person_col, start_col, end_col, type_col], descending=[False, False, True, False])
    
    # Use window functions to identify overlaps more efficiently
    result = (
        df_sorted
        .lazy()
        .with_columns([
            # Previous row's end date within same person
            pl.col(end_col).shift(1).over(person_col).alias("prev_end"),
            pl.col(start_col).shift(1).over(person_col).alias("prev_start"),
            pl.col(person_col).shift(1).alias("prev_person")
        ])
        .with_columns([
            # Calculate interval lengths
            (pl.col(end_col) - pl.col(start_col)).alias("curr_interval"),
            (pl.col("prev_end") - pl.col("prev_start")).alias("prev_interval")
        ])
        .with_columns([
            # Mark rows to keep (inverse of removal logic)
            ~(
                # Same person
                (pl.col(person_col) == pl.col("prev_person")) &
                # Current starts after or at same time as previous
                (pl.col(start_col) >= pl.col("prev_start")) &
                # Current ends before or at same time as previous
                (pl.col(end_col) <= pl.col("prev_end")) &
                # Not both single-day intervals
                ~(
                    (pl.col("curr_interval") <= pl.duration(days=1)) &
                    (pl.col("prev_interval") <= pl.duration(days=1))
                )
            ).alias("keep_row")
        ])
        .filter(pl.col("keep_row"))
        .select(df.columns)
        .collect()
    )
    
    if verbose > 0:
        removed_count = df.height - result.height
        print(f"Removed {removed_count} overlapping rows. Final rows: {result.height}")
    
    return result

In [None]:
def build_visit_occurrence_v2_optimized(df_raw_pl, verbose=0):
    # -- Initialization (optimized with chaining) --
    df = (df_raw_pl
          .pipe(build_visit_detail_optimized)
          .pipe(build_visit_detail_extended_optimized)
          .with_columns([
              pl.lit(None).alias("parent_visit_detail_id"),
              pl.lit(False).alias("is_contained"),
              pl.lit(False).alias("is_partial"),
              pl.lit(False).alias("not_contained"),
          ]))
    
    # Pre-compute unknown mask for efficient counting
    unknown_mask = pl.col("main_visit") == "Unknown"
    n_unknown = df.select(unknown_mask.sum()).item()
    n_iter = 0
    
    # -- Optimized Loop --
    while n_unknown > 0 and n_iter < 1000:
        if verbose > 0:
            print(f"Iter {n_iter:>2}: {n_unknown} unknown rows left.")

        df = single_pass_optimized(df)

        # More efficient unknown counting
        n_unknown = df.select(unknown_mask.sum()).item()
        n_iter += 1

    # Assign visit_occurrence_id only to main visits
    df = assign_visit_occurrence_id_optimized(df)    
    
    # -- Build both tables efficiently --
    # Define columns to drop once
    helper_cols = [
        "visit_detail_id_original",
        "is_contained", 
        "is_partial",
        "not_contained"
    ]
    
    visit_detail_drop_cols = helper_cols + [
        "visit_start_datetime",
        "visit_end_datetime",
        "main_visit"
    ]
    
    visit_occurrence_drop_cols = helper_cols + [
        "visit_detail_start_datetime",
        "visit_detail_end_datetime",
        "visit_detail_id", 
        "parent_visit_detail_id",
        "main_visit"
    ]
    
    # Build visit_detail
    visit_detail = df.drop(visit_detail_drop_cols)
    
    # Build visit_occurrence (single operation chain)
    visit_occurrence = (df
        .filter(pl.col("main_visit") == "Yes")
        .drop(visit_occurrence_drop_cols)
        .rename({"visit_detail_type_concept_id": "visit_type_concept_id"})
    )

    return visit_detail, visit_occurrence


def build_visit_detail_optimized(df):
    """Optimized version using lazy evaluation and efficient operations"""
    return (df.lazy()
            .sort(["person_id", "start_date", "end_date", "type_concept"], 
                  descending=[False, False, True, False])
            .with_row_index("visit_detail_id")  # More efficient than numpy arange
            .rename({
                "start_date": "visit_detail_start_datetime",
                "end_date": "visit_detail_end_datetime", 
                "type_concept": "visit_detail_type_concept_id",
            })
            .collect())


def build_visit_detail_extended_optimized(visit_detail):
    """Optimized version with better aggregation"""
    # More efficient aggregation - get first row per person
    visit_occurrence_dates = (
        visit_detail
        .group_by("person_id", maintain_order=True)
        .agg([
            pl.col("visit_detail_start_datetime").first().alias("visit_start_datetime"),
            pl.col("visit_detail_end_datetime").first().alias("visit_end_datetime"),
            pl.col("visit_detail_id").first().alias("visit_detail_id_original"),
        ])
    )
    
    # Single join and column creation
    return (visit_detail
            .join(visit_occurrence_dates, on="person_id", how="left")
            .with_columns(
                pl.lit("Unknown")
                .cast(pl.Enum(["Yes", "No", "Unknown"]))
                .alias("main_visit")
            ))


def single_pass_optimized(df):
    """Optimized single_pass with reduced redundancy and better column operations"""
    return (
        df.lazy()
        # Find next batch of main_visits
        .with_columns(
            main_visit=pl.when(
                (pl.col("main_visit") == "Unknown") &
                (pl.col("visit_detail_id_original") == pl.col("visit_detail_id"))
            )
            .then(pl.lit("Yes"))
            .otherwise(pl.col("main_visit"))
        )
        # Identify contained, partial and not contained rows (optimized conditions)
        .with_columns([
            pl.when(
                (pl.col("main_visit") == "Unknown") &
                (pl.col("visit_start_datetime") <= pl.col("visit_detail_start_datetime")) &
                (pl.col("visit_end_datetime") >= pl.col("visit_detail_end_datetime"))
            )
            .then(True)
            .otherwise(pl.col("is_contained"))
            .alias("is_contained"),
            
            pl.when(
                (pl.col("main_visit") == "Unknown") &
                (pl.col("visit_detail_start_datetime") <= pl.col("visit_end_datetime")) &
                (pl.col("visit_detail_end_datetime") > pl.col("visit_end_datetime"))
            )
            .then(True)
            .otherwise(pl.col("is_partial"))
            .alias("is_partial"),
            
            pl.when(
                (pl.col("main_visit") == "Unknown") &
                (pl.col("visit_detail_start_datetime") >= pl.col("visit_end_datetime"))
            )
            .then(True)
            .otherwise(pl.col("not_contained"))
            .alias("not_contained"),
        ])
        # -- Contained cases --
        .with_columns([
            pl.when(
                pl.col("is_contained") & 
                (pl.col("visit_detail_id") != pl.col("visit_detail_id_original"))
            )
            .then(pl.col("visit_detail_id_original"))
            .otherwise(pl.col("parent_visit_detail_id"))
            .alias("parent_visit_detail_id"),
            
            pl.when(pl.col("is_contained"))
            .then(pl.lit("No"))
            .otherwise(pl.col("main_visit"))
            .alias("main_visit"),
        ])
        # -- Partial Cases --
        .with_columns([
            pl.when(pl.col("is_partial"))
            .then(pl.lit("No"))
            .otherwise(pl.col("main_visit"))
            .alias("main_visit"),
            
            pl.col("visit_detail_end_datetime")
            .filter(pl.col("is_partial"))
            .max()
            .over("person_id")
            .alias("latest"),
        ])
        .with_columns(
            visit_end_datetime=pl.when(pl.col("visit_detail_id") == pl.col("visit_detail_id_original"))
            .then(pl.coalesce([pl.col("latest"), pl.col("visit_detail_end_datetime")]))
            .otherwise(pl.col("visit_end_datetime"))
        )
        # -- Not contained cases --
        .with_columns([
            pl.col("visit_detail_id")
            .filter(pl.col("not_contained"))
            .first()
            .over("person_id")
            .alias("newest_id"),
            
            pl.col("visit_detail_start_datetime")
            .filter(pl.col("not_contained"))
            .first()
            .over("person_id")
            .alias("newest_start"),
            
            pl.col("visit_detail_end_datetime")
            .filter(pl.col("not_contained"))
            .first()
            .over("person_id")
            .alias("newest_end"),
        ])
        .with_columns([
            pl.when(pl.col("main_visit") == "Unknown")
            .then(pl.coalesce([pl.col("newest_id"), pl.col("visit_detail_id_original")]))
            .otherwise(pl.col("visit_detail_id_original"))
            .alias("visit_detail_id_original"),
            
            pl.when(pl.col("main_visit") == "Unknown")
            .then(pl.coalesce([pl.col("newest_start"), pl.col("visit_start_datetime")]))
            .otherwise(pl.col("visit_start_datetime"))
            .alias("visit_start_datetime"),
            
            pl.when(pl.col("main_visit") == "Unknown")
            .then(pl.coalesce([pl.col("newest_end"), pl.col("visit_end_datetime")]))
            .otherwise(pl.col("visit_end_datetime"))
            .alias("visit_end_datetime"),
        ])
        # Clean up temporary columns
        .drop(["latest", "newest_id", "newest_start", "newest_end"])
        .select(df.columns)
        .collect()
    )


def assign_visit_occurrence_id_optimized(visit_occurrence):
    """Optimized version - eliminates unnecessary operations and counts"""
    return (visit_occurrence
        .with_columns(
            # Direct calculation without helper column or unnecessary cast
            visit_occurrence_id=pl.when(pl.col("main_visit") == "Yes")
            .then((pl.col("main_visit") == "Yes").cum_sum() - 1)
            .otherwise(None)
        )
        .with_columns(
            visit_occurrence_id=pl.col("visit_occurrence_id").forward_fill()
        )
    )

In [None]:
build_visit_occurrence_v2(df_raw_pl)

In [None]:
build_visit_occurrence_v2_optimized(df_raw_pl)

In [None]:
%timeit -n 10 -r 5 build_visit_occurrence_v2_optimized(df_raw_pl)

### Recursive approach

In [None]:
import warnings
import pandas as pd

def find_overlap_index(df: pd.DataFrame) -> pd.Series:
    """Finds all rows that:
       - belong to the same person_id
       - are contained with the previous row.
       - are not single day visits
    and removes them.

    Parameters
    ----------
    df : pd.DataFrame
        pandas Dataframe with at least four columns.
        Assumes first column is person_id, second column is
        start_date and third column is end_date.

    Returns
    -------
    pd.Series
        pandas Series with bools. True if row is contained
        with the previous row, False otherwise.
    """
    # 1. Check that current and previous patient are the same
    idx_person = df.iloc[:, 0] == df.iloc[:, 0].shift(1)
    # 2. Check that current start_date is later that previous start_date
    idx_start = df.iloc[:, 1] >= df.iloc[:, 1].shift(1)
    # 3. Check that current end_date is sooner that previous end_date
    idx_end = df.iloc[:, 2] <= df.iloc[:, 2].shift(1)
    # 4. Check that current interval and previos interval are not both single_day
    interval = df.iloc[:, 2] - df.iloc[:, 1]
    idx_int_curr = interval <= pd.Timedelta(1, unit="D")
    idx_int_prev = interval.shift(1) <= pd.Timedelta(1, unit="D")
    idx_interval = ~(idx_int_curr & idx_int_prev)
    # 5. If everything past is true, I can drop the row
    return idx_start & idx_end & idx_person & idx_interval


def remove_overlap(
    df: pd.DataFrame,
    sorting_columns: tuple,
    ascending_order: tuple,
    verbose: int = 0,
    _counter: int = 0,
    _counter_lim: int = 1000,
) -> pd.DataFrame:
    """Removes all rows that are completely contained within
    another row. It will not remove rows that are only partially
    contained within the previous one.

    The function works by sorting the rows by columns. If two or
    more rows are overlapping, only the top one will be kept.


    Parameters
    ----------
    df : pd.DataFrame
        pandas dataframe with overlapping rows to be removed.
        Selection of columns is done by selecting ncols in order.
        This allows its use for different tables with columns
        that have the same purpose but different names.
    sorting_columns : tuple
        Columns to use for sorting.
        Usually, expects 4 columns: 'person_id', 'start_date', 'end_date'
        and some '*_concept_id', like 'visit_concept_id'.
    ascending_order : tuple
        List of bools indicating if each row should have ascending or descending
        order.
        Important! Usually all are true except end_date column. See Notes.
    verbose : int, optional, default 0
        Information output
        - 0 No info
        - 2 Show number of iterations
        - 3 Show an example of the first row being removed and
            the row that contains it.
    _counter : int
        Iteration control param. Number of iterations.
        0 will be used to begin and function will take over.
    _counter_lim : int, optional, default 1000
        Iteration control param. Limit of iterations

    Returns
    -------
    pd.DataFrame
        Copy of input dataframe with contained rows removed.

    Notes
    -------
    The usual behavior is to have 'person_id', 'start_date' and 'end_date'
    as first columns, in ascending, ascending and descending order, respectively.
    This ensures that:
    - All records for the same person are together (sorting by person_id first)
    - Earlier records are placed at the top (sorting by ascending start_date)
    - Longer duration visits are placed at the top (sorting by descending end_date)

    Bear in mind that missing values will be placed at the bottom by default. Any extra
    columns provided will leave any missing values out in case of overlapping records.
    """
    # == Preparation =================================================
    # Sanity checks
    if len(sorting_columns) != len(ascending_order):
        raise ValueError(
            "'sorting_columns' and 'ascending_order' lengths must be equal."
        )

    cond_sort = sorting_columns[:3] != ["person_id", "start_date", "end_date"]
    cond_asce = ascending_order[:3] != [True, True, False]
    if cond_sort or cond_asce:
        warnings.warn(
            "Sorting and ascending initial columns are not the expected order. \
                 Make sure data output is correct."
        )

    # Sort the dataframe if first iteration
    if _counter == 0:
        if verbose > 0:
            print("Removing overlapping rows...")
        if verbose > 1:
            print(f" Iter 0 => {df.shape[0]} initial rows.")
        df = df.sort_values(sorting_columns, ascending=ascending_order)

    # == Find indexes ================================================
    # Get the rows
    idx_to_remove = find_overlap_index(df)

    # == Main "loop" =================================================
    # Prepare next loop
    idx_to_remove_sum = idx_to_remove.sum()
    _counter += 1
    # If there's still room to go, go
    if (idx_to_remove_sum != 0) and (_counter < _counter_lim):
        if verbose > 1:
            # Show iteration and number of rows removed
            print(f" Iter {_counter} => {idx_to_remove_sum} rows removed.")
        if verbose > 2:
            # Get first removed row and show container and contained row
            idx_max = df.index.get_loc(idx_to_remove.idxmax())
            print(f"{df.iloc[(idx_max-1):idx_max+1, :4]}")
        return remove_overlap(
            df.loc[~idx_to_remove], sorting_columns, ascending_order, verbose, _counter
        )
    else:
        return df

In [None]:
import sys
import pyarrow as pa

sys.path.append("../../")
import bps_to_omop.utils.process_dates as pro_dat

table_raw = pa.Table.from_pandas(df_raw)
table_raw = table_raw.cast(
    pa.schema(
        [
            ("person_id", pa.int64()),
            ("start_date", pa.date64()),
            ("end_date", pa.date64()),
            ("type_concept", pa.int64()),
            ("should_remain", pa.int64()),
            ("visit_concept_id", pa.int64()),
            ("provider_id", pa.int64()),
        ]
    )
)
# Define sorting order
sorting_columns = ["person_id", "start_date", "end_date", "visit_concept_id"]
ascending_order = [True, True, False, True]

df_rare = table_raw.to_pandas()
df_done = remove_overlap(df_rare, sorting_columns, ascending_order, verbose=1)
df_done.sort_index()

In [None]:
%timeit -n 10 -r 5 remove_overlap(df_rare, sorting_columns, ascending_order, verbose=0)

### Adapt polars to remove_overlap

The old method did something wrong, it did not joined together dates that were partially contained. Some rows that should be removed are not.

To measure speed with big datasets, we need to verify that both functions return the same results. To do this, we will adapt the functions created for the polars use to match the expected result of the old implementation. 

In [None]:
# Retrieve the latest visit_detail_end_datetime from the partially contained visits
def retrieve_newest_not_contained_adapted(df):
    newest_not_contained = (
        df.filter(pl.col("main_visit") == "Unknown")
        .group_by("person_id")
        .agg(
            pl.col("visit_detail_id").first().alias("visit_detail_id_newest"),
            pl.col("visit_detail_start_datetime")
            .first()
            .alias("visit_detail_start_datetime_newest"),
            pl.col("visit_detail_end_datetime")
            .first()
            .alias("visit_detail_end_datetime_newest"),
        )
    )
    return newest_not_contained


def update_not_contained_rows_adapted(df):

    newest_not_contained = retrieve_newest_not_contained_adapted(df)

    # Join back to the main dataframe and update visit_end_datetime
    df = df.join(
        newest_not_contained,
        on="person_id",
        how="left",
    )
    # Update the values on visit_detail_id_original, visit_start_datetime and visit_end_datetime
    df = df.with_columns(
        visit_detail_id_original=pl.when((pl.col("main_visit") == "Unknown"))
        .then(pl.col("visit_detail_id_newest"))
        .otherwise(pl.col("visit_detail_id_original")),
        visit_start_datetime=pl.when((pl.col("main_visit") == "Unknown"))
        .then(pl.col("visit_detail_start_datetime_newest"))
        .otherwise(pl.col("visit_start_datetime")),
        visit_end_datetime=pl.when((pl.col("main_visit") == "Unknown"))
        .then(pl.col("visit_detail_end_datetime_newest"))
        .otherwise(pl.col("visit_end_datetime")),
    ).drop(
        pl.col(
            "visit_detail_id_newest",
            "visit_detail_start_datetime_newest",
            "visit_detail_end_datetime_newest",
        )
    )
    return df


def remove_overlap_polars(df, verbose=0):
    # -- Initialization --
    # Get the core of the visit_detail table
    df = build_visit_detail(df)
    # Extend the table for processing
    df = build_visit_detail_extended(df)
    # Initialize counters for the while loop
    n_unknown = df.filter(pl.col("main_visit") == "Unknown").height
    n_iter = 0

    # -- Loop --
    while n_unknown > 0 and n_iter < 1000:
        if verbose > 0:
            print(f"Iter {n_iter:>2}: {n_unknown} unknown rows left.")
        if verbose > 1:
            print(df.head(20))

        # Look for next batch of main_visits
        df = identify_next_main_visits(df)

        # Identify and mark completely contained visits
        df = identify_contained_rows(df)
        df = update_contained_rows(df)

        # Identify and mark partially contained visits
        df = identify_partial_rows(df)
        df = update_partial_rows(df)

        # Identify and mark not contained visits
        df = identify_not_contained_rows(df)
        df = update_not_contained_rows(df)

        # Update conditions
        n_unknown = df.filter(pl.col("main_visit") == "Unknown").height
        n_iter += 1

    # Assign an unique visit_occurrence_id only to main_visits
    df = assign_visit_occurrence_id(df)

    # Drop the extra helper columns
    df = df.drop(
        # Drop helpers
        pl.col("visit_detail_id_original"),
        pl.col("is_contained"),
        pl.col("is_partial"),
        pl.col("not_contained"),
    )

    # -- Build the core of the visit_detal table --
    visit_detail = df

    # Drop visit_occurrence columns
    visit_detail = visit_detail.drop(
        pl.col("visit_start_datetime"),
        pl.col("visit_end_datetime"),
        pl.col(
            "main_visit"
        ),  # This one is dropped here so it can be used for visit_occurrence
    )

    # -- Build the core of the visit_occurrence table --
    visit_occurrence = df

    # Get only main visits
    visit_occurrence = df.filter(pl.col("main_visit") == "Yes").drop(
        pl.col("main_visit")
    )

    # Rename columns
    visit_occurrence = visit_occurrence.rename(
        {
            "visit_detail_type_concept_id": "visit_type_concept_id",
        }
    )

    # Drop columns from visit_detail
    visit_occurrence = visit_occurrence.drop(
        pl.col("visit_detail_start_datetime"),
        pl.col("visit_detail_end_datetime"),
        pl.col("visit_detail_id"),
        pl.col("parent_visit_detail_id"),
    )

    return visit_detail, visit_occurrence


_visit_detail, visit_occurrence_polars = remove_overlap_polars(df_raw_pl, verbose=1)
visit_occurrence_polars = visit_occurrence_polars.select(
    "person_id",
    "visit_start_datetime",
    "visit_end_datetime",
    "visit_type_concept_id",
    "should_remain",
    "visit_concept_id",
    "provider_id",
)

visit_occurrence_polars = visit_occurrence_polars.with_columns(
    pl.col("visit_start_datetime").dt.to_string().str.to_datetime(),
    pl.col("visit_end_datetime").dt.to_string().str.to_datetime(),
)

print(visit_occurrence_polars)

In [None]:
# Get the recursive approach here
visit_occurrence_recursive = remove_overlap(df_rare, sorting_columns, ascending_order, verbose=2)
visit_occurrence_recursive = (
    visit_occurrence_recursive.reset_index(drop=True)
    .sort_values(["person_id", "start_date", "end_date", "visit_concept_id"])
    .rename({
        "start_date":"visit_start_datetime",
        "end_date": "visit_end_datetime",
        "type_concept":"visit_type_concept_id",
    },axis=1)
)

visit_occurrence_recursive = pl.DataFrame(visit_occurrence_recursive)
visit_occurrence_recursive = visit_occurrence_recursive.with_columns(
    pl.col("visit_start_datetime").dt.to_string().str.to_datetime(),
    pl.col("visit_end_datetime").dt.to_string().str.to_datetime(),
    pl.col("should_remain").cast(pl.Boolean)
)
print(visit_occurrence_recursive)

In [None]:
from polars.testing import assert_frame_equal

# assert_frame_equal(visit_occurrence_recursive, visit_occurrence_polars)

## Prueba con datasets grandes

Vamos a comparar la velocidad de ambos métodos con datasets grandes.

Nos traemos la función para generar datasets

In [None]:
import numpy as np
import pandas as pd
import pyarrow as pa
from pyarrow import parquet


def create_sample_df(
    n: int = 1000,
    n_dates: int = 50,
    first_date: str = "2020-01-01",
    last_date: str = "2023-01-01",
    mean_duration_days: int = 60,
    std_duration_days: int = 180,
) -> pd.DataFrame:

    # == Parameters ==
    np.random.seed(42)
    pd.options.mode.string_storage = "pyarrow"
    # Start date from which to start the dates
    first_date = pd.to_datetime(first_date)
    last_date = pd.to_datetime(last_date)
    max_days = (last_date - first_date).days

    # == Generate IDs randomly ==
    # -- Generate the Ids
    people = np.random.randint(10000000, 99999999 + 1, size=n)
    person_id = np.random.choice(people, n * n_dates)

    # == Generate random dates ==
    # Generate random integers for days and convert to timedelta
    random_days = np.random.randint(0, max_days, size=n * n_dates)
    # Create the columns
    observation_start_date = first_date + pd.to_timedelta(random_days, unit="D")
    # Generate a gaussian sample of dates
    random_days = np.random.normal(
        mean_duration_days, std_duration_days, size=n * n_dates
    )
    random_days = np.int32(random_days)
    observation_end_date = observation_start_date + pd.to_timedelta(
        random_days, unit="D"
    )
    # Correct end_dates
    # => If they are smaller than start_date, take start_date
    observation_end_date = np.where(
        observation_end_date < observation_start_date,
        observation_start_date,
        observation_end_date,
    )

    # == Generate the code ==
    period_type_concept_id = np.random.randint(1, 11, size=n * n_dates)

    # == Generate the dataframe ==
    df_raw = {
        "person_id": person_id,
        "observation_period_start_date": observation_start_date,
        "observation_period_end_date": observation_end_date,
        "period_type_concept_id": period_type_concept_id,
    }
    return pd.DataFrame(df_raw)

### Sanity check - Return the same results

In [None]:
# Cargamos los datos
df_raw = create_sample_df(
    n=100,
    n_dates=10,
)
df_raw.columns = ["person_id", "start_date", "end_date", "type_concept"]
df_raw.loc[:,"visit_concept_id"] = 9202

# df_raw = df_raw.sort_values(
#     ["person_id", "start_date", "end_date", "type_concept"],
#     ascending=[True, True, False, True],
# )
df_raw.head(10)

In [None]:
# ==  Recursive approach ==
visit_occurrence_recursive = remove_overlap(df_raw, sorting_columns, ascending_order, verbose=2)
visit_occurrence_recursive = (
    visit_occurrence_recursive.reset_index(drop=True)
    .sort_values(["person_id", "start_date", "end_date", "visit_concept_id"])
    .rename({
        "start_date":"visit_start_datetime",
        "end_date": "visit_end_datetime",
        "type_concept":"visit_type_concept_id",
    },axis=1)
)

visit_occurrence_recursive = pl.DataFrame(visit_occurrence_recursive)
visit_occurrence_recursive = visit_occurrence_recursive.with_columns(
    pl.col("visit_start_datetime").dt.to_string().str.to_datetime(),
    pl.col("visit_end_datetime").dt.to_string().str.to_datetime(),
)
visit_occurrence_recursive.head(10)


In [None]:
# == polars approach ==
_visit_detail, visit_occurrence_polars = remove_overlap_polars(pl.DataFrame(df_raw), verbose=1)
visit_occurrence_polars = visit_occurrence_polars.select(
    "person_id",
    "visit_start_datetime",
    "visit_end_datetime",
    "visit_type_concept_id",
    "visit_concept_id",
)

visit_occurrence_polars = visit_occurrence_polars.with_columns(
    pl.col("visit_start_datetime").dt.to_string().str.to_datetime(),
    pl.col("visit_end_datetime").dt.to_string().str.to_datetime(),
)
visit_occurrence_polars.head(10)

In [None]:
build_visit_occurrence_v2(df_raw_pl)

In [None]:
from polars.testing import assert_frame_equal

# assert_frame_equal(visit_occurrence_recursive, visit_occurrence_polars)

### Time test

In [None]:
# Cargamos los datos
df_raw = create_sample_df(n=1000)
df_raw.columns = ['person_id', 'start_date', 'end_date', 'type_concept']
df_raw.loc[:,"visit_concept_id"] = 9202

df_raw_pl = pl.DataFrame(df_raw)

print('\nrecursive:')
%timeit -n 10 -r 5 remove_overlap(df_raw, sorting_columns, ascending_order, verbose=0)

# print('\npolars:') # => Currently broken
# %timeit -n 10 -r 5 build_visit_occurrence(df_raw_pl, verbose=0)

print('\npolars optimized :')
%timeit -n 10 -r 5 remove_overlap_optimized(df_raw_pl, verbose=0)

print('\npolars test:')
%timeit -n 10 -r 5 build_visit_occurrence_v2(df_raw_pl)

print('\npolars test optimized:')
%timeit -n 10 -r 5 build_visit_occurrence_v2_optimized(df_raw_pl)


For n = 100 000

    recursive:
    3.66 s ± 68.8 ms per loop (mean ± std. dev. of 5 runs, 10 loops each)

    polars optimized :
    1.08 s ± 50.2 ms per loop (mean ± std. dev. of 5 runs, 10 loops each)

# 2. Eliminar filas cercanas
Una vez que las filas contenidas en otras se han eliminado, el objetivo ahora es eliminar aquellas que están separadas por un número de días menor al que estipulemos.

## 2.1 Creación dataset de prueba
Vamos a suponer que vamos a agrupar aquellas fechas a menos de 1 año (365 días exactamente) una de otra. 

In [None]:
import numpy as np
import pandas as pd

nombre_columnas = ["person_id", "start_date", "end_date", "type_concept"]
filas = [
    # Estas fechas deberían juntarse porque están a menos de 365 dias
    # type_concept debería ser 2
    (1, "2020-01-01", "2020-02-01", 1),
    (1, "2020-03-01", "2020-04-01", 2),
    (1, "2020-05-01", "2020-12-01", 2),
    # Esta última de la misma persona no
    (1, "2022-01-01", "2022-01-01", 2),
    # Estas fechas deberían juntarse porque se pisan
    # type_concept debería ser 1
    (2, "2020-01-01", "2020-06-01", 1),
    (2, "2020-03-01", "2020-09-01", 1),
    (2, "2020-06-01", "2020-12-01", 2),
    # Estas dos fechas NO deberían juntarse,
    # cada uno es su propio periodo
    (3, "2021-01-01", "2021-01-01", 1),
    (3, "2023-02-01", "2023-02-01", 2),
    (3, "2024-03-01", "2024-04-01", 3),
    # Se juntarían pero no porque son personas distintas
    (4, "2024-01-01", "2024-02-01", 1),
    (5, "2025-01-01", "2025-02-01", 2),
    # Deberían juntarse porque entras ellas hay poca distancia,
    # pero si eliminas una de golpe las otras están muy
    # separadas y no se juntan.
    # type_concept debería ser 2
    (6, "2020-01-01", "2020-12-01", 1),
    (6, "2021-01-01", "2021-12-01", 2),
    (6, "2022-01-01", "2022-12-01", 2),
    (6, "2023-01-01", "2023-12-01", 2),
]
df_raw = pd.DataFrame.from_records(filas, columns=nombre_columnas)
df_raw["start_date"] = pd.to_datetime(df_raw["start_date"])
df_raw["end_date"] = pd.to_datetime(df_raw["end_date"])
df_raw

Suponiendo que **agrupamos periodos separados por menos de 365 días** y que **usamos la moda para calcular el `type_concept` final**. El resultado del agrupamiento de los datos creados debería ser el siguiente:

In [None]:
nombre_columnas = ["person_id", "start_date", "end_date", "type_concept"]
filas = [
    # Estas fechas deberían juntarse porque están a menos de 365 dias
    # type_concept debería ser 2
    (1, "2020-01-01", "2020-12-01", 2),
    # Esta última de la misma persona no
    (1, "2022-01-01", "2022-01-01", 2),
    # Estas fechas deberían juntarse porque se pisan
    # type_concept debería ser 1
    (2, "2020-01-01", "2020-12-01", 1),
    # Estas dos fechas NO deberían juntarse,
    # cada uno es su propio periodo
    (3, "2021-01-01", "2021-01-01", 1),
    (3, "2023-02-01", "2023-02-01", 2),
    (3, "2024-03-01", "2024-04-01", 3),
    # Se juntarían pero no porque son personas distintas
    (4, "2024-01-01", "2024-02-01", 1),
    (5, "2025-01-01", "2025-02-01", 2),
    # Deberían juntarse porque entras ellas hay poca distancia,
    # pero si eliminas una de golpe las otras están muy
    # separadas y no se juntan.
    # type_concept debería ser 2
    (6, "2020-01-01", "2023-12-01", 2),
]
df_result = pd.DataFrame.from_records(filas, columns=nombre_columnas)
df_result["start_date"] = pd.to_datetime(df_result["start_date"])
df_result["end_date"] = pd.to_datetime(df_result["end_date"])
df_result

## 2.2 Eliminación de filas cercanas

Hay varios problemas
- Si te encuentras varias filas que cumplen la condición seguidas, puedes perder información si la primera y la última filas están muy separadas.
- No se ha encontrado una manera efectiva de hacer esto sin iterar como antes.
    - O bien iteras por personas, y no tienes que vigilar que mezclas personas
    - O bien lo haces de golpe, pero es muy complejo llevar la cuenta de los `type_concept` y `person_id` que has eliminado


### 2.2.1 Usando sólo índices

Hay que encontrar la primera y última fila de cada persona y también aquellos casos en los que sólo haya una única fila.

Luego hay que buscar también aquellos casos en los que la siguiente fila esté muy alejada, lo que implicaría que hemos encontrado una brecha en el periodo de observación.

In [None]:
# Get the raw files
df_rare = df_raw.sort_values(
    ["person_id", "start_date", "end_date"], ascending=[True, True, False]
)
# It is VERY important to reset the index to make sure we can
# retrieve them realiably after sorting them.
df_rare = df_rare.reset_index(drop=True)

# Create index for first, last or only person in dataset
df_rare["idx_person_first"] = (
    df_rare["person_id"] == df_rare["person_id"].shift(-1)
) & (df_rare["person_id"] != df_rare["person_id"].shift(1))
df_rare["idx_person_last"] = (
    df_rare["person_id"] != df_rare["person_id"].shift(-1)
) & (df_rare["person_id"] == df_rare["person_id"].shift(1))
df_rare["idx_person_only"] = (
    df_rare["person_id"] != df_rare["person_id"].shift(-1)
) & (df_rare["person_id"] != df_rare["person_id"].shift(1))
# Create index if the break is too big and needs to be kept
n_days = 365
df_rare["next_interval"] = df_rare["start_date"].shift(-1) - df_rare["end_date"]
df_rare["idx_interval"] = df_rare["next_interval"] >= pd.Timedelta(n_days, unit="D")
# Combine all to see which rows remain
df_rare["to_remain"] = (
    df_rare["idx_person_first"]
    | df_rare["idx_person_last"]
    | df_rare["idx_person_only"]
    | df_rare["idx_interval"]
)

df_rare

De esta gente me tengo que quedar seguro :
- `idx_person_first == True`
- `idx_person_last == True`
- `idx_person_only == True`

1. Si para una persona sólo hay 1 `idx_person_only == True`, me quedo ese y a correr. En este caso no hay que hacer nada, esa fila tiene la primera y la última fecha del paciente.

2. Si para una persona sólo hay 1 `idx_person_first == True` y 1 `idx_person_last == True`, entonces tengo el principio y el final. 
    1. Si no hay ningún `idx_interval == True`, junto la `start_date` del `idx_person_first == True` y la `end_date` del `idx_person_last == True`.
    2. Si hay algún `idx_interval == True`, tengo que tener en cuenta que esas filas indican brechas en el periodo de observación. La fila donde `idx_interval == True` indica que es la última del periodo y que la siguiente es el comienzo de otro.

Básicamente hay que registrar por un lado las `start_date`, con sus respectivos `person_id`, y por otro lado las nuevas `end_date`. Vamos a escribir el código que registra esto.

In [None]:
# We will create an initial dataframe with only person_id and
# start_date. The end_date rows and type_concept will be added
# later as new columns.

# == start_date and person_id ==========================================
# To retrieve the start_date we need the indexes of:
# - single day periods (idx_person_only == True)
# - first dates (idx_person_first == True)
# - Rows just after period breaks, (idx_interval.index + 1)

# Get the person condition indexes
idx_start = df_rare.index[
    df_rare["idx_person_only"]
    | df_rare["idx_person_first"]
    | df_rare["idx_interval"].shift(1)
]
# Get the interval indexes
df_done = df_rare.loc[idx_start, ["person_id", "start_date"]]
df_done

# == end_date ==========================================================
# Get the indexes
idx_end = df_rare.index[
    df_rare["idx_person_only"] | df_rare["idx_person_last"] | df_rare["idx_interval"]
]
# Append values found to final dataframe
df_done["end_date"] = df_rare.loc[idx_end, ["end_date"]].values

df_done

Para encontrar los `type_concept`, podemos usar los índices del principio y el final, hacer un zip y, como deberían estar en orden. Tendré una lista con las parejas inicio final de cada periodo.

Si busco todos los `type_concept` dentro de esos periodos, puedo hacer la moda y asignar el `type_concept` más común.

In [None]:
import scipy.stats as st

# I can iterate over idx_start and idx_end to get the
# periods
mode_values = []
for i in np.arange(len(idx_start)):
    df_tmp = df_rare.loc[idx_start[i] : idx_end[i]]
    print(f"{i=}")
    print(df_tmp[["person_id", "start_date", "end_date", "type_concept"]])
    mode = st.mode(df_tmp["type_concept"].values)
    print(f"mode is {mode}", "\n")
    mode_values.append(mode[0])

In [None]:
# Add to dataframe
df_done["type_concept"] = mode_values
df_done

Y listo, ya tengo el dataframe con sólo los inicios y finales de los periodos, incluyendo el type_concept más común calculado usando la moda. Comprobamos que es igual que los resultados esperados:

In [None]:
check_person = df_done["person_id"].values == df_result["person_id"].values
print(f"{'Person_id col is correct:':<28} {check_person.all()}")
check_start_date = df_done["start_date"].values == df_result["start_date"].values
print(f"{'start_date col is correct:':<28} {check_start_date.all()}")
check_end_date = df_done["end_date"].values == df_result["end_date"].values
print(f"{'end_date col is correct:':<28} {check_end_date.all()}")
check_type_concept = df_done["type_concept"].values == df_result["type_concept"].values
print(f"{'type_concept col is correct:'::<28} {check_type_concept.all()}")

### 2.2.2 Prueba con función recursive (NOT FINISHED)

ESTO ESTÁ AQUÍ PARA FUTURAS REFERENCIAS. EL CÓDIGO NO ESTÁ TERMINADO PORQUE EL MÉTODO POR ÍNDICES FUNCIONA LO SUFICIENTEMENTE BIEN Y NO HAY GARANTÍAS DE QUE ESTO LO MEJORE.

Parece que lo mejor (cof) va a ser repetir la estrategia anterior e iterar recursivamente. Así además nos aseguramos que podemos llevar la cuenta de los type_concept y quedarnos con el más representativo.)

In [None]:
# def find_neighbors_index(df: pd.DataFrame,
#                          n_days: int) -> pd.Series:

#     # 1. Check that current and next patient are the same
#     idx_person = df.iloc[:, 0] == df.iloc[:, 0].shift(-1)
#     # 2. Check that current end_date and next start_date
#     # are closer than n_days
#     idx_interval = (
#         (df.iloc[:, 2] - df.iloc[:, 1].shift(-1)) <=
#         pd.Timedelta(n_days, unit='D')
#     )
#     # 4. If everything past is true, I can drop the row
#     return idx_person & idx_interval


# def remove_all_neighbors_recursive_v1(
#         df: pd.DataFrame,
#         n_days: int,
#         verbose: int = 0,
#         _counter: int = 0,
#         _counter_lim: int = 1000) -> pd.DataFrame:

#     # Get the rows
#     idx_to_remove = find_neighbors_index(df, n_days)
#     # Prepare next loop
#     idx_to_remove_sum = idx_to_remove.sum()
#     _counter += 1
#     # If there's still room to go, go
#     if (idx_to_remove_sum != 0) and (_counter < _counter_lim):
#         if verbose >= 1:
#             print(f"Iter {_counter} => {idx_to_remove_sum} rows removed.")
#         if verbose >= 2:
#             print(df[idx_to_remove].head(10))

#         # Modify end_dates
#         df.iloc[:,2] = np.where(idx_to_remove,
#                                 df.iloc[:,2].shift(-1),
#                                 df.iloc[:,2])
#         return remove_all_neighbors_recursive_v1(
#             df[idx_to_remove], verbose, _counter)
#     else:
#         return df

# n_days = 365
# df_rare = df_raw.sort_values(
#     ['person_id', 'start_date', 'end_date'],
#     ascending=[True, True, True])
# df_rare = remove_all_neighbors_recursive_v1(df_rare, n_days, verbose=2)
# df_done = df_rare.sort_index()
# df_done

In [None]:
# n_days = 365
# df = df_raw.sort_values(
#     ['person_id', 'start_date', 'end_date'],
#     ascending=[True, True, True])

# # >>> Iter 1
# idx_to_remove = find_neighbors_index(df, n_days)
# print(f"Iter {1} => {idx_to_remove.sum()} rows removed.")
# print(df[idx_to_remove])
# # <<<

# # Record changes
# df['to_join'] = idx_to_remove
# df['new_end_date'] = np.where(df['to_join'],
#                               df.iloc[:, 2].shift(-1),
#                               df.iloc[:, 2])
# df

In [None]:
# # >>> # Iter 2
# df.iloc[:,2] = np.where(idx_to_remove,
#                         df.iloc[:,2].shift(-1),
#                         df.iloc[:,2])
# idx_to_remove = find_neighbors_index(df, n_days)
# # <<<

# # Record changes
# df['to_join'] = idx_to_remove
# df['new_end_date'] = np.where(df['to_join'],
#                               df.iloc[:, 2].shift(-1),
#                               df.iloc[:, 2])
# df

In [None]:
# # >>> # Iter 3
# df.iloc[:,2] = np.where(idx_to_remove,
#                         df.iloc[:,2].shift(-1),
#                         df.iloc[:,2])
# idx_to_remove = find_neighbors_index(df, n_days)
# # <<<

# # Record changes
# df['to_join'] = idx_to_remove
# df['new_end_date'] = np.where(df['to_join'],
#                               df.iloc[:, 2].shift(-1),
#                               df.iloc[:, 2])
# df

## 2.3 Todo junto

Ahora juntamos todo en una función, para poder medir el tiempo.

In [None]:
import numpy as np
import pandas as pd
import scipy.stats as st


def find_person_index(df: pd.DataFrame) -> tuple[pd.Series]:
    """Finds all rows that are contained with the previous
    row, making sure they belong to the same person_id.

    Parameters
    ----------
    df : pd.DataFrame
        pandas Dataframe with at least three columns.
        Assumes first column is person_id, second column is
        start_date and third column is end_date

    Returns
    -------
    tuple[pd.Series]
        Tuple with three pandas Series with bools:
        - idx_person_first, True if first row of the person
        - idx_person_last, True if last row of the person
        - idx_person_only, True if only row of the person
        False otherwise.
    """

    # Create index for first, last or only person in dataset
    idx_person_first = (df.iloc[:, 0] == df.iloc[:, 0].shift(-1)) & (
        df.iloc[:, 0] != df.iloc[:, 0].shift(1)
    )
    idx_person_last = (df.iloc[:, 0] != df.iloc[:, 0].shift(-1)) & (
        df.iloc[:, 0] == df.iloc[:, 0].shift(1)
    )
    idx_person_only = (df.iloc[:, 0] != df.iloc[:, 0].shift(-1)) & (
        df.iloc[:, 0] != df.iloc[:, 0].shift(1)
    )
    return (idx_person_first, idx_person_last, idx_person_only)


def group_dates(df: pd.DataFrame, n_days: int) -> pd.DataFrame:
    """Groups rows of dates from the same person that are less
    than n_days apart, keeping only the first start_date and
    the last end_date, respectively.

    It will remove rows that are partially contained within
    the previous one.

    Parameters
    ----------
    df : pd.DataFrame
        pandas dataframe with at least four columns:
        ['person_id', 'start_date', 'end_date', 'type_concept'].
        Column names do not need to be the same but, the order
        must be the same as here.
        This allows its use for different tables with columns
        that have the same purpose but different names.
    verbose : int, optional
        Information output, by default 0
        - 0 No info
        - 1 Show number of iterations

    Returns
    -------
    pd.DataFrame
        Copy of input dataframe with grouped rows.
    """

    # == Preparation ==============================================
    # Sort so we know for sure the order is right
    df_rare = df.copy().sort_values(
        [df.columns[0], df.columns[1], df.columns[2]], ascending=[True, True, False]
    )
    # It is VERY important to reset the index to make sure we can
    # retrieve them realiably after sorting them.
    df_rare = df_rare.reset_index(drop=True)

    # == Index look-up ============================================
    (idx_person_first, idx_person_last, idx_person_only) = find_person_index(df_rare)
    # Create index if the break is too big and needs to be kept
    next_interval = df_rare.iloc[:, 1].shift(-1) - df_rare.iloc[:, 2]
    idx_interval = next_interval >= pd.Timedelta(n_days, unit="D")

    # == Retrieve relevant rows ===================================
    # -- start_date and person_id ---------------------------------
    # To retrieve the start_date we need the indexes of:
    # - single day periods (idx_person_only == True)
    # - first dates (idx_person_first == True)
    # - Rows just after period breaks, (idx_interval.index + 1)

    # Get the person condition indexes
    idx_start = df_rare.index[
        idx_person_only | idx_person_first | idx_interval.shift(1)
    ]

    # -- end_date -------------------------------------------------
    # Get the indexes
    idx_end = df_rare.index[idx_person_only | idx_person_last | idx_interval]

    # == Compute type_concept =====================================
    # Iterate over idx_start and idx_end to get the periods
    mode_values = []
    for i in np.arange(len(idx_start)):
        df_tmp = df_rare.loc[idx_start[i] : idx_end[i]]
        mode_values.append(st.mode(df_tmp.iloc[:, 3].values)[0])

    # == Build final dataframe ====================================
    # Create a copy (.loc) with the first two columns
    df_done = df_rare.loc[idx_start, [df.columns[0], df.columns[1]]]
    # Append values found to final dataframe
    df_done[df.columns[2]] = df_rare.loc[idx_end, [df.columns[2]]].values
    # Add to dataframe
    df_done[df.columns[3]] = mode_values

    return df_done

In [None]:
# == Parametros ==
n_days = 365

# == Creación de datos ==
df_done = group_dates(df_raw, n_days)
df_done

volvemos a comprobar que sale bien

In [None]:
check_person = df_done["person_id"].values == df_result["person_id"].values
print(f"{'Person_id col is correct:':<28} {check_person.all()}")
check_start_date = df_done["start_date"].values == df_result["start_date"].values
print(f"{'start_date col is correct:':<28} {check_start_date.all()}")
check_end_date = df_done["end_date"].values == df_result["end_date"].values
print(f"{'end_date col is correct:':<28} {check_end_date.all()}")
check_type_concept = df_done["type_concept"].values == df_result["type_concept"].values
print(f"{'type_concept col is correct:'::<28} {check_type_concept.all()}")

In [None]:
%timeit -n 10 -r 10 group_dates(df_rare,n_days)

# 3. Prueba con datasets grandes
Vamos a comparar si el metodo de pyarrow sigue funcionando más rápido con datasets grandes.

Nos traemos la función para generar datasets

In [None]:
import numpy as np
import pandas as pd
import pyarrow as pa
from pyarrow import parquet


def create_sample_df(
    n: int = 1000,
    n_dates: int = 50,
    first_date: str = "2020-01-01",
    last_date: str = "2023-01-01",
    mean_duration_days: int = 60,
    std_duration_days: int = 180,
) -> pd.DataFrame:

    # == Parameters ==
    np.random.seed(42)
    pd.options.mode.string_storage = "pyarrow"
    # Start date from which to start the dates
    first_date = pd.to_datetime(first_date)
    last_date = pd.to_datetime(last_date)
    max_days = (last_date - first_date).days

    # == Generate IDs randomly ==
    # -- Generate the Ids
    people = np.random.randint(10000000, 99999999 + 1, size=n)
    person_id = np.random.choice(people, n * n_dates)

    # == Generate random dates ==
    # Generate random integers for days and convert to timedelta
    random_days = np.random.randint(0, max_days, size=n * n_dates)
    # Create the columns
    observation_start_date = first_date + pd.to_timedelta(random_days, unit="D")
    # Generate a gaussian sample of dates
    random_days = np.random.normal(
        mean_duration_days, std_duration_days, size=n * n_dates
    )
    random_days = np.int32(random_days)
    observation_end_date = observation_start_date + pd.to_timedelta(
        random_days, unit="D"
    )
    # Correct end_dates
    # => If they are smaller than start_date, take start_date
    observation_end_date = np.where(
        observation_end_date < observation_start_date,
        observation_start_date,
        observation_end_date,
    )

    # == Generate the code ==
    period_type_concept_id = np.random.randint(1, 11, size=n * n_dates)

    # == Generate the dataframe ==
    df_raw = {
        "person_id": person_id,
        "observation_period_start_date": observation_start_date,
        "observation_period_end_date": observation_end_date,
        "period_type_concept_id": period_type_concept_id,
    }
    return pd.DataFrame(df_raw)

Nos traemos la función de pyarrow tal y como estaba el 12/09/2024

In [None]:
import bps_to_omop.general as gen
import pyarrow as pa
import pyarrow.compute as pc
import numpy as np

# Añadimos el directorio superior al path para poder extraer
# las funciones de las carpetas ETL*
import sys
import os

# Add the parent directory of func_folder to the Python path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))


def group_dates_original_pyarrow(table_done, n_days):
    # -- Thirdly, group up dates -----------------------------------------------
    # Agrupamos las fechas usando group_person_dates(). Básicamente calcula la
    # distancia temporal entre las filas adyacentes de cada persona, juntándolas
    # si es tan por debajo del límite marcado por n_days.
    # Agrupamos
    table_OBSERVATION_PERIOD = []

    person_list = pc.unique(table_done["person_id"])
    # Percentage points where you want to print progress
    for i, person in enumerate(person_list[:]):
        # --Group person
        table_person = group_person_dates(table_done, person, n_days)
        # Append table
        table_OBSERVATION_PERIOD.append(table_person)
    # Concatenate
    table_OBSERVATION_PERIOD = pa.concat_tables(table_OBSERVATION_PERIOD)
    return table_OBSERVATION_PERIOD


def group_person_dates(
    table_rare: pa.Table, person: str | int, n_days: int
) -> pa.Table:
    """Filters original table for a specific person and reduces
    the amount of date records grouping all records that are separated
    by n_days or less.

    Parameters
    ----------
    table_rare : pa.Table
        Table as prepared by 'prepare_table_raw_to_rare()'.
    person : str | int
        person id, can be an int (the usual) or a string.
    n_days : int
        number of maximum days between subsequent records.

    Returns
    -------
    pa.Table
        Table identical to table_rare but with less date records.
    """

    # Filter for the current person_id
    filt = pc.is_in(
        table_rare["person_id"], pa.array([person])  # pylint: disable=E1101
    )
    table_person = table_rare.filter(filt)
    # Retrieve corresponding dates
    start_dates = table_person["start_date"]
    end_dates = table_person["end_date"]
    # Group dates closer
    start_dates, end_dates, _ = group_observation_dates(
        start_dates, end_dates, n_days, verbose=False
    )
    # Create person
    person_id = pa_utils.create_uniform_int_array(len(start_dates), value=person)
    # Retrieve most common period type
    period_type_concept_id = pc.mode(  # pylint: disable=E1101
        table_person["period_type_concept_id"]
    )[0][0]
    period_type_concept_id = pa_utils.create_uniform_int_array(
        len(start_dates), value=period_type_concept_id
    )
    # return table
    return pa.Table.from_arrays(
        [person_id, start_dates, end_dates, period_type_concept_id],
        names=["person_id", "start_date", "end_date", "period_type_concept_id"],
    )


def group_observation_dates(
    start_dates: pa.Array, end_dates: pa.Array, n_days: int, verbose: bool = False
) -> tuple[pa.Array, pa.Array, None | pa.Table]:
    """Given a pair of 'start_dates' and 'end_dates', it will
    compute the days between each 'end_date' and the next
    'start_date' and remove dates that are smaller that a
    given number of days ('n_days').
n_days = 365

# Cargamos los datos
df_raw = create_sample_df(n=100)
df_raw.columns = ['person_id', 'start_date', 'end_date', 'type_concept']

df_raw = df_raw.sort_values(
    ['person_id', 'start_date', 'end_date', 'type_concept'],
    ascending=[True, True, False, True])
df_rare = df_raw.reset_index(drop=True).copy()

print('\nshift:')
%timeit -n 1 -r 1 group_dates(df_rare,n_days)

df_rare = df_raw.reset_index(drop=True).copy()
df_rare.columns = ['person_id', 'start_date', 'end_date', 'period_type_concept_id']
table_rare = pa.Table.from_pandas(df_rare,preserve_index=False)
print('\npyarrow:')
%timeit -n 1 -r 1 group_dates_original_pyarrow(table_rare,n_days)
    The new dates will only contain start and end dates that have
    more than 'n_days' of difference between them.

    If dates contain nans/nulls, they will be ignored and grouped
    with the closest dates.

    Parameters
    ----------
    start_dates : pa.Array
        Array of start dates
    end_dates : pa.Array
        Array of end dates
    n_days : int
        _description_
    verbose : bool, optional
        _description_, by default False

    Returns
    -------
    tuple[pa.Array, pa.Array, None | pa.Table]
        Always return a 3-item tuple.
        First item is reduced start dates.
        Second item is reduced end dates.
        Third item is None if verbose=True,
        if verbose=False, is table with start_dates,
        end_dates and days between them. Usefull when
        verifying dates.

    Raises
    ------
    AssertionError
        The resulting starting dates should always come before
        their corresponding end dates. Return an AssertionError
        otherwise.
    """
    # Get an array of end_dates, taking away the last one
    # (last date cannot be compared to the next start date)
    from_dates = end_dates[:-1]
    # Get an array of start_dates, taking away the first one
    # (first date cannot be compared to the previous end date)
    to_dates = start_dates[1:]

    # -- Compute days between
    intervals = pc.days_between(from_dates, to_dates).to_numpy(  # pylint: disable=E1101
        zero_copy_only=False
    )
    # Create an inner table for the calculations if verbose
    inner_table = None
    if verbose:
        inner_table = pa.Table.from_arrays(
            [start_dates, end_dates, pa.array(np.append(intervals, np.nan))],
            names=["start", "end", "intervals"],
        )

    # Filter intervals under some assumption
    filt = intervals >= n_days
    # => When this filt is 'true', it means that for that index,
    # let's call it 'idx', between the end date of 'idx' and the start
    # date of 'idx+1' there more than 'n_days' days.
    # i.e.:
    # (start_date[idx+1] - end_date[idx]).days > n_days

    # if no interval is greater, take the first and last rows
    if np.nansum(filt) == 0:
        idx_end_dates = np.array([len(intervals)])
        # Sum 1 to get start dates
        idx_start_dates = np.array([0])

    # If some filters exist take those
    else:
        # Get indexes of end_dates
        idx_end_dates = filt.nonzero()[0]
        # Sum 1 to get corresponding start dates
        idx_start_dates = idx_end_dates + 1
        # Append last entry as last end_date
        idx_end_dates = np.append(idx_end_dates, len(intervals))
        # Append first entry as first start_date
        idx_start_dates = np.append(0, idx_start_dates)

    # if verbose:
    #     print(f'{idx_start_dates=}')
    #     print(f'{idx_end_dates=}')

    # Make sure all end values are after start values
    new_start = start_dates.take(idx_start_dates)
    new_end = end_dates.take(idx_end_dates)
    if pc.any(pc.less(new_end, new_start)).as_py():  # pylint: disable=E1101
        if verbose:
            print(f"{start_dates=}", f"{end_dates=}")
            print(f"{new_start=}", f"{new_end=}")
        raise AssertionError(
            "Some end dates happen before start dates. Try sorting the original data."
        )

    return (new_start, new_end, inner_table)

## 3.1. Sanity check
### DATA CREATION

Probamos primero que los resultados sean iguales con ambas funciones

In [None]:
n_days = 365

# Cargamos los datos
df_raw = create_sample_df(
    n=100,
    n_dates=10,
)
df_raw.columns = ["person_id", "start_date", "end_date", "type_concept"]

df_raw = df_raw.sort_values(
    ["person_id", "start_date", "end_date", "type_concept"],
    ascending=[True, True, False, True],
)

Let's check how quick `remove_overlap()` is:

In [None]:
%timeit -n 10 -r 10 remove_overlap(df_raw,0,False)

Quite quick!

In [None]:
# Remove contained dates
df_raw = remove_overlap(df_raw, 0, False)

In [None]:
df_raw[df_raw["person_id"] == 10271836]

In [None]:
df_raw[df_raw["person_id"] == 23315092]

### PYARROW

In [None]:
# == pyarrow method ==
df_rare = df_raw.copy()
df_rare.columns = ["person_id", "start_date", "end_date", "period_type_concept_id"]
table_rare = pa.Table.from_pandas(df_rare, preserve_index=False)
table_done = group_dates_original_pyarrow(table_rare, n_days)
df_done_pyarrow = table_done.to_pandas()
df_done_pyarrow = df_done_pyarrow.sort_values(
    ["person_id", "start_date", "end_date", "period_type_concept_id"],
    ascending=[True, True, False, True],
)

In [None]:
df_done_pyarrow[df_done_pyarrow["person_id"] == 10271836]

In [None]:
df_done_pyarrow[df_done_pyarrow["person_id"] == 23315092]

Pyarrow hace la primera persona (10271836), todas las fechas tienen menos de 365 días entre sí, así que se unen en una sola. Las siguientes son todas de una única fecha por persona hasta 23315092, que tiene dos. Esta también la hace bien.


### SHIFT

In [None]:
# == Shift method ==
df_rare = df_raw.copy()
df_done_shift = group_dates(df_rare, n_days)

In [None]:
df_done_shift[df_done_shift["person_id"] == 10271836]

In [None]:
df_done_shift[df_done_shift["person_id"] == 23315092]

Ahora el método shift hace bien la primera persona (10271836) la que tiene dos periodos (23315092). El type_concept cambia del método pyarrow al shift, pero me fio más del shift en este momento.

## 3.2 Time measurement
Ahora probamos a medir el tiempo que tarda cada uno:

In [None]:
n_days = 365

# Cargamos los datos
df_raw = create_sample_df(n=100)
df_raw.columns = ['person_id', 'start_date', 'end_date', 'type_concept']

df_raw = df_raw.sort_values(
    ['person_id', 'start_date', 'end_date', 'type_concept'],
    ascending=[True, True, False, True])
df_rare = df_raw.reset_index(drop=True).copy()

print('\nshift:')
%timeit -n 1 -r 1 group_dates(df_rare,n_days)

df_rare = df_raw.reset_index(drop=True).copy()
df_rare.columns = ['person_id', 'start_date', 'end_date', 'period_type_concept_id']
table_rare = pa.Table.from_pandas(df_rare,preserve_index=False)
print('\npyarrow:')
%timeit -n 1 -r 1 group_dates_original_pyarrow(table_rare,n_days)

Para n = 1000

    shift:
    375 ms ± 884 μs per loop (mean ± std. dev. of 10 runs, 10 loops each)
    pyarrow:
    454 ms ± 664 μs per loop (mean ± std. dev. of 10 runs, 10 loops each)

Para n = 10000

    shift:
    3.69 s ± 4.92 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
    pyarrow:
    23.6 s ± 565 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)

Para n = 30000

    shift:
    10.2 s ± 18.5 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)
    pyarrow:
    3min 10s ± 5.35 s per loop (mean ± std. dev. of 2 runs, 2 loops each)
 

Ahora está bastante claro que el método shift funciona mucho más rápido si tenemos muchas personas. Al final en el método original estamos pegando tablas una encima de otra, lo cual resta mucho tiempo. Y esto teniendo en cuenta que el método shift está ordenando dentro de la propia función, cosa que en el de pyarrow dejamos fuera.

Quizá si se pudiera implementar con pyarrow un modo siguiendo el patrón de los shift, se podría conseguir algo mejor. Con todo, los 10 s con 30000 paciente y 50 fechas por paciente ya me parece un buen resultado.

## Prueba timestamp vs datetime

Nos hemos encontrado que el código remove_overlap va mucho más rápido si las fechas están en formato timestamp (pa.timestamp('us)) que si están en datetime (pa.date64()).

El problema está en que los datos finales en el proyecto `sarscov` no coinciden si se usa un método o el otro.

Probamos a lanzar el código aquí para comprobarlo.

In [None]:
df_raw = create_sample_df(n=1000)
df_raw.columns = ["person_id", "start_date", "end_date", "type_concept"]

In [None]:
n_days = 365
df_rare = df_raw.loc[:,['person_id','start_date','end_date','type_concept']]
table_raw = pa.Table.from_pandas(df_rare, preserve_index=False)
table_raw = table_raw.cast(
    pa.schema([
        ('person_id', pa.int64()),
        ('start_date', pa.timestamp('us')),
        ('end_date', pa.timestamp('us')),
        ('type_concept', pa.int64()),
    ])
)
df_rare = table_raw.to_pandas()
remove_overlap(df_rare,2).info()

%timeit -n 10 -r 10 group_dates_original_pyarrow(table_rare,n_days)

In [None]:
n_days = 365
df_rare = df_raw.loc[:,['person_id','start_date','end_date','type_concept']]
table_raw = pa.Table.from_pandas(df_rare, preserve_index=False)
table_raw = table_raw.cast(
    pa.schema([
        ('person_id', pa.int64()),
        ('start_date', pa.date64()),
        ('end_date', pa.date64()),
        ('type_concept', pa.int64()),
    ])
)
df_rare = table_raw.to_pandas()
remove_overlap(df_rare,2).info()

%timeit -n 10 -r 10 group_dates_original_pyarrow(table_rare,n_days)

Ambos formatos funcionan bien, dejando el mismo número de filas.

Puede que el problema venga de que algunas fechas en los datos del proyecto vienen con hora. Por ejemplo, todas las de farmacia de dispensación. Si paso estos registros a date64 pierdo la información de la hora, por lo que el orden puede que sea distinto.

# 4. Nuevo método para calcular type_concept

Vamos a comparar el método actual de group_dates con hacer groupby

In [None]:
import sys
import pandas as pd
import numpy as np
import scipy.stats as st

sys.path.append("../../")
from bps_to_omop.general import group_dates, find_person_index


def create_sample_data():
    nombre_columnas = ["person_id", "start_date", "end_date", "type_concept"]
    n_days = 365
    df_in = [
        # Una única fecha
        (1, "2020-01-01", "2020-02-01", 1),
        # Dos fechas que se juntan con type_concept iguales
        (2, "2020-01-01", "2020-02-01", 1),
        (2, "2020-03-01", "2020-04-01", 1),
        # Dos fechas que se juntan con type_concept distintos
        (3, "2020-01-01", "2020-02-01", 1),
        (3, "2020-03-01", "2020-04-01", 2),
        # tres fechas que se juntan
        (4, "2020-01-01", "2020-02-01", 1),
        (4, "2020-03-01", "2020-04-01", 1),
        (4, "2020-05-01", "2020-12-01", 2),
        # una persona con dos grupos distintos
        (5, "2020-01-01", "2020-02-01", 1),
        (5, "2020-03-01", "2020-04-01", 1),
        (5, "2020-05-01", "2020-12-01", 2),
        (5, "2022-01-01", "2022-02-01", 3),
        (5, "2022-03-01", "2022-04-01", 3),
        (5, "2022-05-01", "2022-12-01", 2),
    ]
    df_in = pd.DataFrame.from_records(df_in, columns=nombre_columnas).assign(
        start_date=lambda x: pd.to_datetime(x["start_date"]),
        end_date=lambda x: pd.to_datetime(x["end_date"]),
    )
    return df_in


def group_dates_v2(df: pd.DataFrame, n_days: int, verbose: int = 0) -> pd.DataFrame:
    # == Preparation ==============================================
    if verbose > 0:
        print("Grouping dates:")
        print("- Sorting and preparing data...")
    # Sort so we know for sure the order is right
    df_rare = df.copy().sort_values(
        [df.columns[0], df.columns[1], df.columns[2]], ascending=[True, True, False]
    )
    # It is VERY important to reset the index to make sure we can
    # retrieve them realiably after sorting them.
    df_rare = df_rare.reset_index(drop=True)

    # == Index look-up ============================================
    if verbose > 0:
        print("- Looking up indexes...")
    (idx_person_first, idx_person_last, idx_person_only) = find_person_index(df_rare)
    # Create index if the break is too big and needs to be kept
    next_interval = df_rare.iloc[:, 1].shift(-1) - df_rare.iloc[:, 2]
    idx_interval = next_interval >= pd.Timedelta(n_days, unit="D")

    # == Retrieve relevant rows ===================================
    if verbose > 0:
        print("- Retrieving rows...")
    # -- start_date and person_id ---------------------------------
    # To retrieve the start_date we need the indexes of:
    # - single day periods (idx_person_only == True)
    # - first dates (idx_person_first == True)
    # - Rows just after period breaks, (idx_interval.index + 1)

    # Get the person condition indexes
    idx_start = df_rare.index[
        idx_person_only | idx_person_first | idx_interval.shift(1)
    ]

    # -- end_date -------------------------------------------------
    # Get the indexes
    idx_end = df_rare.index[idx_person_only | idx_person_last | idx_interval]

    # == Compute type_concept =====================================
    if verbose > 0:
        print("- Computing type_concept...")
    # Iterate over idx_start and idx_end to get the periods
    mode_values = []
    for i in np.arange(len(idx_start)):
        df_tmp = df_rare.loc[idx_start[i] : idx_end[i]]
        mode_values.append(st.mode(df_tmp.iloc[:, 3].values)[0])

        if (verbose > 1) and ((i) % int(len(idx_start) / 4) == 0):
            print(f"  - ({(i+1)/len(idx_start)*100:.1f} %) {(i+1)}/{len(idx_start)}")
    if verbose > 1:
        print(f"  - (100.0 %) {len(idx_start)}/{len(idx_start)}")

    # == Build final dataframe ====================================
    if verbose > 0:
        print("- Closing up...")
    # Create a copy (.loc) with the first two columns
    df_done = df_rare.loc[idx_start, [df.columns[0], df.columns[1]]]
    # Append values found to final dataframe
    df_done[df.columns[2]] = df_rare.loc[idx_end, [df.columns[2]]].values
    # Add to dataframe
    df_done[df.columns[3]] = mode_values

    if verbose > 0:
        print("- Done!")
    return df_done

In [None]:
verbose = 1
n_days = 365

df = create_sample_data()
df

**(!!)**

La idea aquí es que estamos buscando los índice y construyendo manualmente un dataframe con las fechas iniciales y finales.

NO podemos usar el truco del groupby para el type concept directamente, ya que no sabemos los intervalos finales.

Es decir, podemos agrupar por person_id, pero habría que agrupar también por las fechas, para poder sacar para cada persona y cada observation_period, cuál es el type_concept más frecuente.

In [None]:
# == Preparation ==============================================
if verbose > 0:
    print("Grouping dates:")
    print("- Sorting and preparing data...")
# Sort so we know for sure the order is right
df_rare = df.copy().sort_values(
    [df.columns[0], df.columns[1], df.columns[2]], ascending=[True, True, False]
)
# It is VERY important to reset the index to make sure we can
# retrieve them realiably after sorting them.
df_rare = df_rare.reset_index(drop=True)

# == Index look-up ============================================
if verbose > 0:
    print("- Looking up indexes...")
(idx_person_first, idx_person_last, idx_person_only) = find_person_index(df_rare)
# Create index if the break is too big and needs to be kept
next_interval = df_rare.iloc[:, 1].shift(-1) - df_rare.iloc[:, 2]
idx_interval = next_interval >= pd.Timedelta(n_days, unit="D")

# == Retrieve relevant rows ===================================
if verbose > 0:
    print("- Retrieving rows...")
# -- start_date and person_id ---------------------------------
# To retrieve the start_date we need the indexes of:
# - single day periods (idx_person_only == True)
# - first dates (idx_person_first == True)
# - Rows just after period breaks, (idx_interval.index + 1)

# Get the person condition indexes
idx_start = df_rare.index[idx_person_only | idx_person_first | idx_interval.shift(1)]

# -- end_date -------------------------------------------------
# Get the indexes
idx_end = df_rare.index[idx_person_only | idx_person_last | idx_interval]

# == Compute type_concept =====================================
if verbose > 0:
    print("- Computing type_concept...")
# Iterate over idx_start and idx_end to get the periods
mode_values = []
for i in np.arange(len(idx_start)):
    df_tmp = df_rare.loc[idx_start[i] : idx_end[i]]
    mode_values.append(st.mode(df_tmp.iloc[:, 3].values)[0])

    if (verbose > 1) and ((i) % int(len(idx_start) / 4) == 0):
        print(f"  - ({(i+1)/len(idx_start)*100:.1f} %) {(i+1)}/{len(idx_start)}")
if verbose > 1:
    print(f"  - (100.0 %) {len(idx_start)}/{len(idx_start)}")

# == Build final dataframe ====================================
if verbose > 0:
    print("- Closing up...")
# Create a copy (.loc) with the first two columns
df_done = df_rare.loc[idx_start, [df.columns[0], df.columns[1]]]
# Append values found to final dataframe
df_done[df.columns[2]] = df_rare.loc[idx_end, [df.columns[2]]].values
# # Add to dataframe
df_done[df.columns[3]] = mode_values

if verbose > 0:
    print("- Done!")
df_done

# A. Unfinished testing code

In [None]:
import pandas as pd
import numpy as np
import time
from typing import List, Tuple


def assign_groups_masking(
    indices: np.ndarray, starts: np.ndarray, ends: np.ndarray
) -> np.ndarray:
    """Group assignment using boolean masking"""
    group_ids = np.zeros(len(indices), dtype=int)
    for group_num, (start, end) in enumerate(zip(starts, ends), 1):
        mask = (indices >= start) & (indices <= end)
        group_ids[mask] = group_num
    return group_ids


def assign_groups_searchsorted(
    indices: np.ndarray, starts: np.ndarray, ends: np.ndarray
) -> np.ndarray:
    """Group assignment using searchsorted"""
    boundaries = np.sort(np.concatenate([starts, ends + 1]))
    return np.searchsorted(boundaries, indices, side="right") // 2


def generate_test_case(n_rows: int, n_groups: int) -> Tuple[np.ndarray, np.ndarray]:
    """Generate test data with given size and number of groups"""
    # Create roughly equal-sized groups
    group_size = n_rows // n_groups
    starts = np.arange(0, n_rows, group_size)
    ends = starts + group_size - 1
    ends[-1] = n_rows - 1  # Adjust last group
    return starts, ends


def run_benchmark():
    # Test configurations
    row_sizes = [10_000, 100_000, 1_000_000]
    group_configs = [
        ("Few Large Groups", lambda x: max(5, x // 1_000_000)),
        ("Medium Groups", lambda x: max(50, x // 100_000)),
        ("Many Small Groups", lambda x: max(500, x // 10_000)),
    ]

    results = []

    for n_rows in row_sizes:
        indices = np.arange(n_rows)

        for group_desc, group_func in group_configs:
            n_groups = group_func(n_rows)
            starts, ends = generate_test_case(n_rows, n_groups)

            # Warm-up run
            _ = assign_groups_masking(indices, starts, ends)
            _ = assign_groups_searchsorted(indices, starts, ends)

            # Timing masking approach
            start_time = time.perf_counter()
            for _ in range(5):  # Multiple runs for more stable results
                _ = assign_groups_masking(indices, starts, ends)
            masking_time = (time.perf_counter() - start_time) / 5

            # Timing searchsorted approach
            start_time = time.perf_counter()
            for _ in range(5):
                _ = assign_groups_searchsorted(indices, starts, ends)
            searchsorted_time = (time.perf_counter() - start_time) / 5

            results.append(
                {
                    "Rows": n_rows,
                    "Groups": n_groups,
                    "Configuration": group_desc,
                    "Masking Time": masking_time,
                    "Searchsorted Time": searchsorted_time,
                }
            )

    return pd.DataFrame(results)


# Run benchmark
results_df = run_benchmark()

# Print detailed results
print("\nDetailed Benchmark Results:")
print("=" * 80)
for _, row in results_df.iterrows():
    print(f"\nConfiguration: {row['Configuration']}")
    print(f"Data Size: {row['Rows']:,} rows, {row['Groups']:,} groups")
    print(f"Masking Time: {row['Masking Time']*1000:.2f}ms")
    print(f"Searchsorted Time: {row['Searchsorted Time']*1000:.2f}ms")
    speedup = row["Masking Time"] / row["Searchsorted Time"]
    faster_method = "searchsorted" if speedup > 1 else "masking"
    print(
        f"Winner: {faster_method} ({abs(speedup):,.2f}x {'faster' if speedup > 1 else 'slower'})"
    )

# Calculate and print summary statistics
print("\nSummary Statistics:")
print("=" * 80)
for config in results_df["Configuration"].unique():
    config_results = results_df[results_df["Configuration"] == config]
    print(f"\n{config}:")
    avg_speedup = (
        config_results["Masking Time"] / config_results["Searchsorted Time"]
    ).mean()
    print(f"Average speedup using searchsorted: {avg_speedup:.2f}x")

# Validation of correctness
print("\nValidating correctness of implementations...")
test_indices = np.arange(1000)
test_starts = np.array([0, 200, 400, 600, 800])
test_ends = np.array([199, 399, 599, 799, 999])

masking_results = assign_groups_masking(test_indices, test_starts, test_ends)
searchsorted_results = assign_groups_searchsorted(test_indices, test_starts, test_ends)

if np.array_equal(masking_results, searchsorted_results):
    print("✓ Both implementations produce identical results")
else:
    print("⚠ WARNING: Implementations produce different results!")