# Setup

In [2]:
%reload_ext autoreload

In [3]:
# Silence WARNING:root:The use of `check_types` is deprecated and does not have any effect.
# https://github.com/tensorflow/probability/issues/1523
import logging

logger = logging.getLogger()


class CheckTypesFilter(logging.Filter):
    def filter(self, record):
        return "check_types" not in record.getMessage()


logger.addFilter(CheckTypesFilter())

In [1]:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.set_printoptions(precision=3)
import scipy.stats
import einops


from functools import partial
from collections import namedtuple
import itertools
from itertools import repeat
from time import time

import chex
import jax
import jax.random as jr
import jax.numpy as jnp
from jax import vmap, grad, jit, lax
from jax import numpy as jnp
import jax.scipy as jsp

from flax.core import freeze, unfreeze
from flax import linen as nn
import flax

import jaxopt
import optax


from PIL import Image


#jax.config.update("jax_enable_x64", False)



In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds


import torch
from torch.utils.data import TensorDataset
import torchvision.transforms as T

In [2]:
import os 
cpu_count = os.cpu_count()
print(cpu_count)

# Run jax on multiple CPU cores
# https://github.com/google/jax/issues/5506
# https://stackoverflow.com/questions/72328521/jax-pmap-with-multi-core-cpu
import os 
#os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=90'

import jax
print(jax.devices())

96
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]


# Load the labels

In [100]:
%pwd

'/home/kpmurphy/github/label-shift/tta'

In [2]:
from pathlib import Path
root = '/home/kpmurphy/data/CheXpert'
root = Path(root)
labels = pd.read_csv(root / "labels.csv", index_col="image_id")




In [3]:

labels.head()

Unnamed: 0_level_0,Unnamed: 0,NO_FINDING,ENLARGED_CARDIOMEDIASTINUM,CARDIOMEGALY,AIRSPACE_OPACITY,LUNG_LESION,PULMONARY_EDEMA,CONSOLIDATION,PNEUMONIA,ATELECTASIS,...,EFFUSION,PLEURAL_OTHER,FRACTURE,SUPPORT_DEVICES,patient_id,split,GENDER,AGE_AT_CXR,PRIMARY_RACE,ETHNICITY
image_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
CheXpert-v1.0/train/patient42720/study2/view1_frontal.jpg,0,3,1,3,3,3,1,3,3,3,...,1,3,3,1,patient42720,train,Male,58,White,Non-Hispanic/Non-Latino
CheXpert-v1.0/train/patient42720/study7/view1_frontal.jpg,1,3,3,0,1,3,3,3,3,3,...,1,3,3,1,patient42720,train,Male,58,White,Non-Hispanic/Non-Latino
CheXpert-v1.0/train/patient42720/study8/view1_frontal.jpg,2,3,3,0,1,3,3,3,3,3,...,1,3,3,1,patient42720,train,Male,58,White,Non-Hispanic/Non-Latino
CheXpert-v1.0/train/patient42720/study6/view1_frontal.jpg,3,3,3,3,3,3,1,3,3,3,...,1,3,3,1,patient42720,train,Male,58,White,Non-Hispanic/Non-Latino
CheXpert-v1.0/train/patient42720/study1/view1_frontal.jpg,4,3,3,3,3,3,1,3,3,3,...,1,3,3,1,patient42720,train,Male,58,White,Non-Hispanic/Non-Latino


In [27]:
len(labels)

190499

# Joint the embeddings with the labels

In [104]:
def extract_labels(root, max_rows=0):
    labels = pd.read_csv(root / "labels.csv", index_col="image_id")

    # Extract subset of rows for which all labels are available
    labels = labels.loc[labels["PNEUMONIA"].isin({1, 3})]
    labels = labels.loc[labels["EFFUSION"].isin({1, 3})]
    labels = labels.loc[labels["GENDER"] != "Unknown"]

    if max_rows == 0:
        max_rows = len(labels)

    columns = ["PNEUMONIA", "EFFUSION", "GENDER"]
    for t in columns:
        code, uniques = pd.factorize(labels[t], sort=True)
        print(t, code, uniques) #1->0, 3->1, female->0, male->1
        labels[t] = code
    
    m = np.median(labels["AGE_AT_CXR"])
    print('median age ', m)
    labels["AGE_QUANTIZED"] = (labels["AGE_AT_CXR"] > m)
    columns.append("AGE_QUANTIZED")

    YZ = labels[columns].to_numpy()
    YZ = YZ[:max_rows]
    return YZ, labels, columns

In [106]:
YZ, labels, columns = extract_labels(root, max_rows=0)
print(YZ.shape)

PNEUMONIA [1 1 1 ... 0 0 0] Int64Index([1, 3], dtype='int64')
EFFUSION [0 0 0 ... 0 0 0] Int64Index([1, 3], dtype='int64')
GENDER [1 1 1 ... 1 0 0] Index(['Female', 'Male'], dtype='object')
median age  62.0
(139907, 4)


In [119]:
print(type(YZ))

<class 'numpy.ndarray'>


In [105]:
def extract_features(root, labels, max_rows=20):
    datastore = np.load(root / "embeddings.npz")
    if max_rows == 0:
        max_rows = len(labels)
    ndims = 1376
    X = np.zeros((max_rows, ndims))
    i = 0
    for fname in labels.index:
        x = datastore[fname]
        i += 1
        if i >= max_rows: break
        X[i,:] = x
    return X

In [126]:
%%time
X = extract_features(root, labels, max_rows=0)
print(X.shape)



(139907, 1376)
CPU times: user 15min 47s, sys: 4.41 s, total: 15min 51s
Wall time: 15min 49s


In [123]:
print(columns)

['PNEUMONIA', 'EFFUSION', 'GENDER', 'AGE_QUANTIZED']


In [127]:
np.savez(root / 'data_matrix.npz', X=X, YZ=YZ, columns=columns)
