In [9]:
# load model
# run all evaluations
import os
import torch
import itertools
import numpy as np
from sklearn.metrics import confusion_matrix
from dataset import MnistDataset
from utils import load_original_paper_model
from hydra import initialize, compose
from torch.utils.data import DataLoader

import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('TkAgg')

from time import time
from sklearn import metrics
from sklearn.cluster import KMeans, SpectralClustering
from scipy.optimize import linear_sum_assignment

In [10]:
device = torch.device("mps")
paper_model = load_original_paper_model()
paper_model.to(device)
cfg_name = "train"
with initialize(version_base=None, config_path="../configs"):
    cfg = compose(config_name=cfg_name)
model_dict = {
    "ucc_original": {
        "experiment_id":"546114710104461663",
        "run_id":"e3e102e875734beb84c3ccb6ca2f46e4"
    },
    # "drn_1_wasserstein": {
    #     "experiment_id": "545254671228892812",
    #     "run_id":"59194205951142bd832a43b511edb19d"
    # },
    # "drn_cross_entropy": {
    #     "experiment_id": "189454739472380536",
    #     "run_id":"98d0cca708cc4f5ba40314aa134af4cd"
    # },
    # "drn_jsd": {
    #     "experiment_id": "657963627447301561",
    #     "run_id": "8b5480c32a224f23803e7d69082e41ba"
    # },
    # "ucc_ori_djs":{
    #     "experiment_id": "596489502680300638",
    #     "run_id": "83bdbc58e8ff4ffaab1eb67ed8f49623"
    # }
}

In [11]:
def plot_confusion_matrix(cm, classes,
						  normalize=False,
						  title='Confusion matrix',
						  cmap=plt.cm.Blues):
	"""
	This function prints and plots the confusion matrix.
	Normalization can be applied by setting `normalize=True`.
	"""
	if normalize:
		cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
		print("Normalized confusion matrix")
	else:
		print('Confusion matrix, without normalization')

	# print(cm)

	plt.imshow(cm, interpolation='nearest', cmap=cmap)
	plt.title(title)
	plt.colorbar()
	tick_marks = np.arange(len(classes))
	plt.xticks(tick_marks, classes, rotation=0)
	plt.yticks(tick_marks, classes)

	fmt = '.3f' if normalize else 'd'
	thresh = cm.max() / 2.
	for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
		plt.text(j, i, format(cm[i, j], fmt),
				 horizontalalignment="center",
				 color="white" if cm[i, j] > thresh else "black")

	plt.ylabel('True label')
	plt.xlabel('Predicted label')
	plt.tight_layout()
 
def kde(data: torch.Tensor, num_nodes, sigma):
    device = data.device
    # data shape: (batch_size, num_instances, num_features)
    batch_size, num_instances, num_features = data.shape

    # Create sample points
    k_sample_points = (
        torch.linspace(0, 1, steps=num_nodes)
        .repeat(batch_size, num_instances, 1)
        .to(device)
    )

    # Calculate constants
    k_alpha = 1 / np.sqrt(2 * np.pi * sigma**2)
    k_beta = -1 / (2 * sigma**2)

    # Iterate over features and calculate kernel density estimation for each feature
    out_list = []
    for i in range(num_features):
        one_feature = data[:, :, i: i + 1].repeat(1, 1, num_nodes)
        k_diff_2 = (k_sample_points - one_feature) ** 2
        k_result = k_alpha * torch.exp(k_beta * k_diff_2)
        k_out_unnormalized = k_result.sum(dim=1)
        k_norm_coeff = k_out_unnormalized.sum(dim=1).view(-1, 1)
        k_out = k_out_unnormalized / k_norm_coeff.repeat(
            1, k_out_unnormalized.size(1)
        )
        out_list.append(k_out)

    # Concatenate the results
    concat_out = torch.cat(out_list, dim=-1)
    return concat_out

def js_divergence(p,q):
    m = 0.5*(p+q)
    log_p_over_m = np.log2(p/m)
    log_q_over_m = np.log2(q/m)

    return 0.5*np.sum(p*log_p_over_m) + 0.5*np.sum(q*log_q_over_m)

def cluster(estimator=None, data=None):
	t0 = time()
	estimator.fit(data)
	predicted_clustering_labels = estimator.labels_

	return predicted_clustering_labels

In [None]:
mode = "test"
for model_name, values in model_dict.items():
    print(f"Running evaluation for {model_name} ......")
    experiment_id = values["experiment_id"]
    run_id = values["run_id"]
    model_name = f"{model_name}_{experiment_id}_{run_id}"
    save_directory = f"evaluation/{model_name}/{mode}"
    model = torch.load(f"mlruns/{experiment_id}/{run_id}/artifacts/best_model/data/model.pth", weights_only=False)
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)
    truth_labels_filename = f"{save_directory}/ucc_truth_labels.txt"
    predicted_labels_filename = f"{save_directory}/ucc_predicted_labels.txt"
    acc_filename = f"{save_directory}/ucc_accuracy.txt"
    conf_mat_filename = f"{save_directory}/ucc_confusion_matrix.png"
    confusion_matrix_normalized_filename = f"{save_directory}/ucc_confusion_matrix_normalized.png"
    confusion_matrix_unnormalized_filename = f"{save_directory}/ucc_confusion_matrix_unnormalized.png"
    features_filename = f"{save_directory}/extracted_features.txt"
    clustering_file_name = f"{save_directory}/cluster.txt"
    predicted_cluster_nn_labels_filename = f"{save_directory}/cluster_predicted_labels_nn.txt"
    predicted_cluster_spectral_labels_filename = f"{save_directory}/cluster_predicted_labels_spectral.txt"
    clustering_spectral_acc_filename = f"{save_directory}/cluster_accuracy_spectral.txt"
    clustering_nn_acc_filename = f"{save_directory}/cluster_accuracy_nn.txt"
    distributions_filename = f"{save_directory}/distributions.txt"
    mean_features_filename = f"{save_directory}/mean_features.txt"
    fig_filename = f"{save_directory}/distributions.png"
    js_divergence_filename = f"{save_directory}/js_divergence.txt"
    js_divergence_fig_filename = f"{save_directory}/js_divergence.png"

    class_names = ["ucc1", "ucc2", "ucc3", "ucc4"]
    truth_labels_list = list()
    predicted_labels_list = list()
    num_classes = len(class_names)
    num_instances = 32
    num_samples_per_class = int(40/num_classes)
    num_batches = (252*20)//num_samples_per_class

    dataset = MnistDataset(
        num_instances = num_instances,
        num_samples_per_class = num_samples_per_class,
        digit_arr = list(range(0,10)),
        ucc_start = 1,
        ucc_end = 4,
        mode=mode,
        length=num_batches
    )
    dataloader = DataLoader(dataset, batch_size=20)
    with torch.no_grad():
        for batch_samples, batch_labels in dataloader:
            batch_samples = batch_samples.to(device)
            truth_labels_list+=batch_labels.numpy().flatten().tolist()    
            pred = model(batch_samples)
            _, pred = torch.max(pred, dim=1)
            predicted_labels_list+=(pred.to("cpu").numpy().flatten()).tolist()
        with open(truth_labels_filename,'ab') as f_truth_labels_filename:
            np.savetxt(f_truth_labels_filename, truth_labels_list, fmt='%d', delimiter='\t')
        with open(predicted_labels_filename,'ab') as f_predicted_labels_filename:
            np.savetxt(f_predicted_labels_filename, predicted_labels_list, fmt='%d', delimiter='\t')

    truth_labels_arr = np.array(truth_labels_list)
    print('truth_labels_arr shape:{}'.format(truth_labels_arr.shape))

    predicted_labels_arr = np.array(predicted_labels_list)
    print('predicted_labels_arr shape:{}'.format(predicted_labels_arr.shape))

    conf_mat = confusion_matrix(truth_labels_arr, predicted_labels_arr)
    print('conf_mat shape:{}'.format(conf_mat.shape))

    np.savetxt(conf_mat_filename, conf_mat, fmt='%d', delimiter='\t')

    ucc_acc = np.sum(conf_mat.diagonal())/np.sum(conf_mat)
    np.savetxt(acc_filename, ucc_acc.reshape((-1,1)), fmt='%.4f', delimiter='\t')

    fig1 = plt.figure(figsize=(9,9))
    plot_confusion_matrix(conf_mat, classes=class_names, normalize=True, title='Confusion matrix')
    fig1.savefig(confusion_matrix_normalized_filename, bbox_inches='tight')

    fig2 = plt.figure(figsize=(9,9))
    plot_confusion_matrix(conf_mat, classes=class_names, normalize=False, title='Confusion matrix, without normalization')
    fig2.savefig(confusion_matrix_unnormalized_filename, bbox_inches='tight')
    
    
#   extract_features
    splitted_dataset = np.load('../data/mnist/splitted_mnist_dataset.npz')
    x_data = splitted_dataset[f'x_{mode}']
    y_data = splitted_dataset[f'y_{mode}']
    batch_size= model.batch_size 

    x_data = x_data.reshape(x_data.shape[0], 1, x_data.shape[1], x_data.shape[2])
    x_data = x_data.astype('float32')
    x_data /= 255
    x_data = (x_data-np.mean(x_data,axis=(1,2,3))[:,np.newaxis,np.newaxis,np.newaxis])/np.std(x_data,axis=(1,2,3))[:,np.newaxis,np.newaxis,np.newaxis]

    num_samples = x_data.shape[0]
    num_batches = num_samples//batch_size
    last_batch_size = int(num_samples%batch_size)
    with torch.no_grad():
        for i in range(num_batches):
            print('Batch %d/%d' % (i,num_batches))
            batch_data = x_data[i*batch_size:(i+1)*batch_size]
            batch_data = torch.tensor(batch_data).to(device)
            # batch_label = y_data[i*batch_size:(i+1)*batch_size]
            features_out = model.encoder(batch_data).detach().to("cpu").numpy()
            with open(features_filename, 'ab') as f_features_file:
                np.savetxt(f_features_file, features_out.reshape((-1,cfg.args.num_features)), fmt='%5.3f', delimiter='\t')
   
        # obtain distributions
    labels_arr = splitted_dataset[f'y_{mode}']
    features_arr = np.loadtxt(features_filename, comments='#', delimiter='\t', dtype='float32')
    num_features = features_arr.shape[1]

    print(features_arr.shape)
    for i in range(10):
        digit_key = 'digit' + str(i)
        digit_value = i
        # print('digit_key: {}'.format(digit_key))

        temp_indices = np.where(labels_arr == digit_value)[0]
        # concat_size = len(temp_indices)

        batch_data = (features_arr[temp_indices,:])[np.newaxis,:,:]
        print('batch_data shape:{}'.format(batch_data.shape))

        temp_mean_features = np.mean(batch_data[0,:,:], axis=0).reshape((1,-1))

        temp_distributions = kde(torch.tensor(batch_data), num_nodes=11, sigma=0.1)
        # print(temp_distributions.shape)

        with open(distributions_filename, 'ab') as f_distributions_file:
            np.savetxt(f_distributions_file, temp_distributions, fmt='%5.3f', delimiter='\t')
        
        with open(mean_features_filename, 'ab') as f_mean_features_file:
            np.savetxt(f_mean_features_file, temp_mean_features, fmt='%5.3f', delimiter='\t')
            
    # plot distributions
    num_bins = 11
    num_features = 10
    classes = [str(x) for x in range(0,10)]
    distributions_arr = np.loadtxt(distributions_filename, comments='#', delimiter='\t', dtype='float32')
    hist_max = np.amax(distributions_arr)

    m_color = ['#e6194b', '#3cb44b', '#ffe119', '#4363d8', '#f58231', '#911eb4', '#46f0f0', '#f032e6', '#bcf60c', '#fabebe', '#008080', '#e6beff', '#9a6324', '#fffac8', '#800000', '#aaffc3', '#808000', '#ffd8b1', '#000075', '#808080', '#ffffff', '#000000']
    m_shape=['-',':']
    fig, ax = plt.subplots(len(classes),num_features,figsize=(16, 10))
    fig.subplots_adjust(left=0.04, right=0.98, top=0.95, bottom=0.1)
    # plt.suptitle('deneme')
    # fig.canvas.set_window_title("heh")

    legend_color = list()
    for h in range(num_features):
        temp_hist_data = distributions_arr[:,h*num_bins:(h+1)*num_bins]

        for c in range(len(classes)):
            data = temp_hist_data[c,:]

            ax1 = ax[c,h]
            bp, = ax1.plot(np.arange(0,1.1,0.1),data,c=m_color[c],linestyle='-')

            if h == (num_features -1):
                legend_color.append(bp)

            ax1.tick_params(axis='both',labelsize=6)
            ax1.tick_params(axis='x',rotation=90)
            ax1.set_ylim((-0.05,hist_max+0.05))
            ax1.yaxis.set_ticks(np.arange(0, hist_max+0.05, 0.1))
            ax1.yaxis.grid(True, linestyle='-', which='major', color='lightgrey', alpha=0.5)
            ax1.set_xlim((-0.1,1.1))
            ax1.xaxis.set_ticks(np.arange(0, 1.1, 0.2))
            # ax1.xaxis.set_tick_params(labelsize=6, label1On=False)
            ax1.xaxis.grid(True, linestyle='-', which='major', color='lightgrey', alpha=0.5)
            ax1.set_axisbelow(True)

    fig.tight_layout()
    fig.subplots_adjust(left=0.03, right=0.98, top=0.96, bottom=0.10)
    fig.legend(legend_color, classes, loc='lower center', bbox_to_anchor=(0.5, 0.01), fancybox=True, shadow=True, ncol=10, fontsize=15)

    # fig.savefig(fig_filename, bbox_inches='tight')
    fig.savefig(fig_filename, bbox_inches='tight')
            
            # plot js divergence
    distributions_arr = np.loadtxt(distributions_filename, comments='#', delimiter='\t', dtype='float32')

    num_features = distributions_arr.shape[1]/num_bins

    distributions_arr /= num_features

    js_divergence_arr = np.zeros((10,10))
    for i in range(10):
        p = np.clip(distributions_arr[i,:],1e-12,1)
        for k in range(i,10):
            q = np.clip(distributions_arr[k,:],1e-12,1)
            js_divergence_arr[i,k] = js_divergence(p,q)
            js_divergence_arr[k,i] = js_divergence_arr[i,k]

    print('JS divergence: min={:.2f} - max={:.3f} - mean={:.3f} - std={:.3f}'.format(np.amin(js_divergence_arr[js_divergence_arr>0]),np.amax(js_divergence_arr),np.mean(js_divergence_arr),np.std(js_divergence_arr)))

    np.savetxt(js_divergence_filename, js_divergence_arr, fmt='%.3f', delimiter='\t', comments='# ')

    fig = plt.figure(figsize=(10,9))

    plt.imshow(js_divergence_arr, interpolation='nearest', cmap=plt.cm.Blues, vmin=0, vmax=1 )
    plt.title('$D_{\mathcal{JS}}(\mathcal{P}||\mathcal{Q})$')
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=0)
    plt.yticks(tick_marks, classes)

    fmt = '.3f'
    for m, n in itertools.product(range(js_divergence_arr.shape[0]), range(js_divergence_arr.shape[1])):
        plt.text(n, m, format(js_divergence_arr[m, n], fmt),
                    horizontalalignment="center",
                    color="black")

    plt.ylabel('$\mathcal{P}$')
    plt.xlabel('$\mathcal{Q}$')

    fig.tight_layout()
    fig.savefig(js_divergence_fig_filename, bbox_inches='tight')
    
# clustering
    features_arr = np.loadtxt(features_filename, comments='#', delimiter='\t', dtype='float32')
    clustering_type = "kmeans"
    all_classes_arr = np.arange(10)

    # cluster all data
    num_clusters = len(all_classes_arr)

    estimator = KMeans(n_clusters=num_clusters, init='k-means++', n_init=10)
    predicted_labels_arr = cluster(estimator=estimator, data=features_arr)
    np.savetxt(predicted_cluster_nn_labels_filename, predicted_labels_arr.reshape((-1,1)), fmt='%d', delimiter='\t')

    estimator = SpectralClustering(n_clusters=num_clusters, eigen_solver='arpack', affinity="nearest_neighbors")
    predicted_labels_arr = cluster(estimator=estimator, data=features_arr)
    np.savetxt(predicted_cluster_spectral_labels_filename, predicted_labels_arr.reshape((-1,1)), fmt='%d', delimiter='\t')

#   calculating clustering accuracy
    truth_labels_arr = splitted_dataset['y_test']
    predicted_nn_labels_arr = np.loadtxt(predicted_cluster_nn_labels_filename, delimiter='\t', dtype='int')
    predicted_spectral_labels_arr = np.loadtxt(predicted_cluster_spectral_labels_filename, delimiter='\t', dtype='int')

    cost_matrix = np.zeros((10,10))
    num_samples = np.zeros(10)
    for truth_val in range(10):
        # print('truth_val:{}'.format(truth_val))
        temp_sample_indices = np.where(truth_labels_arr == truth_val)[0]
        num_samples[truth_val] = temp_sample_indices.shape[0]

        temp_predicted_labels = predicted_nn_labels_arr[temp_sample_indices]

        for predicted_val in range(10):

            temp_matching_pairs = np.where(temp_predicted_labels == predicted_val)[0]

            cost_matrix[truth_val,predicted_val] = 1- (temp_matching_pairs.shape[0]/temp_sample_indices.shape[0])

            # print('predicted_val:{}'.format(predicted_val))
            # print('num samples:{}'.format(temp_sample_indices.shape[0]))
            # print('num matching pairs:{}'.format(temp_matching_pairs.shape[0]))
            # print('accuracy:{}'.format(temp_matching_pairs.shape[0]/temp_sample_indices.shape[0]))
            # print('cost:{}'.format(1- (temp_matching_pairs.shape[0]/temp_sample_indices.shape[0])))

    # print(np.round(cost_matrix,3))

    row_ind, col_ind = linear_sum_assignment(cost_matrix)

    cost = cost_matrix[row_ind,col_ind]

    clustering_acc = ((1-cost)*num_samples).sum() / num_samples.sum()

    # print(row_ind)
    # print(col_ind)
    # print(np.round(cost,3))
    # print(num_samples)
    print('Clustering acc:{}'.format( ((1-cost)*num_samples).sum() / num_samples.sum() ) )

    np.savetxt(clustering_nn_acc_filename, clustering_acc.reshape((-1,1)), fmt='%.4f', delimiter='\t')
    
    predicted_spectral_labels_arr = np.loadtxt(predicted_cluster_spectral_labels_filename, delimiter='\t', dtype='int')

    cost_matrix = np.zeros((10,10))
    num_samples = np.zeros(10)
    for truth_val in range(10):
        # print('truth_val:{}'.format(truth_val))
        temp_sample_indices = np.where(truth_labels_arr == truth_val)[0]
        num_samples[truth_val] = temp_sample_indices.shape[0]

        temp_predicted_labels = predicted_spectral_labels_arr[temp_sample_indices]

        for predicted_val in range(10):

            temp_matching_pairs = np.where(temp_predicted_labels == predicted_val)[0]

            cost_matrix[truth_val,predicted_val] = 1- (temp_matching_pairs.shape[0]/temp_sample_indices.shape[0])

            # print('predicted_val:{}'.format(predicted_val))
            # print('num samples:{}'.format(temp_sample_indices.shape[0]))
            # print('num matching pairs:{}'.format(temp_matching_pairs.shape[0]))
            # print('accuracy:{}'.format(temp_matching_pairs.shape[0]/temp_sample_indices.shape[0]))
            # print('cost:{}'.format(1- (temp_matching_pairs.shape[0]/temp_sample_indices.shape[0])))

    # print(np.round(cost_matrix,3))

    row_ind, col_ind = linear_sum_assignment(cost_matrix)

    cost = cost_matrix[row_ind,col_ind]

    clustering_acc = ((1-cost)*num_samples).sum() / num_samples.sum()

    # print(row_ind)
    # print(col_ind)
    # print(np.round(cost,3))
    # print(num_samples)
    print('Clustering acc:{}'.format( ((1-cost)*num_samples).sum() / num_samples.sum() ) )

    np.savetxt(clustering_spectral_acc_filename, clustering_acc.reshape((-1,1)), fmt='%.4f', delimiter='\t')

Running evaluation for ucc_original ......
x_train shape: torch.Size([50000, 1, 28, 28])
50000 train samples
10000 val samples
tensor(0.3105)
tensor(0.1325)
10000 test samples
truth_labels_arr shape:(504,)
predicted_labels_arr shape:(504,)
conf_mat shape:(4, 4)
Normalized confusion matrix
Confusion matrix, without normalization
Batch 0/500
Batch 1/500
Batch 2/500
Batch 3/500
Batch 4/500
Batch 5/500
Batch 6/500
Batch 7/500
Batch 8/500
Batch 9/500
Batch 10/500
Batch 11/500
Batch 12/500
Batch 13/500
Batch 14/500
Batch 15/500
Batch 16/500
Batch 17/500
Batch 18/500
Batch 19/500
Batch 20/500
Batch 21/500
Batch 22/500
Batch 23/500
Batch 24/500
Batch 25/500
Batch 26/500
Batch 27/500
Batch 28/500
Batch 29/500
Batch 30/500
Batch 31/500
Batch 32/500
Batch 33/500
Batch 34/500
Batch 35/500
Batch 36/500
Batch 37/500
Batch 38/500
Batch 39/500
Batch 40/500
Batch 41/500
Batch 42/500
Batch 43/500
Batch 44/500
Batch 45/500
Batch 46/500
Batch 47/500
Batch 48/500
Batch 49/500
Batch 50/500
Batch 51/500
Batc

: 