# Extract Embeddings from MIMIC-IV-CXR using CXR-Foundation
For more information on CXR-Foundation,  please refer to https://github.com/Google-Health/imaging-research/blob/master/cxr-foundation/README.md  
**We recommend running the following codes using Google Colab**

In [1]:
import numpy as np
import pandas as pd
import os
import pickle
import matplotlib.pyplot as plt
from tqdm import tqdm
import time

## Pacakage Preparation

### Installation
Install the CXR Foundation package

In [3]:
# Notebook specific dependencies
!pip install matplotlib tf-models-official==2.14.0 google-cloud-storage

!git clone https://github.com/Google-Health/imaging-research.git
!pip install imaging-research/cxr-foundation/

Collecting tf-models-official==2.14.0
  Downloading tf_models_official-2.14.0-py2.py3-none-any.whl (2.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m20.6 MB/s[0m eta [36m0:00:00[0m
Collecting sacrebleu (from tf-models-official==2.14.0)
  Downloading sacrebleu-2.4.2-py3-none-any.whl (106 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.7/106.7 kB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
Collecting seqeval (from tf-models-official==2.14.0)
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting tensorflow-model-optimization>=0.4.1 (from tf-models-official==2.14.0)
  Downloading tensorflow_model_optimization-0.8.0-py2.py3-none-any.whl (242 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m242.5/242.5 kB[0m [31m19.4 MB/s[0m 

In [4]:
!pip install pydicom



### Authenticate to Access Data

The following cell is for Colab only. If running elsewhere, authenticate with the [gcloud CLI](https://cloud.google.com/sdk/gcloud/reference/auth/login).

In [5]:
from google.colab import auth

# Authenticate user for access. There will be a popup asking you to sign in with your user and approve access.
auth.authenticate_user()

## Util functions and imports

In [6]:
import io
import pydicom
from PIL import Image
from cxr_foundation import constants
from cxr_foundation import example_generator_lib

def show_dicom(dicom_path):
    """Displays a DICOM image."""
    # Read the DICOM file
    dicom_data = pydicom.dcmread(dicom_path)

    # Extract the image data
    image = dicom_data.pixel_array

    # Display the image
    figure_size = 7
    plt.figure(figsize=(figure_size, figure_size))
    plt.imshow(image, cmap='gray')
    plt.axis('off')
    plt.show()

In [7]:
def generate_path(img_metadata,dicom_dir,embedding_dir,output_type):
  """Generate paths for dicoms and embeddings."""
  paths = []
  for idx, row in img_metadata.iterrows():
    p_id = str(int(row['subject_id']))
    s_id = str(int(row['study_id']))
    d_id = row['dicom_id']
    remote_path = f"files/p{p_id[:2]}/p{p_id}/s{s_id}/{d_id}.dcm"
    paths.append(remote_path)
  # Path for DICOMs in google cloud
  img_metadata['remote_dicom_file'] = paths
  # Path for downloaded DICOMs
  img_metadata["local_dicom_file"] =  img_metadata['remote_dicom_file'].apply(
      lambda x: os.path.join(dicom_dir, os.path.basename(x)))
  # Path for generated embeddings
  img_metadata["embedding_file"] =  img_metadata['remote_dicom_file'].apply(
      lambda x: os.path.join(embedding_dir, os.path.basename(x).replace(".dcm", output_type)))
  return img_metadata

In [8]:
def download_blob(bucket_name, source_blob_name, destination_file_name, project_id):
    """Downloads a blob from the bucket."""
    bucket = storage_client.bucket(bucket_name, user_project=project_id)
    blob = bucket.blob(source_blob_name)
    blob.download_to_filename(destination_file_name)
    # print(f"Blob {source_blob_name} downloaded to {destination_file_name}.")

## Download pre-processed ICU information

In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
from google.colab import files
uploaded = files.upload()
# Upload data_utils.py

In [5]:
path = '/content/drive/My Drive/MIMIC_Multimodal_Data/'

icus = pickle.load(open(path+'processed_icu_24h.pkl','rb'))
metadata = pd.read_csv(path+'metadata_24h.csv',index_col=0)

### Index list of ICU stays with usable Chest X-Rays

In [8]:
allowed_view_position = ['AP', 'PA']

In [9]:
idx_list = []
vp = []
num_imgs = []
# all_img_metadata = pd.DataFrame()
for i in range(len(icus)):
  icu = icus[i]
  img_metadata = icu.images['metadata']
  # Filter out images whose view position is not in accepted set
  filtered_img_metadata = img_metadata[img_metadata['ViewPosition'].isin(allowed_view_position)]
  if not filtered_img_metadata.empty:
    idx_list.append(i)
    num_imgs.append(len(filtered_img_metadata))
    # all_img_metadata = pd.concat([all_img_metadata, img_metadata])

In [10]:
print(len(idx_list))
print(pd.Series(num_imgs).describe())

11181
count    11181.000000
mean         1.510688
std          0.885279
min          1.000000
25%          1.000000
50%          1.000000
75%          2.000000
max         14.000000
dtype: float64


## Download DICOMs from Google Cloud & Generate all embeddings

Request access to the files using the Google Cloud Storage Browser from https://physionet.org/content/mimic-cxr-jpg/2.1.0/

### Paths

In [3]:
DICOM_DIR = '/content/drive/My Drive/MIMIC_Multimodal_Data/dicoms' #@param {type: 'string'}
EMBEDDINGS_DIR = '/content/drive/My Drive/MIMIC_Multimodal_Data/embeddings' #@param {type: 'string'}
OUTPUT_TYPE = '.npz'#@param {type: 'string'}
OUTPUT_DIR = '/content/drive/My Drive/MIMIC_Multimodal_Data/outputs' #@param {type: 'string'}

### Get access for downloading DICOMs

In [16]:
project_id = 'your-project-id' # fill in the project id of your project to bill
bucket_name = 'mimic-cxr-2.0.0.physionet.org'

In [17]:
from google.cloud import storage
# Create a storage client
storage_client = storage.Client()

### Selection between embedding versions

#### Embedding Version

We support the following three embedding versions:
- cxr_foundation: the original CXR foundation embedding
- elixr: the raw image embedding from the Q-former output in ELIXR (https://arxiv.org/abs/2308.01317), can be used for data-efficient classification (same as CXR foundation embedding)
- elixr_img_contrastive: the text-aligned image embedding from the Q-former output in ELIXR (https://arxiv.org/abs/2308.01317), can be used for image retrieval. Refer to "Image Retrieval Demo" section in this colab for example usage.

In [18]:
from cxr_foundation.inference import generate_embeddings
from cxr_foundation import embeddings_data
from cxr_foundation.embeddings_data import read_tfrecord_values, read_npz_values, get_dataset

# help(generate_embeddings)
# help(read_tfrecord_values)
# help(read_npz_values)
# help(get_dataset)

In [19]:
from cxr_foundation.inference import ModelVersion
import shutil

EMBEDDING_VERSION = 'cxr_foundation' #@param ['elixr', 'cxr_foundation', 'elixr_img_contrastive']
if EMBEDDING_VERSION == 'cxr_foundation':
  MODEL_VERSION = ModelVersion.V1
  TOKEN_NUM = 1
  EMBEDDINGS_SIZE = 1376
elif EMBEDDING_VERSION == 'elixr':
  MODEL_VERSION = ModelVersion.V2
  TOKEN_NUM = 32
  EMBEDDINGS_SIZE = 768
elif EMBEDDING_VERSION == 'elixr_img_contrastive':
  MODEL_VERSION = ModelVersion.V2_CONTRASTIVE
  TOKEN_NUM = 32
  EMBEDDINGS_SIZE = 128
if not os.path.exists(EMBEDDINGS_DIR):
  os.makedirs(EMBEDDINGS_DIR)
else:
  # Empty embedding dir to avoid caching when switching embedding versions
  shutil.rmtree(EMBEDDINGS_DIR)
  os.makedirs(EMBEDDINGS_DIR)

In [20]:
import logging

from cxr_foundation.inference import generate_embeddings, InputFileType, OutputFileType, ModelVersion

logger = logging.getLogger()
logger.setLevel(logging.INFO)

### Generate and store embedding outputs

In [22]:
pd.options.mode.chained_assignment = None

In [None]:
i = 0
for idx in idx_list[0:]:
  icu = icus[idx]
  img_metadata = icu.images['metadata']
  # Filter out images whose view position is not in accepted set
  filtered_img_metadata = img_metadata[img_metadata['ViewPosition'].isin(allowed_view_position)]
  df_img_metadata = generate_path(filtered_img_metadata,DICOM_DIR,EMBEDDINGS_DIR,OUTPUT_TYPE)
  # print(idx)
  # display(df_img_metadata)
  for _,row in df_img_metadata.iterrows():
    source_blob_name = row['remote_dicom_file']
    destination_file_name = row['local_dicom_file']
    # Download DICOMs from Google Cloud Storage
    download_blob(bucket_name, source_blob_name, destination_file_name, project_id)
    # Display DICOMs
    # show_dicom(destination_file_name)

  # Generate and store a few embeddings in .npz format
  generate_embeddings(input_files=df_img_metadata["local_dicom_file"].values, output_dir=EMBEDDINGS_DIR,
                      input_type=InputFileType.DICOM, output_type=OutputFileType.NPZ, model_version=MODEL_VERSION)

  # Read the data from generated .npz embeddings file
  embeddings = []
  for _,row in df_img_metadata.iterrows():
    embedding_name = row["embedding_file"]
    value = embeddings_data.read_npz_values(embedding_name)
    embeddings.append(value)
  output_embedding = np.mean(embeddings, axis=0)
  # print("Output Embedding:", output_embedding)
  # print(output_embedding.shape)

  # Export averaged embeddings
  output_path = os.path.join(OUTPUT_DIR, f'CXR_{idx}.npz')
  np.savez_compressed(output_path, array=output_embedding)
  print(f'idx: {idx}, i: {i}')
  i += 1

### Integrate generated embeddings

In [6]:
# Directory for output embeddings
OUTPUT_DIR = './embeddings/cxr_foundation_outputs/'

In [11]:
embeddings = []
with tqdm(total=len(idx_list)) as pbar:
  for idx in idx_list:
    output_path = os.path.join(OUTPUT_DIR, f'CXR_{idx}.npz')
    emb = np.load(output_path)
    # print(emb.files)
    # print(emb['array'])
    embeddings.append(emb['array'])
    pbar.update(1)

100%|██████████| 11181/11181 [00:03<00:00, 3494.92it/s]


In [13]:
# Convert the list of arrays into a DataFrame
cxr_foundation_embeddings= pd.DataFrame(embeddings)
cxr_foundation_embeddings.index = idx_list
cxr_foundation_embeddings.index.name = 'Index'
cxr_foundation_embeddings

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9,...,1366,1367,1368,1369,1370,1371,1372,1373,1374,1375
Index,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
3,-1.381176,-0.940575,0.715067,-0.797603,0.356198,0.246666,-0.003760,0.528408,0.644446,1.259951,...,-0.618315,-0.935153,0.016467,0.924472,1.342215,1.881090,1.097744,-1.578163,0.880282,-1.080822
11,0.581038,-1.797232,0.575292,-2.580911,-0.198946,-0.137570,0.249380,0.587596,0.143332,0.340715,...,0.084487,-1.056361,0.339028,0.841624,1.585286,1.118701,1.623493,-0.970832,0.595209,-0.655883
13,-0.614110,-1.026556,0.810352,-2.059892,0.188802,0.699066,0.191839,0.398325,0.440779,1.621906,...,0.376719,-0.724244,0.346482,0.279807,1.396719,1.366546,1.493984,-2.406215,0.620417,-1.347265
28,-0.736443,-1.590252,0.475662,-3.094897,-0.319674,0.403751,0.022844,0.378421,1.919591,1.027444,...,0.232888,-1.040388,-0.132198,-0.233667,1.888865,0.769814,1.028867,-1.937112,0.634653,-0.550897
32,0.480620,-0.940023,0.649912,-1.676881,0.778880,-0.413885,0.363687,1.024940,1.054450,0.434259,...,-0.281471,-1.436268,0.815421,-0.568383,1.810253,1.011805,1.896561,-1.716560,0.120864,-0.852970
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
73156,-0.962804,-0.943192,0.648740,-2.202918,0.427941,0.547557,0.352034,-0.198258,0.797327,0.485057,...,0.053254,-1.654803,-0.112861,0.768092,1.514888,0.731088,1.334879,-1.849943,1.660445,-0.264385
73158,-0.393963,-2.005205,1.043008,-1.708687,0.032316,0.020179,0.944152,-0.067904,0.194546,0.307785,...,0.330033,-0.159879,0.256290,0.216845,1.135049,0.159038,1.036105,-2.024945,0.863886,-0.367802
73167,0.273118,-1.392482,1.045443,-2.198258,-0.017214,-0.622114,0.391003,0.580721,1.298687,0.236781,...,-0.500133,-1.005181,-0.167412,0.304767,1.647456,0.437331,2.075221,-2.000816,0.798234,-0.215647
73168,1.088542,-1.974966,0.448694,-2.974573,0.709508,0.482062,0.144350,1.297873,-0.983312,1.449268,...,0.236932,-2.503725,1.837659,-0.132703,1.242561,2.597867,0.892291,-0.887340,-0.384910,-1.353121


In [14]:
# Convert DataFrame to NumPy arrays
data = cxr_foundation_embeddings.to_numpy()
columns = cxr_foundation_embeddings.columns.to_numpy()
index = cxr_foundation_embeddings.index.to_numpy()

# Save to NPZ file
np.savez('embeddings/cxr_foundation_embeddings.npz', data=data, columns=columns, index=index)