In [9]:
import os
import torch
import torchvision
import torch.nn as nn
import pickle
import pylab
import numpy as np
import scipy
import torch.optim as optim
import pandas as pd
import torchvision.datasets as datasets

from sklearn.neighbors import KNeighborsClassifier
from sklearn.neighbors import KernelDensity
from sklearn.preprocessing import MinMaxScaler

from scipy.stats import shapiro, normaltest

from torchvision import transforms

from copy import deepcopy

# Local imports
from local_models import *
from helper_functions import *
from piece_hurdle_model import *
from optimize_explanations import *
from evaluation_metrics import *

from IPython.display import Image

In [10]:
def load_fashion_dataloaders():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    train_set = torchvision.datasets.FashionMNIST(
        root='./data/after_anon_review',
        train=True,
        download=True,
        transform=transform
    )
    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=1,
        shuffle=False
    )

    test_set = torchvision.datasets.FashionMNIST(
        root='./data/after_anon_review',
        train=False,
        download=True,
        transform=transform
    )
    test_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=1,
        shuffle=False
    )

    return train_loader, test_loader

In [11]:
G, cnn = load_models(CNN, Generator)

#train_loader, test_loader = load_dataloaders()
train_loader, test_loader = load_fashion_dataloaders()
#X_train, y_train, X_test, y_test = get_MNIST_data(datasets)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/after_anon_review\FashionMNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:00<00:00, 27398872.31it/s]


Extracting ./data/after_anon_review\FashionMNIST\raw\train-images-idx3-ubyte.gz to ./data/after_anon_review\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/after_anon_review\FashionMNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 1593672.45it/s]


Extracting ./data/after_anon_review\FashionMNIST\raw\train-labels-idx1-ubyte.gz to ./data/after_anon_review\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/after_anon_review\FashionMNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:00<00:00, 17149572.65it/s]


Extracting ./data/after_anon_review\FashionMNIST\raw\t10k-images-idx3-ubyte.gz to ./data/after_anon_review\FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/after_anon_review\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<?, ?it/s]

Extracting ./data/after_anon_review\FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to ./data/after_anon_review\FashionMNIST\raw






In [12]:
def get_MNIST_data(datasets):
	mnist_trainset = datasets.MNIST(root='./data/mnist_train', train=True, download=True, transform=None)
	mnist_testset = datasets.MNIST(root ='./data/mnist_test', train=False, download=True, transform=None)
	X_train = mnist_trainset.data
	y_train = mnist_trainset.targets
	X_test = mnist_testset.data
	y_test = mnist_testset.targets
	return X_train, y_train, X_test, y_test

In [13]:
def return_feature_contribution_data(data_loader, cnn, num_classes=10):
    
    full_data = dict()
    pred_idx = dict() 

    for class_name in list(range(num_classes)):
        pred_idx[class_name] = list()
        
    for i, data in enumerate(data_loader):
        # print progress
        if i % 10000 == 0:
            print(  100 * round(i / len(data_loader), 2), "% complete..."  )     
        image, label = data
        label = int(label.detach().numpy())
        acts = cnn(image)[1][0].detach().numpy()
        pred = int(torch.argmax(  cnn(image)[0]  ).detach().numpy()) 
        pred_idx[pred].append(acts.tolist())
                
    return pred_idx

In [14]:
collected_data = return_feature_contribution_data(train_loader, cnn)
dist_data = {}

# 假设 num_classes 是类别的数量
num_classes = 10

# 为每个类别创建一个空列表
for class_name in range(num_classes):
    dist_data[class_name] = {'activations': []}

# 将 pred_idx_train 中的数据填充到 dist_data 中
for class_name, activations_list in collected_data.items():
    # 将 activations_list 转换为 numpy 数组
    activations_array = np.array(activations_list)
    # 将 activations_array 存储到 dist_data 对应的类别中
    dist_data[class_name]['activations'] = activations_array

  label = int(label.detach().numpy())


0.0 % complete...
17.0 % complete...
33.0 % complete...
50.0 % complete...
67.0 % complete...
83.0 % complete...


In [15]:
with open('collected_fashion.pickle', 'wb') as handle:
    pickle.dump(dist_data, handle, protocol=pickle.HIGHEST_PROTOCOL)

print("数据已成功存储为 pickle 文件。")

数据已成功存储为 pickle 文件。
