We use the official implementation of ICE available here: https://github.com/zhangrh93/InvertibleCE

In [1]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt  
import os
import torchvision
import numpy as np
import os
from pathlib import Path
from torchsummary import summary
import shutil
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
model = torchvision.models.resnet34(pretrained=True)
model = model.eval().to(device)
g = nn.Sequential(*(list(model.children())[:-2]))  # All layers except the last two
h = lambda x: model.fc(torch.mean(x, (2, 3))) 
h_2d = lambda x: model.fc(x)

In [5]:
dog_list = os.listdir('dataset/ChurchFolder')

In [6]:
from PIL import Image

from torch.utils.data import Dataset


class ImageDataset(Dataset):
    def __init__(self, imgs, transform=None):
        self.imgs = imgs
        self.transform = transform

    def __getitem__(self, index):
        img_path, label = self.imgs[index]
        img = Image.open(img_path).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
            img = img.permute(1,2,0)
        return img,label

    def __len__(self):
        return len(self.imgs)

### Parameters and Methods initialization

Stanford Dogs Dataset is a subset of the ImageNet. The only thing we need to do is to connect the classes from Stanford Dogs Dataset to ImageNet by class ID.

In [7]:
ImageNet_ID2idx = {}
ImageNet_idx2ID = []
with open('synset_words.txt','r') as f:
    for i,line in enumerate(f):
        tsplit = line.split('\n')[0].split(' ')
        ID = tsplit[0]
        name = ' '.join(tsplit[1:])
        #print(ID)

        ImageNet_ID2idx[ID] = i
        ImageNet_idx2ID.append(ID)

fpath = Path('dataset') / 'ChurchFolder'
dog_paths = {d:fpath / d for d in dog_list}

dog_dict = {}
fpath = Path('dataset') / 'ChurchFolder'
dog_name2idx = {}

In [8]:
for i,k in enumerate(dog_paths.keys()):
  ID = k.split('-')[0]
  name = k.split('-')[1]
  if ID in ImageNet_idx2ID:
    dog_dict[ImageNet_ID2idx[ID]] = {'idx':ImageNet_ID2idx[ID],'ID':ID,'name':name,'path':fpath / k}
    dog_name2idx[k] = ImageNet_ID2idx[ID]

In [9]:
dog_dict

{497: {'idx': 497,
  'ID': 'n03028079',
  'name': 'church',
  'path': PosixPath('dataset/ChurchFolder/n03028079-church')}}

In [10]:
paras = {}
target_class = 497
paras['target_classes'] = [497]
paras['classes_names'] = ['church']
paras['n_components'] = 25
paras['layer_name'] = 'layer4'
paras['title'] = "church"
paras['overwrite'] = False

In [11]:
def make_imgs(paths = [],labels = []):
    imgs = []
    for i,path in enumerate(paths):
        imgs += [(os.path.join(path,t),labels[i]) for t in os.listdir(path)]
    return imgs

In [12]:
def get_loaders(target_classes = []):
  loaders = []
  for i,idx in enumerate(target_classes):
    loaders.append(DataLoader(
            ImageDataset(make_imgs([dog_dict[idx]['path']],[i]),
                    transforms.Compose([
                    transforms.Resize((224,224)),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                            std=[0.229, 0.224, 0.225]),
                    ])),
            batch_size = 16,
            num_workers=2,
            #shuffle=True
    ))
  return loaders

In [13]:
import numpy as np
import os

import torch
from torch.utils.data import TensorDataset,DataLoader

class ModelWrapper:
    def __init__(self,model,batch_size = 128):
        self.model = model
        self.batch_size = batch_size

    def get_feature(self,x,layer_name):
        '''
        get feature map from a given layer
        '''
        pass
    def feature_predict(self,feature,layer_name = None):
        '''
        prediction from given feature maps
        '''
        pass  
    def predict(self,x):
        '''
        provide prediction from given feature
        '''
        pass
    

class PytorchModelWrapper(ModelWrapper):   
    def __init__(self,
                 model,
                 layer_dict = {},
                 predict_target = None,
                 input_channel_first = False, # True if input image is channel first
                 model_channel_first = True, #True if model use channel first
                 #switch_channel = None, #"f_to_l" or "l_to_f" if switch channel is required from loader to model
                 numpy_out = True,
                 input_size = [3,224,224], #model's input size
                 batch_size=128):#target: (layer_name,unit_nums)

        super().__init__(model,batch_size)
        self.layer_dict = layer_dict
        self.layer_dict.update(dict(model.named_children()))
        self.predict_target = predict_target
        self.input_channel = 'f' if input_channel_first else 'l'
        self.model_channel = 'f' if model_channel_first else 'l'
        self.numpy_out = numpy_out
        self.input_size = list(input_size)
        
        self.non_negative = False

        self.CUDA = torch.cuda.is_available()

    def _to_tensor(self,x):
        if type(x) == np.ndarray:
            x = torch.from_numpy(x)
        x = torch.clone(x)
        if x.ndim == 3:
            x = x.unsqueeze(0)
        return x

    def _switch_channel_f_to_l(self,x): #transform from channel first to channel last
        if x.ndim == 3:
            x = x.permute(1,2,0)
        if x.ndim == 4:
            x = x.permute(0,2,3,1)

        return x

    def _switch_channel_l_to_f(self,x): #transform from channel last to channel first
        if x.ndim == 3:
            x = x.permute(2,0,1)
        if x.ndim == 4:
            x = x.permute(0,3,1,2)

        return x

    def _switch_channel(self,x,layer_in='input',layer_out='output',to_model=True):
        c_from = None
        c_to = None
        if to_model:
            c_from = self.input_channel if layer_in == 'input' else 'l'            
            c_to = self.model_channel
        else:
            c_from = self.model_channel
            c_to = 'l'

        #print (x.shape,c_from,c_to,layer_in,layer_out,to_model)

        if c_from == 'f' and c_to == 'l':
            x = self._switch_channel_f_to_l(x)
        if c_from == 'l' and c_to == 'f':
            x = self._switch_channel_l_to_f(x)
        return x


    def _fun(self,x,layer_in = "input",layer_out = "output"):
        #tensor cpu in cpu out

        x = x.type(torch.FloatTensor)


        in_flag = False
        if layer_in == "input":
            in_flag = True
        

        data_in = x.clone()
        if self.CUDA:
            data_in = data_in.cuda()
        data_out = []

        handles = []
        
        def hook_in(m,i,o):
            return data_in
        def hook_out(m,i,o):
            data_out.append(o)

        if layer_in == "input":
            nx = x
        else:
            handles.append(self.layer_dict[layer_in].register_forward_hook(hook_in))
            nx = torch.zeros([x.size()[0]]+self.input_size)

        if not layer_out == "output":
            handles.append(self.layer_dict[layer_out].register_forward_hook(hook_out))

        if self.CUDA:
            nx = nx.cuda()
            
        with torch.no_grad():
            ny = self.model(nx)

        #print(data_out)

        if layer_out == "output":
            data_out = ny
        else:
            data_out = data_out[0]

        data_out = data_out.cpu()

        for handle in handles:
            handle.remove() 

        if self.non_negative:
            data_out = torch.relu(data_out)

        return data_out

    def _batch_fn(self,x,layer_in = "input",layer_out = "output"):
        #numpy in numpy out
                
        if type(x) == torch.Tensor or type(x) == np.ndarray:
            x = self._to_tensor(x)

            dataset = TensorDataset(x)
            x = DataLoader(dataset,batch_size=self.batch_size)
        

        out = []

        for nx in x:
            nx = nx[0]
            nx = self._switch_channel(nx,layer_in=layer_in,layer_out=layer_out,to_model=True)
            out.append(self._fun(nx,layer_in,layer_out))

        res = torch.cat(out,0)

        res = self._switch_channel(res,layer_in=layer_in,layer_out=layer_out,to_model=False)
        if self.numpy_out:
            res = res.detach().numpy()


        return res

    def set_predict_target(self,predict_target):
        self.predict_target = predict_target
    
    def get_feature(self,x,layer_name):
        if layer_name not in self.layer_dict:
            print ("Target layer not exists")
            return None

        out = self._batch_fn(x,layer_out = layer_name)

        return out

    def feature_predict(self,feature,layer_name = None):

        if layer_name not in self.layer_dict:
            print ("Target layer not exists")
            return None

        out = self._batch_fn(feature,layer_in = layer_name)
        if self.predict_target is not None:
            out = out[:,self.predict_target]
        return out        


    def predict(self,x):

        out = self._batch_fn(x)
        if self.predict_target is not None:
            out = out[:,self.predict_target]

        return out

In [14]:
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
from skimage.transform import resize


#npdir = '/dataset/ILSVRC2012/nparray_train'
#imdir = '/dataset/ILSVRC2012/ILSVRC2012_img_train'

mean = [103.939, 116.779, 123.68]
SIZE = [224,224]
EPSILON = 1e-8

class img_utils():

    def __init__(self,
                img_size = (224,224),
                nchannels = 3,
                img_format = "channels_last",
                mode = None,
                std = None,
                mean = None):
        self.img_format = img_format
        self.nchannels = nchannels
        self.fsize = list(img_size)
        self.img_size = self.fsize + [self.nchannels]
        #if img_format == 'channels_first':
        #    self.img_size = [self.nchannels] + self.fsize
        #else:
        #    self.img_size = self.fsize + [self.nchannels]
        
        self.std = std
        self.mean = mean
        self.mode = mode


    def deprocessing(self,x):
        mode = self.mode
        x = np.array(x)
        X = x.copy()

        if self.img_format == "channels_first":
            if X.ndim == 3:
                X = np.transpose(X,(1,2,0))
            else:
                X = np.transpose(X,(0,2,3,1))

        if mode is None:
            return X

        if mode == "tf":
            X +=1
            X *=127.5
            return X

        if mode == "torch":
            mean = [0.485, 0.456, 0.406]
            std = [0.229, 0.224, 0.225]

        if mode == 'caffe':
            mean = [103.939, 116.779, 123.68]
            std = None            

        if mode == "norm":
            mean = self.mean
            std = self.std


        if std is not None:
            X[..., 0] *= std[0]
            X[..., 1] *= std[1]
            X[..., 2] *= std[2]        
        X[..., 0] += mean[0]
        X[..., 1] += mean[1]
        X[..., 2] += mean[2]   
            
        if mode == 'caffe':
            # 'RGB'->'BGR'
            X = X[..., ::-1]  
        
        if mode == "torch":
            X *= 255
        return X
    
    def resize_img(self,array, smooth = False):
        fsize = self.fsize
        size = array.shape
        if smooth:            
            tsize = list(size)
            tsize[1] = fsize[0]
            tsize[2] = fsize[1]
            res = resize(array, tsize, order=1, mode='reflect',anti_aliasing=False)
        else:
            res = []
            for i in range(size[0]):
                temp = array[i]
                temp = np.repeat(temp,fsize[0]//size[1],axis = 0)
                temp = np.repeat(temp,fsize[1]//size[2],axis = 1)
                res.append(temp)
            res = np.array(res)
        return res
    

    def flatten(self,array):
        size = array.shape
        return array.reshape(-1,size[-1])  

    def show_img(self,X,nrows=1,ncols=1,heatmaps = None,useColorBar = True, deprocessing = True, save_path = None):
        X = np.array(X)
        if not heatmaps is None:
            heatmaps = np.array(heatmaps)
        if len(X.shape)<4:
            print ("Dim should be 4")
            return 

        X = np.array(X)
        if deprocessing:
            X = self.deprocessing(X)

        if (not X.min() == 0) or X.max()>1:
            X = X - X.min()
            X = X / X.max()

        if X.shape[0] == 1:
            X = np.squeeze(X)
            X = np.expand_dims(X,axis=0)
        else:
            X = np.squeeze(X)

        if self.nchannels == 1:
            cmap = "Greys"
        else:
            cmap = "viridis"
        
        l = nrows*ncols
        plt.figure(figsize=(5*ncols, 5*nrows))
        for i in range(l):
            plt.subplot(nrows, ncols, i+1)
            plt.axis('off')
            img = X[i]
            img = np.clip(img,0,1)
            plt.imshow(img,cmap = cmap)
            if not heatmaps is None:
                if not heatmaps[i] is None:
                    heapmap = heatmaps[i]
                    plt.imshow(heapmap, cmap='jet', alpha=0.5,interpolation = "bilinear")
                    if useColorBar:
                        plt.colorbar()

        if save_path is not None:
            plt.savefig(save_path)
        plt.show()   
        

    def img_filter(self,x,h,threshold=0.5,background = 0.2,smooth = True, minmax = False):
        x = x.copy()
        h = h.copy()

        if minmax:
            h = h - h.min()

        h = h * (h>0)
        for i in range(h.shape[0]):
            h[i] = h[i] / (h[i].max() + EPSILON)

        h = (h - threshold) * (1/(1-threshold))
        
        #h = h * (h>0)

        h = self.resize_img(h,smooth = smooth)
        h = ((h>0).astype("float")*(1-background) + background)
        h_mask = np.repeat(h,self.nchannels).reshape(list(h.shape)+[-1])
        if self.img_format == 'channels_first':
            h_mask = np.transpose(h_mask,(0,3,1,2))
        x = x * h_mask
        
        h = h - h.min()
        h = h / (h.max() + EPSILON)
        
        return x,h


    def contour_img(self,x,h,dpi = 100):
        dpi = float(dpi)
        size = x.shape
        if x.max()>1:
            x = x/x.max()
        #fig = plt.figure(figsize=(size[1]/dpi,size[0]/dpi),dpi=dpi)
        fig = plt.figure(figsize=(size[1]/dpi,size[0]/dpi),dpi=dpi*SIZE[0]/size[0])
        ax = plt.Axes(fig, [0., 0., 1., 1.])
        ax.set_axis_off()
        fig.add_axes(ax)
        xa = np.linspace(0, size[1]-1, size[1])
        ya = np.linspace(0, size[0]-1, size[0])
        X, Y = np.meshgrid(xa, ya)
        if x.shape[-1] == 1:
            x = np.squeeze(x)
            ax.imshow(x,cmap = "Greys")
        else:
            ax.imshow(x)
        ax.contour(X, Y, h,colors = 'r')
        return fig

    def res_ana(self,model,classesLoader,classNos,reducer, layer_name = "conv5_block3_out"):

        w,b = model.model.layers[-1].get_weights()
        V_ = reducer._reducer.components_
        w_ = np.dot(V_,w)
        ana = []
        
        for No in classNos:
            target = No
            tX, ty = classesLoader.load_val([No])
            tX = np.concatenate(tX)
            ty = np.concatenate(ty).astype(int)   
            X = model.get_feature(tX,layer_name=layer_name).mean(axis = (1,2))
            S = reducer.transform(X)
            U = X-np.dot(S,V_)

            C1 = np.dot(X,w[:,target])
            C2 = np.dot(S,w_[:,target]) + np.dot(U,w[:,target])

            C = np.dot(S,w_[:,target])
            res = np.dot(U,w[:,target])

            ana.append(np.array([C1,C2,C,res]))
        return [np.array(ana),reducer]

In [15]:
# Copyright 2018 The Lucid Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Helper for using sklearn.decomposition on high-dimensional tensors.
Provides ChannelReducer, a wrapper around sklearn.decomposition to help them
apply to arbitrary rank tensors. It saves lots of annoying reshaping.
"""

import numpy as np
import sklearn.decomposition
import sklearn.cluster


from sklearn.base import BaseEstimator 

ALGORITHM_NAMES = {}
for name in dir(sklearn.decomposition):
    obj = sklearn.decomposition.__getattribute__(name)
    if isinstance(obj, type) and issubclass(obj, BaseEstimator):
        ALGORITHM_NAMES[name] = 'decomposition'
for name in dir(sklearn.cluster):
    obj = sklearn.cluster.__getattribute__(name)
    if isinstance(obj, type) and issubclass(obj, BaseEstimator):
        ALGORITHM_NAMES[name] = 'cluster'


class ChannelDecompositionReducer(object):

    def __init__(self, n_components=3, reduction_alg="NMF", **kwargs):

        if not isinstance(n_components, int):
            raise ValueError("n_components must be an int, not '%s'." % n_components)

        # Defensively look up reduction_alg if it is a string and give useful errors.
        algorithm_map = {}
        for name in dir(sklearn.decomposition):
            obj = sklearn.decomposition.__getattribute__(name)
            if isinstance(obj, type) and issubclass(obj, BaseEstimator):
                algorithm_map[name] = obj
        if isinstance(reduction_alg, str):
            if reduction_alg in algorithm_map:
                reduction_alg = algorithm_map[reduction_alg]
            else:
                raise ValueError("Unknown dimensionality reduction method '%s'." % reduction_alg)


        self.n_components = n_components
        self._reducer = reduction_alg(n_components=n_components, **kwargs)
        self._is_fit = False

    def _apply_flat(cls, f, acts):
        orig_shape = acts.shape
        acts_flat = acts.reshape([-1, acts.shape[-1]])
        new_flat = f(acts_flat)
        if not isinstance(new_flat, np.ndarray):
            return new_flat
        shape = list(orig_shape[:-1]) + [-1]
        return new_flat.reshape(shape)

    def fit(self, acts):
        if hasattr(self._reducer,'partial_fit'):
            res = self._apply_flat(self._reducer.partial_fit, acts)
        else:
            res = self._apply_flat(self._reducer.fit, acts)
        self._is_fit = True
        return res

    def fit_transform(self, acts):
        res = self._apply_flat(self._reducer.fit_transform, acts)
        self._is_fit = True
        return res

    def transform(self, acts):
        res = self._apply_flat(self._reducer.transform, acts)
        return res

    def inverse_transform(self, acts):
        if hasattr(self._reducer,'inverse_transform'):
            res = self._apply_flat(self._reducer.inverse_transform, acts)
        else:
            res = np.dot(acts,self._reducer.components_)
        return res


class ChannelClusterReducer(object):

    def __init__(self, n_components=3, reduction_alg="KMeans", **kwargs):


        if not isinstance(n_components, int):
            raise ValueError("n_components must be an int, not '%s'." % n_components)

        # Defensively look up reduction_alg if it is a string and give useful errors.
        algorithm_map = {}
        for name in dir(sklearn.cluster):
            obj = sklearn.cluster.__getattribute__(name)
            if isinstance(obj, type) and issubclass(obj, BaseEstimator):
                algorithm_map[name] = obj
        if isinstance(reduction_alg, str):
            if reduction_alg in algorithm_map:
                reduction_alg = algorithm_map[reduction_alg]
            else:
                raise ValueError("Unknown dimensionality reduction method '%s'." % reduction_alg)


        self.n_components = n_components
        self._reducer = reduction_alg(n_clusters=n_components, **kwargs)
        self._is_fit = False

    def _apply_flat(self, f, acts):
        """Utility for applying f to inner dimension of acts.
        Flattens acts into a 2D tensor, applies f, then unflattens so that all
        dimesnions except innermost are unchanged.
        """
        orig_shape = acts.shape
        acts_flat = acts.reshape([-1, acts.shape[-1]])
        new_flat = f(acts_flat)
        if not isinstance(new_flat, np.ndarray):
            return new_flat
        shape = list(orig_shape[:-1]) + [-1]
        new_flat = new_flat.reshape(shape)


        if new_flat.shape[-1] == 1:
            new_flat = new_flat.reshape(-1)
            t_flat = np.zeros([new_flat.shape[0],self.n_components])
            t_flat[np.arange(new_flat.shape[0]),new_flat] = 1
            new_flat = t_flat.reshape(shape)

        return new_flat

    def fit(self, acts):
        if hasattr(self._reducer,'partial_fit'):
            res = self._apply_flat(self._reducer.partial_fit, acts)
        else:
            res = self._apply_flat(self._reducer.fit, acts)
        self._reducer.components_ = self._reducer.cluster_centers_
        self._is_fit = True
        return res

    def fit_predict(self, acts):
        res = self._apply_flat(self._reducer.fit_predict, acts)
        self._reducer.components_ = self._reducer.cluster_centers_
        self._is_fit = True
        return res

    def transform(self, acts):
        res = self._apply_flat(self._reducer.predict, acts)
        return res

    def inverse_transform(self, acts):
        res = np.dot(acts,self._reducer.components_)
        return res


In [16]:
import os
from pathlib import Path
import pickle

import pydotplus

import numpy as np
import matplotlib.pyplot as plt
import time


FONT_SIZE = 30
#CALC_LIMIT = 3e4
CALC_LIMIT = 1e8
#CALC_LIMIT = 1e9
TRAIN_LIMIT = 50
REDUCER_PATH = "reducer/resnet50"
USE_TRAINED_REDUCER = False
ESTIMATE_NUM = 10

class Explainer():
    def __init__(self,
                 title = "",
                 layer_name = "",
                 class_names = None,
                 utils = None,
                 keep_feature_images = True,
                 useMean = True,
                 reducer_type = "NMF",
                 n_components = 10,
                 featuretopk = 20,
                 featureimgtopk = 5,
                 epsilon = 1e-4):
        self.title = title
        self.layer_name = layer_name
        self.class_names = class_names
        self.class_nos = len(class_names) if class_names is not None else 0

        self.keep_feature_images = keep_feature_images
        self.useMean = useMean
        self.reducer_type = reducer_type
        self.featuretopk = featuretopk
        self.featureimgtopk = featureimgtopk #number of images for a feature
        self.n_components = n_components
        self.epsilon = epsilon
        
        self.utils = utils

        self.reducer = None
        self.feature_distribution = None

        self.feature_base = []
        self.features = {}

        self.exp_location = Path('Explainers')

        self.font = FONT_SIZE
        
    def load(self):
        title = self.title
        with open(self.exp_location / title / (title+".pickle"),"rb") as f:
            tdict = pickle.load(f)
            self.__dict__.update(tdict)
            
    def save(self):
        if not os.path.exists(self.exp_location):
            os.mkdir(self.exp_location)
        title = self.title
        if not os.path.exists(self.exp_location / title):
            os.mkdir(self.exp_location / title)
        with open(self.exp_location / title / (title+".pickle"),"wb") as f:
            pickle.dump(self.__dict__,f)

    def train_model(self,model,loaders):
        self._train_reducer(model,loaders)
        self._estimate_weight(model,loaders)

    def _train_reducer(self,model,loaders):

        print ("Training reducer:")

        if self.reducer is None:
            if not self.reducer_type in ALGORITHM_NAMES:
                print ('reducer not exist')
                return 

            if ALGORITHM_NAMES[self.reducer_type] == 'decomposition':
                self.reducer = ChannelDecompositionReducer(n_components = self.n_components,reduction_alg = self.reducer_type)
            else:
                self.reducer = ChannelClusterReducer(n_components = self.n_components,reduction_alg = self.reducer_type)
        
        X_features = []
        for loader in loaders:
            X_features.append(model.get_feature(loader,self.layer_name))
        print ('1/5 Featuer maps gathered.')

        if not self.reducer._is_fit:
            nX_feature = np.concatenate(X_features)
            total = np.product(nX_feature.shape)
            l = nX_feature.shape[0]
            if total > CALC_LIMIT:
                p = CALC_LIMIT / total
                print ("dataset too big, train with {:.2f} instances".format(p))
                idx = np.random.choice(l,int(l*p),replace = False)
                nX_feature = nX_feature[idx]

            print ("loading complete, with size of {}".format(nX_feature.shape))
            start_time = time.time()
            nX = self.reducer.fit_transform(nX_feature)

            print ("2/5 Reducer trained, spent {} s.".format(time.time()-start_time))
        

        self.cavs = self.reducer._reducer.components_
        nX = nX.mean(axis = (1,2))
        self.feature_distribution = {"overall":[(nX[:,i].mean(),nX[:,i].std(),nX[:,i].min(),nX[:,i].max()) for i in range(self.n_components)]}

        reX = []
        self.feature_distribution['classes'] = []
        for X_feature in X_features:
            t_feature = self.reducer.transform(X_feature)
            pred_feature = self._feature_filter(t_feature)
            self.feature_distribution['classes'].append([pred_feature.mean(axis=0),pred_feature.std(axis=0),pred_feature.min(axis=0),pred_feature.max(axis=0)])
            reX.append(self.reducer.inverse_transform(t_feature))

        err = []
        for i in range(len(self.class_names)):
            res_true = model.feature_predict(X_features[i],layer_name=self.layer_name)[:,i] #
            res_recon = model.feature_predict(reX[i],layer_name=self.layer_name)[:,i] #
            err.append(abs(res_true-res_recon).mean(axis=0) / (abs(res_true.mean(axis=0))+self.epsilon))


        self.reducer_err = np.array(err)
        if type(self.reducer_err) is not np.ndarray:
            self.reducer_err = np.array([self.reducer_err])

        print ("3/5 Error estimated, fidelity: {}.".format(self.reducer_err))

        return self.reducer_err

    def _estimate_weight(self,model,loaders):
        if self.reducer is None:
            return

        X_features = []

        for loader in loaders:
            X_features.append(model.get_feature(loader,self.layer_name)[:ESTIMATE_NUM])
        X_feature = np.concatenate(X_features)

        print ('4/5 Weight estimator initialized.')
        
        self.test_weight = []
        for i in range(self.n_components):
            cav = self.cavs[i,:]

            res1 =  model.feature_predict(X_feature - self.epsilon * cav,layer_name=self.layer_name)
            res2 =  model.feature_predict(X_feature + self.epsilon * cav,layer_name=self.layer_name)

            res_dif = res2 - res1
            dif = res_dif.mean(axis=0) / (2 * self.epsilon)
            if type(dif) is not np.ndarray:
                dif = np.array([dif])
            self.test_weight.append(dif)
        
        print ('5/5 Weight estimated.')

        self.test_weight = np.array(self.test_weight)

    def generate_features(self,model,loaders):
        self._visulize_features(model,loaders)
        self._save_features()
        if self.keep_feature_images == False:
            self.features = {}
        return

    def _feature_filter(self,featureMaps,threshold = None):
    #filter feature map to feature value with threshold for target value
        if self.useMean:
            res = featureMaps.mean(axis = (1,2))
        else:
            res = featureMaps.max(axis=(1,2))
        if threshold is not None:
            res = - abs(res - threshold) 
        return res
   
    def _update_feature_dict(self,x,h,nx,nh,threshold = None):

        if type(x) == type(None):
            return nx,nh
        else:
            x = np.concatenate([x,nx])
            h = np.concatenate([h,nh])

            nidx = self._feature_filter(h,threshold = threshold).argsort()[-self.featureimgtopk:]
            x = x[nidx,...]
            h = h[nidx,...]
            return x,h

    def save_concept_crops(self, concept_id, images, heatmaps, save_dir, threshold=0.5):
        """
        Save cropped image regions corresponding to a concept's activation.
        Args:
            concept_id (int): Concept index.
            images (numpy.ndarray): Shape [N, H, W, C], in [0,1] or normalized.
            heatmaps (numpy.ndarray): Shape [N, H_small, W_small], one per image.
            save_dir (Path): Directory where crops will be saved.
            threshold (float): Threshold for activation masking.
        """

        concept_dir = save_dir / f"concept_{concept_id}"
        concept_dir.mkdir(parents=True, exist_ok=True)

        for i in range(len(images)):
            image = images[i]
            heatmap = heatmaps[i]

            # Resize heatmap to match image resolution
            heatmap_resized = self.utils.resize_img(np.expand_dims(heatmap, 0), smooth=True)[0]  # [H, W]

            # Normalize heatmap to [0, 1] just in case
            heatmap_resized -= heatmap_resized.min()
            heatmap_resized /= (heatmap_resized.max() + 1e-8)

            # Threshold to get binary mask
            binary_mask = (heatmap_resized > threshold).astype(np.uint8)

            # Find bounding box
            ys, xs = np.where(binary_mask == 1)
            if len(xs) == 0 or len(ys) == 0:
                continue  # skip images with no activated region

            x_min, x_max = xs.min(), xs.max()
            y_min, y_max = ys.min(), ys.max()

            # Sanity clamp (optional)
            if x_max - x_min < 5 or y_max - y_min < 5:
                continue  # skip too small crops

            # Clip coordinates to image bounds
            h_img, w_img = image.shape[:2]
            x_min = max(x_min, 0)
            y_min = max(y_min, 0)
            x_max = min(x_max, w_img - 1)
            y_max = min(y_max, h_img - 1)

            # Crop and save
            crop = image[y_min:y_max+1, x_min:x_max+1, :]  # [H, W, C]
            crop = np.clip(crop, 0, 1)  # Make sure values are in [0, 1]
            # Convert to numpy if needed
            if isinstance(crop, torch.Tensor):
                crop = crop.detach().cpu().numpy()

            # Convert to uint8 image and save
            crop_img = Image.fromarray((crop * 255).astype(np.uint8))
            crop_img.save(concept_dir / f"crop_{i}.png")
        
    def _visulize_features(self,model,loaders, featureIdx = None, inter_dict = None):
        featuretopk = min(self.featuretopk, self.n_components)

        imgTopk = self.featureimgtopk
        if featureIdx is None:
            featureIdx = []
            tidx = []
            w = self.test_weight
            for i,_ in enumerate(self.class_names):
                tw = w[:,i]
                tidx += tw.argsort()[::-1][:featuretopk].tolist()
            featureIdx += list(set(tidx))                    

        nowIdx = set(self.features.keys())
        featureIdx = list(set(featureIdx) - nowIdx)
        featureIdx.sort()

        if len(featureIdx) == 0:
            print ("All feature gathered")
            return

        print ("visulizing features:")
        print (featureIdx)

        features = {}
        for No in featureIdx:
            features[No] = [None,None]
            
        if inter_dict is not None:
            for k in inter_dict.keys():
                inter_dict[k] = [[None,None] for No in featureIdx]
        
        print ("loading training data")
        for i,loader in enumerate(loaders):
            
            for X in loader:
                X = X[0]
                featureMaps = self.reducer.transform(model.get_feature(X,self.layer_name))
                
                X_feature = self._feature_filter(featureMaps)

                for No in featureIdx:
                    samples,heatmap = features[No]
                    idx = X_feature[:,No].argsort()[-imgTopk:]
                    
                    nheatmap = featureMaps[idx,:,:,No]
                    nsamples = X[idx,...]
                    
                    samples,heatmap = self._update_feature_dict(samples,heatmap,nsamples,nheatmap)
                    
                    features[No] = [samples,heatmap]
                    self.save_concept_crops(
                                        concept_id=No,
                                        images=samples,
                                        heatmaps=heatmap,
                                        save_dir=self.exp_location / self.title / "concept_crops",
                                        threshold=0.5
                                    )
                    
                    if inter_dict is not None:
                        for k in inter_dict.keys():
                            vmin = self.feature_distribution['overall'][No][2]
                            vmax = self.feature_distribution['overall'][No][3]
                            temp_v = (vmax - vmin) * k + vmin
                            inter_dict[k][No] = self._update_feature_dict(inter_dict[k][No][0],inter_dict[k][No][1],X,featureMaps[:,:,:,No],threshold = temp_v)
                        
                    
            print ("Done with class: {}, {}/{}".format(self.class_names[i],i+1,len(loaders)))
        # create repeat prototypes in case lack of samples
        for no,(x,h) in features.items():
            idx = h.mean(axis = (1,2)).argmax()
            for i in range(h.shape[0]):
                if h[i].max() == 0:
                    x[i] = x[idx]
                    h[i] = h[idx]
        
        self.features.update(features)
        self.save()
        return inter_dict

    def _save_features(self,threshold=0.5,background = 0.2,smooth = True):
        feature_path = self.exp_location / self.title / "feature_imgs"
        #utils = self.utils

        if not os.path.exists(feature_path):
            os.mkdir(feature_path)

        for idx in self.features.keys(): 

            x,h = self.features[idx]
            #x = self.gen_masked_imgs(x,h,threshold=threshold,background = background,smooth = smooth)
            minmax = False
            if self.reducer_type == 'PCA':
                minmax = True
            x,h = self.utils.img_filter(x,h,threshold=threshold,background = background,smooth = smooth,minmax = minmax)
            
            nsize = self.utils.img_size.copy()
            nsize[1] = nsize[1]* self.featureimgtopk
            nimg = np.zeros(nsize)
            nh = np.zeros(nsize[:-1])
            for i in range(x.shape[0]):
                timg = self.utils.deprocessing(x[i])
                if timg.max()>1:
                    timg = timg / 255.0
                    timg = abs(timg)
                timg = np.clip(timg,0,1)
                nimg[:,i*self.utils.img_size[1]:(i+1)*self.utils.img_size[1],:] = timg
                nh[:,i*self.utils.img_size[1]:(i+1)*self.utils.img_size[1]] = h[i]
            fig = self.utils.contour_img(nimg,nh)
            fig.savefig(feature_path / (str(idx)+".jpg"),bbox_inches='tight',pad_inches=0)
            plt.close(fig)

    def global_explanations(self):        
        title = self.title
        fpath = (self.exp_location / self.title / "feature_imgs").absolute()
        feature_topk = min(self.featuretopk,self.n_components)
        feature_weight = self.test_weight
        class_names = self.class_names
        Nos = range(self.class_nos)

        font = self.font
        
        def LR_graph(wlist,No):
            def node_string(count,fidx,w,No):
                nodestr = ""
                nodestr += "{} [label=< <table border=\"0\">".format(count)

                nodestr+="<tr>"
                nodestr+="<td><img src= \"{}\" /></td>".format(str(fpath / ("{}.jpg".format(fidx)))) 
                nodestr+="</tr>"


                #nodestr +="<tr><td><FONT POINT-SIZE=\"{}\"> ClassName: {} </FONT></td></tr>".format(font,classes.No2Name[No])
                nodestr +="<tr><td><FONT POINT-SIZE=\"{}\"> FeatureRank: {} </FONT></td></tr>".format(font,count)

                nodestr +="<tr><td><FONT POINT-SIZE=\"{}\"> Feature: {}, Weight: {:.3f} </FONT></td></tr>".format(font,fidx,w)

                nodestr += "</table>  >];\n" 
                return nodestr

            resstr = "digraph Tree {node [shape=box] ;rankdir = LR;\n"


            count = len(wlist)
            for k,v in wlist:
                resstr+=node_string(count,k,v,No)
                count-=1
            
            resstr += "0 [label=< <table border=\"0\">" 
            resstr += "<tr><td><FONT POINT-SIZE=\"{}\"> ClassName: {} </FONT></td></tr>" .format(font,class_names[No])
            resstr += "<tr><td><FONT POINT-SIZE=\"{}\"> Fidelity error: {:.3f} % </FONT></td></tr>" .format(font,self.reducer_err[No]*100)
            resstr += "<tr><td><FONT POINT-SIZE=\"{}\"> First {} features out of {} </FONT></td></tr>" .format(font,feature_topk,self.n_components)
            resstr += "</table>  >];\n"
            

            resstr += "}"

            return resstr

        if not os.path.exists(self.exp_location / title / "GE"):
            os.mkdir(self.exp_location / title / "GE")
                    
        print ("Generate explanations with fullset condition")

        for i in Nos:
            wlist = [(j,feature_weight[j][i]) for j in feature_weight[:,i].argsort()[-feature_topk:]]
            graph = pydotplus.graph_from_dot_data(LR_graph(wlist,i))  
            graph.write_jpg(str(self.exp_location / title / "GE" / ("{}.jpg".format(class_names[i]))))
 
    def local_explanations(self,x,model,background = 0.2,name = None,with_total = True,display_value = True):
        utils = self.utils
        font = self.font
        featuretopk = min(self.featuretopk,self.n_components)

        target_classes = list(range(self.class_nos))
        w = self.test_weight

        pred = model.predict(x)[0][target_classes]
        
        fpath = self.exp_location / self.title / "explanations"

        if not os.path.exists(fpath):
            os.mkdir(fpath)

        afpath = fpath / "all"

        if not os.path.exists(afpath):
            os.mkdir(afpath)


        if name is not None:
            fpath = fpath / name
            if not os.path.exists(fpath):
                os.mkdir(fpath)
            else:
                print ("Folder exists")
                return 
        else:
            count = 0
            while os.path.exists(fpath / str(count)):
                count+=1
            fpath = fpath / str(count)
            os.mkdir(fpath)
            name = str(count)


        if self.reducer is not None:
            h = self.reducer.transform(model.get_feature(x,self.layer_name))[0]
        else:
            h = model.get_feature(x,self.layer_name)[0]

        feature_idx = []
        for cidx in target_classes:
            tw = w[:,cidx]
            tw_idx = tw.argsort()[::-1][:featuretopk]
            feature_idx.append(tw_idx)
        feature_idx = list(set(np.concatenate(feature_idx).tolist()))

        

        for k in feature_idx:
            
            minmax = False
            if self.reducer_type == "PCA":
                minmax = True

            x1,h1 = utils.img_filter(x,np.array([h[:,:,k]]),background=background,minmax = minmax)
            x1 = utils.deprocessing(x1)
            x1 = x1 / x1.max()
            x1 = abs(x1)
            fig = utils.contour_img(x1[0],h1[0])
            fig.savefig(fpath / ("feature_{}.jpg".format(k)))
            plt.close()

        fpath = fpath.absolute()
        gpath = self.exp_location.absolute() / self.title / 'feature_imgs'
        def node_string(fidx,score,weight):
            
            
            nodestr = ""
            nodestr += "<table border=\"0\">\n"
            nodestr+="<tr>"
            nodestr+="<td><img src= \"{}\" /></td>".format(str(fpath / ("feature_{}.jpg".format(fidx))))
            nodestr+="<td><img src= \"{}\" /></td>".format(str(gpath / ("{}.jpg".format(fidx))))
            nodestr+="</tr>\n"
            if display_value:
                nodestr +="<tr><td colspan=\"2\"><FONT POINT-SIZE=\"{}\"> ClassName: {}, Feature: {}</FONT></td></tr>\n".format(font,self.class_names[cidx],fidx)
                nodestr +="<tr><td colspan=\"2\"><FONT POINT-SIZE=\"{}\"> Similarity: {:.3f}, Weight: {:.3f}, Contribution: {:.3f}</FONT></td></tr> \n".format(font,score,weight,score*weight)
            nodestr += "</table>  \n" 
            return nodestr




        s = h.mean(axis = (0,1))
        for cidx in target_classes:
            tw = w[:,cidx]
            tw_idx = tw.argsort()[::-1][:featuretopk] 
            
            total = 0

            resstr = "digraph Tree {node [shape=plaintext] ;\n"
            resstr += "1 [label=< \n<table border=\"0\"> \n"
            for fidx in tw_idx:
                resstr+="<tr><td>\n"
                    
                resstr+=node_string(fidx,s[fidx],tw[fidx])
                total+=s[fidx]*tw[fidx]
                    
                resstr+="</td></tr>\n"

            if with_total:
                resstr +="<tr><td><FONT POINT-SIZE=\"{}\"> Total Conrtibution: {:.3f}, Prediction: {:.3f}</FONT></td></tr> \n".format(font,total,pred[cidx])
            resstr += "</table> \n >];\n"
            resstr += "}"

            graph = pydotplus.graph_from_dot_data(resstr)  
            graph.write_jpg(str(fpath / ("explanation_{}.jpg".format(cidx))))
            graph.write_jpg(str(afpath / ("{}_{}.jpg".format(name,cidx))))

In [17]:
def get_explainer(paras,model):
  classes_names = paras['classes_names']
  target_classes = paras['target_classes']
  n_components = paras['n_components']
  layer_name = paras['layer_name']
  title = paras['title']
  overwrite = paras['overwrite']

  if overwrite:
    try:
        shutil.rmtree('Explainers/'+title)
    except:
        pass

  model.set_predict_target(target_classes)
  loaders = get_loaders(target_classes)
  if os.path.exists(Path('Explainers')/title):
    Exp = Explainer(title = title)
    Exp.load()
    print ('model loaded')
  else:

    # create an Explainer
    Exp = Explainer(title = title,
                    layer_name = layer_name,
                    class_names = classes_names,
                    utils = img_utils(mode = "torch"),
                    n_components = n_components,
                    reducer_type = "NMF"
                  )

    # train reducer based on target classes
    Exp.train_model(model,loaders)
    # generate features
    Exp.generate_features(model, loaders)
    # generate global explanations
    Exp.global_explanations()
    # save the explainer, use load to load it with the same title
    Exp.save()

  return Exp, loaders

In [18]:
model = PytorchModelWrapper(model,batch_size=8,predict_target=paras['target_classes'],input_channel_first = False,model_channel_first = True)

In [19]:
e, loaders= get_explainer(paras,model)

Training reducer:
1/5 Featuer maps gathered.
loading complete, with size of (888, 7, 7, 512)
2/5 Reducer trained, spent 13.558522701263428 s.
3/5 Error estimated, fidelity: [0.06096106].
4/5 Weight estimator initialized.
5/5 Weight estimated.
visulizing features:
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 20]
loading training data
Done with class: church, 1/1
Generate explanations with fullset condition


In [20]:
def extract_ice_concepts_and_metrics(exp: Explainer, model: PytorchModelWrapper, loaders):
    """
    Extract concept basis W, activation representations U, and importance scores for ICE explanations.
    """
    import numpy as np

    # Step 1: Collect all features from loaders
    print("Extracting features from loader...")
    all_features = []
    for loader in loaders:
        features = model.get_feature(loader, exp.layer_name)  # [N, H, W, C]
        all_features.append(features)
    all_features = np.concatenate(all_features, axis=0)

    # Step 2: Transform into concept space using fitted reducer
    assert exp.reducer._is_fit, "Reducer is not fitted."
    U = exp.reducer.transform(all_features)  # [N, H, W, K]

    # Step 3: Collect concept basis vectors (W) and importance
    W = exp.reducer._reducer.components_  # [K, D]
    importance = np.mean(np.abs(exp.test_weight), axis=1)  # [K]

    print(f"Concepts extracted: U shape = {U.shape}, W shape = {W.shape}, importance = {importance.shape}")

    return U, W, importance

In [21]:
from pathlib import Path
from tqdm import tqdm

In [22]:
U, W, importances = extract_ice_concepts_and_metrics(e, model, loaders)

Extracting features from loader...
Concepts extracted: U shape = (888, 7, 7, 25), W shape = (25, 512), importance = (25,)


In [23]:
UW = U@W
UW = torch.tensor(UW, device='cuda')

In [24]:
UW_reconstructed = UW.permute(0, 3, 1, 2)

In [25]:
model = torchvision.models.resnet34(pretrained=True)
model = model.eval().to(device)
g = nn.Sequential(*(list(model.children())[:-2]))  # All layers except the last two
h = lambda x: model.fc(torch.mean(x, (2, 3))) 
h_2d = lambda x: model.fc(x)
pred  = h(UW_reconstructed)
c= torch.argmax(pred, dim=-1)
accuracy = torch.sum(c==target_class)/len(c)
print('accuracy after NMF is:', accuracy)

accuracy after NMF is: tensor(0.9899, device='cuda:0')


### Complexity with C-Gini

In [26]:
#gini index
def compute_gini_index(concept_importance):
    if len(concept_importance.shape) != 1:
        raise ValueError("Input concept_importance must be a 1D tensor.")

    sorted_importance = torch.sort(concept_importance)[0]
    n = sorted_importance.size(0)
    mean_importance = torch.mean(sorted_importance)

    if mean_importance == 0:
        return 0.0  # Gini index is 0 if all values are zero
    cumulative_indices = torch.arange(1, n + 1, dtype=torch.float32, device=concept_importance.device)
    numerator = torch.sum((2 * cumulative_indices - n - 1) * sorted_importance)
    denominator = n * torch.sum(sorted_importance)

    gini_index = numerator / denominator
    return gini_index.item()

In [27]:
compute_gini_index(torch.tensor(importances))

0.5211431980133057

### Faithfulness using C-Ins and C-Del 

In [28]:
def concept_insertion(U, W, importances, h, imagenet_class):
    U = torch.tensor(U, device='cuda')
    W = torch.tensor(W, device='cuda')
    importances = torch.tensor(importances, device='cuda')
    sorted_indices = torch.argsort(importances, descending=True)
    
    u_insert = torch.zeros_like(U).to('cuda')
    results = {}
    
    def predict(u):
        uw = u @ W
        uw = uw.permute(0, 3, 1, 2)
        return h(uw)
    
    pred = predict(u_insert)
    acc = (pred.argmax(dim=1) == imagenet_class).float().mean().item()
    results[0] = acc
    
    for i in range(1, W.shape[0] + 1):
        u_insert[..., sorted_indices[:i]] = U[..., sorted_indices[:i]]
        pred = predict(u_insert)
        acc = (pred.argmax(dim=1) == imagenet_class).float().mean().item()
        results[i] = acc
    return results

In [29]:
def concept_deletion(U, W, importances, h, imagenet_class):
    U = torch.tensor(U, device='cuda')
    W = torch.tensor(W, device='cuda')
    importances = torch.tensor(importances, device='cuda')
    sorted_indices = torch.argsort(importances, descending=True)
    
    results = {}
    def predict(u):
        uw = u @ W
        uw = uw.permute(0, 3, 1, 2)
        return h(uw)
    
    pred = predict(U)
    acc = (pred.argmax(dim=1) == imagenet_class).float().mean().item()
    results[0] = acc
    
    for i in range(1, W.shape[0] + 1):
        u_mod = U.clone()
        u_mod[..., sorted_indices[:i]] = 0
        pred = predict(u_mod)
        acc = (pred.argmax(dim=1) == imagenet_class).float().mean().item()
        results[i] = acc
    return results


In [30]:
def compute_insertion_auc(insertion_scores):
    x = np.array(list(insertion_scores.keys()))
    y = np.array(list(insertion_scores.values()))
    return np.trapz(y, x) / (x[-1] - x[0])

def compute_deletion_score(deletion_scores):
    x = np.array(list(deletion_scores.keys()))
    y = np.array(list(deletion_scores.values()))
    return (y[0] - y).mean()


In [31]:
model = torchvision.models.resnet34(pretrained=True)
model = model.eval().to(device)
g = nn.Sequential(*(list(model.children())[:-2]))  # All layers except the last two
h = lambda x: model.fc(torch.mean(x, (2, 3))) 
h_2d = lambda x: model.fc(x)

In [32]:
insertion_scores = concept_insertion(U, W, importances, h, imagenet_class=target_class)
deletion_scores = concept_deletion(U, W, importances, h, imagenet_class=target_class)

print("Insertion AUC (C_Ins):", compute_insertion_auc(insertion_scores))
print("Deletion AUC (C_Del):", compute_deletion_score(deletion_scores))


Insertion AUC (C_Ins): 0.8861937177181244
Deletion AUC (C_Del): 0.5872314721614552


In [38]:
import torch
import torch.nn.functional as F
import numpy as np
from scipy.stats import spearmanr
from scipy.optimize import linear_sum_assignment

class ICEEvaluator:
    def __init__(self, model, exp, loaders, device='cuda'):
        self.model = model
        self.exp = exp
        self.loaders = loaders
        self.device = device

        self.h_2d = lambda x: self.model.fc(x)
        self.W = torch.tensor(self.exp.reducer._reducer.components_).float().to(self.device)  # [K, C]

        self.images, self.A, self.U, self.UW = self._extract_all()

    def _extract_all(self):
        # Collect all image features
        features = []
        images = []
        for loader in self.loaders:
            for x, _ in loader:
                images.append(x)
                with torch.no_grad():
                    model = PytorchModelWrapper(self.model,batch_size=8,predict_target=paras['target_classes'],input_channel_first = False,model_channel_first = True)
                    f = model.get_feature(x.to(self.device), self.exp.layer_name)  # [N, H, W, C]
                    features.append(f)
        A = torch.cat([torch.tensor(f) for f in features], dim=0).to(self.device)  # [N, H, W, C]
        images = torch.cat(images, dim=0).to(self.device)

        U = torch.tensor(self.exp.reducer.transform(A.detach().cpu().numpy())).float().to(self.device)  # [N, H, W, K]
        W = self.W  # [K, C]
        UW = torch.einsum('nhwk,kc->nhwc', U, W)  # [N, H, W, C]

        return images, A, U, UW

    def evaluate_projection(self):
        A = self.A
        UW = self.UW

        A_avg = torch.mean(A, dim=(1, 2))
        UW_avg = torch.mean(UW, dim=(1, 2))

        logits_A = self.h_2d(A_avg)
        logits_UW = self.h_2d(UW_avg)

        mse = F.mse_loss(UW_avg, A_avg)
        kl = F.kl_div(F.log_softmax(logits_UW, dim=-1), F.softmax(logits_A, dim=-1), reduction='batchmean')
        return mse.item(), kl.item()

    def compute_sparsity(self):
        U_flat = self.U.view(-1, self.U.shape[-1])
        non_zero = (U_flat != 0).float().sum(dim=1)
        k = self.U.shape[-1]
        sparsity_scores = non_zero / k
        return sparsity_scores.mean().item()

In [35]:
ICEeval = ICEEvaluator( model, e, loaders)

### NMF projection error

In [36]:
mse_loss, kl_loss = ICEeval.evaluate_projection()
mse_loss, kl_loss

(0.2321694791316986, 0.3566300868988037)

### Representation sparsity (U_Spr)

In [37]:
rep_sparsity = ICEeval.compute_sparsity()
rep_sparsity

0.5597802996635437

### Stability (C-Stab)

In [43]:
def extract_w_list_ice(loaders, model, layer_name, n_components=10, k_folds=5):
    """
    Run ICE multiple times (e.g., k-fold splits) and extract concept bases W from each run.
    """
    from sklearn.model_selection import KFold
    from sklearn.decomposition import NMF
    import numpy as np

    # Step 1: Collect all features into one array
    all_features = []
    model = PytorchModelWrapper(model,batch_size=8,predict_target=paras['target_classes'],input_channel_first = False,model_channel_first = True)
    for loader in loaders:
        feats = model.get_feature(loader, layer_name)  # [N, H, W, C]
        all_features.append(feats)
    all_features = np.concatenate(all_features, axis=0)
    N, H, W, C = all_features.shape
    flat_feats = all_features.reshape(N * H * W, C)

    # Step 2: Run KFold splits
    kf = KFold(n_splits=k_folds, shuffle=True, random_state=42)
    w_list = []

    for i, (train_idx, _) in enumerate(kf.split(flat_feats)):
        subset_feats = flat_feats[train_idx]

        # Run NMF on this subset
        nmf = NMF(n_components=n_components, init='nndsvda', max_iter=500)
        nmf.fit(subset_feats)

        W = nmf.components_  # Shape: [K, D]
        w_list.append(W)

    return w_list


In [42]:
import torch
import torch.nn.functional as F
import numpy as np
from scipy.optimize import linear_sum_assignment

def compute_stability(w_list):
    """
    Computes pairwise cosine similarity between concept bases across k-folds.
    Uses Hungarian algorithm to match concept vectors.
    """
    similarities = []
    num_models = len(w_list)

    for i in range(num_models):
        W1 = F.normalize(torch.tensor(w_list[i]), p=2, dim=1).to('cuda')
        for j in range(i + 1, num_models):
            W2 = F.normalize(torch.tensor(w_list[j]), p=2, dim=1).to('cuda')

            sim_matrix = W1 @ W2.T  # shape [K, K]
            cost = 1 - sim_matrix.detach().cpu().numpy()
            row_ind, col_ind = linear_sum_assignment(cost)
            matched_sims = sim_matrix[row_ind, col_ind]
            similarities.append(matched_sims.mean().item())

    return np.mean(similarities) if similarities else 0.0


In [41]:
w_list = extract_w_list_ice(loaders, model, paras['layer_name'], k_folds=5)
stability_score = compute_stability(w_list)
print("Stability Score:", stability_score)

Stability Score: 0.9802432954311371
