## Overview

Notebook workspace for PlaneSpotSlicer.

Prerequisites:

* install scvis using this package: https://github.com/shahcompbio/scvis

In [1]:
from typing import Union

import meerkat as mk
import numpy as np
import torch
import torch.optim as optim
from torch.nn.functional import cross_entropy
from tqdm import tqdm

from domino.utils import unpack_args

from abstract import Slicer
from sklearn import mixture

import glob

In [62]:
from typing import Union

import meerkat as mk
import numpy as np
import torch
import torch.optim as optim
from torch.nn.functional import cross_entropy
from tqdm import tqdm

from domino.utils import unpack_args

from abstract import Slicer

## PlaneSpot imports
from sklearn import mixture
import glob

from domino.utils import convert_to_numpy, unpack_args
import pandas as pd

class PlaneSpotSlicer(Slicer):
    r"""
    Implements PlaneSpot [plumb_2023], a simple SDM that fits a GMM to a 2D model 
    embedding, fit using scvis [ding_2018]. 

    ..  [plumb_2023]
        Gregory Plumb*, Nari Johnson*, Ángel Alexander Cabrera, Ameet Talwalkar.
        Towards a More Rigorous Science of Blindspot Discovery in Image 
        Classification Models. arXiv:2207.04104 [cs] (2023)
        
    ..  [ding_2018]
        Jiarui Ding, Anne Condon, and Sohrab P Shah. 
        Interpretable dimensionality reduction of single cell transcriptome 
        data with deep generative models. 
        Nature communications, 9(1):1–13. (2018)
    """

    def __init__(
        self,
        scvis_conda_env: str, # name of conda environment where scvis is installed
        n_slices: int = 10,
        n_max_mixture_components: int = 33, # maximum number of mixture components
        weight: float = 0.025, # weight hyperparameter
        scvis_config_path = None, # custom scvis config path
        scvis_output_dir = 'scvis', # path to output directory for scvis
        fit_scvis = True # flag to load rather than re-compute the scvis embedding 
    ):
        super().__init__(n_slices=n_slices)
        
        # scvis hyper-parameters
        self.scvis_conda_env = scvis_conda_env
        self.config.scvis_config_path = scvis_config_path
        self.config.scvis_output_dir = scvis_output_dir
        self.fit_scvis = fit_scvis
        
        # GMM hyper-parameters
        self.config.n_max_mixture_components = n_max_mixture_components
        self.config.weight = weight

        self.gmm = None

    def fit(
        self,
        data: Union[dict, mk.DataPanel] = None,
        embeddings: Union[str, np.ndarray] = "embedding",
        targets: Union[str, np.ndarray] = None,
        pred_probs: Union[str, np.ndarray] = None,
        losses: Union[str, np.ndarray] = None,
        verbose: bool = True,
        **kwargs
    ):
        embeddings, targets, pred_probs, losses = unpack_args(
            data, embeddings, targets, pred_probs, losses
        )
        
        embeddings, targets, pred_probs = convert_to_numpy(
            embeddings, targets, pred_probs
        )
        
        # 1.  Fit scvis.
        if verbose:
            print('Fitting scvis...')
        
        scvis_embeddings = self._fit_scvis(embeddings.reshape(embeddings.shape[0], embeddings.shape[1]))
        
        # 2.  Fit GMM.
        if verbose:
            print('Fitting GMM...')
            
        self._fit_gmm(scvis_embeddings,
                     pred_probs)

    def predict_proba(
        self,
        data: mk.DataPanel,
        scvis_embeddings: str, # scvis column name
        pred_probs: str, # predicted probabilities column name
    ) -> np.ndarray:
        
        # Append the scvis embedding and predicted probabilities; normalize
        X = self._combine_embedding(dp[scvis_embeddings], dp[pred_probs])
        return self.gmm.predict_proba(X)

    def predict(
        self,
        data: mk.DataPanel,
        embeddings: Union[str, np.ndarray] = "embedding",
        targets: Union[str, np.ndarray] = None,
        pred_probs: Union[str, np.ndarray] = None,
        losses: Union[str, np.ndarray] = None,
    ) -> np.ndarray:
        probs = self.predict_proba(
            data=data,
            embeddings=embeddings,
            targets=targets,
            pred_probs=pred_probs,
            losses=losses,
        )

        # TODO (Greg): check if this is the preferred way to get hard predictions from
        # probabilities
        return (probs > 0.5).astype(np.int32)
    
    def _load_scvis_embeddings(self) -> np.ndarray:
        ''' Load and return the scvis embeddings.
        '''
        ### Load and return the scvis embeddings
        search_string = f'{self.config.scvis_output_dir}/*.tsv'
        scvis_embedding_filepath = sorted(glob.glob(search_string), key = len)[0]
        return pd.read_csv(scvis_embedding_filepath, sep = '\t', index_col = 0).values
    
    def _combine_embedding(self, 
                           scvis_reps: np.ndarray, 
                           pred_probs: np.ndarray) -> np.ndarray:
        ''' Normalizes 
        '''
        # Normalize the embeddings using the minimum and maximum column values
        X = np.copy(scvis_reps)
        X -= self.min_scvis_vals
        X /= self.max_scvis_vals
        
        # Append (weighted) predicted probabilities to the embedding
        return np.concatenate((X, self.config.weight * pred_probs.reshape(-1, 1)), axis = 1)
        
    def _fit_scvis(
        self, embeddings: np.ndarray
    ):
        ''' Fits an scvis model to the input embedding(s).
        '''
        if self.fit_scvis:
            ### Fit scvis
            
            # Make output directory
            os.system(f'rm -rf {self.config.scvis_output_dir}')
            os.system(f'mkdir {self.config.scvis_output_dir}')

            # Dump the embeddings as a CSV file
            embedding_filepath = f'{self.config.scvis_output_dir}/tmp.tsv'
            embedding_df = pd.DataFrame(embeddings)
            embedding_df.to_csv(embedding_filepath, sep = '\t', index = False)

            # Run scvis using the command line
            # source: https://github.com/shahcompbio/scvis
            command = f'conda run -n {self.scvis_conda_env} scvis train --data_matrix_file {embedding_filepath} --out_dir {self.config.scvis_output_dir}'

            if self.config.scvis_config_path is not None:
                print(self.config.scvis_config_path)
                # Add optional scvis config
                command += f' --config_file {self.config.scvis_config_path}'

            # Run the command (blocking)
            print(command)
            os.system(command)
            print('done')

            # Cleanup
            os.system('rm -rf {}'.format(embedding_filepath))
        
        ### Load and return the scvis embeddings
        return self._load_scvis_embeddings()
    
    def _fit_gmm(
        self, 
        reduced_embeddings: np.ndarray, 
        pred_probs: np.ndarray
    ):
        # Store the min and max column values to normalize in the future.
        self.min_scvis_vals = np.min(reduced_embeddings, axis = 0)
        self.max_scvis_vals = np.max(reduced_embeddings, axis = 0)

        X = self._combine_embedding(reduced_embeddings, pred_probs)
        
        lowest_bic = np.infty
        bic = []
        n_components_range = range(self.config.n_slices, self.config.n_max_mixture_components)

        for n_components in n_components_range:
            # Fit a GMM with n_components components
            gmm = mixture.GaussianMixture(n_components = n_components, covariance_type = 'full')
            gmm.fit(X)
            
            # Calculate the Bayesian Information Criteria
            bic.append(gmm.bic(X))
            if bic[-1] < lowest_bic:
                lowest_bic = bic[-1]
                best_gmm = gmm
                
        self.gmm = best_gmm




## Demo

(copied from `examples/01_intro.ipynb`)

In [3]:
import os

dp = mk.datasets.get("imagenette")

# we'll only be using the validation data
dp = dp.lz[dp["split"] == "valid"]

In [4]:
import torch
from torchvision.models import resnet18
import torchvision.transforms as transforms
model = resnet18(pretrained=True)

In [5]:
# 1. Define transform
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]),
])

# 2. Create new column with transform 
dp["input"] = dp["img"].to_lambda(transform)

# 1. Move the model to device
DEVICE = 'cpu'
model.to(DEVICE).eval()

# 2. Define a function that runs a forward pass over a batch 
@torch.no_grad()
def predict(batch: mk.DataPanel):
    input_col: mk.TensorColumn = batch["input"] 
    x: torch.Tensor = input_col.data.to(DEVICE)  # We get the underlying torch tensor with `data` and move to GPU 
    out: torch.Tensor = model(x)  # Run forward pass

    # Return a dictionary with one key for each of the new columns. Each value in the
    # dictionary should have the same length as the batch. 
    return {
        "pred": out.cpu().numpy().argmax(axis=-1),
        "probs": torch.softmax(out, axis=-1).cpu().numpy(),
    }

# 3. Apply the update. Note that the `predict` function operates on batches, so we set 
# `is_batched_fn=True`. Also, the `predict` function only accesses the "input" column, by 
# specifying that here we instruct update to only load that one column and skip others 
dp = dp.update(
    function=predict,
    is_batched_fn=True,
    batch_size=32,
    input_columns=["input"], 
    pbar=True
)

  0%|          | 0/123 [00:00<?, ?it/s]

In [6]:
import hashlib

import hashlib
x = "helo"
int.from_bytes(hashlib.sha256(x.encode('utf-8')).digest(), 'big') % 100

dp["correct"] = dp["pred"] == mk.NumpyArrayColumn(dp["label_idx"])
accuracy = dp["correct"].mean()
print(f"Micro accuracy across the ten Imagenette classes: {accuracy:0.3}")

Micro accuracy across the ten Imagenette classes: 0.672


In [7]:
# Choose a single ImageNet class

In [8]:
LABEL_IDX = 571

# convert to a binary task 
dp["prob"] = dp["probs"][:, LABEL_IDX]
dp["target"] = (dp["label_idx"] == LABEL_IDX)

## 1. Embed

In [9]:
from domino import embed

In [10]:
class Features:
    def __init__(self, requires_grad = None):
        self.features = None
        self.requires_grad = requires_grad
        
    def __call__(self, modules, module_in, module_out):
        if self.requires_grad is not None:
            module_out.requires_grad = self.requires_grad
        self.features = module_out
        
# Register feature hook
feature_hook = Features()
handle = list(model.modules())[66].register_forward_hook(feature_hook)

In [11]:
# Extract the last-layer embeddings from the model, and add them as the "embedding" column.

def last_layer(batch: mk.DataPanel):
    input_col: mk.TensorColumn = batch["input"] 
    x: torch.Tensor = input_col.data.to(DEVICE)  # We get the underlying torch tensor with `data` and move to GPU 
    
    ## add a hook to the model
    out: torch.Tensor = model(x)  # Run forward pass
    features: np.ndarray = feature_hook.features.data.cpu().numpy()

    # Return a dictionary with one key for each of the new columns. Each value in the
    # dictionary should have the same length as the batch. 
    return {
        "embedding": features
    }


In [12]:
dp = dp.update(
    function=last_layer,
    is_batched_fn=True,
    batch_size=32,
    input_columns=["input"], 
    pbar=True
)

  0%|          | 0/123 [00:00<?, ?it/s]

# 2. Slice

In [63]:
planespot = PlaneSpotSlicer(scvis_conda_env = 'scvis',
                           fit_scvis = False)

In [64]:
planespot.fit(data = dp, embeddings="embedding", targets="target", pred_probs="prob")

Fitting scvis...
Fitting GMM...


In [65]:
dp['scvis'] = planespot._load_scvis_embeddings()

In [66]:
dp.head()

Unnamed: 0,img_path (PandasSeriesColumn),label (PandasSeriesColumn),label_id (PandasSeriesColumn),label_idx (PandasSeriesColumn),split (PandasSeriesColumn),img (ImageColumn),input (LambdaColumn),pred (NumpyArrayColumn),probs (NumpyArrayColumn),correct (NumpyArrayColumn),prob (NumpyArrayColumn),target (PandasSeriesColumn),embedding (NumpyArrayColumn),planespot_slices (NumpyArrayColumn),scvis (NumpyArrayColumn)
0,val/n02979186/n02979186_8971.JPEG,cassette player,n02979186,482,valid,,"tensor([[[2.2318, 2.2318, 2.2318, ..., 2.2318, 2.2318, 2.2318],  [2.2318, 2.2318, 2.2318, ..., 2.2318, 2.2318, 2.2318],  [2.2318, 2.2318, 2.2318, ..., 2.2318, 2.2318, 2.2318],  ...,  [2.2318, 2.2318, 2.2318, ..., 2.2318, 2.2318, 2.2318],  [2.2318, 2.2318, 2.2318, ..., 2.2318, 2.2318, 2.2318],  [2.2318, 2.2318, 2.2318, ..., 2.2318, 2.2318, 2.2318]],  [[2.4111, 2.4111, 2.4111, ..., 2.4111, 2.4111, 2.4111],  [2.4111, 2.4111, 2.4111, ..., 2.4111, 2.4111, 2.4111],  [2.4111, 2.4111, 2.4111, ..., 2.4111, 2.4111, 2.4111],  ...,  [2.4111, 2.4111, 2.4111, ..., 2.4111, 2.4111, 2.4111],  [2.4111, 2.4111, 2.4111, ..., 2.4111, 2.4111, 2.4111],  [2.4111, 2.4111, 2.4111, ..., 2.4111, 2.4111, 2.4111]],  [[2.6226, 2.6226, 2.6226, ..., 2.6226, 2.6226, 2.6226],  [2.6226, 2.6226, 2.6226, ..., 2.6226, 2.6226, 2.6226],  [2.6226, 2.6226, 2.6226, ..., 2.6226, 2.6226, 2.6226],  ...,  [2.6226, 2.6226, 2.6226, ..., 2.6226, 2.6226, 2.6226],  [2.6226, 2.6226, 2.6226, ..., 2.6226, 2.6226, 2.6226],  [2.6226, 2.6226, 2.6226, ..., 2.6226, 2.6226, 2.6226]]])",482,"np.ndarray(shape=(1000,))",True,3.824188e-07,False,"np.ndarray(shape=(512, 1, 1))","np.ndarray(shape=(18,))","np.ndarray(shape=(2,))"
1,val/n02979186/n02979186_14550.JPEG,cassette player,n02979186,482,valid,,"tensor([[[-1.3987, -1.3987, -1.3987, ..., 0.5193, 0.5193, 0.4508],  [-1.3815, -1.3815, -1.3815, ..., 0.5878, 0.5878, 0.5364],  [-1.3302, -1.3302, -1.3302, ..., 0.7933, 0.8104, 0.7762],  ...,  [-1.9124, -1.8953, -1.8953, ..., -1.5870, -1.5870, -1.6213],  [-1.8953, -1.8782, -1.8782, ..., -1.6042, -1.6042, -1.6213],  [-1.8953, -1.8782, -1.8782, ..., -1.6384, -1.6384, -1.6213]],  [[-1.3529, -1.3529, -1.3529, ..., 1.2031, 1.2031, 1.1681],  [-1.3880, -1.3880, -1.3880, ..., 1.2731, 1.2906, 1.2556],  [-1.3880, -1.3880, -1.3880, ..., 1.4832, 1.5182, 1.5007],  ...,  [-2.0007, -1.9832, -1.9832, ..., -1.7206, -1.7206, -1.7381],  [-1.9832, -1.9657, -1.9657, ..., -1.7731, -1.7731, -1.7731],  [-1.9832, -1.9657, -1.9657, ..., -1.8431, -1.8431, -1.8081]],  [[-1.1073, -1.1073, -1.1073, ..., 1.1585, 1.1237, 1.0191],  [-1.1421, -1.1421, -1.1421, ..., 1.1934, 1.1759, 1.0714],  [-1.1421, -1.1421, -1.1421, ..., 1.3851, 1.3851, 1.2805],  ...,  [-1.7522, -1.7347, -1.7347, ..., -1.3687, -1.3861, -1.4210],  [-1.7347, -1.7173, -1.7173, ..., -1.4384, -1.4384, -1.4559],  [-1.7347, -1.7173, -1.7173, ..., -1.5081, -1.5081, -1.4907]]])",754,"np.ndarray(shape=(1000,))",False,0.001313724,False,"np.ndarray(shape=(512, 1, 1))","np.ndarray(shape=(18,))","np.ndarray(shape=(2,))"
2,val/n02979186/n02979186_11971.JPEG,cassette player,n02979186,482,valid,,"tensor([[[1.2385, 1.3584, 1.3755, ..., 0.6392, 0.6049, 0.6906],  [1.1529, 1.2557, 1.3070, ..., 0.6221, 0.6221, 0.7591],  [1.0673, 1.1529, 1.2214, ..., 0.6392, 0.6392, 0.8104],  ...,  [0.5536, 0.6221, 0.6734, ..., 0.4166, 0.4166, 0.4337],  [0.5364, 0.6392, 0.7077, ..., 0.3994, 0.3823, 0.3823],  [0.5536, 0.6734, 0.7591, ..., 0.3652, 0.3309, 0.3309]],  [[1.3081, 1.4307, 1.4482, ..., 0.7829, 0.7304, 0.7829],  [1.2206, 1.3256, 1.3782, ..., 0.7829, 0.7654, 0.8529],  [1.1331, 1.2206, 1.2906, ..., 0.8004, 0.8004, 0.9230],  ...,  [0.9230, 0.9930, 1.0455, ..., 0.7654, 0.7654, 0.7829],  [0.8880, 0.9755, 1.0630, ..., 0.7129, 0.7129, 0.7304],  [0.8704, 0.9930, 1.0805, ..., 0.6604, 0.6429, 0.6779]],  [[1.0017, 1.1237, 1.1411, ..., 0.7925, 0.7402, 0.7576],  [0.9145, 1.0191, 1.0714, ..., 0.7751, 0.7576, 0.8274],  [0.8274, 0.9145, 0.9842, ..., 0.7925, 0.7925, 0.8797],  ...,  [1.2980, 1.3677, 1.4200, ..., 1.2631, 1.2631, 1.2457],  [1.2631, 1.3502, 1.4374, ..., 1.2282, 1.2108, 1.1934],  [1.2457, 1.3677, 1.4548, ..., 1.1585, 1.1411, 1.1411]]])",482,"np.ndarray(shape=(1000,))",True,5.850358e-06,False,"np.ndarray(shape=(512, 1, 1))","np.ndarray(shape=(18,))","np.ndarray(shape=(2,))"
3,val/n02979186/n02979186_11550.JPEG,cassette player,n02979186,482,valid,,"tensor([[[-0.4739, -0.4397, -0.3198, ..., 2.0434, 2.0263, 1.9920],  [-0.3883, -0.3541, -0.2856, ..., 2.0434, 2.0263, 2.0092],  [-0.3712, -0.3712, -0.3369, ..., 2.0263, 2.0092, 1.9920],  ...,  [-1.2445, -1.2445, -1.2445, ..., -1.2959, -1.2959, -1.2959],  [-1.2959, -1.2959, -1.2959, ..., -1.3987, -1.3987, -1.3987],  [-1.3644, -1.3644, -1.3644, ..., -1.4843, -1.4843, -1.4843]],  [[-1.1253, -1.0728, -0.9678, ..., 0.9755, 0.9405, 0.9055],  [-1.0378, -1.0203, -0.9503, ..., 0.9755, 0.9405, 0.9230],  [-1.0728, -1.0728, -1.0203, ..., 0.9580, 0.9230, 0.9055],  ...,  [-1.8081, -1.8081, -1.8081, ..., -1.7381, -1.7381, -1.7381],  [-1.8256, -1.8256, -1.8256, ..., -1.7731, -1.7731, -1.7731],  [-1.8431, -1.8431, -1.8431, ..., -1.8081, -1.8081, -1.8081]],  [[-1.0898, -1.0376, -0.9330, ..., 0.3045, 0.2871, 0.2522],  [-1.0201, -1.0027, -0.9156, ..., 0.3219, 0.2871, 0.2696],  [-1.0724, -1.0724, -1.0201, ..., 0.3045, 0.2696, 0.2522],  ...,  [-1.7347, -1.7347, -1.7347, ..., -1.5604, -1.5604, -1.5604],  [-1.6824, -1.6824, -1.6824, ..., -1.4733, -1.4733, -1.4733],  [-1.6302, -1.6302, -1.6302, ..., -1.4559, -1.4559, -1.4559]]])",482,"np.ndarray(shape=(1000,))",True,1.493245e-06,False,"np.ndarray(shape=(512, 1, 1))","np.ndarray(shape=(18,))","np.ndarray(shape=(2,))"
4,val/n02979186/n02979186_8751.JPEG,cassette player,n02979186,482,valid,,"tensor([[[ 0.1254, 0.1254, 0.1254, ..., -0.0458, -0.0458, -0.0458],  [ 0.1254, 0.1254, 0.1254, ..., -0.0629, -0.0629, -0.0629],  [ 0.1254, 0.1254, 0.1083, ..., -0.0629, -0.0629, -0.0629],  ...,  [-1.6898, -1.6898, -1.6727, ..., -1.5870, -1.5699, -1.5699],  [-1.7583, -1.7412, -1.7240, ..., -1.5699, -1.5528, -1.5528],  [-1.8097, -1.7754, -1.7412, ..., -1.5699, -1.5528, -1.5528]],  [[-0.3550, -0.3550, -0.3550, ..., -0.4776, -0.4776, -0.4776],  [-0.3550, -0.3550, -0.3550, ..., -0.4776, -0.4776, -0.4776],  [-0.3550, -0.3550, -0.3725, ..., -0.4776, -0.4776, -0.4776],  ...,  [-1.4230, -1.4230, -1.4055, ..., -1.4930, -1.4755, -1.4755],  [-1.4755, -1.4580, -1.4405, ..., -1.5105, -1.4930, -1.4930],  [-1.5280, -1.4930, -1.4580, ..., -1.5455, -1.5280, -1.5280]],  [[-0.6541, -0.6541, -0.6541, ..., -0.7936, -0.7936, -0.7936],  [-0.6541, -0.6541, -0.6541, ..., -0.7761, -0.7761, -0.7761],  [-0.6541, -0.6541, -0.6715, ..., -0.7587, -0.7587, -0.7587],  ...,  [-1.3513, -1.3513, -1.3339, ..., -1.4036, -1.3861, -1.3861],  [-1.4384, -1.4210, -1.3861, ..., -1.4210, -1.4036, -1.4036],  [-1.5081, -1.4733, -1.4384, ..., -1.4384, -1.4210, -1.4210]]])",482,"np.ndarray(shape=(1000,))",True,5.843006e-07,False,"np.ndarray(shape=(512, 1, 1))","np.ndarray(shape=(18,))","np.ndarray(shape=(2,))"


In [67]:
dp["planespot_slices"] = planespot.predict_proba(
    data=dp, scvis_embeddings = 'scvis', pred_probs = 'prob'
)

Unnamed: 0,img_path (PandasSeriesColumn),label (PandasSeriesColumn),label_id (PandasSeriesColumn),label_idx (PandasSeriesColumn),split (PandasSeriesColumn),img (ImageColumn),input (LambdaColumn),pred (NumpyArrayColumn),probs (NumpyArrayColumn),correct (NumpyArrayColumn),prob (NumpyArrayColumn),target (PandasSeriesColumn),embedding (NumpyArrayColumn),scvis (NumpyArrayColumn),planespot_slices (NumpyArrayColumn)
0,val/n02979186/n02979186_8971.JPEG,cassette player,n02979186,482,valid,,"tensor([[[2.2318, 2.2318, 2.2318, ..., 2.2318, 2.2318, 2.2318],  [2.2318, 2.2318, 2.2318, ..., 2.2318, 2.2318, 2.2318],  [2.2318, 2.2318, 2.2318, ..., 2.2318, 2.2318, 2.2318],  ...,  [2.2318, 2.2318, 2.2318, ..., 2.2318, 2.2318, 2.2318],  [2.2318, 2.2318, 2.2318, ..., 2.2318, 2.2318, 2.2318],  [2.2318, 2.2318, 2.2318, ..., 2.2318, 2.2318, 2.2318]],  [[2.4111, 2.4111, 2.4111, ..., 2.4111, 2.4111, 2.4111],  [2.4111, 2.4111, 2.4111, ..., 2.4111, 2.4111, 2.4111],  [2.4111, 2.4111, 2.4111, ..., 2.4111, 2.4111, 2.4111],  ...,  [2.4111, 2.4111, 2.4111, ..., 2.4111, 2.4111, 2.4111],  [2.4111, 2.4111, 2.4111, ..., 2.4111, 2.4111, 2.4111],  [2.4111, 2.4111, 2.4111, ..., 2.4111, 2.4111, 2.4111]],  [[2.6226, 2.6226, 2.6226, ..., 2.6226, 2.6226, 2.6226],  [2.6226, 2.6226, 2.6226, ..., 2.6226, 2.6226, 2.6226],  [2.6226, 2.6226, 2.6226, ..., 2.6226, 2.6226, 2.6226],  ...,  [2.6226, 2.6226, 2.6226, ..., 2.6226, 2.6226, 2.6226],  [2.6226, 2.6226, 2.6226, ..., 2.6226, 2.6226, 2.6226],  [2.6226, 2.6226, 2.6226, ..., 2.6226, 2.6226, 2.6226]]])",482,"np.ndarray(shape=(1000,))",True,3.824188e-07,False,"np.ndarray(shape=(512, 1, 1))","np.ndarray(shape=(2,))","np.ndarray(shape=(19,))"
1,val/n02979186/n02979186_14550.JPEG,cassette player,n02979186,482,valid,,"tensor([[[-1.3987, -1.3987, -1.3987, ..., 0.5193, 0.5193, 0.4508],  [-1.3815, -1.3815, -1.3815, ..., 0.5878, 0.5878, 0.5364],  [-1.3302, -1.3302, -1.3302, ..., 0.7933, 0.8104, 0.7762],  ...,  [-1.9124, -1.8953, -1.8953, ..., -1.5870, -1.5870, -1.6213],  [-1.8953, -1.8782, -1.8782, ..., -1.6042, -1.6042, -1.6213],  [-1.8953, -1.8782, -1.8782, ..., -1.6384, -1.6384, -1.6213]],  [[-1.3529, -1.3529, -1.3529, ..., 1.2031, 1.2031, 1.1681],  [-1.3880, -1.3880, -1.3880, ..., 1.2731, 1.2906, 1.2556],  [-1.3880, -1.3880, -1.3880, ..., 1.4832, 1.5182, 1.5007],  ...,  [-2.0007, -1.9832, -1.9832, ..., -1.7206, -1.7206, -1.7381],  [-1.9832, -1.9657, -1.9657, ..., -1.7731, -1.7731, -1.7731],  [-1.9832, -1.9657, -1.9657, ..., -1.8431, -1.8431, -1.8081]],  [[-1.1073, -1.1073, -1.1073, ..., 1.1585, 1.1237, 1.0191],  [-1.1421, -1.1421, -1.1421, ..., 1.1934, 1.1759, 1.0714],  [-1.1421, -1.1421, -1.1421, ..., 1.3851, 1.3851, 1.2805],  ...,  [-1.7522, -1.7347, -1.7347, ..., -1.3687, -1.3861, -1.4210],  [-1.7347, -1.7173, -1.7173, ..., -1.4384, -1.4384, -1.4559],  [-1.7347, -1.7173, -1.7173, ..., -1.5081, -1.5081, -1.4907]]])",754,"np.ndarray(shape=(1000,))",False,1.313724e-03,False,"np.ndarray(shape=(512, 1, 1))","np.ndarray(shape=(2,))","np.ndarray(shape=(19,))"
2,val/n02979186/n02979186_11971.JPEG,cassette player,n02979186,482,valid,,"tensor([[[1.2385, 1.3584, 1.3755, ..., 0.6392, 0.6049, 0.6906],  [1.1529, 1.2557, 1.3070, ..., 0.6221, 0.6221, 0.7591],  [1.0673, 1.1529, 1.2214, ..., 0.6392, 0.6392, 0.8104],  ...,  [0.5536, 0.6221, 0.6734, ..., 0.4166, 0.4166, 0.4337],  [0.5364, 0.6392, 0.7077, ..., 0.3994, 0.3823, 0.3823],  [0.5536, 0.6734, 0.7591, ..., 0.3652, 0.3309, 0.3309]],  [[1.3081, 1.4307, 1.4482, ..., 0.7829, 0.7304, 0.7829],  [1.2206, 1.3256, 1.3782, ..., 0.7829, 0.7654, 0.8529],  [1.1331, 1.2206, 1.2906, ..., 0.8004, 0.8004, 0.9230],  ...,  [0.9230, 0.9930, 1.0455, ..., 0.7654, 0.7654, 0.7829],  [0.8880, 0.9755, 1.0630, ..., 0.7129, 0.7129, 0.7304],  [0.8704, 0.9930, 1.0805, ..., 0.6604, 0.6429, 0.6779]],  [[1.0017, 1.1237, 1.1411, ..., 0.7925, 0.7402, 0.7576],  [0.9145, 1.0191, 1.0714, ..., 0.7751, 0.7576, 0.8274],  [0.8274, 0.9145, 0.9842, ..., 0.7925, 0.7925, 0.8797],  ...,  [1.2980, 1.3677, 1.4200, ..., 1.2631, 1.2631, 1.2457],  [1.2631, 1.3502, 1.4374, ..., 1.2282, 1.2108, 1.1934],  [1.2457, 1.3677, 1.4548, ..., 1.1585, 1.1411, 1.1411]]])",482,"np.ndarray(shape=(1000,))",True,5.850358e-06,False,"np.ndarray(shape=(512, 1, 1))","np.ndarray(shape=(2,))","np.ndarray(shape=(19,))"
3,val/n02979186/n02979186_11550.JPEG,cassette player,n02979186,482,valid,,"tensor([[[-0.4739, -0.4397, -0.3198, ..., 2.0434, 2.0263, 1.9920],  [-0.3883, -0.3541, -0.2856, ..., 2.0434, 2.0263, 2.0092],  [-0.3712, -0.3712, -0.3369, ..., 2.0263, 2.0092, 1.9920],  ...,  [-1.2445, -1.2445, -1.2445, ..., -1.2959, -1.2959, -1.2959],  [-1.2959, -1.2959, -1.2959, ..., -1.3987, -1.3987, -1.3987],  [-1.3644, -1.3644, -1.3644, ..., -1.4843, -1.4843, -1.4843]],  [[-1.1253, -1.0728, -0.9678, ..., 0.9755, 0.9405, 0.9055],  [-1.0378, -1.0203, -0.9503, ..., 0.9755, 0.9405, 0.9230],  [-1.0728, -1.0728, -1.0203, ..., 0.9580, 0.9230, 0.9055],  ...,  [-1.8081, -1.8081, -1.8081, ..., -1.7381, -1.7381, -1.7381],  [-1.8256, -1.8256, -1.8256, ..., -1.7731, -1.7731, -1.7731],  [-1.8431, -1.8431, -1.8431, ..., -1.8081, -1.8081, -1.8081]],  [[-1.0898, -1.0376, -0.9330, ..., 0.3045, 0.2871, 0.2522],  [-1.0201, -1.0027, -0.9156, ..., 0.3219, 0.2871, 0.2696],  [-1.0724, -1.0724, -1.0201, ..., 0.3045, 0.2696, 0.2522],  ...,  [-1.7347, -1.7347, -1.7347, ..., -1.5604, -1.5604, -1.5604],  [-1.6824, -1.6824, -1.6824, ..., -1.4733, -1.4733, -1.4733],  [-1.6302, -1.6302, -1.6302, ..., -1.4559, -1.4559, -1.4559]]])",482,"np.ndarray(shape=(1000,))",True,1.493245e-06,False,"np.ndarray(shape=(512, 1, 1))","np.ndarray(shape=(2,))","np.ndarray(shape=(19,))"
4,val/n02979186/n02979186_8751.JPEG,cassette player,n02979186,482,valid,,"tensor([[[ 0.1254, 0.1254, 0.1254, ..., -0.0458, -0.0458, -0.0458],  [ 0.1254, 0.1254, 0.1254, ..., -0.0629, -0.0629, -0.0629],  [ 0.1254, 0.1254, 0.1083, ..., -0.0629, -0.0629, -0.0629],  ...,  [-1.6898, -1.6898, -1.6727, ..., -1.5870, -1.5699, -1.5699],  [-1.7583, -1.7412, -1.7240, ..., -1.5699, -1.5528, -1.5528],  [-1.8097, -1.7754, -1.7412, ..., -1.5699, -1.5528, -1.5528]],  [[-0.3550, -0.3550, -0.3550, ..., -0.4776, -0.4776, -0.4776],  [-0.3550, -0.3550, -0.3550, ..., -0.4776, -0.4776, -0.4776],  [-0.3550, -0.3550, -0.3725, ..., -0.4776, -0.4776, -0.4776],  ...,  [-1.4230, -1.4230, -1.4055, ..., -1.4930, -1.4755, -1.4755],  [-1.4755, -1.4580, -1.4405, ..., -1.5105, -1.4930, -1.4930],  [-1.5280, -1.4930, -1.4580, ..., -1.5455, -1.5280, -1.5280]],  [[-0.6541, -0.6541, -0.6541, ..., -0.7936, -0.7936, -0.7936],  [-0.6541, -0.6541, -0.6541, ..., -0.7761, -0.7761, -0.7761],  [-0.6541, -0.6541, -0.6715, ..., -0.7587, -0.7587, -0.7587],  ...,  [-1.3513, -1.3513, -1.3339, ..., -1.4036, -1.3861, -1.3861],  [-1.4384, -1.4210, -1.3861, ..., -1.4210, -1.4036, -1.4036],  [-1.5081, -1.4733, -1.4384, ..., -1.4384, -1.4210, -1.4210]]])",482,"np.ndarray(shape=(1000,))",True,5.843006e-07,False,"np.ndarray(shape=(512, 1, 1))","np.ndarray(shape=(2,))","np.ndarray(shape=(19,))"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3920,val/n03425413/n03425413_17521.JPEG,gas pump,n03425413,571,valid,,"tensor([[[-1.2445, 0.0227, 0.7762, ..., 1.0159, 1.3070, 1.2899],  [-1.2617, -0.0116, 0.7933, ..., 0.8961, 1.2214, 1.2728],  [-1.2959, -0.0458, 0.7933, ..., 0.7077, 1.0502, 1.2385],  ...,  [ 0.4166, 0.3994, 0.3823, ..., 1.6324, 1.6495, 1.5639],  [ 0.3994, 0.3823, 0.3823, ..., 1.8037, 1.8037, 1.6495],  [ 0.3823, 0.3823, 0.3823, ..., 1.8722, 1.8550, 1.6495]],  [[-1.1253, 0.1527, 0.9230, ..., 1.2731, 1.6232, 1.6057],  [-1.1253, 0.1352, 0.9405, ..., 1.1856, 1.5357, 1.5882],  [-1.1429, 0.1176, 0.9405, ..., 1.0105, 1.3606, 1.5532],  ...,  [ 0.9230, 0.9055, 0.8880, ..., 1.4482, 1.4657, 1.4307],  [ 0.9230, 0.9230, 0.9055, ..., 1.5882, 1.5882, 1.4132],  [ 0.9230, 0.9230, 0.9230, ..., 1.6408, 1.6057, 1.3606]],  [[-0.8458, 0.3916, 1.1237, ..., 1.4897, 1.8557, 1.9951],  [-0.8458, 0.3916, 1.1585, ..., 1.4025, 1.7860, 2.0125],  [-0.8633, 0.3742, 1.1759, ..., 1.2108, 1.6291, 1.9951],  ...,  [ 1.6814, 1.6640, 1.6465, ..., 1.0714, 1.1062, 1.1237],  [ 1.6988, 1.6814, 1.6814, ..., 1.1237, 1.1411, 1.0365],  [ 1.6988, 1.6988, 1.6988, ..., 1.1062, 1.1062, 0.9145]]])",571,"np.ndarray(shape=(1000,))",True,9.950945e-01,True,"np.ndarray(shape=(512, 1, 1))","np.ndarray(shape=(2,))","np.ndarray(shape=(19,))"
3921,val/n03425413/n03425413_20711.JPEG,gas pump,n03425413,571,valid,,"tensor([[[-0.9705, -0.7650, -0.4226, ..., 0.8447, 0.8276, 0.8276],  [-0.1828, 0.0912, 0.5022, ..., 0.8618, 0.8618, 0.8618],  [ 0.3823, 0.4508, 0.5536, ..., 0.8618, 0.8618, 0.8618],  ...,  [-2.0665, -2.1179, -2.1008, ..., -2.1008, -2.1008, -2.1008],  [-1.9980, -2.0323, -2.0494, ..., -2.1008, -2.1008, -2.1008],  [-1.7754, -1.7925, -1.9124, ..., -2.1179, -2.1179, -2.1179]],  [[-1.3880, -1.1779, -0.8102, ..., 0.9230, 0.9055, 0.9055],  [-0.6001, -0.2850, 0.1352, ..., 0.9405, 0.9405, 0.9405],  [-0.0049, 0.0826, 0.2052, ..., 0.9405, 0.9405, 0.9405],  ...,  [-0.2500, -0.1625, -0.1450, ..., -0.2675, -0.2675, -0.2675],  [-0.2850, -0.1800, -0.1625, ..., -0.2675, -0.2675, -0.2675],  [-0.4076, -0.2325, -0.1975, ..., -0.2675, -0.2850, -0.2850]],  [[-1.2641, -1.1247, -0.8284, ..., 0.9494, 0.9319, 0.9319],  [-0.3753, -0.1835, 0.1476, ..., 0.9668, 0.9668, 0.9668],  [ 0.1302, 0.1476, 0.1825, ..., 0.9668, 0.9668, 0.9668],  ...,  [ 0.6531, 0.7576, 0.8274, ..., 0.8622, 0.8448, 0.8274],  [ 0.6008, 0.7576, 0.8448, ..., 0.8448, 0.8274, 0.8099],  [ 0.4788, 0.7576, 0.9145, ..., 0.8099, 0.7925, 0.7751]]])",590,"np.ndarray(shape=(1000,))",False,6.099857e-02,True,"np.ndarray(shape=(512, 1, 1))","np.ndarray(shape=(2,))","np.ndarray(shape=(19,))"
3922,val/n03425413/n03425413_19050.JPEG,gas pump,n03425413,571,valid,,"tensor([[[-0.3883, -0.3198, -0.4568, ..., -1.6384, -1.5014, -1.3302],  [-0.5253, -0.5082, -0.5938, ..., -1.4672, -1.2274, -1.1932],  [-0.8164, -0.8335, -0.7993, ..., -1.1760, -0.8678, -1.0390],  ...,  [-0.0972, -0.0629, -0.0801, ..., 0.8104, 0.8276, 0.8276],  [-0.1999, -0.1999, -0.2856, ..., 0.7762, 0.7933, 0.7933],  [-0.3541, -0.4397, -0.5767, ..., 0.7591, 0.7762, 0.7762]],  [[-0.0049, 0.0651, -0.0749, ..., -1.3880, -1.2304, -1.0728],  [-0.1625, -0.1450, -0.2325, ..., -1.2129, -0.9853, -0.9328],  [-0.4601, -0.4776, -0.4426, ..., -0.9328, -0.6176, -0.7752],  ...,  [ 0.0476, 0.0651, 0.0476, ..., 0.8179, 0.8354, 0.8529],  [-0.0049, -0.0399, -0.1275, ..., 0.7829, 0.8004, 0.8179],  [-0.1099, -0.2325, -0.4076, ..., 0.7654, 0.7829, 0.8004]],  [[-0.4624, -0.3927, -0.5321, ..., -1.6650, -1.5256, -1.4036],  [-0.5844, -0.5670, -0.6541, ..., -1.4384, -1.2293, -1.2467],  [-0.8458, -0.8633, -0.8284, ..., -1.1247, -0.8284, -1.0724],  ...,  [-0.2707, -0.2184, -0.2010, ..., 0.6356, 0.6705, 0.7228],  [-0.3578, -0.3404, -0.3927, ..., 0.6008, 0.6356, 0.6879],  [-0.4798, -0.5844, -0.6890, ..., 0.5834, 0.6182, 0.6705]]])",866,"np.ndarray(shape=(1000,))",False,3.938547e-04,True,"np.ndarray(shape=(512, 1, 1))","np.ndarray(shape=(2,))","np.ndarray(shape=(19,))"
3923,val/n03425413/n03425413_13831.JPEG,gas pump,n03425413,571,valid,,"tensor([[[ 1.0331, 0.9988, 0.9646, ..., 0.4851, 0.7591, 0.9646],  [ 1.0844, 1.0159, 0.9817, ..., 0.4508, 0.7591, 0.9817],  [ 1.1358, 1.0673, 1.0502, ..., 0.3994, 0.7419, 0.9988],  ...,  [-0.4054, 0.6221, 1.5468, ..., 0.3994, 0.7248, 0.9646],  [-0.4054, 0.6392, 1.5639, ..., 0.4337, 0.7248, 0.9474],  [-0.4054, 0.6392, 1.5639, ..., 0.4679, 0.7248, 0.9303]],  [[ 1.3957, 1.3606, 1.3431, ..., -0.1625, -0.0399, 0.0476],  [ 1.4307, 1.3782, 1.3431, ..., -0.1975, -0.0224, 0.0826],  [ 1.4832, 1.4307, 1.3957, ..., -0.2500, -0.0399, 0.1001],  ...,  [-1.4580, -1.2479, -1.1253, ..., -0.1450, -0.2500, -0.3025],  [-1.4405, -1.2479, -1.1078, ..., -0.1099, -0.2325, -0.2850],  [-1.4405, -1.2479, -1.1078, ..., -0.0749, -0.2150, -0.2850]],  [[ 1.8731, 1.8034, 1.7337, ..., -0.3753, -0.3055, -0.3230],  [ 1.8383, 1.7685, 1.6814, ..., -0.3753, -0.2881, -0.2532],  [ 1.8208, 1.7337, 1.6640, ..., -0.4101, -0.2881, -0.2184],  ...,  [-1.3164, -0.9853, -0.7238, ..., -0.3404, -0.4101, -0.4798],  [-1.3164, -0.9853, -0.7064, ..., -0.3055, -0.4101, -0.4798],  [-1.3164, -0.9853, -0.7064, ..., -0.2707, -0.3927, -0.4798]]])",571,"np.ndarray(shape=(1000,))",True,4.792972e-01,True,"np.ndarray(shape=(512, 1, 1))","np.ndarray(shape=(2,))","np.ndarray(shape=(19,))"
