<a href="https://colab.research.google.com/github/ccorbett0116/Fall2025ResearchProject/blob/main/Research_Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Project Title:
# Authors: Jose Henriquez, Cole Corbett
## Description:
The deployment of medical AI systems across different hospitals raises critical questions about whether fairness and representation quality can be reliably transferred across clinical domains. Models trained on one hospital’s imaging data are often reused in new environments where patient demographics, imaging devices, and diagnostic practices differ substantially, potentially resulting in unintended bias against certain groups. This project investigates this challenge by studying fairness-aware representation alignment in medical imaging. The student will train contrastive learning models—such as SimCLR—independently on two large-scale chest X-ray datasets: CheXpert (from Stanford Hospital) and MIMIC-CXR (from Beth Israel Deaconess Medical Center). After learning embeddings in each domain, the student will apply domain alignment techniques such as Procrustes alignment to map representations from the CheXpert embedding space into the MIMIC-CXR space. The aligned embeddings will then be evaluated using fairness metrics designed for representation spaces, including demographic subgroup alignment, intra- vs. inter-group embedding disparity, and cluster-level demographic parity. The expected outcome is a rigorous understanding of whether fairness properties learned in one hospital setting preserve, degrade, or improve when transferred to another, revealing how robust model fairness is to realworld clinical domain shifts. A practical use case involves a healthcare network seeking to deploy a model trained at a major academic hospital (e.g., Stanford) into a community hospital setting: this project helps determine whether the transferred representations remain equitable across patient groups such as older adults, women, or specific disease cohorts. The findings will support responsible AI deployment in healthcare by highlighting the conditions under which fairness is stable across institutions and identifying scenarios where domain-specific mitigation strategies may be required.

In [1]:
#Process is probably different on colab, this is hyperspecific to me because I'm working on Pycharm connected to my WSL
import sys
!{sys.executable} -m pip install kagglehub polars
#We're going to use polars because it's significantly faster, it's build on rust and enables multi-threaded processing as well as some memory optimizations over pandas.

Collecting kagglehub
  Downloading kagglehub-0.3.13-py3-none-any.whl.metadata (38 kB)
Collecting polars
  Downloading polars-1.35.2-py3-none-any.whl.metadata (10 kB)
Collecting polars-runtime-32==1.35.2 (from polars)
  Downloading polars_runtime_32-1.35.2-cp39-abi3-win_amd64.whl.metadata (1.5 kB)
Downloading kagglehub-0.3.13-py3-none-any.whl (68 kB)
Downloading polars-1.35.2-py3-none-any.whl (783 kB)
   ---------------------------------------- 0.0/783.6 kB ? eta -:--:--
   --------------------------------------- 783.6/783.6 kB 12.9 MB/s eta 0:00:00
Downloading polars_runtime_32-1.35.2-cp39-abi3-win_amd64.whl (41.3 MB)
   ---------------------------------------- 0.0/41.3 MB ? eta -:--:--
   ----------- ---------------------------- 11.5/41.3 MB 55.7 MB/s eta 0:00:01
   ------------------------------ --------- 31.7/41.3 MB 75.2 MB/s eta 0:00:01
   ---------------------------------------- 41.3/41.3 MB 66.2 MB/s eta 0:00:00
Installing collected packages: polars-runtime-32, polars, kagglehub


[notice] A new release of pip is available: 25.0.1 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
#Again, this is probably different on colab
import kagglehub
path_chexpert = kagglehub.dataset_download("mimsadiislam/chexpert")
print("Path to chexpert dataset files:", path_chexpert)
path_mimic = kagglehub.dataset_download("simhadrisadaram/mimic-cxr-dataset")
print("Path to mimic dataset files:", path_mimic)

Downloading from https://www.kaggle.com/api/v1/datasets/download/mimsadiislam/chexpert?dataset_version_number=1...


100%|██████████| 10.7G/10.7G [02:08<00:00, 89.4MB/s]

Extracting files...





Path to chexpert dataset files: C:\Users\joseh\.cache\kagglehub\datasets\mimsadiislam\chexpert\versions\1
Downloading from https://www.kaggle.com/api/v1/datasets/download/simhadrisadaram/mimic-cxr-dataset?dataset_version_number=2...


100%|██████████| 16.5G/16.5G [03:30<00:00, 84.2MB/s]

Extracting files...





Path to mimic dataset files: C:\Users\joseh\.cache\kagglehub\datasets\simhadrisadaram\mimic-cxr-dataset\versions\2


In [None]:
import os
os.listdir(path_mimic)
os.makedirs("./checkpoints", exist_ok=True)
os.makedirs("./embeddings", exist_ok=True)

In [5]:
import polars as pl
import os

dir_chexpert = os.path.join(path_chexpert, "CheXpert-v1.0-small")
dir_mimic = path_mimic

train_csv_chexpert = os.path.join(dir_chexpert, "train.csv")
train_csv_mimic = os.path.join(dir_mimic, "mimic_cxr_aug_train.csv")
valid_csv_chexpert = os.path.join(dir_chexpert, "valid.csv")
valid_csv_mimic = os.path.join(dir_mimic, "mimic_cxr_aug_validate.csv")

df_train_chexpert = pl.read_csv(train_csv_chexpert)
df_train_mimic = pl.read_csv(train_csv_mimic)
df_valid_chexpert = pl.read_csv(valid_csv_chexpert)
df_valid_mimic = pl.read_csv(valid_csv_mimic)

In [6]:
df_train_chexpert.head()

Path,Sex,Age,Frontal/Lateral,AP/PA,No Finding,Enlarged Cardiomediastinum,Cardiomegaly,Lung Opacity,Lung Lesion,Edema,Consolidation,Pneumonia,Atelectasis,Pneumothorax,Pleural Effusion,Pleural Other,Fracture,Support Devices
str,str,i64,str,str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""CheXpert-v1.0-small/train/pati…","""Female""",68,"""Frontal""","""AP""",1.0,,,,,,,,,0.0,,,,1.0
"""CheXpert-v1.0-small/train/pati…","""Female""",87,"""Frontal""","""AP""",,,-1.0,1.0,,-1.0,-1.0,,-1.0,,-1.0,,1.0,
"""CheXpert-v1.0-small/train/pati…","""Female""",83,"""Frontal""","""AP""",,,,1.0,,,-1.0,,,,,,1.0,
"""CheXpert-v1.0-small/train/pati…","""Female""",83,"""Lateral""",,,,,1.0,,,-1.0,,,,,,1.0,
"""CheXpert-v1.0-small/train/pati…","""Male""",41,"""Frontal""","""AP""",,,,,,1.0,,,,0.0,,,,


In [7]:
df_train_mimic.head()

Unnamed: 0.1,Unnamed: 0,subject_id,image,view,AP,PA,Lateral,text,text_augment
i64,i64,i64,str,str,str,str,str,str,str
0,0,10000032,"""['files/p10/p10000032/s5041426…","""['PA', 'LATERAL', 'AP']""","""['files/p10/p10000032/s5391176…","""['files/p10/p10000032/s5041426…","""['files/p10/p10000032/s5041426…","""['Findings: There is no focal …","""['Findings: There is no focus,…"
1,1,10000764,"""['files/p10/p10000764/s5737596…","""['AP', 'LATERAL']""","""['files/p10/p10000764/s5737596…","""[]""","""['files/p10/p10000764/s5737596…","""['Findings: PA and lateral vie…","""['Finds: PA and lateral view o…"
2,2,10000898,"""['files/p10/p10000898/s5077138…","""['LATERAL', 'PA']""","""[]""","""['files/p10/p10000898/s5077138…","""['files/p10/p10000898/s5077138…","""['Findings: PA and lateral vie…","""['Finds: PA and side view of t…"
3,3,10000935,"""['files/p10/p10000935/s5057897…","""['AP', 'LATERAL', 'LL', 'PA']""","""['files/p10/p10000935/s5057897…","""['files/p10/p10000935/s5569729…","""['files/p10/p10000935/s5117837…","""['Findings: Lung volumes remai…","""['Results: Pulmonary volumes r…"
4,4,10000980,"""['files/p10/p10000980/s5098509…","""['PA', 'LL', 'AP', 'LATERAL']""","""['files/p10/p10000980/s5196728…","""['files/p10/p10000980/s5098509…","""['files/p10/p10000980/s5457736…","""['Findings: Impression: Compa…","""['Findings: Impression: Compar…"


In [10]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121


Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://download.pytorch.org/whl/cu121


ERROR: Could not find a version that satisfies the requirement torch (from versions: none)
ERROR: No matching distribution found for torch


In [None]:
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as T


# Medical-safe augmentations for SimCLR
simclr_aug = T.Compose([
    T.RandomResizedCrop(224, scale=(0.8, 1.0)),
    T.RandomHorizontalFlip(),
    T.RandomRotation(10),
    T.GaussianBlur(3),
    T.ToTensor(),
])

class XRaySimCLRDataset(Dataset):
    def __init__(self, df, root_dir, transform=None):
        self.df = df
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Polars returns a Row object
        row = self.df[idx]
        img_path = os.path.join(self.root_dir, row["Path"])
        img = Image.open(img_path).convert("RGB")

        if self.transform:
            # Return two augmented views
            return self.transform(img), self.transform(img)
        else:
            return img, img


Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://download.pytorch.org/whl/cu121


ERROR: Could not find a version that satisfies the requirement torch (from versions: none)
ERROR: No matching distribution found for torch


ModuleNotFoundError: No module named 'torchvision'

#prepare dataloader 
from torch.utils.data import DataLoader

train_dataset_chexpert = XRaySimCLRDataset(df_train_chexpert, dir_chexpert, simclr_aug)
train_loader_chexpert = DataLoader(train_dataset_chexpert, batch_size=128, shuffle=True, num_workers=4)

train_dataset_mimic = XRaySimCLRDataset(df_train_mimic, dir_mimic, simclr_aug)
train_loader_mimic = DataLoader(train_dataset_mimic, batch_size=128, shuffle=True, num_workers=4)


In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F

class ProjectionHead(nn.Module):
    def __init__(self, in_dim, hidden_dim=2048, out_dim=128):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, out_dim)
        )

    def forward(self, x):
        return self.mlp(x)

class SimCLR(nn.Module):
    def __init__(self, out_dim=128):
        super().__init__()
        resnet = models.resnet50(weights=None)
        self.encoder = nn.Sequential(*list(resnet.children())[:-1])
        self.projector = ProjectionHead(2048, 2048, out_dim)

    def forward(self, x):
        h = self.encoder(x).squeeze()
        z = self.projector(h)
        return h, z


In [None]:
def nt_xent_loss(z, temperature=0.5):
    z = F.normalize(z, dim=1)
    similarity_matrix = torch.matmul(z, z.T)
    logits = similarity_matrix / temperature
    labels = torch.arange(z.size(0)).to(z.device)
    return F.cross_entropy(logits, labels)


In [None]:
#CheXpert Dataset training loop
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SimCLR().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

epochs = 10  # adjust for full training

for epoch in range(epochs):
    for x1, x2 in train_loader_chexpert:
        x1, x2 = x1.to(device), x2.to(device)

        _, z1 = model(x1)
        _, z2 = model(x2)

        z = torch.cat([z1, z2], dim=0)
        loss = nt_xent_loss(z)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")

# Save checkpoint
torch.save(model.state_dict(), "./checkpoints/simclr_chexpert.pth")


In [None]:
#MIMIC Dataset training loop
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SimCLR().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

epochs = 10  # adjust for full training

for epoch in range(epochs):
    for x1, x2 in train_loader_chexpert:
        x1, x2 = x1.to(device), x2.to(device)

        _, z1 = model(x1)
        _, z2 = model(x2)

        z = torch.cat([z1, z2], dim=0)
        loss = nt_xent_loss(z)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")

# Save checkpoint
torch.save(model.state_dict(), "/checkpoints/simclr_mimic.pth")


In [None]:
def extract_embeddings(model, loader, device="cuda"):
    model.eval()
    embeddings = []
    with torch.no_grad():
        for x1, _ in loader:
            x1 = x1.to(device)
            h, _ = model(x1)
            embeddings.append(h.cpu())
    return torch.cat(embeddings)

emb_chexpert = extract_embeddings(model, train_loader_chexpert)
torch.save(emb_chexpert, "/embeddings/chexpert.pt")
emb_mimic = extract_embeddings(model, train_loader_mimic, device="cuda")
torch.save(emb_mimic, "./embeddings/mimic.pt")



In [None]:
def procrustes_alignment(X, Y):
    U, _, Vt = torch.linalg.svd(Y.T @ X)
    R = U @ Vt
    X_aligned = X @ R.T
    return X_aligned

aligned_chexpert = procrustes_alignment(emb_chexpert, emb_mimic)
torch.save(aligned_chexpert, "/embeddings/chexpert_aligned.pt")


In [None]:
def subgroup_centroid_distance_pl(embeddings, demographics):
    # Convert to a list of unique groups
    groups = demographics.unique().to_list()  # Polars syntax
    centroids = {}

    for g in groups:
        # Filter embeddings by group
        idxs = [i for i, val in enumerate(demographics) if val == g]
        group_emb = embeddings[idxs]
        centroids[g] = group_emb.mean(dim=0)

    dist = {}
    for g1 in groups:
        for g2 in groups:
            dist[(g1, g2)] = torch.norm(centroids[g1] - centroids[g2]).item()
    return dist

df_demo = pl.read_csv(os.path.join(dir_chexpert, "demographics.csv"))
# Assume 'Sex' column exists
demographics = df_demo["Sex"].to_list()  # Polars column → Python list


In [None]:
def subgroup_centroid_distance(embeddings, demographics):
    groups = demographics.unique()
    centroids = {g: embeddings[demographics==g].mean(dim=0) for g in groups}
    dist = {}
    for g1 in groups:
        for g2 in groups:
            dist[(g1,g2)] = torch.norm(centroids[g1] - centroids[g2]).item()
    return dist

import pandas as pd
df_demo = pd.read_csv(os.path.join(dir_chexpert, "demographics.csv"))  # must contain 'Sex' column
distances = subgroup_centroid_distance(aligned_chexpert, df_demo["Sex"])
print(distances)