In [None]:
import os
import seaborn as sns
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
from torchvision.models import inception_v3
from PIL import Image
import japanize_matplotlib
from sklearn.decomposition import PCA


class TensorPlotter:
    def __init__(self, tensors:list, device="cpu", labels=None):
        self.tensors = tensors
        self.num_tensors = len(tensors)
        self.labels = labels or [f"Tensor {i+1}({len(self.tensors[i])})" for i in range(self.num_tensors)]
        assert self.num_tensors == len(self.labels), "Number of tensors and labels should match."
        
        # self.color_map =  plt.cm.Set1([i for i in range(self.num_tensors)])
        # self.color_map =  plt.cm.Dark2([i for i in range(self.num_tensors)])
        self.color_map =  plt.cm.tab10([i for i in range(self.num_tensors)])
        # self.color_map =  plt.cm.hsv(np.linspace(0, 1, self.num_tensors))
        
        self.device = device
        
        return None
    
    def _prepare_data(self, dim):
        reshaped_data = [tensor.reshape(-1, 4) for tensor in self.tensors]
        combined_data = np.vstack(reshaped_data)
        pca = PCA(n_components=dim)
        self.reduced_data = pca.fit_transform(combined_data)

    def pca_plot(self, dim=2, plot_title_append=""):
        
        if dim not in [2, 3]:
            raise ValueError("Only 2D and 3D visualizations are supported.")
        
        reshaped_data = [tensor.reshape(tensor.size(0), -1) for tensor in self.tensors]
        combined_data = np.vstack(reshaped_data)
        
        pca = PCA(n_components=dim)
        self.reduced_data = pca.fit_transform(combined_data)
        
        plot_title = f"{dim}D PCA of 4D Tensor Data"+plot_title_append
        
        if dim == 2:
            self._plot_2d(plot_title)
        else:
            self._plot_3d(plot_title)

    def tsne_plot(self, dim=2, plot_title_append=""):
        
        '''
            GPUがある場合、以下の警告が出る？
            OpenBLAS Warning : Detect OpenMP Loop and this application may hang. Please rebuild the library with USE_OPENMP=1 option.
        '''
        
        if dim not in [2, 3]:
            raise ValueError("Only 2D and 3D visualizations are supported.")
        
        reshaped_data = [tensor.reshape(tensor.size(0), -1) for tensor in self.tensors]
        combined_data = np.vstack(reshaped_data)
        
        tsne = TSNE(n_components=dim)
        self.reduced_data = tsne.fit_transform(combined_data)
        
        plot_title = f"{dim}D t-SNE of 4D Tensor Data"+plot_title_append
        
        if dim == 2:
            self._plot_2d(plot_title)
        else:
            self._plot_3d(plot_title)

    def umap_plot(self, dim=2, random_state=0, plot_title_append=""):
        
        if dim not in [2, 3]:
            raise ValueError("Only 2D and 3D visualizations are supported.")
        
        reshaped_data = [tensor.reshape(tensor.size(0), -1) for tensor in self.tensors]
        combined_data = np.vstack(reshaped_data)
        
        umap = UMAP(n_components=dim, random_state=random_state)
        self.reduced_data = umap.fit_transform(combined_data)
        
        plot_title = f"{dim}D UMAP of 4D Tensor Data"+plot_title_append
        
        if dim == 2:
            self._plot_2d(plot_title)
        else:
            self._plot_3d(title=plot_title)

    def _plot_2d(self, title="", point_size=10):
        colors = self.color_map
        
        split_idx = [sum([len(t) for t in self.tensors[:i+1]]) for i in range(len(self.tensors))][:-1]
        for idx, data in enumerate(np.split(self.reduced_data, split_idx)):
            sns.kdeplot(x=data[:, 0], y=data[:, 1], fill=True, color=colors[idx]*0.7, alpha=0.3)
            plt.scatter(data[:, 0], data[:, 1], color=colors[idx], label=self.labels[idx], s=point_size)

        plt.title(title)
        plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
        plt.show()
        plt.cla(); plt.clf()
        
        return None

    def _plot_3d(self, title="", point_size=10):
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        
        colors = self.color_map
        
        split_idx = [sum([len(t) for t in self.tensors[:i+1]]) for i in range(len(self.tensors))][:-1]
        for idx, data in enumerate(np.split(self.reduced_data, split_idx)):
            ax.scatter(data[:,0], data[:,1], data[:,2], color=colors[idx], label=self.labels[idx], s=point_size)

        ax.set_title(title)
        plt.legend(loc='upper left', bbox_to_anchor=(1.3, 1))
        plt.show()
        plt.cla(); plt.clf()

        return None


# 仮定: features_1とfeatures_2は、2つの異なるデータセットから抽出された特徴ベクトルの配列です。
# これらの特徴ベクトルは、Inceptionネットワークを使用して抽出されます。
def extract_features(directory):
    model = inception_v3(pretrained=True)
    model = model.eval()  # 評価モードに設定
    transform = transforms.Compose([
        transforms.Resize((299, 299)),  # InceptionV3の入力サイズにリサイズ
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNetの正規化パラメータ
    ])
    features = []

    for img_path in directory.glob("*.jpg"):
        img = Image.open(str(img_path)).convert('RGB')
        x = transform(img).unsqueeze(0)  # バッチの次元を追加
        with torch.no_grad():  # 勾配計算を無効化
            feature = model(x)
        features.append(feature.squeeze().numpy())
    return np.array(features)



In [None]:
data_list = [
    {
        "name": "Parade",
        "path": Path("../../sample_data/WIDER_OpenData/0--Parade/"),
    },
    {
        "name": "Handshaking",
        "path": Path("../../sample_data/WIDER_OpenData/1--Handshaking/"),
    },
    {
        "name": "Demonstration",
        "path": Path("../../sample_data/WIDER_OpenData/2--Demonstration/"),
    },
    {
        "name": "Riot",
        "path": Path("../../sample_data/WIDER_OpenData/3--Riot/"),
    },
]

features_list = []
for data in data_list:
    features = torch.tensor(extract_features(data["path"]))
    features_list.append(features)


plotter = TensorPlotter(features_list, labels=[data["name"] for data in data_list])
plotter.pca_plot(dim=3, plot_title_append=" (PCA)")
plotter.pca_plot(dim=2, plot_title_append=" (PCA)")