## Overview

Notebook workspace for PlaneSpotSlicer.

Prerequisites:

* install scvis using this package: https://github.com/shahcompbio/scvis

In [1]:
from typing import Union

import meerkat as mk
import numpy as np
import torch
import torch.optim as optim
from torch.nn.functional import cross_entropy
from tqdm import tqdm

from domino.utils import unpack_args

from abstract import Slicer
from sklearn import mixture

import glob

In [102]:
from typing import Union

import meerkat as mk
import numpy as np
import torch
import torch.optim as optim
from torch.nn.functional import cross_entropy
from tqdm import tqdm

from domino.utils import unpack_args

from abstract import Slicer

## PlaneSpot imports
from sklearn import mixture
import glob
from collections import defaultdict

from domino.utils import convert_to_numpy, unpack_args
import pandas as pd

class PlaneSpotSlicer(Slicer):
    r"""
    Implements PlaneSpot [plumb_2023], a simple SDM that fits a GMM to a 2D model 
    embedding, fit using scvis [ding_2018]. 

    ..  [plumb_2023]
        Gregory Plumb*, Nari Johnson*, Ángel Alexander Cabrera, Ameet Talwalkar.
        Towards a More Rigorous Science of Blindspot Discovery in Image 
        Classification Models. arXiv:2207.04104 [cs] (2023)
        
    ..  [ding_2018]
        Jiarui Ding, Anne Condon, and Sohrab P Shah. 
        Interpretable dimensionality reduction of single cell transcriptome 
        data with deep generative models. 
        Nature communications, 9(1):1–13. (2018)
    """

    def __init__(
        self,
        scvis_conda_env: str, # name of conda environment where scvis is installed
        n_slices: int = 10,
        n_max_mixture_components: int = 33, # maximum number of mixture components
        weight: float = 0.025, # weight hyperparameter
        scvis_config_path = None, # custom scvis config path
        scvis_output_dir = 'scvis', # path to output directory for scvis
        fit_scvis = True # flag to load rather than re-compute the scvis embedding 
    ):
        super().__init__(n_slices=n_slices)
        
        # scvis hyper-parameters
        self.scvis_conda_env = scvis_conda_env
        self.config.scvis_config_path = scvis_config_path
        self.config.scvis_output_dir = scvis_output_dir
        self.fit_scvis = fit_scvis
        
        # GMM hyper-parameters
        self.config.n_max_mixture_components = n_max_mixture_components
        self.config.weight = weight

        self.gmm = None

    def fit(
        self,
        data: Union[dict, mk.DataPanel] = None,
        embeddings: Union[str, np.ndarray] = "embedding",
        targets: Union[str, np.ndarray] = None,
        pred_probs: Union[str, np.ndarray] = None,
        losses: Union[str, np.ndarray] = None,
        verbose: bool = True,
        random_state: int = 0, # random state for GMM
        **kwargs
    ):
        embeddings, targets, pred_probs, losses = unpack_args(
            data, embeddings, targets, pred_probs, losses
        )
        
        embeddings, targets, pred_probs = convert_to_numpy(
            embeddings, targets, pred_probs
        )
        
        # 1.  Fit scvis.
        if verbose:
            print('Fitting scvis...')
        
        scvis_embeddings = self._fit_scvis(embeddings.reshape(embeddings.shape[0], embeddings.shape[1]))
        
        # 2.  Fit GMM.
        if verbose:
            print('Fitting GMM...')
            
        self._fit_gmm(scvis_embeddings,
                     pred_probs,
                     random_state,
                     verbose,)

    def predict_proba(
        self,
        data: mk.DataPanel,
        scvis_embeddings: str, # scvis column name
        pred_probs: str, # predicted probabilities column name
    ) -> np.ndarray:
        ''' Returns the probability that each datapoint belongs to the
            top self.n_slices slices.
            
            Note that the probabilities may not sum to 1.
        '''
        
        # Append the scvis embedding and predicted probabilities; normalize
        X = self._combine_embedding(dp[scvis_embeddings], dp[pred_probs])
        probs_all_components = self.gmm.predict_proba(X)
        
        probs_slices = probs_all_components[:, self.slice_indices]
        return probs_slices
    
    def _gmm_label_to_slice_label(self, gmm_label: int):
        ''' Returns the slice index corresponding to the GMM component 
            index gmm_label.
        
            If the datapoint's GMM component is not in the top self.n_slices 
            slices, returns np.nan instead.
        '''
        slice_idxs = np.argwhere(self.slice_indices == gmm_label)
        
        if len(slice_idxs) > 0:
            return slice_idxs.item()
        else:
            return np.nan
        
    def predict(
        self,
        data: mk.DataPanel,
        scvis_embeddings: str, # scvis column name
        pred_probs: str, # predicted probabilities column name
    ) -> np.ndarray:
        ''' Assigns (or does not assign) each datapoint in dp to a slice.
        
            Datapoints that are not assigned to a slice have a returned label
            of np.nan.
        '''
        
        # Append the scvis embedding and predicted probabilities; normalize
        X = self._combine_embedding(dp[scvis_embeddings], dp[pred_probs])
        hard_predictions = self.gmm.predict(X)
        
        # Re-assign their indices
        return np.array([self._gmm_label_to_slice_label(l) for l in hard_predictions])
    
    def _load_scvis_embeddings(self) -> np.ndarray:
        ''' Loads and returns pre-computed scvis embeddings from 
            self.config.scvis_output_dir.
        '''
        ### Load and return the scvis embeddings
        search_string = f'{self.config.scvis_output_dir}/*.tsv'
        scvis_embedding_filepath = sorted(glob.glob(search_string), key = len)[0]
        return pd.read_csv(scvis_embedding_filepath, sep = '\t', index_col = 0).values
    
    def _combine_embedding(self, 
                           scvis_reps: np.ndarray, 
                           pred_probs: np.ndarray) -> np.ndarray:
        ''' Normalizes the scvis_reps and appends the predicted probabilities.
        '''
        # Normalize the embeddings using the minimum and maximum column values
        X = np.copy(scvis_reps)
        X -= self.min_scvis_vals
        X /= self.max_scvis_vals
        
        # Append (weighted) predicted probabilities to the embedding
        return np.concatenate((X, self.config.weight * pred_probs.reshape(-1, 1)), axis = 1)
        
    def _fit_scvis(
        self, embeddings: np.ndarray
    ):
        ''' Fits an scvis model to the input embedding(s).
        '''
        if self.fit_scvis:
            ### Fit scvis
            
            # Make output directory
            os.system(f'rm -rf {self.config.scvis_output_dir}')
            os.system(f'mkdir {self.config.scvis_output_dir}')

            # Dump the embeddings as a CSV file
            embedding_filepath = f'{self.config.scvis_output_dir}/tmp.tsv'
            embedding_df = pd.DataFrame(embeddings)
            embedding_df.to_csv(embedding_filepath, sep = '\t', index = False)

            # Run scvis using the command line
            # source: https://github.com/shahcompbio/scvis
            command = f'conda run -n {self.scvis_conda_env} scvis train --data_matrix_file {embedding_filepath} --out_dir {self.config.scvis_output_dir}'

            if self.config.scvis_config_path is not None:
                print(self.config.scvis_config_path)
                # Add optional scvis config
                command += f' --config_file {self.config.scvis_config_path}'

            # Run the command (blocking)
            print(command)
            os.system(command)
            print('done')

            # Cleanup
            os.system('rm -rf {}'.format(embedding_filepath))
        
        ### Load and return the scvis embeddings
        return self._load_scvis_embeddings()
    
    def _fit_gmm(
        self, 
        reduced_embeddings: np.ndarray, 
        pred_probs: np.ndarray,
        random_state: int, # random state for sklearn
        verbose: bool = False,
    ):
        ''' Fits an error-aware Gaussian Mixture model to the scvis embeddings 
            and model predictions.
        '''
        # Store the min and max column values to normalize in the future.
        self.min_scvis_vals = np.min(reduced_embeddings, axis = 0)
        self.max_scvis_vals = np.max(reduced_embeddings, axis = 0)

        X = self._combine_embedding(reduced_embeddings, pred_probs)
        
        lowest_bic = np.infty
        bic = []
        n_components_range = range(self.config.n_slices, self.config.n_max_mixture_components)
        for n_components in n_components_range:
            # Fit a GMM with n_components components
            gmm = mixture.GaussianMixture(n_components = n_components, 
                                          covariance_type = 'full', 
                                          random_state = random_state)
            gmm.fit(X)
            
            # Calculate the Bayesian Information Criteria
            bic.append(gmm.bic(X))
            if bic[-1] < lowest_bic:
                lowest_bic = bic[-1]
                best_gmm = gmm
                
        self.gmm = best_gmm
        
        # Assign a score to each mixture component to find the top-k slices
        # Create the map from "group" to "set of points" (recorded as indices)
        hard_preds = self.gmm.predict(X)
        
        cluster_map = defaultdict(list)
        for i, v in enumerate(hard_preds):        
            cluster_map[v].append(i)
            
        if verbose:
            print(f'The best GMM has {len(cluster_map)} components.')
            
        # Score each of those groups
        scores = []
        errors = (1. - pred_probs)
        for i in cluster_map:
            indices = cluster_map[i]
            score = len(indices) * np.mean(errors[indices]) ** 2 # Equivalent to 'number of errors * error rate'
            scores.append((i, score))
        scores = sorted(scores, key = lambda x: -1 * x[1])
        
        # Store the indices of the mixture components with the highest scores
        self.slice_indices = np.array([t[0] for t in scores[:self.config.n_slices]])
        
        if verbose:
            print('Scores:')
            for i, score in scores:
                indices = cluster_map[i]
                print(i, score, len(indices) * np.mean(errors[indices]), np.mean(errors[indices]))
            print()

## Demo

(copied from `examples/01_intro.ipynb`)

In [3]:
import os

dp = mk.datasets.get("imagenette")

# we'll only be using the validation data
dp = dp.lz[dp["split"] == "valid"]

In [4]:
import torch
from torchvision.models import resnet18
import torchvision.transforms as transforms
model = resnet18(pretrained=True)

In [5]:
# 1. Define transform
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]),
])

# 2. Create new column with transform 
dp["input"] = dp["img"].to_lambda(transform)

# 1. Move the model to device
DEVICE = 'cpu'
model.to(DEVICE).eval()

# 2. Define a function that runs a forward pass over a batch 
@torch.no_grad()
def predict(batch: mk.DataPanel):
    input_col: mk.TensorColumn = batch["input"] 
    x: torch.Tensor = input_col.data.to(DEVICE)  # We get the underlying torch tensor with `data` and move to GPU 
    out: torch.Tensor = model(x)  # Run forward pass

    # Return a dictionary with one key for each of the new columns. Each value in the
    # dictionary should have the same length as the batch. 
    return {
        "pred": out.cpu().numpy().argmax(axis=-1),
        "probs": torch.softmax(out, axis=-1).cpu().numpy(),
    }

# 3. Apply the update. Note that the `predict` function operates on batches, so we set 
# `is_batched_fn=True`. Also, the `predict` function only accesses the "input" column, by 
# specifying that here we instruct update to only load that one column and skip others 
dp = dp.update(
    function=predict,
    is_batched_fn=True,
    batch_size=32,
    input_columns=["input"], 
    pbar=True
)

  0%|          | 0/123 [00:00<?, ?it/s]

In [6]:
import hashlib

import hashlib
x = "helo"
int.from_bytes(hashlib.sha256(x.encode('utf-8')).digest(), 'big') % 100

dp["correct"] = dp["pred"] == mk.NumpyArrayColumn(dp["label_idx"])
accuracy = dp["correct"].mean()
print(f"Micro accuracy across the ten Imagenette classes: {accuracy:0.3}")

Micro accuracy across the ten Imagenette classes: 0.672


In [7]:
# Choose a single ImageNet class

In [8]:
LABEL_IDX = 571

# convert to a binary task 
dp["prob"] = dp["probs"][:, LABEL_IDX]
dp["target"] = (dp["label_idx"] == LABEL_IDX)

# Drop rows that do not have the label
dp = dp[dp['target']]

## 1. Embed

In [9]:
from domino import embed

In [10]:
class Features:
    def __init__(self, requires_grad = None):
        self.features = None
        self.requires_grad = requires_grad
        
    def __call__(self, modules, module_in, module_out):
        if self.requires_grad is not None:
            module_out.requires_grad = self.requires_grad
        self.features = module_out
        
# Register feature hook
feature_hook = Features()
handle = list(model.modules())[66].register_forward_hook(feature_hook)

In [11]:
# Extract the last-layer embeddings from the model, and add them as the "embedding" column.

def last_layer(batch: mk.DataPanel):
    input_col: mk.TensorColumn = batch["input"] 
    x: torch.Tensor = input_col.data.to(DEVICE)  # We get the underlying torch tensor with `data` and move to GPU 
    
    ## add a hook to the model
    out: torch.Tensor = model(x)  # Run forward pass
    features: np.ndarray = feature_hook.features.data.cpu().numpy()

    # Return a dictionary with one key for each of the new columns. Each value in the
    # dictionary should have the same length as the batch. 
    return {
        "embedding": features
    }


In [12]:
dp = dp.update(
    function=last_layer,
    is_batched_fn=True,
    batch_size=32,
    input_columns=["input"], 
    pbar=True
)

  0%|          | 0/14 [00:00<?, ?it/s]

# 2. Slice

In [114]:
planespot = PlaneSpotSlicer(scvis_conda_env = 'scvis',
                           fit_scvis = False,
                           n_slices = 10)

In [115]:
np.mean(dp['prob'])

0.54146

In [116]:
planespot.fit(data = dp, 
              embeddings="embedding", 
              targets="target", 
              pred_probs="prob", 
              verbose = True, 
              random_state = 42)


Fitting scvis...
Fitting GMM...
The best GMM has 10 components.
Scores:
6 73.16509943638584 77.45668566226959 0.9445937
7 25.1733506927864 26.070682168006897 0.9655808
8 21.44116238590626 28.544074177742004 0.75115985
0 10.568749205971507 24.109774082899094 0.43835953
1 6.459923337723339 12.18926727771759 0.52996814
4 5.170847513938526 6.431701183319092 0.80396265
9 2.919836722199136 7.641775608062744 0.38208878
2 0.8839976179332965 2.6593196392059326 0.33241495
3 0.595372972376218 3.7800729274749756 0.15750304
5 0.07857854842482989 3.2449230328202248 0.024215844



In [117]:
dp['scvis'] = planespot._load_scvis_embeddings()

In [118]:
dp["planespot_prbs"] = planespot.predict_proba(
    data=dp, scvis_embeddings = 'scvis', pred_probs = 'prob'
)

dp["planespot_hard"] = planespot.predict(
    data=dp, scvis_embeddings = 'scvis', pred_probs = 'prob'
)

In [119]:
dp.head()

Unnamed: 0,img_path (PandasSeriesColumn),label (PandasSeriesColumn),label_id (PandasSeriesColumn),label_idx (PandasSeriesColumn),split (PandasSeriesColumn),img (ListColumn),input (TensorColumn),pred (NumpyArrayColumn),probs (NumpyArrayColumn),correct (NumpyArrayColumn),prob (NumpyArrayColumn),target (PandasSeriesColumn),embedding (NumpyArrayColumn),planespot_slices (NumpyArrayColumn),scvis (NumpyArrayColumn),planespot_prbs (NumpyArrayColumn),planespot_hard (NumpyArrayColumn)
0,val/n03425413/n03425413_10022.JPEG,gas pump,n03425413,571,valid,,"torch.Tensor(shape=torch.Size([3, 224, 224]))",571,"np.ndarray(shape=(1000,))",True,0.518342,True,"np.ndarray(shape=(512, 1, 1))","np.ndarray(shape=(2,))","np.ndarray(shape=(2,))","np.ndarray(shape=(10,))",3
1,val/n03425413/n03425413_14600.JPEG,gas pump,n03425413,571,valid,,"torch.Tensor(shape=torch.Size([3, 224, 224]))",571,"np.ndarray(shape=(1000,))",True,0.995525,True,"np.ndarray(shape=(512, 1, 1))","np.ndarray(shape=(2,))","np.ndarray(shape=(2,))","np.ndarray(shape=(10,))",9
2,val/n03425413/n03425413_3471.JPEG,gas pump,n03425413,571,valid,,"torch.Tensor(shape=torch.Size([3, 224, 224]))",571,"np.ndarray(shape=(1000,))",True,0.619042,True,"np.ndarray(shape=(512, 1, 1))","np.ndarray(shape=(2,))","np.ndarray(shape=(2,))","np.ndarray(shape=(10,))",3
3,val/n03425413/n03425413_15700.JPEG,gas pump,n03425413,571,valid,,"torch.Tensor(shape=torch.Size([3, 224, 224]))",571,"np.ndarray(shape=(1000,))",True,0.561749,True,"np.ndarray(shape=(512, 1, 1))","np.ndarray(shape=(2,))","np.ndarray(shape=(2,))","np.ndarray(shape=(10,))",3
4,val/n03425413/n03425413_6772.JPEG,gas pump,n03425413,571,valid,,"torch.Tensor(shape=torch.Size([3, 224, 224]))",571,"np.ndarray(shape=(1000,))",True,0.831102,True,"np.ndarray(shape=(512, 1, 1))","np.ndarray(shape=(2,))","np.ndarray(shape=(2,))","np.ndarray(shape=(10,))",3


In [120]:
pd.Series(np.array(dp['planespot_hard'])).value_counts()

9    134
0     82
3     55
2     38
1     27
8     24
4     23
6     20
7      8
5      8
dtype: int64