# Balanced Risk Set Matching: Implementation

## Notations Used

| Notation/Variable | Description | Section |
| --- | --- | ---|
| $T_m$ | Time when patient $m$ received treatment. If $T_m$ = $\inf$, the patient was not treated. | 1.2 (Risk Set Matching) |
| $Y_i(t)$ | Observed symptoms (e.g., pain, urgency) of patient $i$ at time $t$ | 1.2 (Risk Set Matching) |
| SymptomHistory(T) | Aggregated symptom history up to time $T$. | 1.2 (Risk Set Matching) |
| $\mathcal{A}$ | Set of all patients included in the study | 2.1 (Optimal Balanced Matching) |
| $\mathcal{T}$ | Subset of treated patients in $\mathcal{A}$ | 2.1 (Optimal Balanced Matching) | 
| $\mathcal{A}$ - $\mathcal{T}$ | Subset of controls in $\mathcal{A}$ | 2.1 (Optimal Balanced Matching) |
| $\mathcal{E}$ | Set of all possible edges $\mathcal{e}$ |  2.1 (Optimal Balanced Matching) |
| $\mathcal{e}$ = ($\alpha_p$, $\alpha_q$) | An edge connecting a treated patient $\alpha_p$ with a potential control $\alpha_q$ | 2.1 (Optimal Balanced Matching) |
| $\delta_e$ | Distance between two patients connected by edge $\mathcal{e}$, often measured using Mahalanobis distance. | 2.1 (Optimal Balanced Matching) |
| $S$ | Number of matched pairs. | 2.1 (Optimal Balanced Matching) |
| $M$ | Set of matched pairs $S\subset\mathcal{E}$ | 2.1 (Optimal Balanced Matching) |
| $B_{pk},B_{ek}$ | Binary variables representing attributes (e.g., symptoms or covariates) for treated/control patients. | 2.2 (Balanced Pair Matching) |
| $g_k^+,g_k^-$ | Gap variables indicating positive or negative imbalance in binary variables. | Appendix |
| $\mathcal{L}_t(a)$ | Risk set: Patients with observed covariates $a$ at time $t$. | 4.3 (Matching on Observed Histories |

In [1]:
# Importing Libraries
import numpy as np # For Data Manipulation
import pandas as pd # For Data visualization

from scipy.spatial.distance import mahalanobis
from ortools.graph.python import min_cost_flow

## I. Interstitial Cystitis Data Set

The data, or the patients, in the study is represented in a vector of covariates (i.e. a variable that is observed or measured in a study, which may influence the outcome)

These covariates include (There were SIX Covariates in the study):
- Baseline Values:
  - Pain at baseline ($P_{baseline}$)
  - Urgency at baseline ($U_{baseline}$)
  - Nocturnal Frequency of voiding i.e. peeing at night ($F_{baseline}$}
- Values at treatment time ($T_p$):
  - Pain at the time of treatment ($P_{T_p}$)
  - Urgency at the time of treatment ($U_{T_p}$)
  - Nocturnal Frequency of voiding at the time of treatment i.e. peeing at night ($F_{T_p}$)

These are time-dependent covariates because their values evolve over time.

Thus, this represents a patient vector with the six covariates:

\begin{equation}
a_p = (P_{baseline}, U_{baseline}, F_{baseline}, P_{T_p}, U_{T_p}, F_{T_p})
\end{equation}

NOTE: I also added more variables like an indicator if the patient is treated (Treatment Status) and if so, how long has he/she/they been treated (Treatment Time).

### Loading the Data Set

In [2]:
# Sample Interstitial Data

# Step 1: Loading the Data
df = pd.read_csv('integer_synthetic_data.csv')
print("Sum (NULL): ", df.isnull().sum())
print("Dataset (No. of Patients, No. of Variables): ", df.shape)
df.head()

Sum (NULL):  Patient ID                      0
Baseline Pain                   0
Baseline Urgency                0
Baseline Frequency              0
Treatment Status                0
Treatment Time                  0
Treatment Pain                  0
Treatment Urgency               0
Treatment Frequency             0
Treatment Pain (3 mos)         47
Treatment Pain (6 mos)         39
Treatment Urgency (3 mos)      44
Treatment Urgency (6 mos)      34
Treatment Frequency (3 mos)    46
Treatment Frequency (6 mos)    40
dtype: int64
Dataset (No. of Patients, No. of Variables):  (400, 15)


Unnamed: 0,Patient ID,Baseline Pain,Baseline Urgency,Baseline Frequency,Treatment Status,Treatment Time,Treatment Pain,Treatment Urgency,Treatment Frequency,Treatment Pain (3 mos),Treatment Pain (6 mos),Treatment Urgency (3 mos),Treatment Urgency (6 mos),Treatment Frequency (3 mos),Treatment Frequency (6 mos)
0,P001,4,4,4,1,0,4,1,3,3.0,2.0,1.0,1.0,3.0,3.0
1,P002,5,6,5,0,-1,5,6,6,,,5.0,,4.0,4.0
2,P003,5,5,3,0,-1,6,6,2,5.0,4.0,4.0,4.0,2.0,2.0
3,P004,6,4,4,0,-1,7,4,5,6.0,6.0,4.0,4.0,3.0,4.0
4,P005,5,6,3,0,-1,6,6,3,6.0,6.0,,,2.0,4.0


### Cleaning the Data Set

In [3]:
# Filling in the Empty Values
df['Treatment Pain'] = df['Treatment Pain'].fillna(-1).astype(int)
df['Treatment Urgency'] = df['Treatment Urgency'].fillna(-1).astype(int)
df['Treatment Frequency'] = df['Treatment Frequency'].fillna(-1).astype(int)
df['Treatment Time'] = df['Treatment Time'].fillna(-1).astype(int)
df['Treatment Pain (3 mos)'] = df['Treatment Pain (3 mos)'].fillna(-1).astype(int)
df['Treatment Urgency (3 mos)'] = df['Treatment Urgency (3 mos)'].fillna(-1).astype(int)
df['Treatment Frequency (3 mos)'] = df['Treatment Frequency (3 mos)'].fillna(-1).astype(int)
df['Treatment Pain (6 mos)'] = df['Treatment Pain (6 mos)'].fillna(-1).astype(int)
df['Treatment Urgency (6 mos)'] = df['Treatment Urgency (6 mos)'].fillna(-1).astype(int)
df['Treatment Frequency (6 mos)'] = df['Treatment Frequency (6 mos)'].fillna(-1).astype(int)
print("Sum (NULL): ", df.isnull().sum())
display(df[:10])

Sum (NULL):  Patient ID                     0
Baseline Pain                  0
Baseline Urgency               0
Baseline Frequency             0
Treatment Status               0
Treatment Time                 0
Treatment Pain                 0
Treatment Urgency              0
Treatment Frequency            0
Treatment Pain (3 mos)         0
Treatment Pain (6 mos)         0
Treatment Urgency (3 mos)      0
Treatment Urgency (6 mos)      0
Treatment Frequency (3 mos)    0
Treatment Frequency (6 mos)    0
dtype: int64


Unnamed: 0,Patient ID,Baseline Pain,Baseline Urgency,Baseline Frequency,Treatment Status,Treatment Time,Treatment Pain,Treatment Urgency,Treatment Frequency,Treatment Pain (3 mos),Treatment Pain (6 mos),Treatment Urgency (3 mos),Treatment Urgency (6 mos),Treatment Frequency (3 mos),Treatment Frequency (6 mos)
0,P001,4,4,4,1,0,4,1,3,3,2,1,1,3,3
1,P002,5,6,5,0,-1,5,6,6,-1,-1,5,-1,4,4
2,P003,5,5,3,0,-1,6,6,2,5,4,4,4,2,2
3,P004,6,4,4,0,-1,7,4,5,6,6,4,4,3,4
4,P005,5,6,3,0,-1,6,6,3,6,6,-1,-1,2,4
5,P006,4,4,2,0,-1,4,5,3,3,5,5,5,-1,2
6,P007,5,5,4,1,6,3,4,4,-1,1,-1,1,-1,1
7,P008,5,6,4,0,-1,5,6,5,6,5,7,6,-1,5
8,P009,4,4,4,0,-1,5,3,3,5,5,4,3,4,3
9,P010,4,5,4,1,0,2,5,2,1,1,3,1,1,1


# II. Performing Risk Set Matching

Let's go over to the terms described in the section, 1.2 Risk Set Matching:
- **Risk Set**: The risk set consists of all patients who are at risk of receiving treatment at a given time $T_m$. When a patient $m$ receives treatment, they are matched with another patient from this risk set who has similar symptom profiles but has not yet received treatment by that time.
- **Risk Set Matching**: It is designed to compare a treated patient with a control patient who has not yet received treatment but has as similar history of symptoms up to the time of treatment. This approach aims to create comparable groups by ensuring that the distributions of symptoms are balanced.

Mathematical Representation of the Risk Set:

\begin{equation}
R(T_m) = \{j : T_j > T_m \text{ or } T_j = ∞\}
\end{equation}

Note: We are getting the patients that were treated earlier than the current patient and also the ones that are yet to be treated

In summary, risk sets are collections of untreated patients eligible for matching against treated patients based on their symptom histories, rather than being pairs themselves. The aim is to achieve balanced and comparable groups for analysis in observational studies.

In [4]:
# Function to construct risk sets
def construct_risk_set(df):
    risk_sets = {}

    for index, treated in df[df['Treatment Status'] == 1].iterrows():
        T_treated = treated['Treatment Time']

        # Select patients who were either never treated (NaN) or treated after T_treated
        eligible_controls = df[(df['Treatment Time'].isna()) | 
                               (df['Treatment Time'] > T_treated)]
        
        # Store the risk set
        risk_sets[treated['Patient ID']] = eligible_controls[['Patient ID', 'Baseline Pain', 'Baseline Urgency', 'Baseline Frequency']]

    return risk_sets

In [5]:
# Constructing Risk Set
risk_sets = construct_risk_set(df)

# Example output: Print risk set for the first treated patient
for index, (patient, controls) in enumerate(risk_sets.items()):
    if index >= 3:  # Printing Only 3 Treated Patients and their Eligible Controls
        break
    print(f"Treated Patient: {patient}")
    print("Eligible Control Patients:")
    print(controls)
    print("-" * 50)

Treated Patient: P001
Eligible Control Patients:
    Patient ID  Baseline Pain  Baseline Urgency  Baseline Frequency
6         P007              5                 5                   4
14        P015              4                 5                   3
31        P032              5                 7                   7
47        P048              5                 6                   5
55        P056              5                 5                   3
58        P059              4                 6                   3
64        P065              5                 6                   4
81        P082              4                 6                   3
86        P087              4                 6                   5
92        P093              5                 5                   4
114       P115              5                 6                   5
120       P121              5                 6                   3
188       P189              7                 6                   5

## III. Performing Optimal Matching using Minimum Cost Flow in a Network

In this section, we describe the optimal matching process using minimum cost flow algorithms, which allow for efficient pairing of treated and control patients based on their covariate profiles.

**Overview of Minimum Cost Flow** <br>
The minimum cost flow algorithm is a network-based optimization method that seeks to minimize the total cost associated with transporting goods through a network while satisfying supply and demand constraints. In the context of patient matching, "goods" represent the matched pairs of treated and control patients, and "cost" corresponds to the distance (or dissimilarity) between their covariate profiles.

**Constraints** <br>
1. Each treated patient must be matched with exactly one control patient
2. Each control patient can be matched with at most one treated patient

In [6]:
# Helper Function
def compute_cov_matrix(df, features):
    cov_matrix = np.cov(df[features].dropna().values.T)
    return np.linalg.inv(cov_matrix)  # Return inverse covariance matrix

In [7]:
def minimum_cost_flow_matching(risk_sets, df, features=['Baseline Pain', 'Baseline Urgency', 'Baseline Frequency']):
    """ Solves the optimal matching using minimum cost flow in a network """
    min_cost_flow_net = min_cost_flow.SimpleMinCostFlow()

    # Node mapping and control usage tracking
    treated_nodes = list(risk_sets.keys())
    control_nodes = []
    control_usage = {}

    # Build control nodes list and count risk set appearances
    for controls in risk_sets.values():
        for _, control in controls.iterrows():
            control_id = control['Patient ID']
            if control_id not in control_usage:
                control_nodes.append(control_id)
                control_usage[control_id] = 0
            control_usage[control_id] += 1

    print(f"🔹 Treated: {len(treated_nodes)}, Controls: {len(control_nodes)}")

    # Create node index mapping
    node_index_map = {pid: i for i, pid in enumerate(treated_nodes + control_nodes)}
    source = len(node_index_map)
    sink = source + 1

    # Compute covariance matrix
    cov_inv = compute_cov_matrix(df, features)

    # Add treated-control edges with costs
    for treated_id, controls in risk_sets.items():
        treated_idx = node_index_map[treated_id]
        treated_data = df[df['Patient ID'] == treated_id][features].values[0]
        
        for _, control in controls.iterrows():
            control_id = control['Patient ID']
            control_idx = node_index_map[control_id]
            control_data = control[features].values
            
            distance = mahalanobis(treated_data, control_data, cov_inv)
            if np.isnan(distance):
                continue
                
            # Add edge with capacity 1 and scaled cost
            min_cost_flow_net.add_arc_with_capacity_and_unit_cost(
                treated_idx, control_idx, 1, int(distance * 100)
            )

    # Connect source to treated nodes (1 unit each)
    for treated_id in treated_nodes:
        min_cost_flow_net.add_arc_with_capacity_and_unit_cost(
            source, node_index_map[treated_id], 1, 0
        )

    # Connect control nodes to sink with MULTIPLE USAGE capacity
    for control_id in control_nodes:
        capacity = control_usage[control_id]
        min_cost_flow_net.add_arc_with_capacity_and_unit_cost(
            node_index_map[control_id], sink, capacity, 0
        )

    # Set supply/demand (match all treated patients)
    min_cost_flow_net.set_node_supply(source, len(treated_nodes))
    min_cost_flow_net.set_node_supply(sink, -len(treated_nodes))

    # Solve and extract matches
    if min_cost_flow_net.solve() == min_cost_flow_net.OPTIMAL:
        matches = {}
        matched_controls = set()  # Set to track matched controls
        
        for i in range(min_cost_flow_net.num_arcs()):
            if min_cost_flow_net.flow(i) > 0:
                start = min_cost_flow_net.tail(i)
                end = min_cost_flow_net.head(i)
                
                if start < len(treated_nodes) and end >= len(treated_nodes):
                    treated_pid = treated_nodes[start]
                    control_pid = control_nodes[end - len(treated_nodes)]
                    
                    # Check if this control patient has already been matched
                    if control_pid not in matched_controls:
                        matches[treated_pid] = control_pid
                        matched_controls.add(control_pid)  # Mark this control as matched
                    
        print(f"✅ Successfully matched {len(matches)} pairs")
        return matches
    else:
        print("❌ No solution: Possible reasons -")
        print("- Not enough eligible controls for all treated patients")
        print("- Network constraints too restrictive")
        print("- Extreme distance values causing cost overflow")
        return {}

In [8]:
# Execute Risk Set Matching
risk_sets = construct_risk_set(df)

# Perform Optimal Matching
matches = minimum_cost_flow_matching(risk_sets, df)

print("Number of Edges: ", len(matches))

# Print results
# for treated, control in matches.items():
#     print(f"Treated Patient: {treated} → Matched Control: {control}")

🔹 Treated: 36, Controls: 19
✅ Successfully matched 3 pairs
Number of Edges:  3


In [9]:
# Check for duplicate matched control patients
matched_controls = [control for _, control in matches.items()]
if len(matched_controls) == len(set(matched_controls)):
    print("All matches are unique!")
else:
    print("Duplicate matches found! duplicates: ", len(matched_controls) - len(set(matched_controls)))

All matches are unique!


## IV. Balanced Pair Matching Using Integer Programming (IN PROGRESS)

In [10]:
from ortools.linear_solver import pywraplp

def balanced_pair_matching_ip(df, covariate_pairs, quantiles=[0.33, 0.66], penalty=1e6):
    # Binarize covariates (using pooled quantiles)
    binary_features = []
    for baseline_col, treatment_col in covariate_pairs:
        # Use pooled data for quantiles
        treated_data = df[df['Treatment Status'] == 1][baseline_col]
        control_data = df[df['Treatment Status'] == 0][baseline_col]
        pooled_data = pd.concat([treated_data, control_data])
        q1 = pooled_data.quantile(quantiles[0])
        q2 = pooled_data.quantile(quantiles[1])
        
        for col in [baseline_col, treatment_col]:
            df[f'{col}_leq_q1'] = (df[col] <= q1).astype(int)
            df[f'{col}_leq_q2'] = (df[col] <= q2).astype(int)
            binary_features.extend([f'{col}_leq_q1', f'{col}_leq_q2'])
    
    # Check control count
    treated = df[df['Treatment Status'] == 1].reset_index(drop=True)
    controls = df[df['Treatment Status'] == 0].reset_index(drop=True)
    if len(controls) < len(treated):
        raise ValueError("Insufficient controls for matching.")
    
    # Integer programming setup
    solver = pywraplp.Solver.CreateSolver('SCIP')
    pairs = [(t, c) for t in treated.index for c in controls.index]
    x = {(t, c): solver.BoolVar(f'x_{t}_{c}') for t, c in pairs}
    
    # Matching constraints
    for t in treated.index:
        solver.Add(solver.Sum(x[(t, c)] for c in controls.index) == 1)
    for c in controls.index:
        solver.Add(solver.Sum(x[(t, c)] for t in treated.index) <= 1)
    
    # Balance constraints with penalties
    objective = solver.Objective()
    for feature in binary_features:
        treated_total = solver.Sum(x[(t, c)] * treated.loc[t, feature] for t, c in pairs)
        control_total = solver.Sum(x[(t, c)] * controls.loc[c, feature] for t, c in pairs)
        imbalance = treated_total - control_total
        solver.Add(imbalance <= penalty)
        solver.Add(imbalance >= -penalty)
        objective.SetCoefficient(imbalance, 1)  # Penalize absolute imbalance
    
    # Mahalanobis distance
    cov_features = [col for pair in covariate_pairs for col in pair]
    df[cov_features] = df[cov_features].fillna(df[cov_features].median())  # Impute missing
    cov_matrix = np.cov(df[cov_features].values.T)
    cov_inv = np.linalg.inv(cov_matrix)
    
    for (t, c) in pairs:
        delta = mahalanobis(
            treated.loc[t, cov_features].values,
            controls.loc[c, cov_features].values,
            cov_inv
        )
        objective.SetCoefficient(x[(t, c)], delta)
    objective.SetMinimization()
    
    # Solve
    status = solver.Solve()
    if status == pywraplp.Solver.OPTIMAL:
        return [(treated.loc[t, 'Patient ID'], controls.loc[c, 'Patient ID']) 
                for (t, c) in pairs if x[(t, c)].solution_value() > 0.5]
    else:
        print("No solution. Try relaxing constraints or checking data.")
        return []

In [11]:
# Define covariate pairs to binarize and balance
covariate_pairs = [
    ('Baseline Pain', 'Treatment Pain'),
    ('Baseline Urgency', 'Treatment Urgency'),
    ('Baseline Frequency', 'Treatment Frequency')
]

# Call the integrated function
balanced_matches = balanced_pair_matching_ip(
    df, 
    covariate_pairs, 
    quantiles=[0.33, 0.66]  # Tertiles
)

# Print results
for treated_id, control_id in balanced_matches:
    print(f"Treated Patient: {treated_id} → Balanced Matched Control: {control_id}")

SystemError: <built-in function Objective_SetCoefficient> returned NULL without setting an error

## V. Graphical Comparisons (IN PROGRESS)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Step 1: Prepare the data for box plotsf
# Assume that Treatment Time is in months and we have hypothetical values for post-treatment scores
df['Treatment Pain (3 Months)'] = df['Treatment Pain'] - np.random.randint(0, 3, size=len(df)) # Simulating some change
df['Treatment Pain (6 Months)'] = df['Treatment Pain'] - np.random.randint(0, 5, size=len(df)) # Simulating some change

# Step 2: Calculate differences
df['Difference Pain (3 Months)'] = df['Treatment Pain (3 Months)'] - df['Baseline Pain']
df['Difference Pain (6 Months)'] = df['Treatment Pain (6 Months)'] - df['Baseline Pain']

# Create a function to plot box plots for each covariate category
def plot_boxplots(df):
    categories = {
        "Baseline": ['Baseline Pain', 'Baseline Urgency', 'Baseline Frequency'],
        "At Treatment": ['Treatment Pain', 'Treatment Urgency', 'Treatment Frequency'],
        "3 Months Post-Treatment": ['Treatment Pain (3 Months)', 'Treatment Urgency', 'Treatment Frequency'],
        "6 Months Post-Treatment": ['Treatment Pain (6 Months)', 'Treatment Urgency', 'Treatment Frequency'],
        "Difference (3 Months)": ['Difference Pain (3 Months)'],
        "Difference (6 Months)": ['Difference Pain (6 Months)']
    }

    for title, vars in categories.items():
        plt.figure(figsize=(12, 6))
        
        # Melt the DataFrame for the current category variables
        melted_df = pd.melt(df[df['Treatment Status'].isin([0, 1])], 
                             id_vars=['Patient ID', 'Treatment Status'], 
                             value_vars=vars,
                             var_name='Covariate',
                             value_name='Score')

        # Create a box plot for the current category
        sns.boxplot(x='Covariate', y='Score', hue='Treatment Status', data=melted_df)
        plt.title(f'Box Plots of {title}')
        plt.xticks(rotation=45)
        plt.ylabel('Score')
        plt.xlabel('Covariate')
        plt.legend(title='Group', labels=['Never Treated', 'Treated'])
        plt.tight_layout()
        plt.show()

# Call the function to create box plots
plot_boxplots(df)