In [27]:
import sys
sys.path.append("../src")

from dataset import ImageDataset, TRANSFORM, LABEL_ID_INV_DIC
from models import ResnetPredictor

import os
import os.path as osp
import hashlib
import shutil
import random 
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm

import torch

import sklearn.cluster as cluster
import sklearn.metrics.pairwise as metrics
from skimage.segmentation import slic

DATA_DIR = "../imagenette"
DATA_FILE = "noisy_imagenette.csv"

MODEL_DIR = "../models"
MODEL_FILE = "resnet_state_dict.pth"

CONCEPT_DIR = "../concepts"

In [39]:
def delete_subdirectories(folder_path):
    if not os.path.isdir(folder_path):
        print("Path does not lead to a valid directory.")
        return
    
    for item in os.listdir(folder_path):
        item_path = os.path.join(folder_path, item)
        if os.path.isdir(item_path):
            shutil.rmtree(item_path)
            print(f"Directory deleted : {item_path}")

folder_path = CONCEPT_DIR
delete_subdirectories(folder_path)

Directory deleted : ../concepts\n03888257


In [34]:
def save_images(addresses, images):
    if not isinstance(addresses, list):
        image_addresses = []
        for i, image in enumerate(images):
            image_name = "0" * (3 - int(np.log10(i + 1))) + str(i + 1) + ".png"
            image_addresses.append(os.path.join(addresses, image_name))
            addresses = image_addresses
    assert len(addresses) == len(images), "Invalid number of addresses"
    for address, image in zip(addresses, images):
        with open(address, "wb") as f:
            Image.fromarray(image).save(f, format="PNG")

In [40]:
class ConceptDiscovery:

    def __init__(
        self,
        data_dir,
        data_file,
        label,
        image_shape,
        average_image_value,
        model,
        bottleneck,
        concept_dir,
        max_imgs=40,
        min_imgs=20,
    ):
        self.data_dir = data_dir
        self.data_file = data_file
        self.concept_dir = concept_dir
        if not os.path.exists(self.concept_dir):
            os.makedirs(self.concept_dir)

        self.label = label

        self.image_shape = image_shape
        self.average_image_value = average_image_value

        self.model = model
        self.model.eval()
        self.bottleneck = bottleneck

        self.max_imgs = max_imgs
        self.min_imgs = min_imgs

    def extract_discovery_images(self, n_image):
        data_df = pd.read_csv(osp.join(self.data_dir, self.data_file))

        label_data_df = data_df[data_df.noisy_labels_0 == self.label]
        label_data_df = label_data_df[:n_image]

        label_data_df.to_csv(
            osp.join(self.data_dir, self.label + "_" + str(n_image) + "_discovery_images.csv"),
            index=False,
        )

    def create_patches(self, param_dict):

        if param_dict is None:
            param_dict = {}

        discovery_image_file = param_dict.pop("discovery_image_file", None)
        n_image = param_dict.pop("n_image", 50)

        n_segment = param_dict.pop("n_segment", [15, 50, 80])
        compactness = param_dict.pop("compactness", 20)
        sigma = param_dict.pop("sigma", 1.0)

        if discovery_image_file is None:
            discovery_image_file = self.label + "_" + str(n_image) + "_discovery_images.csv"
            if not osp.exists(osp.join(self.data_dir, discovery_image_file)):
                self.extract_discovery_images(n_image)
        else:
            if not osp.exists(osp.join(self.data_dir, discovery_image_file)):
                raise ValueError("File does not exist.")
            
        discovery_image_dataset = ImageDataset(
            self.data_dir, discovery_image_file, TRANSFORM, False
        )
        self.discovery_image_dataset = discovery_image_dataset

        # Generate a unique identifier for the parameter set
        param_str = f"{n_segment}_{compactness}_{sigma}"
        param_hash = hashlib.md5(param_str.encode()).hexdigest()
        subfolder = osp.join(self.data_dir, self.label + "_" + param_hash)
        if not osp.exists(subfolder):
            os.makedirs(subfolder)
        patch_param_file = osp.join(subfolder, "patch_params.npy")
        dataset_file = osp.join(subfolder, "dataset.npy")
        image_numbers_file = osp.join(subfolder, "image_numbers.npy")
        patches_file = osp.join(subfolder, "patches.npy")

        # Check if existing patches and parameters match
        if (
            osp.exists(patch_param_file)
            and osp.exists(dataset_file)
            and osp.exists(image_numbers_file)
            and osp.exists(patches_file)
        ):
            saved_params = np.load(patch_param_file, allow_pickle=True).item()
            if (
                saved_params["n_segment"] == n_segment
                and saved_params["compactness"] == compactness
                and saved_params["sigma"] == sigma
            ):
                self.dataset = np.load(dataset_file)
                self.image_numbers = np.load(image_numbers_file)
                self.patches = np.load(patches_file)
                print("Loaded existing patches and parameters from", subfolder)
                return

        # Extract patches if non existent
        dataset, image_numbers, patches = [], [], []

        for i in tqdm(range(len(discovery_image_dataset)), desc="Superpixels and Patches"):
            img, _ = discovery_image_dataset[i]
            img = np.transpose(img.numpy(), (1, 2, 0))

            image_superpixels, image_patches = self._return_superpixels(
                img, n_segment, compactness, sigma
            )

            for superpixel, patch in zip(image_superpixels, image_patches):
                dataset.append(superpixel)
                patches.append(patch)
                image_numbers.append(i)

        self.dataset, self.image_numbers, self.patches = (
            np.array(dataset),
            np.array(image_numbers),
            np.array(patches),
        )

        # Save patches and parameters to files with unique identifier
        np.save(
            patch_param_file, {"n_segment": n_segment, "compactness": compactness, "sigma": sigma}
        )
        np.save(dataset_file, self.dataset)
        np.save(image_numbers_file, self.image_numbers)
        np.save(patches_file, self.patches)

    def _return_superpixels(self, image, n_segment=[15, 50, 80], compactness=20, sigma=1.0):

        if not isinstance(n_segment, list):
            n_segments = [n_segment]
        else:
            n_segments = n_segment

        unique_masks = []
        for i, n_segment in enumerate(n_segments):
            segments = slic(
                image,
                n_segments=n_segment,
                compactness=compactness,
                sigma=sigma,
            )

            param_masks = []
            for s in range(segments.max()):
                mask = (segments == s).astype(float)
                unique = True
                # negligiblity check
                if np.mean(mask) < 0.001:
                    unique = False
                # similarity check
                for seen_mask in unique_masks:
                    jaccard = np.sum(seen_mask * mask) / np.sum((seen_mask + mask) > 0)
                    if jaccard > 0.5:
                        unique = False
                        break
                if unique:
                    param_masks.append(mask)

            unique_masks.extend(param_masks)

        superpixels, patches = [], []
        while unique_masks:
            superpixel, patch = self._extract_patch(image, unique_masks.pop())
            superpixels.append(superpixel)
            patches.append(patch)

        return superpixels, patches

    def _extract_patch(self, image, mask):

        mask_expanded = np.expand_dims(mask, -1)
        patch = mask_expanded * image + (1 - mask_expanded) * float(self.average_image_value) / 255

        ones = np.where(mask == 1)
        h1, h2, w1, w2 = ones[0].min(), ones[0].max(), ones[1].min(), ones[1].max()
        image = Image.fromarray((patch[h1:h2, w1:w2] * 255).astype(np.uint8))
        image_resized = (
            np.array(image.resize(self.image_shape[:2], Image.BICUBIC)).astype(float) / 255
        )

        return image_resized, patch

    def discover_concepts(self, param_dict=None):

        if param_dict is None:
            param_dict = {}

        activation_file = param_dict.pop("activation_file", None)
        batch_size = param_dict.pop("batch_size", 64)

        n_clusters = param_dict.pop("n_clusters", 25)
        save = param_dict.pop("save", True)

        # Activations
        activations = self.load_activations(activation_file)
        if activations is None or self.bottleneck not in activations.keys():
            self.bottleneck_activations = self._patch_activations(batch_size)
            self.save_activations(activation_file)
        else:
            self.bottleneck_activations = activations[self.bottleneck]       

        # Concepts
        self.bottleneck_dic = {}

        self.bottleneck_dic["label"], self.bottleneck_dic["cost"], centers = self._cluster(
            self.bottleneck_activations, n_clusters
        )

        concept_number = 0
        self.bottleneck_dic["concepts"] = []

        for i in tqdm(range(self.bottleneck_dic["label"].max() + 1), desc="Concept selection"):
            label_idxs = np.where(self.bottleneck_dic["label"] == i)[0]

            if len(label_idxs) > self.min_imgs:
                concept_costs = self.bottleneck_dic["cost"][label_idxs]
                concept_idxs = label_idxs[np.argsort(concept_costs)[: self.max_imgs]]
                concept_image_numbers = set(self.image_numbers[label_idxs])
                discovery_size = len(self.discovery_image_dataset)

                highly_common_concept = len(concept_image_numbers) > 0.5 * len(label_idxs)

                mildly_common_concept = len(concept_image_numbers) > 0.25 * len(label_idxs)
                mildly_populated_concept = len(concept_image_numbers) > 0.25 * discovery_size
                cond2 = mildly_populated_concept and mildly_common_concept

                non_common_concept = len(concept_image_numbers) > 0.1 * len(label_idxs)
                highly_populated_concept = len(concept_image_numbers) > 0.5 * discovery_size
                cond3 = non_common_concept and highly_populated_concept

                if highly_common_concept or cond2 or cond3:
                    concept_number += 1
                    concept = "concept{}".format(concept_number)
                    self.bottleneck_dic["concepts"].append(concept)
                    self.bottleneck_dic[concept] = {
                        "images": self.dataset[concept_idxs],
                        "patches": self.patches[concept_idxs],
                        "image_numbers": self.image_numbers[concept_idxs],
                    }
                    self.bottleneck_dic[concept + "_center"] = centers[i]

        self.bottleneck_dic.pop("label", None)
        self.bottleneck_dic.pop("cost", None)

        self.save_concepts()

    def _patch_activations(self, batch_size=64, channel_mean=True):

        if isinstance(self.model, ResnetPredictor):
            bottleneck_name = "resnet." + self.bottleneck

        bottleneck_layer = None
        for name, module in self.model.named_modules():
            if name == bottleneck_name:
                bottleneck_layer = module
                break

        activation = []

        def hook(module, input, output):
            activation.append(output.detach().cpu().numpy())

        hook_handle = bottleneck_layer.register_forward_hook(hook)
        for i in tqdm(range(int(self.dataset.shape[0] / batch_size) + 1), desc="Activations"):
            batch_input = self.dataset[i * batch_size : (i + 1) * batch_size]
            batch_input = torch.tensor(np.transpose(batch_input, (0, 3, 1, 2))).float()
            _ = self.model(batch_input)

        hook_handle.remove()

        activation = np.concatenate(activation, axis=0)

        if channel_mean and len(activation.shape) > 3:
            activation = np.mean(activation, axis=(1, 2))
        else:
            activation = np.reshape(activation, [activation.shape[0], -1])

        return activation

    def _cluster(self, acts, n_clusters=25):

        centers = None
        km = cluster.KMeans(n_clusters)
        d = km.fit(acts)
        centers = km.cluster_centers_

        d = np.linalg.norm(np.expand_dims(acts, 1) - np.expand_dims(centers, 0), ord=2, axis=-1)
        asg, cost = np.argmin(d, -1), np.min(d, -1)

        if centers is None:  ## If clustering returned cluster centers, use medoids
            centers = np.zeros((asg.max() + 1, acts.shape[1]))
            cost = np.zeros(len(acts))
            for cluster_label in range(asg.max() + 1):
                cluster_idxs = np.where(asg == cluster_label)[0]
                cluster_points = acts[cluster_idxs]
                pw_distances = metrics.euclidean_distances(cluster_points)
                centers[cluster_label] = cluster_points[np.argmin(np.sum(pw_distances, -1))]
                cost[cluster_idxs] = np.linalg.norm(
                    acts[cluster_idxs] - np.expand_dims(centers[cluster_label], 0), ord=2, axis=-1
                )
        return asg, cost, centers

    # def save_activations(self, file_name=None):
    #     if file_name is None:
    #         file_name = "patch_activations.pickle"

    #     activation_dict = {self.bottleneck: self.bottleneck_activations}

    #     activations_dir = osp.join(
    #         self.concept_dir,
    #         self.label,
    #         type(self.model).__name__,
    #         self.bottleneck,
    #         "activations",
    #     )
    #     os.makedirs(activations_dir, exist_ok=True)

    #     file_path = osp.join(activations_dir, file_name)

    #     with open(file_path, "wb") as f:
    #         pickle.dump(activation_dict, f)

    # def load_activations(self, file_name=None):
    #     if file_name is None:
    #         file_name = "patch_activations.pickle"

    #     file_path = osp.join(
    #         self.concept_dir,
    #         self.label,
    #         type(self.model).__name__,
    #         self.bottleneck,
    #         "activations",
    #         file_name,
    #     )

    #     if osp.exists(file_path):
    #         with open(file_path, "rb") as f:
    #             activation_dict = pickle.load(f)
    #         bottleneck_activations = activation_dict
    #     else:
    #         bottleneck_activations = None

    #     return bottleneck_activations

    def save_activations(self, file_name=None):
        if file_name is None:
            file_name = "patch_activations.pickle"

        activations_dir = osp.join(
            self.concept_dir,
            self.label,
            type(self.model).__name__,
            "activations",
        )
        os.makedirs(activations_dir, exist_ok=True)

        file_path = osp.join(activations_dir, file_name)
        if osp.exists(file_path):
            with open(file_path, "rb") as f:
                activation_dict = pickle.load(f)
        else:
            activation_dict = {}

        activation_dict[self.bottleneck] = self.bottleneck_activations
        with open(file_path, "wb") as f:
            pickle.dump(activation_dict, f)
    
    def load_activations(self, file_name=None):
        if file_name is None:
            file_name = "patch_activations.pickle"

        file_path = osp.join(
            self.concept_dir,
            self.label,
            type(self.model).__name__,
            "activations",
            file_name,
        )

        if osp.exists(file_path):
            with open(file_path, "rb") as f:
                activation_dict = pickle.load(f)
            return activation_dict
        else:
            return None

    def save_concepts(self):
        for concept in self.bottleneck_dic["concepts"]:
            patches_dir = osp.join(
                self.concept_dir,
                self.label,
                type(self.model).__name__,
                self.bottleneck,
                concept + "_patches",
            )
            images_dir = osp.join(
                self.concept_dir,
                self.label,
                type(self.model).__name__,
                self.bottleneck,
                concept,
            )
            os.makedirs(patches_dir, exist_ok=True)
            os.makedirs(images_dir, exist_ok=True)

            patches = (np.clip(self.bottleneck_dic[concept]["patches"], 0, 1) * 256).astype(
                np.uint8
            )
            images = (np.clip(self.bottleneck_dic[concept]["images"], 0, 1) * 256).astype(np.uint8)
            image_numbers = self.bottleneck_dic[concept]["image_numbers"]

            image_addresses, patch_addresses = [], []
            for i in range(len(images)):
                image_name = "0" * int(np.ceil(2 - np.log10(i + 1))) + "{}_{}".format(
                    i + 1, image_numbers[i]
                )
                patch_addresses.append(os.path.join(patches_dir, image_name + ".png"))
                image_addresses.append(os.path.join(images_dir, image_name + ".png"))

            save_images(patch_addresses, patches)
            save_images(image_addresses, images)

In [41]:
train_dataset = ImageDataset(DATA_DIR, DATA_FILE, TRANSFORM, False)
toy_image, toy_label_vec = train_dataset[5000]
toy_image = np.transpose(toy_image.numpy(), (1, 2, 0))
toy_label = LABEL_ID_INV_DIC[torch.argmax(toy_label_vec).item()]
toy_label

'n03888257'

In [42]:
model_state_dict = torch.load(osp.join(MODEL_DIR, MODEL_FILE))

loaded_model = ResnetPredictor()
loaded_model.load_state_dict(model_state_dict)
loaded_model.eval()

type(loaded_model).__name__

'ResnetPredictor'

In [43]:
concept_discovery = ConceptDiscovery(
    DATA_DIR,
    DATA_FILE,
    toy_label,
    image_shape=(320, 320, 3),
    average_image_value=115,
    model=loaded_model,
    bottleneck="avgpool",
    concept_dir=CONCEPT_DIR,
)

In [44]:
param_dict = {
    "discovery_image_file": None,
    "n_image": 50,
    "n_segment": [15, 50, 80],
    "compactness": 20,
    "sigma": 1.0,
}

concept_discovery.create_patches(param_dict=param_dict)

Loaded existing patches and parameters from ../imagenette\n03888257_ad3617c91c903ff915e3913972217c63


In [45]:
param_dict = {
    "load_activations": True,
    "activation_file": None,
    "batch_size": 64,
    "n_clusters": 25,
    "save": True
}

concept_discovery.discover_concepts(param_dict=param_dict)

Activations: 100%|██████████| 89/89 [03:54<00:00,  2.64s/it]
Concept selection:   0%|          | 0/25 [00:00<?, ?it/s]


AttributeError: 'ConceptDiscovery' object has no attribute 'discovery_image_dataset'

In [7]:
concept_discovery.bottleneck_activations

AttributeError: 'ConceptDiscovery' object has no attribute 'bottleneck_activations'

In [13]:
concept_discovery.save_activations(None)

In [3]:
concept_discovery.bottleneck_dic
concept_discovery.bottleneck_dic.keys()
concept_discovery.bottleneck_dic[concept_discovery.bottleneck_dic["concepts"][0]]
concept_discovery.bottleneck_dic[concept_discovery.bottleneck_dic["concepts"][0]].keys()
concept_discovery.bottleneck_dic[concept_discovery.bottleneck_dic["concepts"][0]]["images"]

NameError: name 'concept_discovery' is not defined

In [None]:
del concept_discovery.dataset  
del concept_discovery.image_numbers
del concept_discovery.patches