# AIS Cohort Builder - UK Biobank RAP

This notebook builds the AIS case-control cohort from UK Biobank data.

**Inputs:**
- `ukb_phenotypes.csv` - Exported phenotype data
- `ukb_rel.dat` - KING kinship coefficients

**Outputs:**
- `cohort.parquet` - Final matched cohort
- `matching_info.parquet` - Case-control pairs
- `ancestry_pcs.parquet` - Genetic PCs for cohort
- `cohort_qc_report.yaml` - QC statistics

## 1. Setup and Configuration

In [None]:
# Configuration - UPDATE THESE PATHS
PROJECT_ID = "project-J2qppj8JFxzP5QV25fJ8yjQG"
PHENOTYPE_FILE = f"{PROJECT_ID}:/ukb_phenotypes.csv"
KINSHIP_FILE = f"{PROJECT_ID}:/Bulk/Genotype Results/Genotype calls/ukb_rel.dat"
OUTPUT_DIR = "/opt/notebooks/cohort_output"

# Cohort parameters
CONTROLS_PER_CASE = 4
PCA_OUTLIER_SD = 6.0
KINSHIP_THRESHOLD = 0.0884  # 2nd degree relatives

In [None]:
# Install dependencies if needed
import subprocess
subprocess.run(["pip", "install", "-q", "pyarrow", "pyyaml"])

In [None]:
import os
import re
import logging
from pathlib import Path

import numpy as np
import pandas as pd
from scipy.spatial.distance import cdist
from sklearn.preprocessing import StandardScaler
import yaml

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)-8s | %(message)s",
)
logger = logging.getLogger(__name__)

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)
print(f"Output directory: {OUTPUT_DIR}")

## 2. Download Data Files

In [None]:
import dxpy

# Download phenotype file
print("Downloading phenotype file...")
local_phenotype = "/opt/notebooks/ukb_phenotypes.csv"
dxpy.download_dxfile(PHENOTYPE_FILE, local_phenotype)
print(f"Downloaded to {local_phenotype}")

# Download kinship file
print("Downloading kinship file...")
local_kinship = "/opt/notebooks/ukb_rel.dat"
dxpy.download_dxfile(KINSHIP_FILE, local_kinship)
print(f"Downloaded to {local_kinship}")

## 3. Load and Explore Data

In [None]:
# Load phenotype data
print("Loading phenotype data...")
pheno_df = pd.read_csv(local_phenotype)
print(f"Loaded {len(pheno_df)} samples")
print(f"Columns: {list(pheno_df.columns)}")
pheno_df.head()

In [None]:
# Load kinship data
print("Loading kinship data...")
kinship_df = pd.read_csv(local_kinship, sep="\s+")
print(f"Loaded {len(kinship_df)} related pairs")
print(f"Columns: {list(kinship_df.columns)}")
kinship_df.head()

## 4. Standardize Column Names

In [None]:
def standardize_phenotype_columns(df):
    """
    Standardize UK Biobank column names to our expected format.
    """
    # Print original columns for debugging
    print(f"Original columns: {list(df.columns)}")
    
    rename_map = {}
    
    for col in df.columns:
        col_lower = col.lower()
        
        # Find eid column
        if col_lower in ['eid', 'participant.eid', 'participant_eid']:
            rename_map[col] = 'eid'
        
        # Sex (field 31) - check multiple patterns
        elif any(x in col_lower for x in ['p31', 'sex', 'gender']) or col in ['31', 'p31']:
            # Also handle participant.p31, p31_i0, etc.
            if 'sex' not in rename_map.values():  # Only map once
                rename_map[col] = 'sex'
        
        # Age at recruitment (field 21003)
        elif 'p21003' in col_lower or '21003' in col:
            rename_map[col] = 'age'
        
        # Ethnic background (field 21000)
        elif 'p21000' in col_lower or '21000' in col:
            rename_map[col] = 'ethnicity'
        
        # Genetic ethnicity (field 22006)
        elif 'p22006' in col_lower or '22006' in col:
            rename_map[col] = 'genetic_ethnicity'
        
        # Genetic PCs (field 22009)
        elif 'p22009' in col_lower or '22009' in col:
            # Extract PC number from various formats: p22009_a1, 22009-0.1, etc.
            match = re.search(r'[_a.-](\d+)$', col)
            if match:
                pc_num = int(match.group(1))
                rename_map[col] = f'pc{pc_num}'
        
        # ICD-10 diagnoses (field 41270)
        elif 'p41270' in col_lower or '41270' in col:
            rename_map[col] = 'diagnoses'
    
    # Apply renaming
    df_renamed = df.rename(columns=rename_map)
    
    print(f"Column rename mapping: {rename_map}")
    print(f"Final columns after rename: {list(df_renamed.columns)}")
    
    # Warn if sex column wasn't found
    if 'sex' not in df_renamed.columns:
        print("WARNING: 'sex' column not found! Check if p31 was exported.")
        print("Looking for any column that might be sex...")
        for col in df_renamed.columns:
            print(f"  - {col}: first values = {df_renamed[col].head(3).tolist()}")
    
    return df_renamed

# Standardize columns
pheno_df = standardize_phenotype_columns(pheno_df)
print(f"\nStandardized columns: {list(pheno_df.columns)}")

## 5. Case Identification

In [None]:
# AIS ICD-10 codes
AIS_CODES = ["M41.1", "M41.2", "M411", "M412"]  # With and without dots
ALL_SCOLIOSIS_CODES = ["M41", "M41.0", "M41.1", "M41.2", "M41.3", "M41.4", 
                       "M41.5", "M41.8", "M41.9", "M410", "M411", "M412", 
                       "M413", "M414", "M415", "M418", "M419"]

def check_diagnosis(diag_str, codes):
    """
    Check if any diagnosis matches the given codes.
    Handles both string and list-like diagnosis fields.
    """
    if pd.isna(diag_str):
        return False
    
    diag_str = str(diag_str)
    
    for code in codes:
        if code in diag_str:
            return True
    return False

# Identify cases
print("Identifying AIS cases...")
pheno_df['is_ais_case'] = pheno_df['diagnoses'].apply(
    lambda x: check_diagnosis(x, AIS_CODES)
)
pheno_df['has_any_scoliosis'] = pheno_df['diagnoses'].apply(
    lambda x: check_diagnosis(x, ALL_SCOLIOSIS_CODES)
)

# Assign labels: 1 = case, 0 = potential control, -1 = excluded
def assign_label(row):
    if row['is_ais_case']:
        return 1
    elif row['has_any_scoliosis']:
        return -1  # Has other scoliosis, exclude from controls
    else:
        return 0

pheno_df['label'] = pheno_df.apply(assign_label, axis=1)

n_cases = (pheno_df['label'] == 1).sum()
n_controls = (pheno_df['label'] == 0).sum()
n_excluded = (pheno_df['label'] == -1).sum()

print(f"\nCase identification results:")
print(f"  AIS cases (M41.1, M41.2): {n_cases}")
print(f"  Potential controls: {n_controls}")
print(f"  Excluded (other scoliosis): {n_excluded}")

## 6. Ancestry Quality Control

In [None]:
# European ancestry codes (field 21000)
EUROPEAN_CODES = [1, 1001, 1002, 1003, "1", "1001", "1002", "1003",
                  "British", "Irish", "White", "Any other white background"]

# Filter to non-excluded samples
df = pheno_df[pheno_df['label'] >= 0].copy()
print(f"Starting with {len(df)} samples (excluding other scoliosis)")

# Step 1: Filter to European ancestry
if 'ethnicity' in df.columns:
    df['is_european'] = df['ethnicity'].isin(EUROPEAN_CODES)
    n_before = len(df)
    df = df[df['is_european']].copy()
    print(f"European ancestry filter: {n_before} -> {len(df)} ({n_before - len(df)} removed)")
else:
    print("Warning: No ethnicity column found, skipping ancestry filter")

# Step 2: PCA outlier removal
pc_cols = [col for col in df.columns if col.startswith('pc') and col[2:].isdigit()]
print(f"Found PC columns: {pc_cols}")

if pc_cols:
    n_before = len(df)
    outlier_mask = pd.Series(False, index=df.index)
    
    for pc in pc_cols[:4]:  # Use first 4 PCs
        pc_values = pd.to_numeric(df[pc], errors='coerce')
        mean = pc_values.mean()
        std = pc_values.std()
        
        pc_outliers = (pc_values - mean).abs() > PCA_OUTLIER_SD * std
        n_outliers = pc_outliers.sum()
        if n_outliers > 0:
            print(f"  {pc}: {n_outliers} outliers (>{PCA_OUTLIER_SD} SD)")
        outlier_mask = outlier_mask | pc_outliers
    
    df = df[~outlier_mask].copy()
    print(f"PCA outlier removal: {n_before} -> {len(df)} ({n_before - len(df)} removed)")
else:
    print("Warning: No PC columns found, skipping PCA outlier removal")

print(f"\nAfter ancestry QC: {len(df)} samples")
print(f"  Cases: {(df['label'] == 1).sum()}")
print(f"  Potential controls: {(df['label'] == 0).sum()}")

## 7. Relatedness Exclusion

In [None]:
print(f"Excluding related samples (kinship > {KINSHIP_THRESHOLD})...")

# Verify we have the eid column
if 'eid' not in df.columns:
    print(f"ERROR: 'eid' column not found in df!")
    print(f"Available columns: {list(df.columns)}")
    raise ValueError("Missing 'eid' column - check column standardization")

print(f"Using 'eid' column for sample identification")
print(f"Sample IDs look like: {df['eid'].head(3).tolist()}")

# Find the kinship column
kinship_col = None
for col in kinship_df.columns:
    if 'kinship' in col.lower():
        kinship_col = col
        break

if kinship_col is None:
    # Try 'Kinship' with capital K
    if 'Kinship' in kinship_df.columns:
        kinship_col = 'Kinship'

# Find ID columns in kinship file
id_cols = [col for col in kinship_df.columns if 'ID' in col]
if not id_cols:
    id_cols = [col for col in kinship_df.columns if 'id' in col.lower()]

print(f"Kinship ID columns: {id_cols}")
print(f"Kinship value column: {kinship_col}")

if kinship_col and len(id_cols) >= 2:
    id1_col, id2_col = id_cols[0], id_cols[1]
    
    print(f"Kinship IDs look like: {kinship_df[id1_col].head(3).tolist()}")
    
    # Filter to related pairs above threshold
    related = kinship_df[kinship_df[kinship_col] > KINSHIP_THRESHOLD]
    print(f"Found {len(related)} related pairs above threshold")
    
    # Get samples in our cohort - use 'eid' column explicitly
    cohort_ids = set(df['eid'].astype(str))
    label_map = dict(zip(df['eid'].astype(str), df['label']))
    
    print(f"Cohort has {len(cohort_ids)} unique IDs")
    
    # Check for overlap between kinship IDs and cohort IDs
    kinship_ids = set(related[id1_col].astype(str)) | set(related[id2_col].astype(str))
    overlap = cohort_ids & kinship_ids
    print(f"Kinship file has {len(kinship_ids)} unique IDs in related pairs")
    print(f"Overlap with cohort: {len(overlap)} IDs")
    
    if len(overlap) == 0:
        print("WARNING: No overlap between cohort IDs and kinship IDs!")
        print(f"  Cohort ID examples: {list(cohort_ids)[:3]}")
        print(f"  Kinship ID examples: {list(kinship_ids)[:3]}")
    
    # Identify samples to remove (prefer removing controls over cases)
    to_remove = set()
    pairs_in_cohort = 0
    
    for _, row in related.iterrows():
        id1 = str(row[id1_col])
        id2 = str(row[id2_col])
        
        # Both must be in cohort
        if id1 not in cohort_ids or id2 not in cohort_ids:
            continue
        
        pairs_in_cohort += 1
        
        # Skip if one already marked for removal
        if id1 in to_remove or id2 in to_remove:
            continue
        
        label1 = label_map.get(id1, 0)
        label2 = label_map.get(id2, 0)
        
        # Prefer to remove controls over cases
        if label1 == 1 and label2 == 0:
            to_remove.add(id2)
        elif label1 == 0 and label2 == 1:
            to_remove.add(id1)
        else:
            to_remove.add(id1)  # Remove one arbitrarily
    
    print(f"Related pairs where both are in cohort: {pairs_in_cohort}")
    
    n_before = len(df)
    df = df[~df['eid'].astype(str).isin(to_remove)].copy()
    print(f"Removed {len(to_remove)} related samples: {n_before} -> {len(df)}")
else:
    print("Warning: Could not parse kinship file, skipping relatedness exclusion")

print(f"\nAfter relatedness exclusion: {len(df)} samples")
print(f"  Cases: {(df['label'] == 1).sum()}")
print(f"  Potential controls: {(df['label'] == 0).sum()}")

## 8. Control Matching

In [None]:
print(f"\nMatching controls to cases ({CONTROLS_PER_CASE}:1 ratio)...")

# Check what columns are available
print(f"\nAvailable columns in df: {list(df.columns)}")

# Prepare matching variables - always start with age, sex
matching_vars = ['age', 'sex']
pc_cols_available = [col for col in df.columns if col.startswith('pc') and col[2:].isdigit()]
matching_vars.extend(pc_cols_available[:4])  # Add up to 4 PCs

print(f"\nDesired matching variables: {matching_vars}")

# Check which are actually in the dataframe
present_vars = [v for v in matching_vars if v in df.columns]
missing_vars = [v for v in matching_vars if v not in df.columns]

if missing_vars:
    print(f"WARNING: Missing matching variables: {missing_vars}")
    print("These will NOT be used for matching!")

# Convert to numeric and handle missing
for var in present_vars:
    df[var] = pd.to_numeric(df[var], errors='coerce')
    n_valid = df[var].notna().sum()
    print(f"  {var}: {n_valid} valid values ({100*n_valid/len(df):.1f}%)")

# Split cases and controls
cases_df = df[df['label'] == 1].copy()
controls_df = df[df['label'] == 0].copy()

print(f"\nCases: {len(cases_df)}, Controls: {len(controls_df)}")

# Get available matching variables (present and with valid data)
available_vars = [v for v in present_vars if df[v].notna().sum() > 0]
print(f"\nFinal matching variables to use: {available_vars}")

if 'sex' not in available_vars:
    print("\n" + "="*60)
    print("WARNING: SEX IS NOT BEING USED FOR MATCHING!")
    print("This may result in sex imbalance between cases and controls.")
    print("Check if p31 was exported from Table Exporter.")
    print("="*60 + "\n")

if not available_vars:
    raise ValueError("No matching variables available!")

In [None]:
# Standardize features for matching
case_features = cases_df[available_vars].values
control_features = controls_df[available_vars].values

# Impute missing with mean
for i in range(case_features.shape[1]):
    all_vals = np.concatenate([case_features[:, i], control_features[:, i]])
    mean_val = np.nanmean(all_vals)
    case_features[:, i] = np.where(np.isnan(case_features[:, i]), mean_val, case_features[:, i])
    control_features[:, i] = np.where(np.isnan(control_features[:, i]), mean_val, control_features[:, i])

# Standardize
scaler = StandardScaler()
all_features = np.vstack([case_features, control_features])
scaler.fit(all_features)

case_features_scaled = scaler.transform(case_features)
control_features_scaled = scaler.transform(control_features)

print(f"Feature matrix shapes: cases={case_features_scaled.shape}, controls={control_features_scaled.shape}")

In [None]:
# Nearest neighbor matching
print("Computing pairwise distances...")
distances = cdist(case_features_scaled, control_features_scaled, metric='euclidean')
print(f"Distance matrix shape: {distances.shape}")

# Perform matching without replacement
matched_control_indices = []
matching_records = []

available_controls = set(range(len(controls_df)))
case_ids = cases_df['eid'].values
control_ids = controls_df['eid'].values

print(f"Matching {len(case_ids)} cases to up to {CONTROLS_PER_CASE} controls each...")

for i, case_id in enumerate(case_ids):
    if not available_controls:
        print(f"Ran out of controls at case {i}")
        break
    
    # Get distances to available controls
    available_list = sorted(list(available_controls))
    case_distances = distances[i, available_list]
    
    # Find nearest controls
    n_to_match = min(CONTROLS_PER_CASE, len(available_list))
    nearest_indices = np.argsort(case_distances)[:n_to_match]
    
    for rank, local_idx in enumerate(nearest_indices):
        global_idx = available_list[local_idx]
        control_id = control_ids[global_idx]
        distance = case_distances[local_idx]
        
        matched_control_indices.append(global_idx)
        matching_records.append({
            'case_eid': case_id,
            'control_eid': control_id,
            'match_rank': rank + 1,
            'distance': distance,
        })
        
        available_controls.discard(global_idx)

# Create results
matching_info = pd.DataFrame(matching_records)
matched_control_ids = set(matching_info['control_eid'])

matched_cases = cases_df.copy()
matched_controls = controls_df[controls_df['eid'].isin(matched_control_ids)].copy()
matched_cohort = pd.concat([matched_cases, matched_controls], ignore_index=True)

n_matched_cases = len(matched_cases)
n_matched_controls = len(matched_controls)
ratio = n_matched_controls / n_matched_cases if n_matched_cases > 0 else 0

print(f"\nMatching complete:")
print(f"  Matched cases: {n_matched_cases}")
print(f"  Matched controls: {n_matched_controls}")
print(f"  Ratio: {ratio:.2f}:1")
print(f"  Mean distance: {matching_info['distance'].mean():.4f}")
print(f"  Max distance: {matching_info['distance'].max():.4f}")

## 9. Assess Covariate Balance

In [None]:
print("\nCovariate balance assessment:")
print("-" * 50)

cases = matched_cohort[matched_cohort['label'] == 1]
controls = matched_cohort[matched_cohort['label'] == 0]

balance_results = {}

for var in available_vars:
    case_vals = pd.to_numeric(cases[var], errors='coerce')
    control_vals = pd.to_numeric(controls[var], errors='coerce')
    
    case_mean = case_vals.mean()
    control_mean = control_vals.mean()
    case_std = case_vals.std()
    control_std = control_vals.std()
    
    # Standardized mean difference
    pooled_std = np.sqrt((case_std**2 + control_std**2) / 2)
    smd = (case_mean - control_mean) / pooled_std if pooled_std > 0 else 0
    
    balanced = abs(smd) < 0.1
    status = "OK" if balanced else "IMBALANCED"
    
    balance_results[var] = {
        'case_mean': case_mean,
        'control_mean': control_mean,
        'smd': smd,
        'balanced': balanced
    }
    
    print(f"  {var}: SMD = {smd:.4f} [{status}]")

all_balanced = all(b['balanced'] for b in balance_results.values())
print(f"\nOverall balance: {'GOOD' if all_balanced else 'NEEDS REVIEW'}")

## 10. Save Outputs

In [None]:
# Prepare final cohort dataframe
cohort_cols = ['eid', 'label', 'age', 'sex']
cohort_cols.extend([c for c in pc_cols_available[:4] if c in matched_cohort.columns])
available_cohort_cols = [c for c in cohort_cols if c in matched_cohort.columns]

final_cohort = matched_cohort[available_cohort_cols].copy()
final_cohort['is_case'] = final_cohort['label'] == 1

# Save cohort
cohort_path = f"{OUTPUT_DIR}/cohort.parquet"
final_cohort.to_parquet(cohort_path, index=False)
print(f"Saved cohort to {cohort_path}")

# Save matching info
matching_path = f"{OUTPUT_DIR}/matching_info.parquet"
matching_info.to_parquet(matching_path, index=False)
print(f"Saved matching info to {matching_path}")

# Save ancestry PCs for matched samples
pc_cols_to_save = ['eid'] + [c for c in matched_cohort.columns if c.startswith('pc')]
ancestry_pcs = matched_cohort[pc_cols_to_save].copy()
pcs_path = f"{OUTPUT_DIR}/ancestry_pcs.parquet"
ancestry_pcs.to_parquet(pcs_path, index=False)
print(f"Saved ancestry PCs to {pcs_path}")

In [None]:
# Generate QC report
report = {
    'cohort_summary': {
        'total_samples': int(len(final_cohort)),
        'n_cases': int(n_matched_cases),
        'n_controls': int(n_matched_controls),
        'control_case_ratio': float(round(ratio, 2)),
    },
    'matching_summary': {
        'mean_distance': float(round(matching_info['distance'].mean(), 4)),
        'max_distance': float(round(matching_info['distance'].max(), 4)),
        'min_distance': float(round(matching_info['distance'].min(), 4)),
    },
    'covariate_balance': {
        var: {
            'standardized_mean_diff': float(round(stats['smd'], 4)),
            'balanced': bool(stats['balanced'])
        }
        for var, stats in balance_results.items()
    },
    'demographics': {}
}

if 'age' in final_cohort.columns:
    report['demographics']['case_mean_age'] = float(round(cases['age'].mean(), 1))
    report['demographics']['control_mean_age'] = float(round(controls['age'].mean(), 1))

if 'sex' in final_cohort.columns:
    report['demographics']['case_female_pct'] = float(round(100 * (cases['sex'] == 0).mean(), 1))
    report['demographics']['control_female_pct'] = float(round(100 * (controls['sex'] == 0).mean(), 1))

report_path = f"{OUTPUT_DIR}/cohort_qc_report.yaml"
with open(report_path, 'w') as f:
    yaml.dump(report, f, default_flow_style=False)
print(f"Saved QC report to {report_path}")

## 11. Upload Results to Project

In [None]:
# Upload results back to DNAnexus project
import dxpy

output_files = [
    f"{OUTPUT_DIR}/cohort.parquet",
    f"{OUTPUT_DIR}/matching_info.parquet", 
    f"{OUTPUT_DIR}/ancestry_pcs.parquet",
    f"{OUTPUT_DIR}/cohort_qc_report.yaml"
]

print("Uploading results to project...")
for filepath in output_files:
    filename = os.path.basename(filepath)
    dxpy.upload_local_file(
        filepath,
        project=PROJECT_ID,
        folder="/cohort_output",
        parents=True
    )
    print(f"  Uploaded {filename}")

print(f"\nAll files uploaded to {PROJECT_ID}:/cohort_output/")

## 12. Summary

In [None]:
print("=" * 60)
print("COHORT BUILDING COMPLETE")
print("=" * 60)
print(f"\nFinal Cohort:")
print(f"  Total samples: {len(final_cohort)}")
print(f"  Cases: {n_matched_cases}")
print(f"  Controls: {n_matched_controls}")
print(f"  Ratio: {ratio:.2f}:1")
print(f"\nOutput files in {PROJECT_ID}:/cohort_output/")
print(f"  - cohort.parquet")
print(f"  - matching_info.parquet")
print(f"  - ancestry_pcs.parquet")
print(f"  - cohort_qc_report.yaml")
print("=" * 60)