Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Models for Camelyon17K

This CoLAB shows how to load and run models on Camelyon17K. In particular, we have two sets of models:

1. *Generative*: here the models generate synthetic images with and without tumours across different hospitals.
2. *Classification*: here the models take an image of a potentially tumourous slide image and classify whether there are or are not tumours.

The CoLAB is divided into two sections for these two use cases.

We save our models using jax2tf for ease of use. Note that this CoLAB was *NOT* used for any results in the paper but is provided to (1) show sample results with our saved out models and (2) demonstrate how our pipeline operates.

This code was run on a TPU so it is unclear how feasible the different parts will be to run on a CPU.


In [None]:
# See instructions at https://github.com/google/jax#installation for how to install.
# !pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

!pip install matplotlib
!pip install numpy
!pip install tensorflow
!pip install tensorflow_datasets

In [None]:
# @title Imports
import jax
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import tensorflow as tf
import tensorflow_datasets as tfds

In [None]:
# @markdown Download the open source models.
# @markdown Save them to `./open_source/`.

# Path to open sourcing directory.
base_path = './open_source/' # @param {type: 'string'}

## Image Generation

Here we show how to sample from two different generative models (trained on the full Camelyon dataset or the *most skewed* version) and also load samples generated by those models. Both models were trained with unlabelled data as well as labelled data.

In [None]:
# @title Sampling code
# @markdown Note that you will need a GPU or TPU to run this in any amount of reasonable time.

file_name =  'skewed100_gendata' # @param {type: 'string'} ['gendata', 'skewed100_gendata']

if file_name == 'skewed100_gendata':
  model_name = '44644773_3_skewed100_gendatamodel'
else:
  model_name = '51586976_1_genmodel'
with tf.device('TPU'):
  restored_model = tf.saved_model.load(f'{base_path}/histopathology/models/{model_name}/')

all_images = []
# We use these hospital ids from the WILDS dataset.
# Hospital ids [1, 2] are OOD Val and Test.
for hospital in [0, 3, 4]:
  for label_id in [0, 1]:
    one_hot_hospital = jax.nn.one_hot(hospital, 5)[None, :]
    one_hot_label = jax.nn.one_hot(label_id, 2)[None, :]
    res = restored_model(np.zeros(1,), one_hot_label, one_hot_hospital)
    all_images.append(res[0])



In [None]:
fig, ax = plt.subplots(3, 2, figsize=(10, 10))

for i in range(3):
  for j in range(2):
    ax[i][j].imshow(all_images[i * 2 + j])
    ax[i][j].axis('off')
    ax[i][j].set_title(f'Label: {j}, Center: {i}')

# Classification

In [None]:
# @title Load in a saved model and evaluate
# @markdown This cell and the one below show how to load in the saved out models and run inference on them on the evaluation datasets.
# @markdown We exported 4 models on histopathology: the baseline and our model conditioned on the hospital and tumor label (with color augmentation).
# @markdown We export these two setups for the *most skewed* and *all data* setting.


# @markdown Note that the following code gives only the results for one model: in the paper we report results across five runs.

model_name = 'baseline' # @param {type: 'string'} ['ours_multiclass', 'baseline', 'skewed100_baseline', 'skewed100_ours_multiclass']
device = 'CPU' # @param {type: 'string'} ['CPU', 'GPU', 'TPU']
with tf.device(device):
  restored_model = tf.saved_model.load(f'{base_path}/histopathology/models/{model_name}')

In [None]:
# @markdown Create a tfds version of the Camelyon17 dataset:
# @markdown Follow the instructions in the [WILDS code](https://github.com/p-lambda/wilds/blob/main/wilds/datasets/camelyon17_dataset.py)
# @markdown to download a the blob file which includes images and metadata.
_CAMELYON_LOCATION = './camelyon17/' # @param
TEST_CENTER = 2
VAL_CENTER = 1

# @markdown Here we load in the full dataset, but note we also created skewed versions in the paper
# @markdown which are not shown here to demonstrate the robustness of our approach in these settings.

def parse_function(filename, label, center):
  image_string = tf.io.read_file(filename)
  image_decoded = tf.image.decode_image(image_string)
  image = tf.cast(image_decoded, tf.float32)
  return {'image': image, 'label': label, 'center': center}

def load_camelyon():
  camelyon_path = os.path.join(_CAMELYON_LOCATION, 'metadata.csv')

  metadata_df = pd.read_csv(camelyon_path, index_col=0,dtype={'patient': 'str'})
  patches_location = f'{_CAMELYON_LOCATION}/patches/'
  input_array = [
      f'{patches_location}/patient_{patient}_node_{node}/patch_patient_{patient}_node_{node}_x_{x}_y_{y}.png'
      for patient, node, x, y in
      metadata_df.loc[:, ['patient', 'node', 'x_coord', 'y_coord']].itertuples(index=False, name=None)]
  metadata_df['images'] = input_array

  # Extract splits
  split_dict = {
            'train': 0,
            'id_val': 1,
            'test': 2,
            'val': 3
        }
  val_center_mask = (metadata_df['center'] == VAL_CENTER)
  test_center_mask = (metadata_df['center'] == TEST_CENTER)
  metadata_df.loc[val_center_mask, 'split'] = split_dict['val']
  metadata_df.loc[test_center_mask, 'split'] = split_dict['test']
  return metadata_df

camelyon_metadata = load_camelyon()

def load_eval_dataset(batch_size, split='id_val'):
  """Load in the Camelyon eval dataset into a tfds structure."""
  if split == 'id_val':
    split_id = 1
  elif split == 'ood_test':
    split_id = 2
  elif split == 'ood_val':
    split_id = 3
  else:
    raise ValueError(f'Unknown split: {split}')
  eval_data = camelyon_metadata[camelyon_metadata['split'] == split_id]
  files = eval_data['images'].values
  labels = eval_data['tumor'].values
  center = eval_data['center'].values
  images = tf.constant(files)
  labels = tf.constant(labels)
  center = tf.constant(center)
  dataset = tf.data.Dataset.from_tensor_slices((images, labels, center))
  dataset = dataset.map(parse_function).batch(batch_size)
  return dataset

In [None]:
predictions = []
center = []
true_labels = []


for eval_dataset in ['id_val', 'ood_val', 'ood_test']:
  print(f'Results for {eval_dataset}')
  ds = load_eval_dataset(512, eval_dataset)
  ds = tfds.as_numpy(ds)
  for i, ds_item in enumerate(ds):
    images = ds_item['image'].astype(np.float32) / 255.0
    labels = ds_item['label']
    centers = ds_item['center']

    logits = restored_model(images)
    predicted_label = np.argmax(logits, axis=-1)
    predictions.append(predicted_label)
    center.append(centers)
    true_labels.append(labels)
  print(
      f'# samples: {np.concatenate(true_labels).shape[0]} in dataset'
      f' {eval_dataset}'
  )
  print(
      'Accuracy:'
      f' {(np.concatenate(predictions) == np.concatenate(true_labels)).mean()}'
  )

  if eval_dataset == 'id_val':
    centers = np.concatenate(center)
    predictions = np.concatenate(predictions)
    true_labels = np.concatenate(true_labels)
    err_center = [
        (predictions[centers == c] == true_labels[centers == c]).mean()
        for c in np.unique(centers)
    ]
    print(f'Fairness GAP: {(max(err_center)) - min(err_center)}')
  print('\n\n')

With the code above, you should get the following results. Note that these are results with a *single* model (in the paper we reported mean and standard deviation across five seeds):

| model | checkpoint name | Training setup | ID_VAL | OOD_VAL | OOD_TEST | FAIRNESS_GAP |
|--------|------|------|---------|---------|-----------|--------------|
| Ours (Multi class) | `ours_multiclass` | All train | 98.0 | 94.2 | 94.8 | 0.006 |
| Baseline | `baseline`  | All train | 92.4 | 85.1 | 62.4 | 0.041
| Ours (Multi class) | `skewed100_ours_multiclass` | Most skewed | 96.0 | 92.9 | 94.2 | 0.023
| Baseline | `skewed100_baseline`  | Most skewed | 75.7 | 88.6 | 64.3 | 0.464 |