## Download the CheXchoNet dataset

| Dataset  | Images | Size |
| -------- | ------ | ---- |
| [CheXchoNet](https://physionet.org/content/chexchonet/1.0.0/)  | 71,589 | 2.7 GB |

In [1]:
# Enter physionet username
PHYSIONET_USERNAME=input('Physionet Username:')

In [2]:
# Download the files
!wget -r -N -c -np --user {PHYSIONET_USERNAME} --ask-password https://physionet.org/files/chexchonet/1.0.0/

In [3]:
# Now move the images to the correct path
import os
OUTPUT_IMAGE_PATH = './cxrs/'
os.rename('physionet.org/files/chexchonet/1.0.0/images/', './cxrs')

# Now move the csv to a local folder
import pandas as pd
OUTPUT_METADATA_PATH = './diffusion_out/'
if not os.path.exists(OUTPUT_METADATA_PATH):
  os.makedirs(OUTPUT_METADATA_PATH)
metadata_df = pd.read_csv('physionet.org/files/chexchonet/1.0.0/metadata.csv')
metadata_df.to_csv(os.path.join(OUTPUT_METADATA_PATH, 'metadata.csv'), index=False)

## Separate into Training, Validation, and Testing Splits

In [4]:
import pandas as pd

OUTPUT_METADATA_PATH = './diffusion_out/'
RANDOM_SEED = None

# Create the output directory
import os
if not os.path.exists(OUTPUT_METADATA_PATH):
  os.makedirs(OUTPUT_METADATA_PATH)

# Seed if defined
import random
if RANDOM_SEED is not None:
  random.seed(RANDOM_SEED)

In [5]:
# Now load the data
chexchonet_df = pd.read_csv(os.path.join(OUTPUT_METADATA_PATH, 'metadata.csv'))
chexchonet_df['file_path'] = chexchonet_df['cxr_path']

In [6]:
def split_list(data, train_split=0.9, test_split=0.05, valid_split=0.05):
    if train_split + test_split + valid_split > 1.0:
        raise ValueError('The splits must sum up to 1.0')

    # Shuffle the list randomly
    random.shuffle(data)

    # Calculate the split indices
    train_end = int(train_split * len(data))
    test_end = train_end + int(test_split * len(data))

    # Split the data
    train_data = data[:train_end]
    test_data = data[train_end:test_end]
    valid_data = data[test_end:]

    return set(train_data), set(test_data), set(valid_data)

# Spliy into the datasets
train, test, valid = split_list(chexchonet_df.patient_id.unique())
def map_set(v):
  if v in train:
    return 'train'
  elif v in test:
    return 'test'
  else:
    return 'valid'

# Now map and label
chexchonet_df['sex_m'] = chexchonet_df['sex'].map(lambda x: 1 if x == 'M' else 0)
chexchonet_df['sex_f'] = chexchonet_df['sex'].map(lambda x: 1 if x == 'F' else 0)

chexchonet_df['diffusion_set'] = chexchonet_df['patient_id'].apply(map_set)
chexchonet_df['inference_set'] = chexchonet_df['patient_id'].apply(map_set)

In [7]:
# Now output
chexchonet_df[chexchonet_df['diffusion_set'] == 'train'].to_csv(os.path.join(OUTPUT_METADATA_PATH, 'diffusion_metadata_train.csv'), index=False)
chexchonet_df[chexchonet_df['diffusion_set'] == 'valid'].to_csv(os.path.join(OUTPUT_METADATA_PATH, 'diffusion_metadata_eval.csv'), index=False)
chexchonet_df[chexchonet_df['diffusion_set'] == 'test'].to_csv(os.path.join(OUTPUT_METADATA_PATH, 'diffusion_metadata_test.csv'), index=False)

## Train the DDPM Model

In [16]:
# Load the files
%load_ext autoreload
%autoreload 2

!pip install diffusers==0.28.0 > /dev/null 2>&1
!pip install accelerate > /dev/null 2>&1
!rm -rf cxr_ddpm
!git clone https://github.com/cstreiffer/cxr_ddpm.git

# File path maniupation
import sys
sys.path.append('cxr_ddpm/src/train/')

In [9]:
# Load the config file
from run import load_file
OUTPUT_METADATA_PATH = 'diffusion_out/'
OUTPUT_MODEL_PATH = 'diffusion_out/model/'
CONFIG_FILE_PATH = 'cxr_ddpm/src/train/training_configs/class_diffusion_large_224.yaml'
args = load_file(CONFIG_FILE_PATH)

# Now specify paths
args.metadata_df_paths['train_metadata_path'] = os.path.join(OUTPUT_METADATA_PATH, 'diffusion_metadata_train.csv')
args.metadata_df_paths['eval_metadata_path'] = os.path.join(OUTPUT_METADATA_PATH, 'diffusion_metadata_eval.csv')
args.metadata_df_paths['test_metadata_path'] = os.path.join(OUTPUT_METADATA_PATH, 'diffusion_metadata_test.csv')
args.model_output_path = OUTPUT_METADATA_PATH

In [10]:
from run import run
run(args)

## Generate Synthetic Data

In [11]:
from gen_images import gen
import numpy as np

# Define paths
MODEL_INPUT_PATH = "diffusion_out/class_diffusion_large_224x224/"
OUTPUT_GEN_PATH = "diffusion_out/gen_images"
NUM_BATCHES = 10 # Total number of batches to run
BATCH_SIZE = 16 # Modify this based on available GPU RAM

# Define custom sample function
def sample_context(bs):
  #   - age   (norm)
  #   - sex_m (one-hot)
  #   - sex_f (one-hot)
  #   - ivsd  (norm)
  #   - lvpwd (norm)
  #   - lvidd (norm)
  s = [np.random.choice([0,1]) for i in range(bs)]
  return [[
      np.random.normal(loc=-.5, scale=1.0),
      s[i],
      1 if s[i] == 0 else 0,
      np.random.normal(loc=.5, scale=1.0),
      np.random.normal(loc=.5, scale=1.0),
      np.random.normal(loc=.5, scale=1.0)
  ] for i in range(bs)]

df = gen(
    MODEL_INPUT_PATH,
    OUTPUT_GEN_PATH,
    NUM_BATCHES,
    BATCH_SIZE,
    sample_fn=sample_context
)

## Evaluate the Synthetic Data

### Inception Score

In [12]:
from eval import inception_score

train_file_path = os.path.join(OUTPUT_METADATA_PATH, 'diffusion_metadata_train.csv')
test_file_path  = os.path.join(OUTPUT_METADATA_PATH, 'diffusion_metadata_test.csv')
gen_file_path   = os.path.join(OUTPUT_GEN_PATH, 'gen_metadata.csv')

train_is = inception_score(train_file_path)
test_is = inception_score(test_file_path)
gen_is = inception_score(gen_file_path)

print(f"Train Score: {train_is:0.4f}")
print(f"Test Score: {test_is:0.4f}")
print(f"Gen Score: {gen_is:0.4f}")

### FID Score

In [18]:
!pip install pytorch-fid

In [14]:
!python -m pytorch_fid {os.path.join(OUTPUT_GEN_PATH, "images")} cxrs/ --device cuda:0