## Overview

Notebook workspace for PlaneSpotSlicer.

Prerequisites:

* install scvis using this package: https://github.com/shahcompbio/scvis

In [1]:
from typing import Union

import meerkat as mk
import numpy as np
import torch
import torch.optim as optim
from torch.nn.functional import cross_entropy
from tqdm import tqdm

from domino.utils import unpack_args

from abstract import Slicer
from sklearn import mixture

import glob

In [26]:
from typing import Union

import meerkat as mk
import numpy as np
import torch
import torch.optim as optim
from torch.nn.functional import cross_entropy
from tqdm import tqdm

from domino.utils import unpack_args

from abstract import Slicer

## PlaneSpot imports
from sklearn import mixture
import glob

from domino.utils import convert_to_numpy, unpack_args
import pandas as pd

class PlaneSpotSlicer(Slicer):
    r"""
    Implements PlaneSpot [plumb_2023], a simple SDM that fits a GMM to a 2D model 
    embedding, fit using scvis [ding_2018]. 

    ..  [plumb_2023]
        Gregory Plumb*, Nari Johnson*, Ángel Alexander Cabrera, Ameet Talwalkar.
        Towards a More Rigorous Science of Blindspot Discovery in Image 
        Classification Models. arXiv:2207.04104 [cs] (2023)
        
    ..  [ding_2018]
        Jiarui Ding, Anne Condon, and Sohrab P Shah. 
        Interpretable dimensionality reduction of single cell transcriptome 
        data with deep generative models. 
        Nature communications, 9(1):1–13. (2018)
    """

    def __init__(
        self,
        scvis_conda_env: str, # name of conda environment where scvis is installed
        n_slices: int = 10,
        n_max_mixture_components: int = 33, # maximum number of mixture components
        weight: float = 0.025, # weight hyperparameter
        scvis_config_path = None, # custom scvis config path
        scvis_output_dir = 'scvis', # path to output directory for scvis
        fit_scvis = True # flag to load rather than re-compute the scvis embedding 
    ):
        super().__init__(n_slices=n_slices)
        
        # scvis hyper-parameters
        self.scvis_conda_env = scvis_conda_env
        self.config.scvis_config_path = scvis_config_path
        self.config.scvis_output_dir = scvis_output_dir
        self.fit_scvis = fit_scvis
        
        # GMM hyper-parameters
        self.config.n_max_mixture_components = n_max_mixture_components
        self.config.weight = weight

        self.gmm = None

    def fit(
        self,
        data: Union[dict, mk.DataPanel] = None,
        embeddings: Union[str, np.ndarray] = "embedding",
        targets: Union[str, np.ndarray] = None,
        pred_probs: Union[str, np.ndarray] = None,
        losses: Union[str, np.ndarray] = None,
        verbose: bool = True,
        **kwargs
    ):
        embeddings, targets, pred_probs, losses = unpack_args(
            data, embeddings, targets, pred_probs, losses
        )
        
        embeddings, targets, pred_probs = convert_to_numpy(
            embeddings, targets, pred_probs
        )
        
        # 1.  Fit scvis.
        if verbose:
            print('Fitting scvis...')
        
        scvis_embeddings = self._fit_scvis(embeddings.reshape(embeddings.shape[0], embeddings.shape[1]))
        
        # 2.  Fit GMM.
        if verbose:
            print('Fitting GMM...')
            
        self._fit_gmm(scvis_embeddings,
                     pred_probs)

    def predict_proba(
        self,
        data: Union[dict, mk.DataPanel] = None,
        embeddings: Union[str, np.ndarray] = "embedding",
        targets: Union[str, np.ndarray] = None,
        pred_probs: Union[str, np.ndarray] = None,
        losses: Union[str, np.ndarray] = None,
    ) -> np.ndarray:
        embeddings, targets, pred_probs, losses = unpack_args(
            data, embeddings, targets, pred_probs, losses
        )

        losses = self._compute_losses(
            pred_probs=pred_probs, targets=targets, losses=losses
        )
        embeddings = torch.tensor(embeddings).to(
            dtype=torch.float, device=self.config.device
        )

        all_weights = []

        for slice_idx in range(self.config.n_slices):
            weights, _, _, _ = md_adversary_weights(
                mean=self.means[slice_idx],
                precision=torch.exp(self.precisions[slice_idx])
                * torch.eye(self.means[slice_idx].shape[0], device=self.config.device),
                x=embeddings,
                losses=losses,
            )
            all_weights.append(weights.cpu().numpy())
        return np.stack(all_weights, axis=1)

    def predict(
        self,
        data: mk.DataPanel,
        embeddings: Union[str, np.ndarray] = "embedding",
        targets: Union[str, np.ndarray] = None,
        pred_probs: Union[str, np.ndarray] = None,
        losses: Union[str, np.ndarray] = None,
    ) -> np.ndarray:
        probs = self.predict_proba(
            data=data,
            embeddings=embeddings,
            targets=targets,
            pred_probs=pred_probs,
            losses=losses,
        )

        # TODO (Greg): check if this is the preferred way to get hard predictions from
        # probabilities
        return (probs > 0.5).astype(np.int32)

    def _fit_scvis(
        self, embeddings: np.ndarray
    ):
        ''' Fits an scvis model to the input embedding(s).
        '''
        if self.fit_scvis:
            ### Fit scvis
            
            # Make output directory
            os.system(f'rm -rf {self.config.scvis_output_dir}')
            os.system(f'mkdir {self.config.scvis_output_dir}')

            # Dump the embeddings as a CSV file
            embedding_filepath = f'{self.config.scvis_output_dir}/tmp.tsv'
            embedding_df = pd.DataFrame(embeddings)
            embedding_df.to_csv(embedding_filepath, sep = '\t', index = False)

            # Run scvis using the command line
            # source: https://github.com/shahcompbio/scvis
            command = f'conda run -n {self.scvis_conda_env} scvis train --data_matrix_file {embedding_filepath} --out_dir {self.config.scvis_output_dir}'

            if self.config.scvis_config_path is not None:
                print(self.config.scvis_config_path)
                # Add optional scvis config
                command += f' --config_file {self.config.scvis_config_path}'

            # Run the command (blocking)
            print(command)
            os.system(command)
            print('done')

            # Cleanup
            os.system('rm -rf {}'.format(embedding_filepath))
        
        ### Load and return the scvis embeddings
        search_string = f'{self.config.scvis_output_dir}/*.tsv'
        scvis_embedding_filepath = sorted(glob.glob(search_string), key = len)[0]
        return pd.read_csv(scvis_embedding_filepath, sep = '\t', index_col = 0).values
    
    def _fit_gmm(
        self, reduced_embeddings: np.ndarray, pred_prbs: np.ndarray
    ):
        # Normalize the embeddings to have range [0, 1]
        X = np.copy(reduced_embeddings)
        X -= np.min(X, axis = 0)
        X /= np.max(X, axis = 0)
        
        # Append (weighted) predicted probabilities to the embedding
        X = np.concatenate((X, self.config.weight * pred_prbs.reshape(-1, 1)), axis = 1)
        
        lowest_bic = np.infty
        bic = []
        n_components_range = range(self.config.n_slices, self.config.n_max_mixture_components)

        for n_components in n_components_range:
            # Fit a GMM with n_components components
            gmm = mixture.GaussianMixture(n_components = n_components, covariance_type = 'full')
            gmm.fit(X)
            
            # Calculate the Bayesian Information Criteria
            bic.append(gmm.bic(X))
            if bic[-1] < lowest_bic:
                lowest_bic = bic[-1]
                best_gmm = gmm
                
        self.gmm = best_gmm




## Demo

(copied from `examples/01_intro.ipynb`)

In [3]:
import os

dp = mk.datasets.get("imagenette")

# we'll only be using the validation data
dp = dp.lz[dp["split"] == "valid"]

In [4]:
import torch
from torchvision.models import resnet18
import torchvision.transforms as transforms
model = resnet18(pretrained=True)

In [5]:
# 1. Define transform
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]),
])

# 2. Create new column with transform 
dp["input"] = dp["img"].to_lambda(transform)

# 1. Move the model to device
DEVICE = 'cpu'
model.to(DEVICE).eval()

# 2. Define a function that runs a forward pass over a batch 
@torch.no_grad()
def predict(batch: mk.DataPanel):
    input_col: mk.TensorColumn = batch["input"] 
    x: torch.Tensor = input_col.data.to(DEVICE)  # We get the underlying torch tensor with `data` and move to GPU 
    out: torch.Tensor = model(x)  # Run forward pass

    # Return a dictionary with one key for each of the new columns. Each value in the
    # dictionary should have the same length as the batch. 
    return {
        "pred": out.cpu().numpy().argmax(axis=-1),
        "probs": torch.softmax(out, axis=-1).cpu().numpy(),
    }

# 3. Apply the update. Note that the `predict` function operates on batches, so we set 
# `is_batched_fn=True`. Also, the `predict` function only accesses the "input" column, by 
# specifying that here we instruct update to only load that one column and skip others 
dp = dp.update(
    function=predict,
    is_batched_fn=True,
    batch_size=32,
    input_columns=["input"], 
    pbar=True
)

  0%|          | 0/123 [00:00<?, ?it/s]

In [6]:
import hashlib

import hashlib
x = "helo"
int.from_bytes(hashlib.sha256(x.encode('utf-8')).digest(), 'big') % 100

dp["correct"] = dp["pred"] == mk.NumpyArrayColumn(dp["label_idx"])
accuracy = dp["correct"].mean()
print(f"Micro accuracy across the ten Imagenette classes: {accuracy:0.3}")

Micro accuracy across the ten Imagenette classes: 0.672


In [7]:
# Choose a single ImageNet class

In [8]:
LABEL_IDX = 571

# convert to a binary task 
dp["prob"] = dp["probs"][:, LABEL_IDX]
dp["target"] = (dp["label_idx"] == LABEL_IDX)

## 1. Embed

In [9]:
from domino import embed

In [10]:
class Features:
    def __init__(self, requires_grad = None):
        self.features = None
        self.requires_grad = requires_grad
        
    def __call__(self, modules, module_in, module_out):
        if self.requires_grad is not None:
            module_out.requires_grad = self.requires_grad
        self.features = module_out
        
# Register feature hook
feature_hook = Features()
handle = list(model.modules())[66].register_forward_hook(feature_hook)

In [11]:
# Extract the last-layer embeddings from the model, and add them as the "embedding" column.

def last_layer(batch: mk.DataPanel):
    input_col: mk.TensorColumn = batch["input"] 
    x: torch.Tensor = input_col.data.to(DEVICE)  # We get the underlying torch tensor with `data` and move to GPU 
    
    ## add a hook to the model
    out: torch.Tensor = model(x)  # Run forward pass
    features: np.ndarray = feature_hook.features.data.cpu().numpy()

    # Return a dictionary with one key for each of the new columns. Each value in the
    # dictionary should have the same length as the batch. 
    return {
        "embedding": features
    }


In [12]:
dp = dp.update(
    function=last_layer,
    is_batched_fn=True,
    batch_size=32,
    input_columns=["input"], 
    pbar=True
)

  0%|          | 0/123 [00:00<?, ?it/s]

# 2. Slice

In [27]:
planespot = PlaneSpotSlicer(scvis_conda_env = 'scvis',
                           fit_scvis = False)

In [None]:
planespot.fit(data = dp, embeddings="embedding", targets="target", pred_probs="prob")

# dp["planespot_slices"] = .predict_proba(
#     data=dp, embeddings="clip(img)", targets="target", pred_probs="prob"
# )

Fitting scvis...
Fitting GMM...
