# ImageSI: Semantic Interaction for Deep Learning Image Projections

### Instructions:

1. Update the variable name of the first file to your own image folder path, replacing **ImgFolder**.
2. Click "Kernel" in the menu bar, then select "Restart Kernel and Run All".
3. Navigate to "ChangeHere" in the notebook and customize the settings.
4. Utilize the interactive plots located near the bottom of the notebook.
5. Interaction Types:
   - **Projection Interaction:** Drag images within the projection plot to create user-defined clusters. Then, click "Fine-tune Model" to refine the convolutional layers and view the updated projection plot.

Reference: 
1. Self, J. Z., House, L., Leman, S., & North, C. (2015). Andromeda: Observation-level and parametric interaction for exploratory data analysis. Technical report, Technical report.

2. Han, H., Faust, R., Keith Norambuena, B. F., Lin, J., Li, S., & North, C. (2023). Explainable interactive projections of images. Machine Vision and Applications, 34(6), 100.

3. Bian, Y., & North, C. (2021, April). Deepsi: Interactive deep learning for semantic interaction. In 26th International Conference on Intelligent User Interfaces (pp. 197-207).

In [50]:
%matplotlib notebook
#%matplotlib inline

### interactive notebook format is required for the interactive plot
import numpy as np
import pandas as pd
import math
from math import isnan
import random
import os
from os import listdir
from os.path import isfile, join
import cv2
from skimage.transform import resize
import csv
from functools import partial
from tqdm import tqdm

from sklearn.decomposition import PCA
from sklearn.manifold import MDS
import sklearn.metrics.pairwise
from sklearn.metrics import silhouette_score

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from matplotlib.patches import FancyBboxPatch
from matplotlib.widgets import Slider, Button

import ipywidgets as widgets
from ipywidgets import interact, Layout, Button, GridBox, ButtonStyle
from IPython.display import display, clear_output, Image
from torch.utils.data import TensorDataset

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor
from torchvision import transforms
from torchvision.models import resnet18, resnet34, resnet101, resnet152

import copy
import time
from PIL import Image

from torchvision import models
import shutil
import torchvision
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from torchvision.models import resnet18
import torch.nn as nn
import torch.nn.functional as F
import zipfile
from sklearn.preprocessing import LabelEncoder
from scipy.spatial.distance import squareform


import torch
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import widgets, Layout, HBox, VBox, Label
from IPython.display import display
from torchvision import transforms
from torch.utils.data import DataLoader


from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

import warnings
warnings.filterwarnings('ignore')
from PIL import Image
import numpy as np
import requests
import cv2
import json
import torch
from pytorch_grad_cam import DeepFeatureFactorization
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image, deprocess_image
from pytorch_grad_cam.utils.image import show_factorization_on_image
from transformers import ResNetForImageClassification

from sklearn.metrics.pairwise import euclidean_distances, manhattan_distances
from sklearn.metrics import silhouette_score
from sklearn.utils import check_random_state

from sklearn.metrics.pairwise import pairwise_distances

from scipy.spatial.distance import pdist, squareform
from sklearn.preprocessing import StandardScaler

# Load and Pre-process Data

Change the **imgFolder** to load a image dataset.

The image folder structure should be arranged in standard pytorch data loading format i.e: <br>
* root/dog/xxx.png
* root/dog/xxy.png
* root/dog/xxz.png


* root/cat/123.png
* root/cat/nsdf3.png
* root/cat/asd932_.png

In [51]:
## settings
imgFolder = '/Users/jiayuelin/InfoVis/Datasets/Animals_Mouth_EXP' # ChangeHere
sampleSizePerCat = 10  #sample size for each image category subfolder
imgDisplaySize = 0.25 #default value for image display size, can be interactively adjusted in the UI
total_img = 20  # maximun total number of images to display on the UI
folderName = True  #whether the imgFolder has subfolder for each category. e.g. the fish dataset
load_weights_from_file = False

In [52]:
class FilenameDataset(Dataset):

    def __init__(self, files, transform=None):
        self.files = list(files)
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        sample = Image.open(self.files[idx]).convert('RGB')
        if self.transform: # whether self-defined transform
            return self.transform(sample)
        transform_default = transforms.Compose([
                transforms.Resize((224,224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        return transform_default(sample)
    
def get_path(imgFolder, sampleSizePerCat, folderName, max_img_num=100):
    """    
    @parameters:
        imgFolder[str]: path for the image folder. e.g. "/root/" 
        sampleSizePerCat[int]: the number of samples from each subfolder. e.g. 10 images each from "root/cat/"
                          and "root/dog/"
        folderName[boolean]: if structure is: root/cat/123.png then True, if root/123.png then False.
        max_img_num[int]: upper limit for TOTAL image samples
    @return[dict]: a dictionary with image index (used in later dataframe as index as well)
             as key and full path of the image as the value
    """
    imgIdx_path = {}
    totalImg = 0
    for (dirpath, dirnames, filenames) in os.walk(imgFolder):
        sampleCount = 0
        for filename in filenames:
            if filename.lower().endswith('jpg') or filename.lower().endswith(
                    'jpeg') or filename.lower().endswith('png'):
                path = dirpath + '/' + filename
                pattern = extractIdx_pattern(path, folderName)
                imgIdx_path[pattern] = path
                sampleCount += 1
                totalImg += 1
            if sampleCount == sampleSizePerCat or totalImg == max_img_num:
                break
        if totalImg == max_img_num:
            break
    return imgIdx_path


def extractIdx_pattern(path, folderName):
    """
    @parameters:
        path[str]: single image path. e.g. "/root/cat/cat_01.png"
        folderName[boolean]: if structure is: root/cat/123.png then True, if root/123.png then False.
    @return[sts]: given a path(string) of image, extract image index from the path, return the image index string
    """
    if folderName:
        pattern = path.split('/')[-2] + '/' + path.split('/')[-1].split(
            '.')[-2]
    else:
        pattern = path.split('/')[-1].split('.')[-2]
    return pattern
    
    
def data_loader(imgIdx_path):
    """
    @parameters:
        imgIdx_path[dict]: a dict get from get_path function storing {image index: full path of the image}
    @return: image loader
    """
    dataset = FilenameDataset(imgIdx_path.values())
    loader = DataLoader(dataset)
    if loader:
        #print("{} images loaded".format(len(loader)))
        return loader
    else:
        print("Invalid path")
        return
            
def feature_extractor_custom(model, loader, imgIdx_path):
    """
    @parameters:
        model[neuron network]: model used to extract features
        loader: image loader returned from the data_loader function
    @return[dataframe]: a dataframe of extracted features indexing by image index(get from extractIdx_pattern function)
    """
    features = []
    for i, img in zip(range(len(loader)), loader):
#         with torch.no_grad():
        model_output = model(img)
        
#         print("x is",model_output.shape)
        x = model_output.detach().numpy().flatten()

        features.append(x)
        index = []
        for path in imgIdx_path.values():
            index.append(extractIdx_pattern(path, folderName))
    df = pd.DataFrame(features,
                      columns=[str(i) for i in range(1, 513)],
                      index=index)
    df.index.name = 'Image'
    return df, model_output

def feature_extractor_vb(model, loader, imgIdx_path):
    """
    @parameters:
        model[neuron network]: model used to extract features
        loader: image loader returned from the data_loader function
    @return[dataframe]: a dataframe of extracted features indexing by image index(get from extractIdx_pattern function)
    """
    features = []
    for i, img in zip(range(len(loader)), loader):
        with torch.no_grad():
            x = model(img)
            #print("map is", target_feature_map)
            #print("map shape is",target_feature_map.shape)
            features.append(x)
        index = []
        for path in imgIdx_path.values():
            index.append(extractIdx_pattern(path, folderName))
    df = pd.DataFrame(features,
                      columns=[str(i) for i in range(1, 513)],
                      index=index)
    df.index.name = 'Image'
    return df

def feature_extractor_triplet(model, loader, imgIdx_path):
    """
    @parameters:
        model[neuron network]: model used to extract features
        loader: image loader returned from the data_loader function
    @return[dataframe]: a dataframe of extracted features indexing by image index(get from extractIdx_pattern function)
    """
    features = []
    model.eval()  # Set the model to evaluation mode
    for i, img in zip(range(len(loader)), loader):
#         with torch.no_grad():
        model_output = model(img)
        
        # print("x is",model_output.shape)
        x = model_output.detach().numpy().flatten()

        features.append(x)
        index = []
        for path in imgIdx_path.values():
            index.append(extractIdx_pattern(path, folderName))
    df = pd.DataFrame(features,
                      columns=[str(i) for i in range(1, 3)],
                      index=index)
    df.index.name = 'Image'
    return df, model_output

def feature_extractor_triplet_ae(model, loader, imgIdx_path):
    """
    @parameters:
        model[neuron network]: model used to extract features
        loader: image loader returned from the data_loader function
    @return[dataframe]: a dataframe of extracted features indexing by image index(get from extractIdx_pattern function)
    """
    features = []
    model.eval()  # Set the model to evaluation mode
    for i, img in zip(range(len(loader)), loader):
        with torch.no_grad():
            model_output = model(img)
        
        # print("x is",model_output.shape)
        x = model_output.detach().numpy().flatten()

        features.append(x)
        index = []
        for path in imgIdx_path.values():
            index.append(extractIdx_pattern(path, folderName))
    df = pd.DataFrame(features,
                      columns=[str(i) for i in range(1, 3)],
                      index=index)
    df.index.name = 'Image'
    return df, model_output

def feature_extractor_tripletvb(model, loader, imgIdx_path):
    """
    @parameters:
        model[neuron network]: model used to extract features
        loader: image loader returned from the data_loader function
    @return[dataframe]: a dataframe of extracted features indexing by image index(get from extractIdx_pattern function)
    """
    
    features = []
    model.eval()  # Set the model to evaluation mode
    for i, img in zip(range(len(loader)), loader):
#         with torch.no_grad():
        model_output, vis, target_feature_map = model(img)
        
        # print("x is",model_output.shape)
        x = model_output.detach().numpy().flatten()

        features.append(x)
        index = []
        for path in imgIdx_path.values():
            index.append(extractIdx_pattern(path, folderName))
    df = pd.DataFrame(features,
                      columns=[str(i) for i in range(1, 3)],
                      index=index)
    df.index.name = 'Image'
    return df, model_output

def df_preprocess(df_image, normalize=False):
    """
    @parameters:
        df_image[dataframe]: image features dataframe
    @return[dataframe]: preprocessed dataframe
    """
    df_image.sort_index(inplace=True)
    df_numeric = df_image.select_dtypes(include='number').drop_duplicates(
    )  #'int32' or 'int64' or 'float32' or 'float64'
    df_category = df_image.select_dtypes(
        exclude='number').drop_duplicates()  #'object'
    ### Z-score normalization
    if normalize:
        normalized_df = (df_numeric - df_numeric.mean()) / df_numeric.std(
        )  # do not normalize animal dataset, all columns are 0-100 scale
        return normalized_df
    return df_image

def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = True

## ResNet Models

In [53]:
class TripletResNetModel(nn.Module):
    def __init__(self):
        super(TripletResNetModel, self).__init__()
        self.Model = models.resnet18(pretrained=True)
        num_filters = self.Model.fc.in_features # The number of input features for the fully connected layer
        self.Model.fc = nn.Sequential( # nn.Sequential is a container module that sequentially applies a list of layers
            nn.Linear(num_filters, 512), # Linear transformation layer that maps the input features (the output of the last convolutional layer) to 512 output features. 
            nn.LeakyReLU(), # Activation function to introduce non-linearity to the output of the linear layer; Allows non zero gradients for negative inputs, mitigate vanishing gradient problem
            nn.Linear(512, 10)) # Linear transformation layer as the final classification layer, each output corresponds to a class in the classification task
        self.Triplet_Loss = nn.Sequential( # Replace the fully connected layer/classification layer with a triplet loss layer
            nn.Linear(10, 2)) # Maps it to a 2D space, represent the distances between the anchor, positive, and negative samples in the triplet loss

    def forward(self, x):
        self.Model.eval()
        features = self.Model(x)
        triplets = self.Triplet_Loss(features)
        return triplets
    
class Autoencoder(nn.Module):
    def __init__(self, input_size, output_size):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Linear(input_size, output_size)
        self.decoder = nn.Linear(output_size, input_size)

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded


class InitialTripletResNetModel(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(InitialTripletResNetModel, self).__init__()
        self.Model = models.resnet18(pretrained=True)
        num_filters = self.Model.fc.in_features
        self.Model.fc = nn.Sequential(
            nn.Linear(num_filters, 512),
            nn.LeakyReLU(),
            nn.Linear(512, input_size))  # Adjust the output size to match the input size of the autoencoder
        self.Autoencoder = Autoencoder(input_size, hidden_size)

    def forward(self, x):
        features = self.Model(x)
        encoded, _ = self.Autoencoder(features)  # Only need the encoded representation
        return encoded


class CustomResNetModel(nn.Module):
    def __init__(self):
        super(CustomResNetModel, self).__init__()
        # Load the pre-trained ResNet model without the classification layer
        self.pretrained_model =  resnet18(pretrained=True)
        # remove fully connected layer
        self.model = nn.Sequential(*list(self.pretrained_model.children())[:-1])

    def forward(self, x):
        features = self.model(x)
        return features

## Human-defined Loss

In [54]:
imageIndex_path_dict = get_path(imgFolder,
                                sampleSizePerCat,
                                folderName=folderName,
                                max_img_num=total_img)

img_loader = data_loader(imageIndex_path_dict)

resnet_model = CustomResNetModel().model
resnet_model.eval()

resnet_df, output = feature_extractor_custom(resnet_model, img_loader, imageIndex_path_dict)
normalized_df = df_preprocess(resnet_df)
normalized_df.head()

Unnamed: 0_level_0,1,2,3,4,5,6,7,8,9,10,...,503,504,505,506,507,508,509,510,511,512
Image,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
close/0adeed6415,1.045141,0.960275,0.693633,0.661743,2.799178,0.331651,1.483747,2.410092,0.0,0.265786,...,0.642623,0.171148,0.349618,1.177837,2.463365,0.247215,1.164082,0.700432,0.127325,0.252951
close/1a5179549c,0.542045,0.396689,0.014007,3.699647,0.601964,0.034061,1.429823,3.10607,0.376622,2.004841,...,0.480173,0.664481,0.634084,0.010682,0.0,0.893567,1.099559,0.340215,1.031656,1.34006
close/27c00574e5,0.016142,0.242051,0.608979,0.233586,1.552251,1.471979,0.199548,0.555423,0.553794,0.489983,...,1.220631,0.676823,2.766618,0.106216,2.08473,0.284578,0.784091,0.389585,1.290607,1.154332
close/2b60bdb95d,0.608418,0.600565,0.289687,6.294758,2.181289,0.426867,1.314924,2.980889,0.407968,1.14658,...,0.258605,0.769757,1.851032,0.082375,1.079353,0.920231,1.719315,1.033523,0.782651,1.425852
close/2e263a2559,0.145794,0.010137,0.264306,1.543341,1.336485,0.627371,0.644012,0.326648,1.049979,0.535992,...,1.017135,0.026396,0.794135,0.0,0.215944,0.200235,0.023284,1.034789,0.492878,1.451761


## Triplet

In [55]:
# imageIndex_path_dict = get_path(imgFolder,
#                                 sampleSizePerCat,
#                                 folderName=folderName,
#                                 max_img_num=total_img)

# img_loader = data_loader(imageIndex_path_dict)

# resnet_model = CustomResNetModel()
# resnet_model.eval()

# resnet_df, output = feature_extractor_custom(resnet_model, img_loader, imageIndex_path_dict)
# normalized_df = df_preprocess(resnet_df)
# normalized_df.head()

In [56]:
# normalized_df.rename(columns={'1': 'x', '2': 'y'}, inplace=True)
# normalized_df

## Visual BackProp

In [57]:
# # imageIndex_path_dict = get_path(imgFolder,
# #                                 sampleSizePerCat,
# #                                 folderName=folderName,
# #                                 max_img_num=total_img)
# # img_loader = data_loader(imageIndex_path_dict)
# # Specify the save path
# #save_path = "/Users/jiayuelin/InfoVis/Datasets/Edamame/Edamame_map.pth"
# # model = VBP.resnet18(pretrained=True).eval()
# model_triplet = TripletResNetModel()
# model_vbp_before = VBP.ResnetVisualizer(model_triplet.eval())

## GradCAM

In [58]:
# grad_model = resnet18(pretrained=True)
# gradcam = GCAM.GradCAM(grad_model, target_layer='layer4')

# grad_cam_model = resnet18(pretrained=True)
# target_layers = [model.layer4[-1]]

# cam_methods = [ScoreCAM]
# #cam_methods = [GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad]
# cam_objects = [cam(model=grad_cam_model, target_layers=target_layers) for cam in cam_methods]

#  Dimension Reduction Model:  Weighted MDS

For DR, we use the Multi-Dimensional Scaling (MDS) algorithm on a weighted data space. **Dimension weights** are applied to the high-dimensional (HD) data.  Weights are normalized to sum to 1, so as to normalize the HD distances to roughly constant size space independent of p.

The **distance function for the high-dimensional (HD) data** is L1 manhattan distance. L1 is good for general purpose use with multi-dimensional quantitative datasets. 

The **distance function for the 2D projected points** is L2 Euclidean distance, which makes sense for human percpetion in the plot.

In [59]:
def distance_matrix_HD(dataHDw):
    """
    Compute the distance matrix for the weighted high-dimensional data using L1 distance function.
    Input HD data should already be weighted.
    
    @parameters:
        dataHDw[pd.df or np.array]: weighted high-dimensional data
    @return[array]: distance matrix for input weighted high-dimensional data
    """
    dist_matrix = sklearn.metrics.pairwise.manhattan_distances(dataHDw)
    return dist_matrix


def distance_matrix_2D(data2D):
    """
    Compute the distance matrix for 2D projected data using L2 distance function.
    
    @parameters: 
        data2D[pd.df or np.array]: projected 2D data
    @return[np.array]: distance matrix for 2D input data
    """
    dist_matrix = sklearn.metrics.pairwise.euclidean_distances(data2D)
    return dist_matrix

**MDS** projects the weighted high-dimensional data to 2D. Tune the algorithm's parameters for performance.

In [60]:
def stress(distHD, dist2D):
    """
    Calculate the MDS stress metric between HD and 2D distances.
    @parameters: 
        distHD[np.array]: distance matrix for high-dimensional data
        dist2D[np.array]: distance matrix for 2D data
    @return[float]: stress value
    """
    s = ((distHD - dist2D)**
         2).sum() / (distHD**2).sum()  # numpy, eliminate sqrt for efficiency
    return s

# def compute_mds(dataHDw):
#     """
#     apply MDS to high-dimensional data to get 2D data
#     @parameters:
#         dataHDw[pd.df or np.array]: weighted high-dimensional data
#     @return[dataframe]: a dataframe of 2D data 
#     """
#     distHD = distance_matrix_HD(dataHDw)
#     ### Adjust these parameters for performance/accuracy tradeoff
#     mds = sklearn.manifold.MDS(n_components=2,
#                                dissimilarity='precomputed',
#                                n_init=10,
#                                max_iter=1000,
#                                random_state=3)
#     # Reduction algorithm happens here:  data2D is n*2 matrix
#     data2D = mds.fit_transform(distHD)

#     ### Rotate the resulting 2D projection to make it more consistent across multiple runs.
#     ### Set the 1st PC to the y axis, plot looks better to spread data vertically with horizontal text labels
#     pca = sklearn.decomposition.PCA(n_components=2)
#     data2D = pca.fit_transform(data2D)
#     data2D = pd.DataFrame(data2D, columns=['y', 'x'], index=dataHDw.index)

#     # data2D.stress_value = stress(distHD, distance_matrix_2D(data2D))
#     return data2D

# def dimension_reduction(dataHD, wts):  # dataHD, wts -> data2D (pandas)
#     """
#     apply weights to high-dimensional data then apply MDS to get 2D data
#     @parameters:
#         dataHD[pd.df or np.array]: original high-dimensional data
#     @return[dataframe]: a dataframe of projected 2D data
#     """
#     ### Normalize the weights to sum to 1
#     wts = wts / wts.sum()

#     ### Apply weights to the HD data
#     dataHDw = dataHD * wts

#     ### DR algorithm
#     data2D = compute_mds(dataHDw)

#     ### Compute row relevances as:  data dot weights
#     ### High relevance means large values in upweighted dimensions
#     #     data2D['relevance'] = dataHDw.sum(axis=1)
#     return data2D

def compute_mds(dataHD):
    """
    apply MDS to high-dimensional data to get 2D data
    @parameters:
        dataHD[pd.df or np.array]: high-dimensional data
    @return[dataframe]: a dataframe of 2D data 
    """
    distHD = distance_matrix_HD(dataHD)
    ### Adjust these parameters for performance/accuracy tradeoff
    mds = sklearn.manifold.MDS(n_components=2,
                               dissimilarity='precomputed',
                               n_init=10,
                               max_iter=1000,
                               random_state=3)
    # Reduction algorithm happens here:  data2D is n*2 matrix
    data2D = mds.fit_transform(distHD)

    ### Rotate the resulting 2D projection to make it more consistent across multiple runs.
    ### Set the 1st PC to the y axis, plot looks better to spread data vertically with horizontal text labels
    pca = sklearn.decomposition.PCA(n_components=2)
    data2D = pca.fit_transform(data2D)
    data2D = pd.DataFrame(data2D, columns=['y', 'x'], index=dataHD.index)

    #data2D.stress_value = stress(distHD, distance_matrix_2D(data2D))
    return data2D

def dimension_reduction(dataHD):  # dataHD -> data2D (pandas)
    """
    apply MDS to high-dimensional data to get 2D data
    @parameters:
        dataHD[pd.df or np.array]: original high-dimensional data
    @return[dataframe]: a dataframe of projected 2D data
    """
    data2D = compute_mds(dataHD)
    return data2D

def get_weights(min_weight, max_weight, index, load_weights_from_file=False):
    """
    initialize weights for high-dimensional data
    @parameters:
        load_weights_from_file[boolean]: whether load weights from a saved csv file or not
    @return[dataframe]: a dataframe of initialized weights (equal for each features if not loaded from a saved csv file)
    """
    if load_weights_from_file:
        weights = pd.read_csv(weights_path)['Weight']
    else:
        # 1.0/len(normalized_df.columns) # initialize to min to make the sliders easier to use.
        weights = pd.Series(init_weight, index=index,
                            name="Weight")  # the current weight list
    return weights

## Triplet

In [61]:
# df_2D = normalized_df
# df_2D['label'] = df_2D.index.str.split('/').str[0]
# #print("The original normalized_df is:", df_2D)
# df_2D = df_2D[['x', 'y', 'label']]
# print("The updated df is:", df_2D)
# silhouette = silhouette_score(df_2D[['x', 'y']], df_2D['label'])
# print("The Silhouette Score after interaction is: ", silhouette)
# adjusted_silhouette = silhouette * 2
# print("The Adjusted Silhouette Score after interaction is: ", adjusted_silhouette)

In [62]:
# init_weight, min_weight, max_weight = 0.00001, 0.00001, 0.9999
# weights = get_weights(min_weight, max_weight, normalized_df.columns)
# df_2D = dimension_reduction(normalized_df, weights)
# df_2D['label'] = df_2D.index.str.split('/').str[0]
# #print("The original normalized_df is:", df_2D)
# df_2D = df_2D[['x', 'y', 'label']]
# print("The updated df is:", df_2D)
# silhouette = silhouette_score(df_2D[['x', 'y']], df_2D['label'])
# print("The Silhouette Score after interaction is: ", silhouette)
# adjusted_silhouette = silhouette * 2
# print("The Adjusted Silhouette Score after interaction is: ", adjusted_silhouette)

In [63]:
df_2D = dimension_reduction(normalized_df)
df_2D['label'] = df_2D.index.str.split('/').str[0]
df_2D = df_2D[['x', 'y', 'label']]
# data_to_normalize = df_2D[['x', 'y']]
# # Initialize the StandardScaler
# scaler = StandardScaler()
# # Fit the scaler to the data and transform the data
# normalized_data = scaler.fit_transform(data_to_normalize)
# # Update the 'x' and 'y' columns in update_df_2D with the normalized values
# df_2D['x'] = normalized_data[:, 0]  
# df_2D['y'] = normalized_data[:, 1] 
print("The updated df is:", df_2D)
silhouette = silhouette_score(df_2D[['x', 'y']], df_2D['label'])
print("The Silhouette Score after MDS is: ", silhouette)
adjusted_silhouette = silhouette * 2
print("The Adjusted Silhouette Score after MDS is: ", adjusted_silhouette)

The updated df is:                            x           y  label
Image                                          
close/0adeed6415  120.289885  167.160318  close
close/1a5179549c  -13.733990 -157.551025  close
close/27c00574e5   -3.124378  223.228423  close
close/2b60bdb95d  -88.917810 -280.697837  close
close/2e263a2559 -196.168363  170.539979  close
close/4bbff77559   94.350522 -214.669982  close
close/4bf4f30634  244.686559  187.710198  close
close/5b3c8cbcb9 -122.014330 -142.745847  close
close/60afeda678  178.006786  335.751650  close
close/7eb99c36c6   79.687377  -97.797975  close
open/351d5ce5a9   184.357166 -110.027376   open
open/36cc96437a  -204.831060  276.198358   open
open/41ffeeae88     2.736976 -319.155388   open
open/45e11254ff  -212.452695 -256.967283   open
open/6307ae132f  -297.026851  106.452975   open
open/6d52c38d91   259.999678 -297.645684   open
open/6e6f16d262   -41.877142  144.902801   open
open/77bc1747df   346.761679   53.492011   open
open/7d65b05fd1  -258

# Inverse Dimension-Reduction Learning Algorithm

Computes the inverse-Dimension-Reduction: given input 2D points, compute new weights.
Optimizes the MDS stress function that compares 2D pairwise distances (||$x_i-x_j||$) to weighted HD pairwise distances ($d_{ij}$):
![Stress](https://wikimedia.org/api/rest_v1/media/math/render/svg/7989b3afc0d8795a78c1631c7e807f260d9cfe68)

Technically, we compute the inverse weighted distance function. We shortcut the optimization by eliminating MDS from the process, and assume that the user input 2D distances are actually the desired HD distances, not the 2D distances after re-projection. Thus, given the input (HD) distances, we find weights that would produce these distances in the HD space.

In [64]:
# This method is used to propose a new weight for current column in a smart fashion
def new_proposal(current, step, direction):
    return np.clip(current + direction * step * random.random(), 0.00001,
                   0.9999)


def inverse_DR(dataHD, data2D, curWeights=None):
    """
    @parameters:
        dataHD[pd.df or np.array]: high-dimensional data
        data2D[pd.df or np.array]: projected 2D data
    @return[pd.Series]: new weights  
    """
    dist2D = distance_matrix_2D(data2D)  # compute 2D distances only once
    col_names = dataHD.columns
    dataHD = dataHD.to_numpy()  # use numpy for efficiency
    row, col = dataHD.shape

    if curWeights == None:
        curWeights = np.array([1.0 / col] * col)  # default weights = 1/p
    else:
        curWeights = curWeights.to_numpy()
        curWeights = curWeights / curWeights.sum(
        )  # Normalize weights to sum to 1
    newWeights = curWeights.copy()  # re-use this array for efficiency

    # Initialize state
    flag = [0] * col  # degree of success of a weight change
    direction = [1] * col  # direction to move a weight, pos or neg
    step = [1.0 / col] * col  # how much to change each weight

    dataHDw = dataHD * curWeights  # weighted space, re-use this array for efficiency
    distHD = distance_matrix_HD(dataHDw)
    curStress = stress(distHD, dist2D)
    print('Starting stress =', curStress, 'Processing...')

    MAX = 500  # default setting of the number of iterations

    # Try to minorly adjust each weight to see if it reduces stress
    for i in range(MAX):
        for dim in range(col):
            # Get a new weight for current column
            nw = new_proposal(curWeights[dim], step[dim], direction[dim])

            # Scale the weight list such that it sums to 1
            s = 1.0 + nw - curWeights[dim]  # 1.0 == curWeights.sum()
            np.true_divide(
                curWeights, s,
                out=newWeights)  # transfers to other array, while doing /
            newWeights[dim] = nw / s

            # Apply new weights to HD data
            np.multiply(
                dataHD, newWeights, out=dataHDw
            )  # dataHDw = dataHD * newWeights; efficiently reuses dataHDw array
            distHD = distance_matrix_HD(dataHDw)

            # Get the new stress
            newStress = stress(distHD, dist2D)

            # If new stress is lower, then update weights and flag this success
            if newStress < curStress:
                temp = curWeights
                curWeights = newWeights
                newWeights = temp  # reuse the old array next iteration
                curStress = newStress
                flag[dim] = flag[dim] + 1
            else:
                flag[dim] = flag[dim] - 1
                direction[dim] = -direction[dim]  # Reverse course

            # If recent success, then speed up the step rate
            if flag[dim] >= 5:
                step[dim] = step[dim] * 2
                flag[dim] = 0
            elif flag[dim] <= -5:
                step[dim] = step[dim] / 2
                flag[dim] = 0

    print('Solution stress =', curStress, 'Done.')
    return pd.Series(curWeights, index=col_names, name="Weight")

# Visualization and UI code

Use these functions to create the GUI components in any cell.

## Sliders

In [65]:
def create_size_slider(imgDisplaySize=imgDisplaySize):
    """
    Create image size adjust slider
    @parameters:
        imgDisplaySize[float]: zoom image displayed size
    @return[widgets slider]: return the slize adjustment slider 
    """

    style = {'description_width': 'initial'}
    size_slider = widgets.FloatSlider(
        min=0,
        max=1,
        step=0.01,
        value=imgDisplaySize,
        description='Adjust image size',
        style=style,
        continuous_update=False,
        readout_format='.5f',
    )
    size_slider.style.handle_color = 'lightblue'
    return size_slider

## Checkbox

In [66]:
def create_checkbox(ax):
    """
    @return: create the checkbox to toggle the images and titles(image index in the dataframe)  
    """
    title_checkbox = widgets.Checkbox(False,
                                      description='Toggle Titles',
                                      indent=False,
                                      layout=Layout(width='20%',
                                                    height='20px'))
    image_checkbox = widgets.Checkbox(True,
                                      description='Toggle Images',
                                      indent=False,
                                      layout=Layout(width='20%',
                                                    height='20px'))

    def title_check_clicked(x):
        image = image_checkbox.value
        draw_plot(ax, df_2D, x, image, imgSize=size_slider.value)

    interact(title_check_clicked, x=title_checkbox)

    def image_clicked(x):
        title = title_checkbox.value
        draw_plot(ax, df_2D, title, x, imgSize=size_slider.value)

    interact(image_clicked, x=image_checkbox)
    return title_checkbox, image_checkbox

## ImageSI<sub>MDS<sup>-1</sup></sub>


In [67]:
class Sample:
    def __init__(self,
                 img_path,
                 idx_key,
                 position=(0.0, 0.0),
                 label=None,
                 id=None,
                 ):
        """
        Everything need for Visualization about each of the sample.
        """
        self.img_path = img_path
        self.position = position  # default position in 2d spatialization
        self.idx_key = idx_key
        self.distances = None

        lbl = idx_key.split('/')[0]
        #print(lbl)
        
        # ChangeHere
        
        if lbl == 'open': 
            self.label = 0
        elif lbl == 'close':
            self.label = 1

        # if lbl == 'diseased':
        #     self.label = 0
        # elif lbl == 'late':
        #     self.label = 1
        # elif lbl == 'ready':
        #     self.label = 2
            
        # if lbl == 'horse_and_human':
        #     self.label = 0
        # elif lbl == 'more_than_one':
        #     self.label = 1
        # elif lbl == 'one':
        #     self.label = 2
        
        # if lbl == 'one':
        #     self.label = 0
        # elif lbl == 'two':
        #     self.label = 1
        # elif lbl == 'three':
        #     self.label = 2

        self.id = id
        self.representation = None

    def to_dict(self, i=0):
        return {
            'id': i if self.id is None else self.id,
            'img': self.img_path,
            'position': self.position,
            'init': True,
            'label': self.label,
        }


class SamplesForTraining:
    def __init__(self, img_ids, positions):
        self.img_ids = img_ids
        self.positions = positions


class Samples:
    def __init__(self,
                 data_path,
                 samplesPerCat,
                 folderName,
                 max_img,
                 labels=None,):
        """
        Control how to manipulate samples, and provide models with features
        """


        
        self.data_path = data_path
        self.imageIndex_path_dict = get_path(data_path,
                                samplesPerCat,
                                folderName=folderName,
                                max_img_num=max_img)
        
        
        self.data_loader = data_loader(self.imageIndex_path_dict)

        self.img_ids = self.imageIndex_path_dict
        
        # self.raw_imgs = raw_imgs
        # self.imgs = imgs
        # self.ids = ids
        # self.labels = labels
        # self.sample_size = len(imgs)

        # self.input_ids, self.attention_masks = tokenize_corpus(texts)

        self.samples = []
        self.positions = []
        
        for key in self.imageIndex_path_dict.keys():
            sample = Sample(self.imageIndex_path_dict[key], key, id=key)  # Assign the image path as the id
            self.samples.append(sample)

        # for text, label, id, raw_text, input_id, attention_mask in zip(texts, labels, ids, raw_texts, self.input_ids,
                                                                    #    self.attention_masks):
            # sample = Sample(text, input_id, attention_mask, label=label, id=id, raw_text=raw_text)
            # self.samples.append(sample)

    def get_samples_for_training(self, updated_samples):
        """ Find updated samples, and return back to train-data
        Filter out the samples that updated.
        Return:
            input_ids
            attention_masks
            relative_positions:
        """
        img_ids = {}
        positions = []
        
        for sample in self.samples:
            if sample.id in updated_samples:
                sample.position = updated_samples[sample.id]
                
                img_ids[sample.idx_key] = sample.img_path
                positions.append(sample.position)
        
#         for sample, updated_sample in zip(self.samples, updated_samples):
#             if not updated_sample['init']:
#                 sample.position = updated_sample['position']

#                 input_ids.append(sample.input_id.tolist())
#                 attention_masks.append(sample.attention_mask.tolist())
#                 positions.append(sample.position)
        
        return SamplesForTraining(img_ids, torch.tensor(positions))

    def update_position(self, positions):
        for sample, position in zip(self.samples, positions):
            sample.position = position

    def update_representation(self, representations):
        for sample, representation in zip(self.samples, representations):
            sample.representation = representation

    def to_json(self):
        json = []
        for i, sample in enumerate(self.samples):
            json.append(sample.to_dict(i))
        return json

In [69]:
class Deep_crop(nn.Module):
    def __init__(self, learning_rate=3e-5, epochs=10): # ChangeHere
        super(Deep_crop, self).__init__()
    
        self.model = CustomResNetModel()
        # remove fully connected layer
        # self.pretrained_model =  resnet18(pretrained=True)
        # # remove fully connected layer
        # self.model = nn.Sequential(*list(self.pretrained_model.children())[:-1])
        self.learning_rate = learning_rate
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.loader = None
        self.epochs = epochs

        # Define the optimizer
        # params_1x = [param for name, param in self.model.named_parameters() if 'fc' not in str(name)]
        # self.optimizer = torch.optim.Adam([{'params': params_1x}], lr=learning_rate, betas=(0.9, 0.99))
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)

    def set_data_loader(self,loader):
        self.loader = loader

    def forward(self, img_ids, intended_distances):
      # return forward propogate results (extracted features)
        predicted_distances = self.get_distances(img_ids)
        
        return torch.sum(torch.abs(predicted_distances-intended_distances))

    def get_representation(self, img_paths):
#         print("img_paths", img_paths)
        imgs_output = []
        loader = data_loader(img_paths)
        for i, img in zip(range(len(loader)), loader):
            # forward image, output shape: (1,512,1,1)
            # if i in img_ids:
            model_output = self.model(img)
            model_output = torch.squeeze(torch.squeeze(model_output,3),2) 
            imgs_output.append(model_output)
            
#         print("imgs_output", imgs_output)

#         list_of_tensors = [torch.tensor(rep.detach().nump) for rep in imgs_output]
        imgs_output = torch.stack(imgs_output)
        imgs_output = torch.squeeze(imgs_output,1)

        return imgs_output
        
    def get_distances(self, img_ids):
        representation = self.get_representation(img_ids)
        
#         list_of_tensors = [torch.tensor(np.array(rep)) for rep in representation]
#         representations = torch.stack(list_of_tensors)
        
        return F.pdist(representation)

    def get_projected_representation(self, img_ids):
        representations = self.get_representation(img_ids)   
        
#         list_of_tensors = [torch.tensor(np.array(rep)) for rep in representations]
#         representations = torch.stack(list_of_tensors)
#         representations = torch.squeeze(torch.squeeze(torch.squeeze(representations,4),3),1) 
        predicted_distances = F.pdist(representations).to('cpu').detach().numpy()

        seed = 0
        max_iter = 1600
        mds = MDS(n_components=2,
                  metric=True,
                  n_init=10,
                  max_iter=max_iter,
                  random_state=seed,
                  eps=1e-9,
                  dissimilarity='precomputed',
                  n_jobs=12)
        coordinates = mds.fit_transform(squareform(predicted_distances))

        y_min, y_max = np.min(coordinates, 0), np.max(coordinates, 0)
        normalized_coordinates = (coordinates - y_min) / (y_max - y_min)

        return torch.tensor(normalized_coordinates).to(self.device)

    def get_projected_distances(self, img_ids):
        
        projected_representation = self.get_projected_representation(img_ids)
        return F.pdist(projected_representation)
    
    def get_intended_distance(self, img_ids, interactions, eps = 1e-8):
        # Reshape interactions to have a 2D shape
        # interactions = interactions.unsqueeze(0) if len(interactions.size()) == 1 else interactions
        
        intended_projected_distances = F.pdist(interactions)
        #print('intended-distances: ', intended_projected_distances)

        distances = self.get_distances(img_ids)
        # print(distances)

        projected_distances = self.get_projected_distances(img_ids)
        # print(projected_distances)

        relative_distances = intended_projected_distances / (projected_distances + eps)

        intended_distances = distances * relative_distances

        return intended_distances.clone().detach()
    
    # def evaluate(self):
    #     self.model.eval()
    #     with torch.no_grad():
    #         self.update_positions()

    # def update_positions(self):
    #     img_ids = self.samples.img_ids
    #     projected_representations = self.get_projected_representation(img_ids)
    #     projected_representations = projected_representations.to('cpu').tolist()
    #     self.samples.update_position(projected_representations)
    #     print("Updated Positions", self.samples.update_position(projected_representations))

    def train(self, samples, early_stopping=False):
        best_loss = None
        best_model_state = None
        zero_loss_epochs = 0  # Counter to track consecutive epochs with zero loss

        img_ids = samples.img_ids
        positions = samples.positions
        relative_distances = self.get_intended_distance(img_ids, positions)

        self.model.eval()  # Set the model to training mode

        for epoch in range(self.epochs):
            running_loss = []

            self.optimizer.zero_grad()

            with torch.set_grad_enabled(True):
                loss = self.forward(img_ids, relative_distances)

                loss.backward()
                self.optimizer.step()

                running_loss.append(loss.item())
                print("Epoch: {}/{} — Loss: {:.10f}".format(epoch + 1, self.epochs, np.mean(running_loss)))

                if (best_loss is None or loss < best_loss) and loss > 0:
                    best_loss = loss.item()
                    best_model_state = self.model.state_dict().copy()

            # Check for early stopping condition
            if early_stopping:
                # If the last epoch had a zero loss, increment the counter
                if running_loss[-1] == 0:
                    zero_loss_epochs += 1
                else:
                    zero_loss_epochs = 0

                # Check if consecutive zero loss epochs reach 5
                if zero_loss_epochs >= 1:
                    print("Early stopping due to consecutive zero loss epochs.")
                    break
                
        if best_loss is not None and best_loss > 0:
            print("Best loss:", best_loss)
            torch.save(best_model_state, '/Users/jiayuelin/InfoVis/Code/Andromeda_IMG_FineTune/Deep_Crop_pth/best_deep_crop_animal_mouth_model.pth') # ChangeHere

## ImageSI<sub>Triplet</sub>

In [73]:
def comp_image_path_dict(df, root_dir=imgFolder):
    def comp_image_path(df_index):
        if root_dir is not None:
            return os.path.join(root_dir, f"{df_index}.jpg")
        else:
            return f"{df_index}.jpg"

    ImageIndex_path_dict = {
        df_index: comp_image_path(df_index) for df_index in df.index
    }

    return ImageIndex_path_dict  

def user_defined_train_data(ax):
    # global normalized_df, weights
    data2Dnew = pd.DataFrame(
        [c.center for c in ax.dragpoint.artists if c.selected],
        columns=['x', 'y'],
        index=[c.label for c in ax.dragpoint.artists if c.selected])
    
    df_2D = dimension_reduction(normalized_df)
    selected_data2D = df_2D.loc[data2Dnew.index]
    
    selected_imageIndex_path_dict = comp_image_path_dict(selected_data2D)
    print("The selected images are:", selected_imageIndex_path_dict.keys())

    return selected_imageIndex_path_dict

def get_updated_coordinates(ax):
    """
    Get the updated x, y coordinates of the dragged points after dragging.
    Returns a dictionary where keys are the labels of the artists and values are (x, y) coordinates.
    """
    
    # Ensure the plot is updated to reflect any dragging changes
    # plt.pause(0.1)

    # Call the get_dragged_coordinates function to get the updated coordinates
    updated_coordinates = ax.dragpoint.get_dragged_coordinates()

    return updated_coordinates

def load_best_model(model, model_path):
    model.load_state_dict(torch.load(model_path))
    model.eval()  # Set the model to evaluation mode
    return model

# data = {
#     'image': [f"{key.split('/')[0]}/{key.split('/')[1]}.jpg" for key in imageIndex_path_dict.keys()]
# }

# train_df = pd.DataFrame(data)

# # Define the file path where you want to save the CSV file
# file_path = "/Users/jiayuelin/InfoVis/Datasets/train_data.csv"

# # Save the DataFrame as a CSV file
# train_df.to_csv(file_path, index=False)

# print(f"CSV file saved at: {file_path}")

# train_data = pd.read_csv(file_path)

# # Directory where the original images are located
# original_image_dir = '/Users/jiayuelin/InfoVis/Datasets/animals/animals'

# # Directory where you want to copy the images
# output_image_dir = "/Users/jiayuelin/InfoVis/Datasets/train"

# # Create the output directory if it doesn't exist
# os.makedirs(output_image_dir, exist_ok=True)

# # Read the CSV file into a DataFrame
# csv_df = pd.read_csv(file_path)

# # Loop through each row in the DataFrame and copy the images
# for index, row in csv_df.iterrows():
#     image_filename = row['image']
#     # image_filename = image_filename.replace('/', '_')
#     source_path = os.path.join(original_image_dir, image_filename)
#     output_path = os.path.join(output_image_dir, image_filename)

#     # Create the subdirectories if they don't exist in the output directory
#     os.makedirs(os.path.dirname(output_path), exist_ok=True)

#     # Check if the source image file exists and copy it to the output directory
#     if os.path.exists(source_path):
#         shutil.copy(source_path, output_path)
#         print(f"Copied {image_filename} to {output_image_dir}")
#     else:
#         print(f"Image {image_filename} not found in {original_image_dir}")

# print("Image extraction completed.")

## Coordinate-based Triplet Margin Loss

In [74]:
class Andromeda_IMG:
    def __init__(self, data_csv_path, image_dir, image_size=224, batch_size=256, learning_rate=1e-5, num_epochs=10): # ChangeHere
        """
        Initialize the Andromeda_IMG class.

        Parameters:
        - data_csv_path (str): Path to the CSV file containing data information.
        - image_dir (str): Directory containing the images.
        - image_size (int): Size to which images will be resized.
        - batch_size (int): Number of samples in each batch for training.
        - learning_rate (float): Learning rate for the optimizer.
        - num_epochs (int): Number of training epochs.

        Returns:
        - None
        """
        self.data_csv_path = data_csv_path
        self.image_dir = image_dir
        self.image_size = image_size
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.num_epochs = num_epochs
        self.device = self.get_default_device()
        self.train_dataset = self.get_train_dataset()
        self.train_dl = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
        self.resnet_model = self.create_resnet_model()
        self.optimizer = torch.optim.Adam(self.resnet_model.parameters(), lr=self.learning_rate)
        self.criterion = self.TripletLoss()
    
    def get_default_device(self):
        """
        Get the default device for PyTorch based on GPU availability.

        Returns:
        - torch.device: Device (CPU or GPU).
        """
        if torch.cuda.is_available():
            return torch.device('cuda')
        else:
            return torch.device('cpu')

    def get_train_dataset(self):
        train_data = pd.read_csv(self.data_csv_path)

        class AnimalDataset_Triplet():
            def __init__(self, df, path, transform=None):
                self.df = df
                self.path = path
                self.transform = transform
                self.coordinates = self.df[['x', 'y']].to_numpy()

            def __len__(self):
                return len(self.df)

            def __getitem__(self, item):
                anchor_coordinates = self.coordinates[item]

                # Calculate the absolute differences
                abs_differences = np.abs(self.coordinates - anchor_coordinates)

                # Define the threshold for identifying positive samples
                threshold = 80

                # Identify positive samples
                # positive_candidates = np.where(np.max(abs_differences, axis=1) < threshold)[0]
                positive_candidates = np.where((np.max(abs_differences, axis=1) < threshold) & (np.arange(len(self.coordinates)) != item))[0]

                if len(positive_candidates) > 0:
                    positive_item = random.choice(positive_candidates)
                else:
                    return self.__getitem__(np.random.randint(len(self)))

                # Identify negative samples
                negative_candidates = np.where(np.max(abs_differences, axis=1) >= threshold)[0]
                negative_item = random.choice(negative_candidates)
            
            # For simulated evaluation
            # def __getitem__(self, item):
            #     anchor_coordinates = self.coordinates[item]

            #     # Calculate the absolute differences
            #     abs_differences = np.abs(self.coordinates - anchor_coordinates)

            #     # Identify positive samples
            #     positive_candidates = np.where(np.abs(abs_differences[:, 0] - abs_differences[:, 1]) == 0)[0]

            #     if len(positive_candidates) > 0:
            #         positive_item = random.choice(positive_candidates)
            #     else:
            #         # Handle the case when positive_candidates is empty (e.g., retry with a different anchor)
            #         return self.__getitem__(np.random.randint(len(self)))

            #     # Identify negative samples
            #     negative_candidates = np.where(np.abs(abs_differences[:, 0] - abs_differences[:, 1]) != 0)[0]
            #     negative_item = random.choice(negative_candidates)

                # Get the file names for the anchor, positive, and negative images
                anchor_image_name = self.df.iloc[item, 0]
                positive_image_name = self.df.iloc[positive_item, 0]
                negative_image_name = self.df.iloc[negative_item, 0]

                # print("Anchor Image Name:", anchor_image_name)
                # print("Anchor Coordinates:", anchor_coordinates)

                # print("Positive Image Name:", positive_image_name)
                # print("Positive Coordinates:", self.coordinates[positive_item])

                # print("Negative Image Name:", negative_image_name)
                # print("Negative Coordinates:", self.coordinates[negative_item])

                # Construct file paths for the anchor, positive, and negative images
                anchor_image_path = os.path.join(self.path, anchor_image_name)
                positive_image_path = os.path.join(self.path, positive_image_name)
                negative_image_path = os.path.join(self.path, negative_image_name)

                anchor_img = Image.open(anchor_image_path).convert('RGB')
                positive_img = Image.open(positive_image_path).convert('RGB')
                negative_img = Image.open(negative_image_path).convert('RGB')

                if self.transform is not None:
                    anchor_img = self.transform(anchor_img)
                    positive_img = self.transform(positive_img)
                    negative_img = self.transform(negative_img)

                return anchor_img, positive_img, negative_img

        train_dataset = AnimalDataset_Triplet(df=train_data, path=self.image_dir, transform=transforms.Compose([transforms.Resize((self.image_size, self.image_size)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]))

        return train_dataset

    class TripletLoss(nn.Module):
        def __init__(self, margin=1.0):
            """
            Initialize the TripletLoss.

            Parameters:
            - margin (float): Margin value for the triplet loss.

            Returns:
            - None
            """
            super(Andromeda_IMG.TripletLoss, self).__init__()
            self.margin = margin

        def calc_euclidean(self, x1, x2):
            """
            Calculate the squared Euclidean distance between two sets of feature vectors.

            Parameters:
            - x1 (torch.Tensor): Feature vectors from the first set.
            - x2 (torch.Tensor): Feature vectors from the second set.

            Returns:
            - torch.Tensor: Squared Euclidean distance between x1 and x2.
            """
            return (x1 - x2).pow(2).sum(1) # Computes the sum of each row of the input tensor in the dimension dim 1, (batch_size, 1)

        def forward(self, anchor, positive, negative):
            """
            Calculate the triplet loss based on anchor, positive, and negative feature vectors.

            Parameters:
            - anchor (torch.Tensor): Feature vectors of anchor samples.
            - positive (torch.Tensor): Feature vectors of positive samples.
            - negative (torch.Tensor): Feature vectors of negative samples.

            Returns:
            - torch.Tensor: Triplet loss value.
            """
            distance_positive = self.calc_euclidean(anchor, positive)
            distance_negative = self.calc_euclidean(anchor, negative)
            losses = torch.relu(distance_positive - distance_negative + self.margin) # Computes the difference between distance_positive and distance_negative,
                                                                                     # with a margin of 1.0. Ensure the loss is non-negative and becomes zero 
                                                                                     # when the distance between the anchor and positive samples is smaller than 
                                                                                     # the distance between the anchor and negative samples.
            return losses.mean()

    def create_resnet_model(self):
        """
        Create a ResNet-based model for triplet learning.

        Returns:
        - nn.Module: ResNet-based model.
        """
        class ResNet_Triplet(nn.Module):
            def __init__(self):
                super().__init__()
                self.Model = resnet18(pretrained=True)
                num_filters = self.Model.fc.in_features # Retrieves the number of input features for the final fully connected layer (fc) of the pre-trained ResNet-18 model.
                self.Model.fc = nn.Sequential( # Modifies the fully connected (fc) layer of the ResNet-18 model. It replaces the existing fc with a new sequence of layers.
                    nn.Linear(num_filters, 512), # Adds a linear layer that reduces the number of features from num_filters (which is the input size of the original fc layer) to 512.
                    nn.LeakyReLU(), # Applies a Leaky ReLU activation function to introduce non-linearity.
                    nn.Linear(512, 10)) # Adds another linear layer that reduces the number of features to 10.
                self.Triplet_Loss = nn.Sequential( # Defines a new sequence of layers for calculating the triplet loss.
                    nn.Linear(10, 2)) # Adds a linear layer that reduces the number of features to 2, positive distance and negative distance.

            def forward(self, x):
                features = self.Model(x) # Passes the input tensor through the pre-trained ResNet-18 model to get the features.
                triplets = self.Triplet_Loss(features) # Passes the features through the triplet loss layer to get the positive and negative distances.
                return triplets
        
        #resnet_model = ResNet_Triplet()
        resnet_model = CustomResNetModel().model # ChangeHere
        return resnet_model
    
    def train(self, early_stopping=False):
        best_loss = None
        best_model_state = None
        zero_loss_epochs = 0  # Counter to track consecutive epochs with zero loss

        for epoch in range(self.num_epochs):
            running_loss = []

            for step, (anchor_img, positive_img, negative_img) in enumerate(self.train_dl):
                anchor_img = anchor_img.to(self.device)
                positive_img = positive_img.to(self.device)
                negative_img = negative_img.to(self.device)
                
                self.optimizer.zero_grad()

                anchor_out = self.resnet_model(anchor_img)
                positive_out = self.resnet_model(positive_img)
                negative_out = self.resnet_model(negative_img)

                loss = self.criterion(anchor_out, positive_out, negative_out)
                loss.backward()
                self.optimizer.step()

                running_loss.append(loss.item())
                print("Epoch: {}/{} — Loss: {:.10f}".format(epoch + 1, self.num_epochs, np.mean(running_loss)))

                if (best_loss is None or loss < best_loss) and loss > 0:
                    best_loss = loss.item()
                    best_model_state = self.resnet_model.state_dict().copy()

            # Check for early stopping condition
            if early_stopping:
                # If the last epoch had a zero loss, increment the counter
                if running_loss[-1] == 0:
                    zero_loss_epochs += 1
                else:
                    zero_loss_epochs = 0

                # Check if consecutive zero loss epochs reach 5
                if zero_loss_epochs >= 10:
                    print("Early stopping due to consecutive zero loss epochs.")
                    break

        if best_loss is not None and best_loss > 0:
            print("Best loss:", best_loss)
            #torch.save(best_model_state, '/Users/jiayuelin/InfoVis/Code/Andromeda_IMG_FineTune/Andromeda_pth/best_andromeda_model_pods_num_mix.pth') # ChangeHere
            torch.save(best_model_state, '/Users/jiayuelin/InfoVis/Code/Andromeda_IMG_FineTune/Andromeda_pth/best_andromeda_model_animal_mouth_any.pth') # ChangeHere

## Buttons

In [76]:
def create_size_slider_button(size_slider, ax):
    """
    @parameters:
        size_slider: 
        ax:
    @return[button]: return image size adjust button
    """
    size_apply_button = widgets.Button(
        description='Apply Slider Size',
        style=ButtonStyle(button_color='lightblue'))

    def size_slider_button_clicked(change):
        toggle_image = hasattr(plot_ax.dragpoint.artists[0], 'ab')
        toggle_tittle = ax.dragpoint.artists[0].text.get_text() != ''
        draw_plot(ax,
                  df_2D,
                  toggle_tittle,
                  toggle_image,
                  imgSize=size_slider.value)

    size_apply_button.on_click(size_slider_button_clicked)
    return size_apply_button

## reset button
def create_reset_buttons(ax, sliders=None):
    """
    @return[button]: return reset plot button  
    """
    reset_button = widgets.Button(description='Reset Plot',
                                  style=ButtonStyle(button_color='salmon'))

    def reset_button_clicked(change):
#         if sliders:
#             for s in sliders:
#                 s.value = init_weight 
        global df_2D, normalized_df  # Update weights and df_2D globals
#         weights = pd.Series(init_weight,
#                             index=normalized_df.columns,
#                             name="Weight")
        df_2D = dimension_reduction(normalized_df)
        toggle_image = hasattr(plot_ax.dragpoint.artists[0], 'ab')
        toggle_tittle = ax.dragpoint.artists[0].text.get_text() != ''
        # Redraw the plot
        draw_plot(ax,
                  df_2D,
                  toggle_tittle,
                  toggle_image,
                  imgSize=size_slider.value)

    reset_button.on_click(reset_button_clicked)

    return reset_button

## Visual Back Prop button
def create_visual_explainations(ax):
    """
    @return[button]: return button and plot Visual Backprop map
    """

    print_button = widgets.Button(description='Visual Explainations')
    print_output = widgets.Output()

    def transformer():
        """
        transformer for image dataset, using the ImageNet mean and std since Resnet18 is pre-trained on ImageNet
        """
        transformer = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        return transformer

    def plot_GradCAM_map(image, model, gradcam, target_class=0):
        # Generate the GradCAM visualization for the provided image
        visualization = gradcam.generate(image, target_class=target_class)

        # Convert the torch tensor to numpy array
        image_np = image.squeeze(0).cpu().detach().numpy().transpose(1, 2, 0)

        # Convert the GradCAM result to numpy array after detaching
        visualization_np = visualization.squeeze(0).cpu().detach().numpy()

        # Rescale GradCAM visualization to match image size
        visualization_resized = np.array(Image.fromarray((visualization_np[0] * 255).astype(np.uint8)).resize((image_np.shape[1], image_np.shape[0])))

        plt.figure(figsize=(6, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(image_np, interpolation='bilinear')
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.imshow(visualization_resized, cmap='viridis', interpolation='bilinear')
        plt.axis('off')

        plt.tight_layout()
        plt.show()


    # def plot_VBP_map(loader, model_bp):
    #     if len(loader) > 0:
    #         # set number of columns (use 3 to demonstrate the change)
    #         ncols = 2
    #         # calculate number of rows
    #         nrows = len(loader)

    #         plt.figure(figsize=(6, 10))
    #         for i, img in zip(range(len(loader)), loader):
    #             with torch.no_grad():
    #                 x, vis, feature_map = model_bp(img)
    #             vis = vis[0].numpy().transpose(1, 2, 0)[:, :, 0]
    #             vis = np.interp(vis, [vis.min(), vis.max()], [0, 1])
    #             img = img[0].numpy().transpose(1, 2, 0)
    #             img = np.interp(img, [img.min(), img.max()], [0, 1])
    #             plt.subplot(nrows, ncols, 2 * i + 1)
    #             plt.imshow(img, interpolation='bilinear')
    #             plt.axis('off')
    #             plt.subplot(nrows, ncols, 2 * (i + 1))
    #             plt.imshow(vis, cmap='viridis', interpolation='bilinear')
    #             plt.axis('off')
    #         plt.tight_layout()
    #         plt.show()
    #     else:
    #         print('Select points in the plot to see details here')

    def plot_VBP_map(loader, model_bp, counter):
        if len(loader) > 0:
            plt.figure(figsize=(6, 5))  # Adjust figure size as needed

            for i, img in enumerate(loader):
                with torch.no_grad():
                    x, vis, feature_map = model_bp(img)
                    
                vis = vis[0].numpy().transpose(1, 2, 0)[:, :, 0]
                vis = np.interp(vis, [vis.min(), vis.max()], [0, 1])

                plt.imshow(vis, cmap='viridis', interpolation='bilinear')
                plt.axis('off')
                plt.axis('tight')  # Adjust axis limits to fit content tightly

                # Save each visualization separately
                filename = f"/Users/jiayuelin/InfoVis/Code/Andromeda_IMG_FineTune/Triplet_VBP/AndromedaSI_noDR_{counter}_{i}.png"
                plt.savefig(filename, bbox_inches='tight', pad_inches=0)  # Save with tight bounding box
                plt.show()
        else:
            print('Select points in the plot to see details here')

    # Load pre-trained ResNet-18
    # grad_cam_model = ResNetForImageClassification.from_pretrained("microsoft/resnet-18")
    # grad_target_layer = grad_cam_model.resnet.encoder.stages[-1].layers[-1]
    #grad_cam_model = resnet18(pretrained=True)
    #print('Resnet model:', grad_cam_model)
    #grad_target_layer = grad_cam_model.layer4[-1]
    #print('Resnet target layer:', grad_target_layer)
    
    # Load the best model weights into the andromeda's custom model
    # deep_crop = Deep_crop()
    # best_model = deep_crop.model
    # state_dict = torch.load('/Users/jiayuelin/InfoVis/Code/Andromeda_IMG_FineTune/best_deep_crop_model.pth')
    # #print("State dict:", state_dict)
    # best_model.load_state_dict(state_dict, strict=False)
    #best_model.eval()  # Set the model to evaluation mode
    
    # last_conv_layer = None
    # for layer in reversed(best_model):
    #     if isinstance(layer, nn.Conv2d):
    #         last_conv_layer = layer
    #         break

    # if last_conv_layer is not None:
    #     last_conv_layer.out_channels = 512
    #     #print("Number of features in the last conv layer changed to 512")
    # else:
    #     print("No convolutional layer found in the model.")

    # if last_conv_layer is not None:
    #     num_features = last_conv_layer.out_channels
    #     #print("Number of features in the last conv layer:", num_features)

    #     num_classes = 1000  

    #     # Add a new fully connected (classification) layer
    #     new_fc_layer = nn.Sequential(
    #         nn.AdaptiveAvgPool2d((1, 1)),  # Apply adaptive average pooling
    #         nn.Flatten(),  # Flatten the feature maps
    #         nn.Linear(num_features, num_classes)  # Create a new FC layer
    #     )

    #     # Replace the last layers with the new FC layer
    #     best_model[-1] = new_fc_layer

    #     #print("New FC layer added successfully!")
    # else:
    #     print("No convolutional layer found in the model.")

    # # Load pre-trained ResNet-18
    # resnet_ft = models.resnet18(pretrained=True)

    # # Load state dictionary from file
    # state_dict = torch.load('/Users/jiayuelin/InfoVis/Code/Andromeda_IMG_FineTune/Deep_Crop_pth/best_deep_crop_animal_mouth_model.pth')

    # state_dict = {key: value for key, value in state_dict.items() if not key.startswith('fc')}

    # mapped_state_dict = {}
    # for key, value in state_dict.items():
    #     # Handle missing keys
    #     if key.startswith('conv1'):
    #         mapped_state_dict['conv1.weight'] = value
    #     elif key.startswith('bn1'):
    #         mapped_state_dict['bn1.weight'] = value
    #         mapped_state_dict['bn1.bias'] = value.new_zeros(value.shape[0])
    #         mapped_state_dict['bn1.running_mean'] = value.new_zeros(value.shape[0])
    #         mapped_state_dict['bn1.running_var'] = value.new_ones(value.shape[0])
    #     elif key.startswith('layer'):
    #         layer_num = key.split('.')[0][-1]  # Extract the layer number
    #         mapped_key = key.replace(f'layer{layer_num}', f'layer{layer_num}.0').replace('.0', '.conv1').replace('.1', '.bn1')
    #         mapped_state_dict[mapped_key] = value
    #     else:
    #         # Handle unexpected keys
    #         if key not in ['fc.weight', 'fc.bias']:
    #             print("Unexpected key:", key)

    # # Load mapped state dictionary into the ResNet model
    # resnet_ft.load_state_dict(mapped_state_dict, strict=False)
    
    # model_vbp_after = VBP.ResnetVisualizer(resnet_ft)
    
    resnet_ft = models.resnet18(pretrained=True)

    # Load state dictionary from file
    state_dict = torch.load('/Users/jiayuelin/InfoVis/Code/Andromeda_IMG_FineTune/Andromeda_pth/best_andromeda_model_animal_mouth_noDR.pth')

    # Map keys from loaded state dictionary to match the modified model's keys
    mapped_state_dict = {}
    for key, value in state_dict.items():
        if key.startswith('Model.'):
            mapped_state_dict[key[6:]] = value
        elif key.startswith('Triplet_Loss.'):
            mapped_state_dict[key] = value

    # Update model's state dictionary with mapped keys
    resnet_ft.load_state_dict(mapped_state_dict)
    
    model_vbp_after = VBP.ResnetVisualizer(resnet_ft)

    #grad_cam_model = best_model
    #print("Grad cam model:", grad_cam_model)
    #grad_target_layer = grad_cam_model[7][1]
    #print("Grad target layer:", grad_target_layer)
  
    n_components = 2
    top_k = 2
    
    #model_bp = VBP.ResnetVisualizer(best_model.eval(), weight_list=torch.ones([512]))
    #df = feature_extractor_vb(model_bp, img_loader, imageIndex_path_dict)
    
    def create_labels(concept_scores, top_k):
        """ Create a list with the ImageNet category names of the top scoring categories """
        imagenet_categories_url = "https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt"
        labels = eval(requests.get(imagenet_categories_url).text)
        concept_categories = np.argsort(concept_scores, axis=1)[:, ::-1][:, :top_k]
        concept_labels_topk = []
        for concept_index in range(concept_categories.shape[0]):
            categories = concept_categories[concept_index, :]
            concept_labels = []
            for category in categories:
                score = concept_scores[concept_index, category]
                label = f"{labels[category].split(',')[0]}:{score:.2f}"
                concept_labels.append(label)
            concept_labels_topk.append("\n".join(concept_labels))
        return concept_labels_topk

    def get_img_path(file_path):
        """A function that gets a file path of an image, 
        and returns a numpy image and a preprocessed
        torch tensor ready to pass to the model """
        
        img = np.array(Image.open(file_path))
        rgb_img_float = np.float32(img) / 255
        input_tensor = preprocess_image(rgb_img_float,
                                        mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])
        
        return img, rgb_img_float, input_tensor
    
    def plot_DFF(model, target_layer, img_path, n_components, top_k):
        img, rgb_img_float, input_tensor = get_img_path(img_path)
        #dff = DeepFeatureFactorization(model=model, target_layer=target_layer, computation_on_concepts=model[8][1])
        dff = DeepFeatureFactorization(model=model, target_layer=target_layer, computation_on_concepts=model.classifier)
        concepts, batch_explanations, concept_outputs = dff(input_tensor, n_components)
        #print("Concepts:", concepts)
        #print("Batch explanations:", batch_explanations)
        #print("Concept outputs:", concept_outputs)
        
        concept_outputs = torch.softmax(torch.from_numpy(concept_outputs), axis=-1).numpy()
        concept_label_strings = create_labels(concept_outputs, top_k=top_k)
        visualization = show_factorization_on_image(rgb_img_float,
                                                    batch_explanations[0],
                                                    image_weight=0.3,
                                                    concept_labels=concept_label_strings)
        
        result = np.hstack((img, visualization))

        max_height = 400  
        if result.shape[0] > max_height:
            scaling_factor = max_height / result.shape[0]
            result = cv2.resize(result, (0, 0), fx=scaling_factor, fy=scaling_factor)

        return result
    
    def get_img_list(model, target_layer, files, n_components, top_k):
        for img_path in files:
            display(Image.fromarray(plot_DFF(model, target_layer, img_path, n_components, top_k)))

    # Example usage:
    def feature_button_clicked(change):
        print_output.clear_output()
        c_index = [c.index for c in ax.circles if c.selected]
        s_index = list(normalized_df.iloc[c_index, :].index)
        #weights_list = torch.tensor(np.float32(weights))
        files = [imageIndex_path_dict[i] for i in s_index]
        
        transform = transformer()
        
        ds = FilenameDataset(files, transform)
        loader = DataLoader(ds)
        #global model_bp
        with print_output:
            # plot_VBP_map(loader, model_vbp_before)
            # plot_VBP_map(loader, model_vbp_after)
            # Example of using index as identifier

            # Assuming only one model_vbp_before and one model_vbp_after
            #plot_VBP_map(loader, model_vbp_before, "before_mouth")
            plot_VBP_map(loader, model_vbp_after, "animal_mouth")
            #get_img_list(grad_cam_model, grad_target_layer, files, n_components, top_k)
        
        # Initialize an empty list to store processed images
        processed_images = []

        # Process each image in the files list
        for file_path in files:
            # Open the image using PIL
            image = Image.open(file_path)

            # Apply transformations to the image
            transformed_image = transform(image)
            transformed_image = transformed_image.unsqueeze(0)  # Add a batch dimension

            # Append the processed image to the list
            processed_images.append(transformed_image)
            
## inverse DR button
def create_inverse_button(ax, fig_show=False):
    """
    @return[button]: return 'Learn New Weights' and 'Update Projections' buttons
    """
    # inverse_button = widgets.Button(
    #     description='Learn New Weights',
    #     style=ButtonStyle(button_color='darkseagreen'))
    
    train_model_button = widgets.Button(
        description='Fine-tune Model',
        style=ButtonStyle(button_color='darkseagreen'))
    
    evaluation_simulator_button = widgets.Button(
    description='Evaluation Simulator',
    style=ButtonStyle(button_color='darkseagreen'))
    
    # copy_button = widgets.Button(
    #     description='Update Projections',
    #     style=ButtonStyle(button_color='darkseagreen'))
    
#     def inverse_button_clicked(change):
#         # Check minimum number of points moved
#         n = sum([i.selected for i in ax.dragpoint.artists])
#         if n < 2:
#             print(
#                 'Need to select or move at least 2 points in the plot first.')
#             return

#         # Get selected data points
#         data2Dnew = pd.DataFrame(
#             [c.center for c in ax.dragpoint.artists if c.selected],
#             columns=['x', 'y'],
#             index=[c.label for c in ax.dragpoint.artists if c.selected])
#         global normalized_df
#         dataHDpart = normalized_df.loc[data2Dnew.index]

#         # Learn new weights
#         global weights
#         weights = inverse_DR(dataHDpart, data2Dnew)
# #         print(weights)

#     inverse_button.on_click(inverse_button_clicked)

    def train_model_button_clicked(change):
        # Check minimum number of points moved
        n = sum([i.selected for i in ax.dragpoint.artists])
        if n < 3:
            print(
                'Need to select or move at least 3 points in the plot first.')
            return
        #### for fine-tuning ####
        # using the same model in backprop file to make sure
        # the forward and backward process happens to the same network
        
        #Get selected data points
        selected_imageIndex_path_dict = user_defined_train_data(ax)
        updated_coordinates = get_updated_coordinates(ax)
        #print("Updated Coordinates:", updated_coordinates)

        x_values, y_values = zip(*updated_coordinates.values())
        #print("X values:", x_values)
        #print("Y values:", y_values)

        data = {
            'image': [f"{key.split('/')[0]}/{key.split('/')[1]}.jpg" for key in selected_imageIndex_path_dict.keys()],
            'x': list(x_values),
            'y': list(y_values)
        }
        
        train_df = pd.DataFrame(data)

        print("User defined train data is:", train_df)

        #Define the file path 
        file_path = "/Users/jiayuelin/InfoVis/Datasets/train_data.csv"

        # Save the DataFrame as a CSV file
        train_df.to_csv(file_path, index=False)
        
        # Create an instance of the Andromeda_IMG class
        andromeda = Andromeda_IMG(data_csv_path='/Users/jiayuelin/InfoVis/Datasets/train_data.csv', image_dir=imgFolder)
        andromeda.train()
        print("Training complete")

        # Load the best model weights into the andromeda's custom model
        best_model = andromeda.resnet_model
        best_model.load_state_dict(torch.load('/Users/jiayuelin/InfoVis/Code/Andromeda_IMG_FineTune/Andromeda_pth/best_andromeda_model_animal_mouth.pth')) # ChangeHere
        best_model.eval()  # Set the model to evaluation mode

        # Use the best model to extract features
        update_normalized_df, model_output = feature_extractor_custom(best_model, img_loader, imageIndex_path_dict)
        
        update_df_2D = dimension_reduction(update_normalized_df)
        #update_df_2D = update_normalized_df
        update_df_2D['label'] = update_df_2D.index.str.split('/').str[0]
        #print("The original normalized_df is:", df_2D)
        update_df_2D = update_df_2D[['x', 'y', 'label']]
        
        silhouette = silhouette_score(update_df_2D[['x', 'y']], update_df_2D['label'])
        #print(f"The Silhouette Score after simulated dragging {n_points_per_label} points per label is: {silhouette}")
        adjusted_silhouette = silhouette * 2
        print(f"The Adjusted Silhouette Score after simulated dragging 8 points per label is: {adjusted_silhouette}")

        # Save the 2D projection as both CSV
        csv_save_path = f'/Users/jiayuelin/InfoVis/Code/Andromeda_IMG_FineTune/Triplet-Margin/animal_mouth_DR_8_EXP.csv' # ChangeHere
        update_df_2D.to_csv(csv_save_path, index=True)

        # Create and save the plot
        plot_ax = create_plot(update_df_2D)
        fig_save_path = f'/Users/jiayuelin/InfoVis/Code/Andromeda_IMG_FineTune/Triplet-Margin/animal_mouth_DR_8_EXP.png' # ChangeHere
        toggle_image = hasattr(plot_ax.dragpoint.artists[0], 'ab')
        toggle_title = ax.dragpoint.artists[0].text.get_text() != '' 
        draw_plot(plot_ax, update_df_2D, toggle_title, toggle_image, size_slider.value)
        plt.savefig(fig_save_path, bbox_inches='tight', pad_inches=0.1)

    train_model_button.on_click(train_model_button_clicked)
    
    # def train_model_button_clicked(change):
    #     # Check minimum number of points moved
    #     n = sum([i.selected for i in ax.dragpoint.artists])
    #     if n < 3:
    #         print(
    #             'Need to select or move at least 3 points in the plot first.')
    #         return
    #     #### for fine-tuning ####
    #     # using the same model in backprop file to make sure
    #     # the forward and backward process happens to the same network
        
    #     #Get selected data points
    #     selected_imageIndex_path_dict = user_defined_train_data(ax)
    #     updated_coordinates = get_updated_coordinates(ax)
    #     #print("Updated Coordinates:", updated_coordinates)

    #     x_values, y_values = zip(*updated_coordinates.values())
    #     #print("X values:", x_values)
    #     #print("Y values:", y_values)

        # data = {
        #     'image': [f"{key.split('/')[0]}/{key.split('/')[1]}.jpg" for key in selected_imageIndex_path_dict.keys()],
        #     'x': list(x_values),
        #     'y': list(y_values)
        # }
        
        # train_df = pd.DataFrame(data)

        # print("User defined train data is:", train_df)

        # #Define the file path 
        # file_path = "/Users/jiayuelin/InfoVis/Datasets/train_data.csv"

        # # Save the DataFrame as a CSV file
        # train_df.to_csv(file_path, index=False)
        
        # # Create an instance of the Andromeda_IMG class
        # andromeda = Andromeda_IMG(data_csv_path="/Users/jiayuelin/InfoVis/Datasets/train_data.csv", image_dir=imgFolder)
        # andromeda.train()

        # # Load the best model weights into the andromeda's custom model
        # best_model = andromeda.resnet_model
        # best_model.load_state_dict(torch.load('/Users/jiayuelin/InfoVis/Code/Andromeda_IMG_FineTune/Andromeda_pth/best_andromeda_model_pods_num_mix_noDR.pth')) # ChangeHere
        # best_model.eval()  # Set the model to evaluation mode

        # # Use the best model to extract features
        # update_normalized_df, model_output = feature_extractor_triplet(best_model, img_loader, imageIndex_path_dict)
        # update_normalized_df.rename(columns={'1': 'x', '2': 'y'}, inplace=True)
        
    #     update_df_2D = update_normalized_df
    #     update_df_2D['label'] = update_df_2D.index.str.split('/').str[0]
    #     #print("The original normalized_df is:", df_2D)
    #     update_df_2D = update_df_2D[['x', 'y', 'label']]
        
    #     silhouette = silhouette_score(update_df_2D[['x', 'y']], update_df_2D['label'])
    #     #print(f"The Silhouette Score after simulated dragging {n_points_per_label} points per label is: {silhouette}")
    #     adjusted_silhouette = silhouette * 2
    #     print(f"The Adjusted Silhouette Score after simulated dragging 8 points per label is: {adjusted_silhouette}")

    #     # Save the 2D projection as both CSV
    #     csv_save_path = f'/Users/jiayuelin/InfoVis/Code/Andromeda_IMG_FineTune/Triplet-Margin/pods_num_mix_noDR_8_EXP.csv' # ChangeHere
    #     update_df_2D.to_csv(csv_save_path, index=True)

    #     # Create and save the plot
    #     plot_ax = create_plot(update_df_2D)
    #     fig_save_path = f'/Users/jiayuelin/InfoVis/Code/Andromeda_IMG_FineTune/Triplet-Margin/pods_num_mix_noDR_8_EXP.png' # ChangeHere
    #     toggle_image = hasattr(plot_ax.dragpoint.artists[0], 'ab')
    #     toggle_title = ax.dragpoint.artists[0].text.get_text() != ''
    #     draw_plot(plot_ax, update_df_2D, toggle_title, toggle_image, size_slider.value)
    #     plt.savefig(fig_save_path, bbox_inches='tight', pad_inches=0.1)

    # train_model_button.on_click(train_model_button_clicked)
        
    # def train_model_button_clicked(change):
    #     # Check minimum number of points moved
    #     n = sum([i.selected for i in ax.dragpoint.artists])
    #     if n < 2:
    #         print(
    #             'Need to select or move at least 2 points in the plot first.')
    #         return
    #     #### for fine-tuning ####
    #     # using the same model in backprop file to make sure
    #     # the forward and backward process happens to the same network
    #     # Instantiate Samples and obtain moved samples
    
    #     moved_samples = {}
    #     for c in ax.dragpoint.artists:
    #         if c.selected:
    #             moved_samples[c.label] = list(c.center)
    #     print("The moved samples are:", moved_samples)

    #     samples = Samples(imgFolder, sampleSizePerCat, folderName, total_img)
    #     # Get samples for training from moved samples
    #     samples_for_training = samples.get_samples_for_training(moved_samples)
    #     #print("The samples for training are:", samples_for_training)
        
    #     # Instantiate the Deep_crop class
    #     deep_crop = Deep_crop()

    #     # if len(samples_for_training.img_ids) > 0:
    #     print("The samples for training are:", samples_for_training.img_ids.keys())
    #     # Train the model using the obtained samples for training
    #     print("Start training...")
    #     deep_crop.train(samples_for_training)

    #     #Load the best model weights into the andromeda's custom model
    #     best_model = deep_crop.model
    #     #print("Initial deep_crop.model state_dict keys:", deep_crop.model.state_dict().keys())
    #     state_dict = torch.load('/Users/jiayuelin/InfoVis/Code/Andromeda_IMG_FineTune/Deep_Crop_pth/best_deep_crop_animal_mouth_model.pth')
    #     best_model.load_state_dict(state_dict)
    #     #print("Loaded best_model state_dict keys:", best_model.state_dict().keys())
    #     best_model.eval()  # Set the model to evaluation mode

    #     # Use the best model to extract features
    #     update_normalized_df, model_output = feature_extractor_custom(best_model, img_loader, imageIndex_path_dict)
    #     #update_normalized_df = df_preprocess(update_normalized_df)
        
    #     update_df_2D = dimension_reduction(update_normalized_df)
    #     update_df_2D['label'] = update_df_2D.index.str.split('/').str[0]
    #     #print("The original normalized_df is:", df_2D)
    #     update_df_2D = update_df_2D[['x', 'y', 'label']]
    #     #print("The updated df is:", update_df_2D)
        
    #     silhouette = silhouette_score(update_df_2D[['x', 'y']], update_df_2D['label'])
    #     print("The Silhouette Score after interaction is: ", silhouette)
    #     adjusted_silhouette = silhouette * 2
    #     print("The Adjusted Silhouette Score after interaction is: ", adjusted_silhouette)
        
    #     # Save the 2D projection as both CSV
    #     csv_save_path = f'/Users/jiayuelin/InfoVis/Code/Andromeda_IMG_FineTune/User-Defined/animal_mouth_DR_8_Exp.csv'
    #     update_df_2D.to_csv(csv_save_path, index=True)

    #     # Create and save the plot
    #     plot_ax = create_plot(update_df_2D)
    #     fig_save_path = f'/Users/jiayuelin/InfoVis/Code/Andromeda_IMG_FineTune/User-Defined/animal_mouth_DR_8_Exp.png'
    #     toggle_image = hasattr(plot_ax.dragpoint.artists[0], 'ab')
    #     toggle_title = ax.dragpoint.artists[0].text.get_text() != ''
    #     draw_plot(plot_ax, update_df_2D, toggle_title, toggle_image, size_slider.value)
    #     plt.savefig(fig_save_path, bbox_inches='tight', pad_inches=0.1)

    # train_model_button.on_click(train_model_button_clicked)

    def simulate_dragging_user(data_2D, n_points_per_label):
        print(f"Simulating dragging with {n_points_per_label} points per label")

        # Initialize a DataFrame to store the simulated dragged points
        dragged_data = data_2D.copy()

        # Initialize dictionaries to store selected/moved points and position changes
        moved_samples = {}
        position_changed = {}

        # Iterate over unique labels in the data
        for label in data_2D['label'].unique():
            # Extract points for the current label
            label_points = data_2D[data_2D['label'] == label][['x', 'y']]

            # Randomly select points for dragging
            selected_points = label_points.sample(n=n_points_per_label)

            # Check if all selected points have NaN values
            if selected_points.isnull().all().all():
                continue  # Skip this label if all points are NaN

            # Extract image IDs without repetition of the label
            img_ids = selected_points.index.str.split('/').str[1]

            # Calculate the mean position for the selected points
            mean_position = selected_points.mean(skipna=True)  # Skip NaN values

            # Check if the mean position is NaN
            if mean_position.isnull().any():
                continue  # Skip this label if mean position is NaN

            # Simulate dragging by setting selected points to the mean position
            dragged_data.loc[selected_points.index, ['x', 'y']] = mean_position.values

            # Store the selected points in the moved_samples dictionary
            moved_samples.update({f"{label}/{img_id}": list(mean_position.values) for img_id in img_ids})
        print("The moved samples are:", moved_samples)
        return dragged_data, moved_samples
    
    def simulate_dragging_triplet(data_2D, n_points_per_label, file_path='/Users/jiayuelin/InfoVis/Datasets/train_data.csv'):
        print(f"Simulating dragging with {n_points_per_label} points per label")

        # Initialize a DataFrame to store the simulated dragged points
        dragged_data = data_2D.copy()

        # Initialize lists to store data for DataFrame
        images = []
        x_values = []
        y_values = []

        # Iterate over unique labels in the data
        for label in data_2D['label'].unique():
            # Extract points for the current label
            label_points = data_2D[data_2D['label'] == label][['x', 'y']]

            # Randomly select points for dragging
            selected_points = label_points.sample(n=n_points_per_label)

            # Check if all selected points have NaN values
            if selected_points.isnull().all().all():
                continue  # Skip this label if all points are NaN

            # Extract image IDs without repetition of the label
            img_ids = selected_points.index.str.split('/').str[1]

            # Calculate the mean position for the selected points
            mean_position = selected_points.mean(skipna=True)  # Skip NaN values

            # Check if the mean position is NaN
            if mean_position.isnull().any():
                continue  # Skip this label if mean position is NaN

            # Simulate dragging by setting selected points to the mean position
            dragged_data.loc[selected_points.index, ['x', 'y']] = mean_position.values

            # Store the selected points in the lists
            images.extend([f"{label}/{img_id}.jpg" for img_id in img_ids])
            x_values.extend([mean_position['x']] * len(img_ids))
            y_values.extend([mean_position['y']] * len(img_ids))

        # Create DataFrame from the lists
        moved_samples_df = pd.DataFrame({'image': images, 'x': x_values, 'y': y_values})
        print("The moved samples are:", moved_samples_df)

        # If file_path is provided, save moved_samples to a CSV file
        if file_path:
            moved_samples_df.to_csv(file_path, index=False)

        return dragged_data, moved_samples_df

    def get_update_df_user(img_loader, imageIndex_path_dict, samples_for_training):
        print("Training the model with moved samples")
        # Instantiate the Deep_crop class
        deep_crop = Deep_crop()

        # Train the model using the obtained samples for training
        deep_crop.train(samples_for_training)
        print("Training complete")
        
        best_model = deep_crop.model
        state_dict = torch.load('/Users/jiayuelin/InfoVis/Code/Andromeda_IMG_FineTune/Deep_Crop_pth/best_deep_crop_animal_mouth_model.pth')
        best_model.load_state_dict(state_dict)
        best_model.eval()  # Set the model to evaluation mode
 
        #Use the trained model to extract features
        update_normalized_df, _ = feature_extractor_custom(best_model, img_loader, imageIndex_path_dict)
        
        # Obtain the updated 2D representation
        update_df_2D = dimension_reduction(update_normalized_df)
        update_df_2D['label'] = update_df_2D.index.str.split('/').str[0]

        return update_df_2D
    
    def get_update_df_triplet(img_loader, imageIndex_path_dict):
        print("Training the model with moved samples")

        # Create an instance of the Andromeda_IMG class
        andromeda = Andromeda_IMG(data_csv_path='/Users/jiayuelin/InfoVis/Datasets/train_data.csv', image_dir=imgFolder)
        andromeda.train()
        print("Training complete")

        # Load the best model weights into the andromeda's custom model
        best_model = andromeda.resnet_model
        best_model.load_state_dict(torch.load('/Users/jiayuelin/InfoVis/Code/Andromeda_IMG_FineTune/Andromeda_pth/best_andromeda_model_pods_maturity_all_noDR.pth'))
        best_model.eval()  # Set the model to evaluation mode

        # Use the best model to extract features
        update_normalized_df, model_output = feature_extractor_triplet(best_model, img_loader, imageIndex_path_dict)
        update_normalized_df.rename(columns={'1': 'x', '2': 'y'}, inplace=True)
        
        #update_normalized_df = df_preprocess(update_normalized_df)
        # update_normalized_df.rename(columns={'1': 'x', '2': 'y'}, inplace=True)
        # update_normalized_df['label'] = update_normalized_df.index.str.split('/').str[0]
        # update_normalized_df = update_normalized_df[['x', 'y', 'label']]
        #update_df_2D = dimension_reduction(update_normalized_df)
        update_df_2D = update_normalized_df
        update_df_2D['label'] = update_df_2D.index.str.split('/').str[0]
        #print("The original normalized_df is:", df_2D)
        update_df_2D = update_df_2D[['x', 'y', 'label']]

        return update_df_2D
    
    def get_update_df_triplet_DR(img_loader, imageIndex_path_dict):
        print("Training the model with moved samples")

        # Create an instance of the Andromeda_IMG class
        andromeda = Andromeda_IMG(data_csv_path='/Users/jiayuelin/InfoVis/Datasets/train_data.csv', image_dir=imgFolder)
        andromeda.train()
        print("Training complete")

        # Load the best model weights into the andromeda's custom model
        best_model = andromeda.resnet_model
        best_model.load_state_dict(torch.load('/Users/jiayuelin/InfoVis/Code/Andromeda_IMG_FineTune/Andromeda_pth/best_andromeda_model_animal_human.pth'))
        best_model.eval()  # Set the model to evaluation mode

        # Use the best model to extract features
        update_normalized_df, model_output = feature_extractor_custom(best_model, img_loader, imageIndex_path_dict)
        
        #update_normalized_df = df_preprocess(update_normalized_df)
        # update_normalized_df.rename(columns={'1': 'x', '2': 'y'}, inplace=True)
        # update_normalized_df['label'] = update_normalized_df.index.str.split('/').str[0]
        # update_normalized_df = update_normalized_df[['x', 'y', 'label']]
        update_df_2D = dimension_reduction(update_normalized_df)
        # update_df_2D = update_normalized_df
        update_df_2D['label'] = update_df_2D.index.str.split('/').str[0]
        #print("The original normalized_df is:", df_2D)
        update_df_2D = update_df_2D[['x', 'y', 'label']]

        return update_df_2D
    
    # def evaluation_simulator_button_clicked(change):
    #     # After training the model and obtaining the 2D representation
    #     for _ in range(1):    
    #         for n_points_per_label in range(20, 21):
    #             #n_points_per_label = 60
    #             # Simulate dragging using the deep_crop model
    #             dragged_data, moved_samples_df = simulate_dragging_triplet(df_2D, n_points_per_label)

    #             update_df_2D = get_update_df_triplet_DR(img_loader, imageIndex_path_dict)
                
    #             # # Select only the 'x' and 'y' columns for normalization
    #             # data_to_normalize = update_df_2D[['x', 'y']]

    #             # # Initialize the StandardScaler
    #             # scaler = StandardScaler()

    #             # # Fit the scaler to the data and transform the data
    #             # normalized_data = scaler.fit_transform(data_to_normalize)

    #             # # Update the 'x' and 'y' columns in update_df_2D with the normalized values
    #             # update_df_2D['x'] = normalized_data[:, 0]  # Update 'x' column with normalized 'x'
    #             # update_df_2D['y'] = normalized_data[:, 1]  # Update 'y' column with normalized 'y'

    #             silhouette = silhouette_score(update_df_2D[['x', 'y']], update_df_2D['label'])
    #             #print(f"The Silhouette Score after simulated dragging {n_points_per_label} points per label is: {silhouette}")
    #             adjusted_silhouette = silhouette * 2
    #             print(f"The Adjusted Silhouette Score after simulated dragging {n_points_per_label} points per label is: {adjusted_silhouette}")

    #             # Save the 2D projection as both CSV
    #             csv_save_path = f'/Users/jiayuelin/InfoVis/Code/Andromeda_IMG_FineTune/Triplet-Margin/animal_human_DR_{n_points_per_label}.csv'
    #             update_df_2D.to_csv(csv_save_path, index=True)

    #             # Create and save the plot
    #             plot_ax = create_plot(update_df_2D)
    #             fig_save_path = f'/Users/jiayuelin/InfoVis/Code/Andromeda_IMG_FineTune/Triplet-Margin/animal_human_DR_{n_points_per_label}.png'
    #             toggle_image = hasattr(plot_ax.dragpoint.artists[0], 'ab')
    #             toggle_title = ax.dragpoint.artists[0].text.get_text() != ''
    #             draw_plot(plot_ax, update_df_2D, toggle_title, toggle_image, size_slider.value)
    #             plt.savefig(fig_save_path, bbox_inches='tight', pad_inches=0.1)

    # evaluation_simulator_button.on_click(evaluation_simulator_button_clicked)

    def evaluation_simulator_button_clicked(change):
        # After training the model and obtaining the 2D representation
        for _ in range(1):
            for n_points_per_label in range(20, 21):
                # Simulate dragging using the deep_crop model
                dragged_data, moved_samples = simulate_dragging_user(df_2D, n_points_per_label)

                # Use the deep_crop.model for training
                samples = Samples(imgFolder, sampleSizePerCat, folderName, total_img)
                samples_for_training = samples.get_samples_for_training(moved_samples)
                update_df_2D = get_update_df_user(img_loader, imageIndex_path_dict, samples_for_training)

                # Continue with your existing code to calculate silhouette score, adjusted silhouette, etc.
                silhouette = silhouette_score(update_df_2D[['x', 'y']], update_df_2D['label'])
                adjusted_silhouette = silhouette * 2
                print(f"The Adjusted Silhouette Score after simulated dragging {n_points_per_label} points per label is: {adjusted_silhouette}")

                # Save the 2D projection as both CSV
                csv_save_path = f'/Users/jiayuelin/InfoVis/Code/Andromeda_IMG_FineTune/User-Defined/animal_mouth_DR_{n_points_per_label}.csv'
                update_df_2D.to_csv(csv_save_path, index=True)

                # Create and save the plot
                plot_ax = create_plot(update_df_2D)
                fig_save_path = f'/Users/jiayuelin/InfoVis/Code/Andromeda_IMG_FineTune/User-Defined/animal_mouth_DR_{n_points_per_label}.png'
                toggle_image = hasattr(plot_ax.dragpoint.artists[0], 'ab')
                toggle_title = ax.dragpoint.artists[0].text.get_text() != ''
                draw_plot(plot_ax, update_df_2D, toggle_title, toggle_image, size_slider.value)
                plt.savefig(fig_save_path, bbox_inches='tight', pad_inches=0.1)

    evaluation_simulator_button.on_click(evaluation_simulator_button_clicked)

    # def copy_button_clicked(change):

    #     global df_2D, size_slider, normalized_df

    #     df_2D = dimension_reduction(normalized_df)
    #     toggle_image = hasattr(plot_ax.dragpoint.artists[0], 'ab')
    #     toggle_tittle = ax.dragpoint.artists[0].text.get_text() != ''

    #     draw_plot(ax, df_2D, toggle_tittle, toggle_image, size_slider.value)
    #     # update weights!!!
    #     global model_bp, prev_model_bp
    #     prev_model_bp = copy.deepcopy(model_bp)
    #     model_bp.update_weights(torch.tensor(np.float32(weights)))
        
    #     print("The update weights are:", weights)
        

    # copy_button.on_click(copy_button_clicked)

    if fig_show:
        fig, ax = plt.subplots(
            figsize=(5, 7))  # reserve a fig for the weights bar chart
        weights.sort_index(ascending=False).plot.barh(ax=ax1)
        ax1.set_xlabel("Weight")
        fig.tight_layout()
        
    return train_model_button, evaluation_simulator_button

    # return inverse_button,train_model_button, copy_button

## Draggable Dimension-Reduction 2D Plot

In [77]:
# Handles mouse drag interaction events in the plot, users can select and drag points.
class DraggablePoints(object):
    def __init__(self, ax, artists):
        self.ax = ax
        self.artists = artists
        self.current_artist = None
        self.last_selected = None
        ax.selected_text.set_text('Selected: none')
        self.offset = (0, 0)
        # Set up mouse listeners
        ax.figure.canvas.mpl_connect('pick_event', self.on_pick)
        ax.figure.canvas.mpl_connect('motion_notify_event', self.on_motion)
        ax.figure.canvas.mpl_connect('button_release_event', self.on_release)

    def on_pick(self, event):
        # When point is clicked on (mouse down), select it and start the drag
        if self.current_artist is None:  # clicking on overlapped points sends multiple events
            self.last_selected = event.artist.index  # event.ind
            self.current_artist = event.artist
            event.artist.selected = True
            event.artist.savecolor = event.artist.get_facecolor()
            event.artist.set_facecolor('green')
            #event.artist.set_alpha(1.0)
            self.ax.selected_text.set_text("Selected: " + event.artist.label)
            x0, y0 = event.artist.center
            self.offset = (x0 -
                           event.mouseevent.xdata), (y0 -
                                                     event.mouseevent.ydata)

    def on_motion(self, event):
        # When dragging, check if point is selected and valid mouse coordinates
        if (self.current_artist
                is not None) and (event.xdata is not None) and (event.ydata
                                                                is not None):
            # Drag the point and its text label
            dx, dy = self.offset
            self.current_artist.center = x0, y0 = event.xdata + dx, event.ydata + dy
            self.current_artist.text.set_position(
                (x0 + self.current_artist.radius, y0))
            if self.current_artist.ab:
                self.current_artist.ab.xybox = (x0, y0)

    def on_release(self, event):
        # When mouse is released, stop the drag
        self.current_artist = None
        
    def get_dragged_coordinates(self):
        """
        Get the new x, y coordinates of the dragged points after dragging.
        Returns a dictionary where keys are the labels of the artists and values are (x, y) coordinates.
        """
        dragged_coordinates = {}

        for artist in self.artists:
            if artist.selected:
                label = artist.label
                x, y = artist.center
                dragged_coordinates[label] = (x, y)

        return dragged_coordinates

In [78]:
def create_plot(data2D, title=False):
    """
    @parameters:
        data2D[pd.df or np.array]: projected 2D data
        title[boolean]: toggle tittle or not 
    @return[AxesSubplot]: plotting axes
    """
    # Initialize DR plot figure
    fig, ax = plt.subplots(figsize=(10, 10), dpi=80)
    ax.selected_text = ax.figure.text(0,
                                      0.005,
                                      'Selected: none',
                                      wrap=True,
                                      color='green')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.figure.tight_layout()

    # Store state data:
    ax.dragpoint = None
    draw_plot(ax, data2D, title)

    return ax


def image_preprocessing(imgIndex):
    """
    segment images if background is black
    @parameters:
        imgIndex[str]: image index
    @return[np.array(image)]: processed image
    """

    global imageIndex_path_dict
    path = imageIndex_path_dict[imgIndex]
    up_width = 224
    up_height = 224
    up_points = (up_width, up_height)
    src = cv2.imread(path, 1)
    src = cv2.resize(src, up_points, interpolation=cv2.INTER_LINEAR)
    tmp = cv2.cvtColor(src, cv2.COLOR_BGR2GRAY)
    _, alpha = cv2.threshold(tmp, 0, 255, cv2.THRESH_BINARY)
    b, g, r = cv2.split(src)
    rgba = [r, g, b, alpha]
    processed_img = cv2.merge(rgba, 4)

    return processed_img

def draw_plot(ax, data2D, title=False, image=True, imgSize=imgDisplaySize):
    """
    @parameters:
        ax[AxesSubplot]: plot ax
        data2D[dataframe]: projected 2D dataframe
        title[boolean]: whether toggle title or not
        image[boolean]: whether toggle image or not
        imgSize[float]: zoom parameter of OffsetImage, determine size of displaying images
    """
    labels_df = data2D.index.to_series().apply(lambda x: x.split('/')[0])
    data2D['label'] = labels_df
    label_list = list(set(data2D['label']))
    conversion_dict = {
        k: v
        for k, v in zip(label_list, range(len(label_list)))
    }
    data2D['label_num'] = data2D['label'].replace(conversion_dict)
    ax.clear()
    wid = max(data2D.x.max() - data2D.x.min(),
              data2D.y.max() - data2D.y.min())  # max range of x,y axes
    if not image:
        cnorm = mpl.colors.Normalize(vmin=data2D.label_num.min(),
                                     vmax=data2D.label_num.max())

        ax.circles = data2D.apply(axis=1,
                                  func=lambda row: mpl.patches.Circle(
                                      xy=(row.x, row.y),
                                      radius=wid / 70,
                                      alpha=0.95,
                                      label=row.label_num,
                                      color=plt.cm.Set3(cnorm(row.label_num)),
                                      picker=True))
        for i, c in enumerate(ax.circles):
            # Store state data:
            c.index, c.label, c.selected = i, data2D.index[i], False
            # Draw circles and text labels in plot
            ax.add_patch(c)
            if title:
                c.text = ax.text(c.center[0] + c.radius,
                                 c.center[1],
                                 c.label,
                                 color='black')
            else:
                c.text = ax.text(c.center[0] + c.radius,
                                 c.center[1],
                                 "",
                                 color='none')
    else:
        ax.circles = []
        for x0, y0, index in zip(data2D.x, data2D.y, data2D.index):
            ax.circles.append(
                mpl.patches.Circle(xy=(x0, y0),
                                   radius=wid * imgSize / 3,
                                   alpha=0.5,
                                   color='none',
                                   picker=True))

        for i, c in enumerate(ax.circles):
            # Store state data:
            c.index, c.label, c.selected = i, data2D.index[i], False
            # Draw circles and text labels in plot
            ax.add_patch(c)
            processed_img = image_preprocessing(c.label)
            img = OffsetImage(processed_img, zoom=imgSize)
            c.ab = AnnotationBbox(img, (c.center[0], c.center[1]),
                                  frameon=False)
            ax.add_artist(c.ab)
            if title:
                c.text = ax.text(c.center[0] + c.radius,
                                 c.center[1],
                                 c.label,
                                 color='black')
            else:
                c.text = ax.text(c.center[0] + c.radius,
                                 c.center[1],
                                 "",
                                 color='none')

    # Make plot circles draggable
    ax.dragpoint = DraggablePoints(ax, ax.circles)
    # Clean up the plot
    ax.set_xticks([])
    ax.set_yticks([])
    ax.axis('equal')

# Interactive Visualization

## Dimension Reduction Plot
This shows the HD data in 2D form, such that **proximity == similarity**, based on the current slider weights.  Distances between points in the plot approximately reflect their distances in the weighted HD data.  Thus points near each other have similar HD data values in the up-weighted dimensions, and points far away have very different HD data values in those dimensions.

Points can be **selected** to highlight in Green and view their details below.  Points can be **dragged** to specify a new projection for learning weights, see below. To reset the plot and clear the selections, click the **Reset** button above.

In [80]:
plot_ax = create_plot(df_2D)
plt.show()

<IPython.core.display.Javascript object>

In [81]:
size_slider = create_size_slider()
size_apply_button = create_size_slider_button(size_slider, plot_ax)
finetune_button, evaluation_simulator_button = create_inverse_button(
    plot_ax)
display(widgets.HBox([size_slider, size_apply_button]))
display(
    widgets.HBox([
        finetune_button, 
        evaluation_simulator_button,
        create_reset_buttons(plot_ax)
    ]))

HBox(children=(FloatSlider(value=0.25, continuous_update=False, description='Adjust image size', max=1.0, read…

HBox(children=(Button(description='Fine-tune ResNet18', style=ButtonStyle(button_color='darkseagreen')), Butto…

In [82]:
point_attribute = create_checkbox(plot_ax)

interactive(children=(Checkbox(value=False, description='Toggle Titles', indent=False, layout=Layout(height='2…

interactive(children=(Checkbox(value=True, description='Toggle Images', indent=False, layout=Layout(height='20…