In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # Set the current GPU device to only device number 0, with device name '/gpu: 0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # Set the log output information, which is the information printed by the system when the program is running. Display only warning and error
import torch
if torch.cuda.is_available():
    print('CUDA is available!')
else:
    print('CUDA is not available!')

from utils import *
import matplotlib.pyplot as plt
import warnings
from models import *
from dataset import *

%matplotlib inline
warnings.filterwarnings("ignore")  # Ignore warning messages

In [12]:
import pickle
def generate_file_path(out_dir, middle_size, class_num, model_name, train_idx):
    path = get_path_metrics(out_dir+"loss_acc", middle_size, class_num, model_name, EPOCH, train_idx, extension=f".pickle")
    return path

def load_data_from_file(path):
    with open(path, "rb") as f:
        return pickle.load(f)

def compute_stats(data):
    data = np.array(data)
    avg = np.mean(data, axis=0)
    std = np.std(data, axis=0)
    var_up = avg - std
    var_lower = avg + std
    return avg, var_up, var_lower

def evaluate_history_contrast_loss_train_test(out_dir, middle_size, class_num, num_times=100):
    loss_train_test_all1, loss_train_test_all2, loss_train_test_all3, loss_train_test_all4, loss_train_test_all5 = [], [], [], [], []
    for i in range(1, num_times + 1):
        path1 = generate_file_path(out_dir, middle_size, class_num, MODEL_NAME[0], i)
        path2 = generate_file_path(out_dir, middle_size, class_num, MODEL_NAME[1], i)
        path3 = generate_file_path(out_dir, middle_size, class_num, MODEL_NAME[2], i)
        path4 = generate_file_path(out_dir, middle_size, class_num, MODEL_NAME[3], i)
        path5 = generate_file_path(out_dir, middle_size, class_num, MODEL_NAME[4], i)

        loss_train_test_all1.append(load_data_from_file(path1))
        loss_train_test_all2.append(load_data_from_file(path2))
        loss_train_test_all3.append(load_data_from_file(path3))
        loss_train_test_all4.append(load_data_from_file(path4))
        loss_train_test_all5.append(load_data_from_file(path5))

    avg1, var_up_1, var_lower_1 = compute_stats(loss_train_test_all1)
    avg2, var_up_2, var_lower_2 = compute_stats(loss_train_test_all2)
    avg3, var_up_3, var_lower_3 = compute_stats(loss_train_test_all3)
    avg4, var_up_4, var_lower_4 = compute_stats(loss_train_test_all4)
    avg5, var_up_5, var_lower_5 = compute_stats(loss_train_test_all5)

    data_to_save = {
        MODEL_NAME[0]: [loss_train_test_all1, avg1, var_up_1, var_lower_1],
        MODEL_NAME[1]: [loss_train_test_all2, avg2, var_up_2, var_lower_2],
        MODEL_NAME[2]: [loss_train_test_all3, avg3, var_up_3, var_lower_3],
        MODEL_NAME[3]: [loss_train_test_all4, avg4, var_up_4, var_lower_4],
        MODEL_NAME[4]: [loss_train_test_all5, avg5, var_up_5, var_lower_5],
    }
    # Format filename with parameters
    file_name = f"./result/loss_accuracy/all_data_middle_size_{middle_size}_class_num_{class_num}_times{num_times}.pkl"
    ensure_dir_exists("./result/loss_accuracy/")

    with open(file_name, 'wb') as file:
        pickle.dump(data_to_save, file)

    return data_to_save

from textwrap import wrap

def plot_variance_comparison_legend_center_right(data, metric_index, metric_name, ax):
    epochs = np.arange(1, 11)
    colors = ['skyblue', 'orange', 'green', 'red', 'purple']  # Distinct colors for each CNN type
    for idx, (cnn_type, color) in enumerate(zip(data.keys(), colors)):
        lower_bounds = data[cnn_type][2][metric_index]
        upper_bounds = data[cnn_type][3][metric_index]
        ax.fill_between(epochs, lower_bounds, upper_bounds, color=color, alpha=0.3, label=f'{cnn_type} Bounds')
    ax.set_xlabel('Epochs')
    ax.set_ylabel(f'{metric_name} Bounds')
    title = f'Comparison of {metric_name} Variance Bounds Across CNN Types'
    ax.set_title("\n".join(wrap(title, 40)))  # Wrap title based on the plot width
    ax.legend(loc='center right')  # Move legend to the center right

def plot_fig0_bounds_comparison_loss_train_test(out_dir, middle_size, class_num, num_times=100):
    # Create a figure with subplots for each metric's variance comparison
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    # Iterate over each metric to create the plots with error bars
    metrics = ['Loss', 'Training Accuracy', 'Validation Accuracy']
    data_instance = evaluate_history_contrast_loss_train_test(out_dir=out_dir, middle_size=middle_size, class_num=class_num, num_times=num_times)
    for i, metric_name in enumerate(metrics):
        plot_variance_comparison_legend_center_right(data_instance, i, metric_name, axes[i])
    # Adjust the layout
    plt.tight_layout()
    # plt.show()
    return fig


def plot_variance_comparison_error_bars(data, metric_index, metric_name, ax):
    epochs = np.arange(1, EPOCH+1)
    colors = ['skyblue', 'orange', 'green', 'red', 'purple']  # Distinct colors for each CNN type
    for idx, (cnn_type, color) in enumerate(zip(data.keys(), colors)):
        means = data[cnn_type][1][metric_index]
        lower_bounds = data[cnn_type][2][metric_index]
        upper_bounds = data[cnn_type][3][metric_index]
        error = [means - lower_bounds, upper_bounds - means]  # Asymmetric error bars
        ax.errorbar(epochs, means, yerr=error, fmt='-o', color=color, ecolor=color, elinewidth=2, capsize=4, label=f'{cnn_type}')
    ax.set_xlabel('Epochs')
    ax.set_ylabel(f'{metric_name} with Variance')
    title = f'Comparison of {metric_name} with Variance Across CNN Types'
    ax.set_title("\n".join(wrap(title, 40)))  # Wrap title based on the plot width
    ax.legend(loc='center right')  # Move legend to the center right
    ax.grid(True)  # Enable grid lines


def plot_fig1_loss_train_test(out_dir, middle_size, class_num, num_times=100):
    # Create a figure with subplots for each metric's variance comparison
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    # Iterate over each metric to create the plots with error bars
    metrics = ['Loss', 'Training Accuracy', 'Validation Accuracy']
    data_instance = evaluate_history_contrast_loss_train_test(out_dir=out_dir, middle_size=middle_size, class_num=class_num, num_times=num_times)
    for i, metric_name in enumerate(metrics):
        plot_variance_comparison_error_bars(data_instance, i, metric_name, axes[i])
    # Adjust the layout
    plt.tight_layout()
    # plt.show()
    return fig


def plot_final_epoch_comparison_adjusted(data, metric_index, metric_name, ax, y_limit=None):
    final_epoch_index = EPOCH-1  # Zero-based indexing for the 10th epoch
    means = [data[cnn_type][1][metric_index, final_epoch_index] for cnn_type in data.keys()]
    cnn_types = list(data.keys())
    colors = ['skyblue', 'orange', 'green', 'red', 'purple']  # Distinct colors for each CNN type
    alpha=0.5
    
    # ax.bar(cnn_types, means, color='skyblue')
    ax.bar(cnn_types, means, color=colors, alpha=alpha)
    ax.set_xlabel('CNN Types')
    ax.set_ylabel(metric_name)
    ax.set_title(f'Comparison of {metric_name} at Final Epoch')
    ax.set_xticklabels(cnn_types, rotation=45, ha='right')
    if y_limit:
        ax.set_ylim(y_limit)


def plot_fig2_bar_plots_final_epoch(out_dir, middle_size, class_num, num_times=100):
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    metrics = ['Loss', 'Training Accuracy', 'Validation Accuracy']
    data_instance = evaluate_history_contrast_loss_train_test(out_dir=out_dir, middle_size=middle_size, class_num=class_num, num_times=num_times)
    for i, metric_name in enumerate(metrics):
        y_limit = (0.9, 1.0) if metric_name != 'Loss' else None
        plot_final_epoch_comparison_adjusted(data_instance, i, metric_name, axes[i], y_limit=y_limit)
    # Adjust the layout
    plt.tight_layout()
    # plt.show()
    return fig


def plot_metrics_autowrap(data, cnn_type, metric_index, metric_name, sub_plot_index, ax, include_bounds=False):
    colors = ['skyblue', 'orange', 'green', 'red', 'purple']  # Distinct colors for each CNN type
    observations = np.array(data[cnn_type][0])[:, metric_index, :]
    means = data[cnn_type][1][metric_index]
    lower_bounds = data[cnn_type][2][metric_index]
    upper_bounds = data[cnn_type][3][metric_index]
    epochs = np.arange(1, 11)

    ax.plot(epochs, means, label=f'{cnn_type} Mean {metric_name}', marker='o', color=colors[sub_plot_index])

    if not include_bounds:
        for obs in observations:
            ax.plot(epochs, obs, linestyle='--', alpha=0.3, color=colors[sub_plot_index])

    if include_bounds:
        ax.fill_between(epochs, lower_bounds, upper_bounds, alpha=0.2, label=f'{cnn_type} {metric_name} Bounds', color=colors[sub_plot_index])

    ax.set_xlabel('Epochs')
    ax.set_ylabel(metric_name)
    title = f'{cnn_type} {metric_name} Across Epochs'
    ax.set_title("\n".join(wrap(title, 40)))  # Wrap title based on the plot width
    ax.legend()


def plot_fig3_plots_the_metric_across_epochs(out_dir, middle_size, class_num, num_times=100):
    fig, axes = plt.subplots(5, 3, figsize=(15, 25), sharex=True, sharey='row')
    metrics = ['Loss', 'Training Accuracy', 'Validation Accuracy']
    data_instance = evaluate_history_contrast_loss_train_test(out_dir=out_dir, middle_size=middle_size, class_num=class_num, num_times=num_times)
    cnn_types = list(data_instance.keys())
    for i, cnn_type in enumerate(cnn_types):
        for j, metric_name in enumerate(metrics):
            plot_metrics_autowrap(data_instance, cnn_type, j, metric_name, i, axes[i, j], include_bounds=(j == 0))
    # Adjust the layout
    plt.tight_layout()
    # plt.show()
    return fig


def plot_metric_distribution_autowrap(data, cnn_type, metric_index, metric_name, sub_plot_index, ax):
    colors = ['skyblue', 'orange', 'green', 'red', 'purple']  # Distinct colors for each CNN type
    observations = np.array(data[cnn_type][0])[:, metric_index, :]
    box = ax.boxplot(observations, labels=[f'Epoch {i+1}' for i in range(observations.shape[1])])
    for boxs in box['boxes']:
        # Set edge color
        boxs.set(color=colors[sub_plot_index])
    for median in box['medians']:
        median.set_color(colors[sub_plot_index])


    ax.set_xlabel('Epochs')
    ax.set_ylabel(metric_name)
    title = f'Distribution of {metric_name} Across Epochs for {cnn_type}'
    ax.set_title("\n".join(wrap(title, 40)))  # Wrap title based on the plot width
    ax.tick_params(axis='x', rotation=45)


def plot_fig4_box_plots_the_metric_across_epochs(out_dir, middle_size, class_num, num_times=100):
    fig, axes = plt.subplots(5, 3, figsize=(15, 25), sharex=True, sharey='row')
    metrics = ['Loss', 'Training Accuracy', 'Validation Accuracy']
    data_instance = evaluate_history_contrast_loss_train_test(out_dir=out_dir, middle_size=middle_size, class_num=class_num, num_times=num_times)
    cnn_types = list(data_instance.keys())
    for i, cnn_type in enumerate(cnn_types):
        for j, metric_name in enumerate(metrics):
            plot_metric_distribution_autowrap(data_instance, cnn_type, j, metric_name, i, axes[i, j])
    # Adjust the layout
    plt.tight_layout()
    # plt.show()
    return fig

In [None]:
dot_without_number_display_images(disp=True)

In [None]:
out_dir = "/data0/user/gfhao/code_cpm/ServerCode/output/"
middle_size = 16
class_num = 20
num_times = 100
fig1 = plot_fig1_loss_train_test(out_dir, middle_size, class_num, num_times)
fig2 = plot_fig2_bar_plots_final_epoch(out_dir, middle_size, class_num, num_times)

In [28]:
import pickle
from tqdm import tqdm
def generate_file_path_model(out_dir, middle_size, class_num, model_name, train_idx, epoch):
    filename = f"model_{train_idx}_epoch_{epoch}.pt"
    path = get_path(out_dir+"model_weights", middle_size, class_num, model_name, filename)
    return path


def generate_fid_path(middle_size, class_num, num_times):
    path = f"/data0/user/gfhao/code_cpm/ServerCode/result/fid/fid_data_middle_size_{middle_size}_class_num_{class_num}_times{num_times}.pkl"
    return path


def load_data_from_file(path):
    with open(path, "rb") as f:
        return pickle.load(f)


def get_history_fid(out_dir, middle_size, class_num, num_times=100):
    model_config = [8, 16, middle_size, middle_size * middle_size // 4]
    # print("middle size:", middle_size)
    fid1, fid2, fid3, fid4, fid5 = [], [], [], [], []
    epoch = EPOCH-1
    train_data, test_data = load_processed_dataset(data_dir="/data0/user/gfhao/datasets/mnist", batch_size=64, display=False, process=True, class_num=class_num)
    # print("datasets loaded!")
    for i in tqdm(range(1, num_times + 1), desc="Progress"):
        fid_all1, fid_matrix_all1 = compute_FID(test_data, i, epoch, model_name=MODEL_NAME[0], class_num=class_num, lambda_sparse=0.01, model_config=model_config, out_dir=out_dir)
        fid_all2, fid_matrix_all2 = compute_FID(test_data, i, epoch, model_name=MODEL_NAME[1], class_num=class_num, lambda_sparse=0.01, model_config=model_config, out_dir=out_dir)
        fid_all3, fid_matrix_all3 = compute_FID(test_data, i, epoch, model_name=MODEL_NAME[2], class_num=class_num, lambda_sparse=0.01, model_config=model_config, out_dir=out_dir)
        fid_all4, fid_matrix_all4 = compute_FID(test_data, i, epoch, model_name=MODEL_NAME[3], class_num=class_num, lambda_sparse=0.01, model_config=model_config, out_dir=out_dir)
        fid_all5, fid_matrix_all5 = compute_FID(test_data, i, epoch, model_name=MODEL_NAME[4], class_num=class_num, lambda_sparse=0.01, model_config=model_config, out_dir=out_dir)

        fid1.append([fid_all1, fid_matrix_all1])
        fid2.append([fid_all2, fid_matrix_all2])
        fid3.append([fid_all3, fid_matrix_all3])
        fid4.append([fid_all4, fid_matrix_all4])
        fid5.append([fid_all5, fid_matrix_all5])

    data_to_save = {
        MODEL_NAME[0]: fid1,
        MODEL_NAME[1]: fid2,
        MODEL_NAME[2]: fid3,
        MODEL_NAME[3]: fid4,
        MODEL_NAME[4]: fid5,
    }
    # Format filename with parameters
    file_name = f"./result/fid/fid_data_middle_size_{middle_size}_class_num_{class_num}_times{num_times}.pkl"
    ensure_dir_exists("./result/fid/")

    with open(file_name, 'wb') as file:
        pickle.dump(data_to_save, file)
    
    return data_to_save


def analyze_fid(middle_size, class_num, num_times):
    opposite_sign_counts = {"middle_size":round(middle_size, 0), "class_num":round(class_num, 0)}
    fid_path = generate_fid_path(middle_size, class_num, num_times)
    fid_data = load_data_from_file(fid_path)
    cnn_types = list(fid_data.keys())
    for i, cnn_type in enumerate(cnn_types):
        for j in range(0, num_times):
            fid_all = fid_data[cnn_type][j][0]
            # 计算两个差值
            diff1 = round(fid_all[0] - fid_all[1], 2)
            diff2 = round(fid_all[2] - fid_all[3], 2)
            if (diff1 * diff2) < 0:
                if cnn_type in opposite_sign_counts:
                    opposite_sign_counts[cnn_type] += 1
                else:
                    opposite_sign_counts[cnn_type] = 1
            else:
                if cnn_type not in opposite_sign_counts:
                    opposite_sign_counts[cnn_type] = 0
    # percentages = {k: (v / num_times) * 100 for k, v in opposite_sign_counts.items()}
    percentages = {k: v for k, v in opposite_sign_counts.items()}
    return percentages

In [None]:
out_dir = "/data0/user/gfhao/code_cpm/ServerCode/output/"
num_times = 100

class_num_runs = [5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,25,30,35,40,45,50]
class_num_runs = [x * 10 for x in class_num_runs]
for middle_size in [16,32,64,128,256]:
    for class_num_run in class_num_runs:
        print("class num: ", class_num_run)
        get_history_fid(out_dir, middle_size, class_num_run, num_times)

In [13]:
def cal_fid_percentages(middle_size, num_times, class_num_runs):
    percentages_per_middle_size = []
    for class_num_run in class_num_runs:
        percentages = analyze_fid(middle_size, class_num_run, num_times)
        percentages_per_middle_size.append(percentages)
    # print(percentages_per_middle_size)
    import matplotlib.pyplot as plt
    model_data = {
        'SimpleCNN': [],
        'DotProductCNN': [],
        'CrossProductCNN': [],
        'DotProductSparseCNN': [],
        'CrossProductSparseCNN': []
    }
    # colors = ['skyblue', 'orange', 'green', 'red', 'purple']  # Distinct 
    class_nums = class_num_runs
    for data in percentages_per_middle_size:
        for model in model_data.keys():
            model_data[model].append(data[model])
    fid_of_middle_size_data = {'middle_size_'+str(middle_size):model_data}
    return fid_of_middle_size_data

In [16]:
import pickle
num_times = 100
class_num_runs = [2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,25,30,35,40,45,50]
class_num_runs = [x * 10 for x in class_num_runs]
middle_sizes = [16, 32, 64, 128, 256]
fids_all = []
file_name = f"./result/fid_all/all_fid_data.pkl"
ensure_dir_exists("./result/fid_all/")
for middle_size in middle_sizes:
    fids_all.append(cal_fid_percentages(middle_size, num_times, class_num_runs))
with open(file_name, 'wb') as file:
    pickle.dump(fids_all, file)

In [26]:
import pickle
import pandas as pd

def load_data_from_file(path):
    with open(path, "rb") as f:
        return pickle.load(f)


def plot_fid_percentages_in_one_fig():
    file_name = f"./result/fid_all/all_fid_data.pkl"
    fid_data = load_data_from_file(file_name)
    class_num_runs = [2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,25,30,35,40,45,50]
    class_num_runs = [x * 10 for x in class_num_runs]
    class_nums = class_num_runs
    break_point = 200
    adjusted_class_nums = []
    for num in class_nums:
        if num < break_point:
            adjusted_class_nums.append(num)
        else:
            adjusted_class_nums.append(int(break_point + (num - break_point) / 2))
    df_data = []
    for element in fid_data:
        for middle_size, models in element.items():
            for model, percentages in models.items():
                for class_num_index, percentage in enumerate(percentages):
                    df_data.append({
                        'Middle Size': middle_size,
                        'Model': model,
                        'Class Number': (class_num_index + 2)*10,  # Starting from class_num 2
                        'Percentage': percentage
                    })
    # Creating DataFrame
    df = pd.DataFrame(df_data)
    # Setting up the figure for multiple subplots - one for each middle size
    fig, axes = plt.subplots(5, 1, figsize=(20, 5 * 5), sharey=True)
    # fig, axes = plt.subplots(5, 1, sharey=True)
    # New colors for different models
    model_colors = ['skyblue', 'orange', 'green', 'red', 'purple']
    model_color_map = {model: color for model, color in zip(df['Model'].unique(), model_colors)}
    # Creating line charts for each middle size
    for i, middle_size in enumerate(df['Middle Size'].unique()):
        ax = axes[i]
        middle_size_data = df[df['Middle Size'] == middle_size]
        for model in df['Model'].unique():
            model_data = middle_size_data[middle_size_data['Model'] == model]
            ax.plot(adjusted_class_nums, model_data['Percentage'], label=model, color=model_color_map[model], marker='o')

        # Setting the title for each subplot
        ax.set_title(f'Middle Size: {middle_size}')

        # Adding x-axis labels (class numbers) for the last subplot
        ax.set_xticks(adjusted_class_nums)
        ax.set_xticklabels(adjusted_class_nums, rotation=45)
        if i == 0:
            ax.legend()

    # General titles and labels
    plt.suptitle('Line Charts of Percentage Values by Class Number (Scaled), Model, and Middle Size')
    fig.text(0.5, 0.04, 'Class Number', ha='center', va='center')
    fig.text(0.06, 0.5, 'Percentage', ha='center', va='center', rotation='vertical')

    # Adjusting layout
    plt.tight_layout(rect=[0.03, 0.05, 1, 0.97])

    return fig

In [None]:
fid_all_fig = plot_fid_percentages_in_one_fig()

In [31]:
import pickle
import pandas as pd
import matplotlib.pyplot as plt

def load_data_from_file(path):
    with open(path, "rb") as f:
        return pickle.load(f)


def plot_fid_percentages_in_one_fig_all():
    file_name = f"./result/fid_all/all_fid_data.pkl"
    fid_data = load_data_from_file(file_name)
    class_num_runs = [2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,25,30,35,40,45,50]
    class_num_runs = [x * 10 for x in class_num_runs]
    class_nums = class_num_runs
    break_point = 200
    adjusted_class_nums = [num if num < break_point else int(break_point + (num - break_point) / 2) for num in class_num_runs]
    df_data = []
    for element in fid_data:
        for middle_size, models in element.items():
            for model, percentages in models.items():
                if model in ['SimpleCNN', 'DotProductCNN', 'CrossProductCNN']:
                    for class_num_index, percentage in enumerate(percentages):
                        df_data.append({
                            'Middle Size': middle_size,
                            'Model': model,
                            'Class Number': (class_num_index + 2)*10,  # Starting from class_num 2
                            'Percentage': percentage
                        })
    # Creating DataFrame
    df = pd.DataFrame(df_data)
    fig, axes = plt.subplots(2, 1, figsize=(15, 12), sharex=True)

    # Models to plot
    models = ['DotProductCNN', 'CrossProductCNN']

    for i, model in enumerate(models):
        ax = axes[i]
        model_data = df[df['Model'] == model]

        # Use a colormap for different middle sizes
        cmap = plt.cm.get_cmap('Accent', len(model_data['Middle Size'].unique()))

        for j, middle_size in enumerate(sorted(model_data['Middle Size'].unique())):
            middle_size_data = model_data[model_data['Middle Size'] == middle_size]
            ax.plot(adjusted_class_nums, middle_size_data['Percentage'], label=middle_size, color=cmap(j), marker='o')

        ax.set_title(f'FID Percentages for {model}')
        ax.set_xlabel('Class Number')
        ax.set_ylabel('Percentage')
        ax.legend()
        ax.grid(True)

    # Adding overall titles and labels
    plt.suptitle('FID Percentages by Class Number and Middle Size', fontsize=16)
    plt.xticks(adjusted_class_nums, adjusted_class_nums, rotation=45)

    # Adjust layout
    plt.tight_layout(rect=[0, 0.03, 1, 0.97])

    return fig

In [None]:
fig = plot_fid_percentages_in_one_fig_all()

In [None]:
out_dir = "/data0/user/gfhao/code_cpm/ServerCode/output/"

middle_size = 16
model_config = [8, 16, middle_size, middle_size * middle_size // 4]
print("middle size:", middle_size)
# class_num_runs = [2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,25,30,35,40,45,50]
class_num_runs = [2,3]
class_num_runs = [x * 10 for x in class_num_runs]
for class_num_run in class_num_runs:
    print("class num run: {}".format(class_num_run))
    train_data, test_data = load_processed_dataset(data_dir="/data0/user/gfhao/datasets/mnist", batch_size=64, display=False, process=True, class_num=class_num_run)
    print("datasets loaded!")
    for name in MODEL_NAME:
        print("model name: ", name)
        train_idx = 1
        epoch = 9
        print("train_idx: ", train_idx)
        print("epoch: ", epoch)
        fid_all, fid_matrix_all = compute_FID(test_data, train_idx, epoch, model_name=name, class_num=class_num_run, lambda_sparse=0.01, model_config=model_config, out_dir=out_dir)
        break
    break

In [43]:
def analyze_train_fid(middle_size, class_num, num_times):
    max_sign = {"middle_size":round(middle_size, 0), "class_num":round(class_num, 0)}
    max_val = 0
    fid_path = generate_fid_path(middle_size, class_num, num_times)
    fid_data = load_data_from_file(fid_path)
    cnn_types = list(fid_data.keys())
    cnn_type = cnn_types[2]  # 'CrossProductCNN'
    for j in range(0, num_times):
        fid_all = fid_data[cnn_type][j][0]
        # 计算两个差值
        diff1 = round(fid_all[0] - fid_all[1], 2)
        diff2 = round(fid_all[2] - fid_all[3], 2)
        if (diff1 * diff2) < 0:
            if cnn_type in max_sign:
                if abs(diff1 * diff2) > max_val:
                    max_val = abs(diff1 * diff2)
                    max_sign[cnn_type] = j+1
                    max_sign["max_diff_fid"] = fid_all
                    max_sign["max_diff_fid_matrix"] = fid_data[cnn_type][j][1]
            else:
                max_sign[cnn_type] = j+1
                max_sign["max_diff_fid"] = fid_all
                max_sign["max_diff_fid_matrix"] = fid_data[cnn_type][j][1]
        else:
            if cnn_type not in max_sign:
                max_sign[cnn_type] = None
                max_sign["max_diff_fid"] = None
                max_sign["max_diff_fid_matrix"] = None
    return max_sign

def analyze_train_fid_positive(middle_size, class_num, num_times):
    max_sign = {"middle_size":round(middle_size, 0), "class_num":round(class_num, 0)}
    max_val = 0
    fid_path = generate_fid_path(middle_size, class_num, num_times)
    fid_data = load_data_from_file(fid_path)
    cnn_types = list(fid_data.keys())
    cnn_type = cnn_types[2]  # 'CrossProductCNN'
    for j in range(0, num_times):
        fid_all = fid_data[cnn_type][j][0]
        diff1 = round(fid_all[0] - fid_all[1], 2)
        diff2 = round(fid_all[2] - fid_all[3], 2)
        if (diff1 * diff2) > 0:
            if cnn_type in max_sign:
                if abs(diff1 * diff2) > max_val:
                    max_val = abs(diff1 * diff2)
                    max_sign[cnn_type] = j+1
                    max_sign["max_diff_fid_positive"] = fid_all
                    max_sign["max_diff_fid_matrix_positive"] = fid_data[cnn_type][j][1]
            else:
                max_sign[cnn_type] = j+1
                max_sign["max_diff_fid_positive"] = fid_all
                max_sign["max_diff_fid_matrix_positive"] = fid_data[cnn_type][j][1]
        else:
            if cnn_type not in max_sign:
                max_sign[cnn_type] = None
                max_sign["max_diff_fid_positive"] = None
                max_sign["max_diff_fid_matrix_positive"] = None
    return max_sign

def get_train_history_fid(out_dir, middle_size, class_num, epoch, num_times=100):
    model_config = [8, 16, middle_size, middle_size * middle_size // 4]
    # print("middle size:", middle_size)
    fid1, fid2, fid3, fid4, fid5 = 0,0,0,0,0
    fid_all_5 = [0, 0, 0, 0, 0]
    train_history_fid_epoch = {"middle_size":round(middle_size, 0), "class_num":round(class_num, 0), "epoch":round(epoch, 0)}
    train_data, test_data = load_processed_dataset(data_dir="/data0/user/gfhao/datasets/mnist", batch_size=64, display=False, process=True, class_num=class_num)
    # print("datasets loaded!")
    for i in tqdm(range(1, num_times + 1), desc="Progress"):
        fid_all1, fid_matrix_all1 = compute_FID(test_data, i, epoch, model_name=MODEL_NAME[0], class_num=class_num, lambda_sparse=0.01, model_config=model_config, out_dir=out_dir)
        fid_all2, fid_matrix_all2 = compute_FID(test_data, i, epoch, model_name=MODEL_NAME[1], class_num=class_num, lambda_sparse=0.01, model_config=model_config, out_dir=out_dir)
        fid_all3, fid_matrix_all3 = compute_FID(test_data, i, epoch, model_name=MODEL_NAME[2], class_num=class_num, lambda_sparse=0.01, model_config=model_config, out_dir=out_dir)
        fid_all4, fid_matrix_all4 = compute_FID(test_data, i, epoch, model_name=MODEL_NAME[3], class_num=class_num, lambda_sparse=0.01, model_config=model_config, out_dir=out_dir)
        fid_all5, fid_matrix_all5 = compute_FID(test_data, i, epoch, model_name=MODEL_NAME[4], class_num=class_num, lambda_sparse=0.01, model_config=model_config, out_dir=out_dir)

        for j, fid_all_i in enumerate([fid_all1,fid_all2,fid_all3,fid_all4,fid_all5]):
            diff1 = round(fid_all_i[0] - fid_all_i[1], 2)
            diff2 = round(fid_all_i[2] - fid_all_i[3], 2)
            if (diff1 * diff2) < 0:
                fid_all_5[j] += 1

    fid1 = fid_all_5[0]
    fid2 = fid_all_5[1]
    fid3 = fid_all_5[2]
    fid4 = fid_all_5[3]
    fid5 = fid_all_5[4]

    data_to_save = {
        MODEL_NAME[0]: fid1,
        MODEL_NAME[1]: fid2,
        MODEL_NAME[2]: fid3,
        MODEL_NAME[3]: fid4,
        MODEL_NAME[4]: fid5,
    }
    # Format filename with parameters
    file_name = f"./result/fid_train/fid_data_middle_size_{middle_size}_class_num_{class_num}_times{num_times}_epoch{epoch}.pkl"
    ensure_dir_exists("./result/fid_train/")

    with open(file_name, 'wb') as file:
        pickle.dump(data_to_save, file)
    
    return data_to_save

In [None]:
class_num_runs = [2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,25,30,35,40,45,50]
class_num_runs = [x * 10 for x in class_num_runs]
middle_size = 32
num_times = 100
'''
for class_num_run in class_num_runs:
    max_sign_fid = analyze_train_fid(middle_size, class_num_run, num_times)
    print(max_sign_fid)
    break
'''
import itertools
import pickle
class_num = 100
max_sign_fid = analyze_train_fid(middle_size, class_num, num_times)
max_diff_fid_matrix = max_sign_fid["max_diff_fid_matrix"]
keys = list(max_sign_fid.keys())
print({k: max_sign_fid[k] for k in itertools.islice(max_sign_fid, 4)})
# Format filename with parameters
file_name = f"./result/matrix/middle_size_{middle_size}_class_num_{class_num}_times{num_times}_max_diff_fid_matrix.pkl"
ensure_dir_exists("./result/matrix/")
with open(file_name, 'wb') as file:
    pickle.dump(max_diff_fid_matrix, file)


In [23]:
import seaborn as sns
import pandas as pd
import matplotlib.colors as mcolors

start_color_hex = "#428abc"  # blue
end_color_hex = "#ed4074"   # red
# Converting HEX colors to RGB
start_color_rgb = mcolors.hex2color(start_color_hex)
end_color_rgb = mcolors.hex2color(end_color_hex)
white_rgb = mcolors.hex2color("#ffffff")

# Creating the colormap from white to blue (#428abc)
white_to_blue_colors = [white_rgb, start_color_rgb]
white_to_blue_cmap = mcolors.LinearSegmentedColormap.from_list("white_to_blue", white_to_blue_colors)

# Creating the colormap from white to red (#ed4074)
white_to_red_colors = [white_rgb, end_color_rgb]
white_to_red_cmap = mcolors.LinearSegmentedColormap.from_list("white_to_red", white_to_red_colors)
custom_cmap = [white_to_red_cmap, white_to_red_cmap, white_to_blue_cmap, white_to_blue_cmap]

def plot_fid_matrix_max(middle_size, class_num, num_times):
    file_path = f"./result/matrix/middle_size_{middle_size}_class_num_{class_num}_times{num_times}_max_diff_fid_matrix.pkl"
    data = pd.read_pickle(file_path)
    max_value_row1 = max(np.max(data[0]), np.max(data[1]))
    max_value_row2 = max(np.max(data[2]), np.max(data[3]))
    # custom_cmap = sns.diverging_palette(240, 0, s=80, l=55, n=9, as_cmap=True)
    fig, axs = plt.subplots(2, 2, figsize=(14, 12))
    # Plotting each matrix with the given titles
    titles = [
        "FID matrix about dot of F1 group",
        "FID matrix about number of F1 group",
        "FID matrix about dot of F2 group",
        "FID matrix about number of F2 group"
    ]
    for i, ax in enumerate(axs.flatten()):
        if i < len(data):
            vmax = max_value_row1 if i < 2 else max_value_row2
            # sns.heatmap(data[i], annot=False, cmap="coolwarm", ax=ax, vmin=0, vmax=vmax)
            sns.heatmap(data[i], annot=False, cmap=custom_cmap[i], ax=ax, vmin=0, vmax=vmax)
            ax.set_title(titles[i], fontsize=20, pad=20)
            ax.set_xlabel("")
            ax.set_ylabel("")
        else:
            ax.axis('off')  # Hide axis if there's no corresponding matrix

    plt.tight_layout()
    return fig

def plot_fid_matrix_positive_max(middle_size, class_num, num_times):
    file_path = f"./result/matrix/middle_size_{middle_size}_class_num_{class_num}_times{num_times}_max_diff_fid_matrix_positive.pkl"
    data = pd.read_pickle(file_path)
    max_value_row1 = max(np.max(data[0]), np.max(data[1]))
    max_value_row2 = max(np.max(data[2]), np.max(data[3]))
    # custom_cmap = sns.diverging_palette(240, 0, s=80, l=55, n=9, as_cmap=True)
    fig, axs = plt.subplots(2, 2, figsize=(14, 12))
    # Plotting each matrix with the given titles
    titles = [
        "FID matrix about dot of F1 group",
        "FID matrix about number of F1 group",
        "FID matrix about dot of F2 group",
        "FID matrix about number of F2 group"
    ]
    for i, ax in enumerate(axs.flatten()):
        if i < len(data):
            vmax = max_value_row1 if i < 2 else max_value_row2
            # sns.heatmap(data[i], annot=False, cmap="coolwarm", ax=ax, vmin=0, vmax=vmax)
            sns.heatmap(data[i], annot=False, cmap=custom_cmap[i], ax=ax, vmin=0, vmax=vmax)
            ax.set_title(titles[i], fontsize=20, pad=20)
            ax.set_xlabel("")
            ax.set_ylabel("")
        else:
            ax.axis('off')  # Hide axis if there's no corresponding matrix

    plt.tight_layout()
    return fig

In [None]:
class_num_runs = [2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,25,30,35,40,45,50]
class_num_runs = [x * 10 for x in class_num_runs]
middle_size = 32
num_times = 100
class_num = 100
fig = plot_fid_matrix_max(middle_size, class_num, num_times)

In [29]:
import shap
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

# Colors in HEX format
start_color_hex = "#428abc"  # blue
end_color_hex = "#ed4074"   # red
middle_color_hex = "#ffffff" # white
middle_middle_color_hex = "#d9d9d9" # gray

# Converting HEX colors to RGB
start_color_rgb = mcolors.hex2color(start_color_hex)
end_color_rgb = mcolors.hex2color(end_color_hex)
middle_color_rgb = mcolors.hex2color(middle_color_hex)
middle_middle_color_rgb = mcolors.hex2color(middle_middle_color_hex)

# Creating the colormap
custom_colors = [
    start_color_rgb, 
    middle_middle_color_rgb, 
    middle_color_rgb,  # Adding more white to expand the middle range
    middle_middle_color_rgb, 
    end_color_rgb
]
custom_cmap_expanded_white = mcolors.LinearSegmentedColormap.from_list("custom_cmap", custom_colors)

class NeuronExtractor(nn.Module):
    def __init__(self, model):
        super(NeuronExtractor, self).__init__()
        self.model = model

    def forward(self, x):
        _ = self.model(x)
        feat_cat = torch.cat([self.model.feature_map[0], self.model.feature_map[1]], dim=1)
        activations = feat_cat
        return activations

def plot_shap(datasets, train_idx, epoch, model_name='', class_num=100, lambda_sparse=0.01, model_config=[], out_dir='', num=4):
    middle_size = model_config[2]
    test_net = initialize_model(model_name, class_num, model_config, lambda_sparse)
    filename = f"model_{train_idx}_epoch_{epoch}.pt"
    path = get_path(out_dir + "model_weights", middle_size, class_num, model_name, filename)
    # print(path)
    ensure_dir_exists(path)
    test_net.load_state_dict(torch.load(path))
    model = test_net
    for param in model.parameters():
        param.requires_grad = True
    test_loader = datasets
    inputs, _ = next(iter(test_loader))
    inputs.requires_grad = True
    model.train()
    explainers = shap.DeepExplainer(NeuronExtractor(model), inputs)
    shap_values = explainers.shap_values(inputs[0:num])
    # shap.summary_plot(shap_values, inputs[0:1])
    shap_numpy = [np.transpose(s, (0, 2, 3, 1)) for s in shap_values]
    input_numpy = np.transpose(inputs[0:num].detach().numpy(), (0, 2, 3, 1))
    colors = ['#397AA3', 'white', '#E71555']
    colormap = LinearSegmentedColormap.from_list('custom_blue_white_red', colors)
    shap.image_plot(shap_numpy, -input_numpy, cmap=colormap)
    

In [None]:
out_dir = "/data0/user/gfhao/code_cpm/ServerCode/output/"

middle_size = 32
model_config = [8, 16, middle_size, middle_size * middle_size // 4]
print("middle size:", middle_size)
# class_num_runs = [2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,25,30,35,40,45,50]
class_num = 100
train_data, test_data = load_processed_dataset(data_dir="/data0/user/gfhao/datasets/mnist", batch_size=64, display=False, process=True, class_num=class_num)
train_idx = 4
name = MODEL_NAME[2]
epoch = 9
plot_shap(test_data, train_idx, epoch, model_name=name, class_num=class_num, lambda_sparse=0.01, model_config=model_config, out_dir=out_dir)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LinearSegmentedColormap

colors = ['#397AA3', 'white', '#E71555']
colormap = LinearSegmentedColormap.from_list('custom_blue_white_red', colors)

gradient = np.linspace(0, 1, 256)
gradient = np.vstack((gradient, gradient))

def plot_color_gradients(cmap_category, cmap_list):
    nrows = len(cmap_list)
    figh = 0.35 + 0.15 + (nrows + (nrows-1)*0.1)*0.22
    fig, axs = plt.subplots(nrows=nrows, figsize=(6.4, figh))
    fig.subplots_adjust(top=1-.35/figh, bottom=.15/figh, left=0.2, right=0.99)

    axs.imshow(gradient, aspect='auto', cmap=cmap_list[0])
    axs.text(-.01, .5, cmap_list[0].name, va='center', ha='right', fontsize=10, transform=axs.transAxes)
    axs.set_axis_off()
plot_color_gradients('Custom', [colormap])
plt.show()

In [None]:
out_dir = "/data0/user/gfhao/code_cpm/ServerCode/output/"
class_num = 100
for middle_size in [16,32,64,128,256]:
    for epoch in range(0,10):
        if middle_size == 16 and epoch == 0:
            continue
        else:
            print("middle_size: {}, epoch: {}".format(middle_size, epoch))
            percent = get_train_history_fid(out_dir, middle_size, class_num, epoch, num_times=100)
            print(percent)

In [None]:
def load_data_from_file(path):
    with open(path, "rb") as f:
        return pickle.load(f)

class_num = 100
num_times = 100
all_fid_trend = []
for middle_size in [16,32,64,128,256]:
    train_fid_epoch_trend = {"middle_size":round(middle_size, 0), "data":[]}
    for epoch in range(0,10):
        train_fid_epoch = {"epoch":round(epoch, 0)}
        file_name = f"./result/fid_train/fid_data_middle_size_{middle_size}_class_num_{class_num}_times{num_times}_epoch{epoch}.pkl"
        ensure_dir_exists("./result/fid_train/")
        train_fid_epoch["data"] = load_data_from_file(file_name)
        train_fid_epoch_trend["data"].append(train_fid_epoch)
    all_fid_trend.append(train_fid_epoch_trend)
# Format filename with parameters
file_name = f"./result/fid_train_all/all_fid_trend_class_num_{class_num}_times{num_times}_all_epoch.pkl"
ensure_dir_exists("./result/fid_train_all/")
with open(file_name, 'wb') as file:
    pickle.dump(all_fid_trend, file)
print(all_fid_trend)

In [None]:
import pickle
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import textwrap


num_times = 100
path_fid_data = f"./result/fid_all/all_fid_data.pkl"
path_fid_trend_100 = f"./result/fid_train_all/all_fid_trend_class_num_100_times{num_times}_all_epoch.pkl"
path_fid_trend_500 = f"./result/fid_train_all/all_fid_trend_class_num_500_times{num_times}_all_epoch.pkl"
with open(path_fid_data, "rb") as f:
    fid_data = pickle.load(f)

with open(path_fid_trend_100, "rb") as f:
    fid_trend_100_data = pickle.load(f)

with open(path_fid_trend_500, "rb") as f:
    fid_trend_500_data = pickle.load(f)


# Function to sort legend labels by the numerical part of the label
def sort_legend(ax):
    handles, labels = ax.get_legend_handles_labels()
    processed_labels = [label.split("(")[-1].rstrip(")") for label in labels]
    labels_sorted = sorted(zip(processed_labels, handles), key=lambda x: int(x[0].split("_")[-1]))
    return [label for label, _ in labels_sorted], [handle for _, handle in labels_sorted]

# Create a new figure with adjusted dimensions
fig_combined_sorted_legend = plt.figure(figsize=(12, 20), constrained_layout=True)
gs = gridspec.GridSpec(4, 2, height_ratios=[4, 4, 3, 3])
ax1_sorted_legend = fig_combined_sorted_legend.add_subplot(gs[0, :])  # First row
ax2_sorted_legend = fig_combined_sorted_legend.add_subplot(gs[1, :])  # Second row
ax3_sorted_legend = fig_combined_sorted_legend.add_subplot(gs[2, :-1])  # Third row, first column
ax4_sorted_legend = fig_combined_sorted_legend.add_subplot(gs[2, -1])  # Third row,

plt.subplots_adjust(hspace=0.4)

# Set font size for the legends and labels
fontsize = 14
legend_fontsize = fontsize
label_fontsize = fontsize

class_num_runs = [2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,25,30,35,40,45,50]
class_num_runs = [x * 10 for x in class_num_runs]
adjusted_class_nums = [num if num < 200 else int(200 + (num - 200) / 2) for num in class_num_runs]
df_data = []
for element in fid_data:
    for middle_size, models in element.items():
        for model, percentages in models.items():
            if model in ['SimpleCNN', 'DotProductCNN', 'CrossProductCNN']:
                for class_num_index, percentage in enumerate(percentages):
                    df_data.append({
                        'Middle Size': middle_size,
                        'Model': model,
                        'Class Number': (class_num_index + 2)*10,  # Starting from class_num 2
                        'Percentage': percentage
                    })
df = pd.DataFrame(df_data)

all_middle_sizes = sorted(set(df['Middle Size'].unique()))
cmap = plt.cm.get_cmap('Accent', len(all_middle_sizes))
color_map = {size: cmap(i) for i, size in enumerate(all_middle_sizes)}

models = ['DotProductCNN', 'CrossProductCNN']

for i, (model, ax) in enumerate(zip(models, [ax1_sorted_legend, ax2_sorted_legend])):
    model_data = df[df['Model'] == model]
    for j, middle_size in enumerate(sorted(model_data['Middle Size'].unique())):
        middle_size_data = model_data[model_data['Middle Size'] == middle_size]
        color = color_map[middle_size]
        ax.plot(adjusted_class_nums, middle_size_data['Percentage'], label=middle_size, color=color)  # , marker='o'
    ax.set_title("\n".join(textwrap.wrap(f'FID Score for {model}', width=30)), fontsize=label_fontsize)
    ax.set_xlabel('Class Number', fontsize=label_fontsize)
    ax.set_ylabel('FID Score', fontsize=label_fontsize)
    ax.set_ylim(0, 100)
    ax.grid(True)
    if i == 0:
        sorted_labels, sorted_handles = sort_legend(ax)
        ax.legend(sorted_handles, sorted_labels, fontsize=legend_fontsize)
    ax.set_xticks(adjusted_class_nums, class_num_runs, rotation=45, fontsize=label_fontsize)
    ylabels = ax.get_yticks()
    ax.set_yticklabels(ylabels, fontsize=label_fontsize)

# Function to plot FID trend data
def plot_fid_train_all_one_fig(data, class_num):
    epochs = range(10)
    middle_sizes = set(item['middle_size'] for item in data)
    sorted_middle_sizes = sorted(middle_sizes)
    plot_data_by_size = {size: [] for size in middle_sizes}
    for item in data:
        size = item['middle_size']
        model_data = [epoch_data['data']['CrossProductCNN'] for epoch_data in item['data']]
        plot_data_by_size[size].append(model_data)

    fig = plt.figure(figsize=(10, 6))
    cmap = plt.cm.get_cmap('Accent', len(sorted_middle_sizes))
    for i, size in enumerate(sorted_middle_sizes):
        all_config_data = plot_data_by_size[size]
        avg_data = [sum(config_data[j] for config_data in all_config_data) / len(all_config_data) for j in epochs]
        plt.plot(epochs, avg_data, label=f'Middle Size {size}', color=color_map[f'middle_size_{size}'])

    plt.title(f'FID Trend for CrossProductCNN, Class Num: {class_num}')
    plt.xlabel('Epoch')
    plt.ylabel('FID Score')
    plt.ylim([0, 100])
    plt.xticks(epochs)
    plt.grid(True)
    plt.tight_layout()

    return fig

fig_trend_100 = plot_fid_train_all_one_fig(fid_trend_100_data, 100)
fig_trend_500 = plot_fid_train_all_one_fig(fid_trend_500_data, 500)

# Copy plots from the original figures to the new layout

for src_fig, ax in zip([fig_trend_100, fig_trend_500], [ax3_sorted_legend, ax4_sorted_legend]):
    src_ax = src_fig.axes[0]
    for line in src_ax.get_lines():
        ax.plot(line.get_xdata(), line.get_ydata(), label=line.get_label(), color=line.get_color())
    # ax.set_title(src_ax.get_title(), fontsize=label_fontsize)
    ax.set_title("\n".join(textwrap.wrap(src_ax.get_title(), width=30)), fontsize=label_fontsize)
    ax.set_xlabel(src_ax.get_xlabel(), fontsize=label_fontsize)
    ax.set_ylabel(src_ax.get_ylabel(), fontsize=label_fontsize)
    ax.set_ylim(0, 100)
    ax.set_xticks(range(10))
    ax.set_xticklabels([str(x+1) for x in range(10)], fontsize=label_fontsize)
    ylabels = ax.get_yticks()
    ax.set_yticklabels(ylabels, fontsize=label_fontsize)
    ax.grid(True)


# Adjust the layout
plt.tight_layout()
# Display the final combined figure with sorted legends
plt.show()

In [None]:
out_dir = "/data0/user/gfhao/code_cpm/ServerCode/output/"
middle_size = 16
class_num = 100
num_times = 100
fig1 = plot_fig1_loss_train_test(out_dir, middle_size, class_num, num_times)

In [48]:
def contrast_crossproduct_symmetry_and_asymmetric(out_dir, middle_size, class_num, epoch, num_times=100):
    model_config = [8, 16, middle_size, middle_size * middle_size // 4]
    # print("middle size:", middle_size)
    fid_all_2 = [[0,0,0,0], [0,0,0,0]]
    train_history_fid_epoch = {"middle_size":round(middle_size, 0), "class_num":round(class_num, 0), "epoch":round(epoch, 0)}
    train_data, test_data = load_processed_dataset(data_dir="/data0/user/gfhao/datasets/mnist", batch_size=64, display=False, process=True, class_num=class_num)
    # print("datasets loaded!")
    for i in tqdm(range(1, num_times + 1), desc="Progress"):
        fid_all_symmetry, fid_matrix_symmetry = compute_FID(test_data, i, epoch, model_name=MODEL_NAME[2], class_num=class_num, lambda_sparse=0.01, model_config=model_config, out_dir=out_dir)
        fid_all_asymmetric, fid_matrix_asymmetric = compute_FID(test_data, i, epoch, model_name=MODEL_NAME[5], class_num=class_num, lambda_sparse=0.01, model_config=model_config, out_dir=out_dir)

        # 计算两个差值
        for j, fid_all_i in enumerate([fid_all_symmetry,fid_all_asymmetric]):
            diff1 = round(fid_all_i[0] - fid_all_i[1], 2)
            diff2 = round(fid_all_i[2] - fid_all_i[3], 2)
            if diff1 > 0 and diff2 > 0:
                fid_all_2[j][0] += 1
            elif diff1 > 0 and diff2 < 0:
                fid_all_2[j][1] += 1
            elif diff1 < 0 and diff2 > 0:
                fid_all_2[j][2] += 1
            elif diff1 < 0 and diff2 < 0:
                fid_all_2[j][3] += 1

    fid1 = fid_all_2[0]
    fid2 = fid_all_2[1]

    data_to_save = {
        MODEL_NAME[2]: fid1,
        MODEL_NAME[5]: fid2,
    }
    # Format filename with parameters
    file_name = f"./result/fid_crossproduct_contrast/fid_crossproduct_contrast_middle_size_{middle_size}_class_num_{class_num}_times{num_times}_epoch{epoch}.pkl"
    ensure_dir_exists("./result/fid_crossproduct_contrast/")

    with open(file_name, 'wb') as file:
        pickle.dump(data_to_save, file)
    
    return data_to_save

In [None]:
out_dir = "/data0/user/gfhao/code_cpm/ServerCode/output/"
class_num = 100
for middle_size in [16,32,64,128,256]:
    for epoch in range(0,10):
        print("middle_size: {}, epoch: {}".format(middle_size, epoch))
        percent = contrast_crossproduct_symmetry_and_asymmetric(out_dir, middle_size, class_num, epoch, num_times=100)
        print(percent)


In [54]:
def evaluate_CrossProduct_history_contrast_loss_train_test(out_dir, middle_size, class_num, num_times=100):
    loss_train_test_all1, loss_train_test_all2, loss_train_test_all3, loss_train_test_all4, loss_train_test_all5 = [], [], [], [], []
    for i in range(1, num_times + 1):
        path1 = generate_file_path(out_dir, middle_size, class_num, MODEL_NAME[2], i)
        path2 = generate_file_path(out_dir, middle_size, class_num, MODEL_NAME[5], i)

        loss_train_test_all1.append(load_data_from_file(path1))
        loss_train_test_all2.append(load_data_from_file(path2))

    avg1, var_up_1, var_lower_1 = compute_stats(loss_train_test_all1)
    avg2, var_up_2, var_lower_2 = compute_stats(loss_train_test_all2)

    data_to_save = {
        MODEL_NAME[2]: [loss_train_test_all1, avg1, var_up_1, var_lower_1],
        MODEL_NAME[5]: [loss_train_test_all2, avg2, var_up_2, var_lower_2],
    }

    return data_to_save

def plot_CrossProduct_variance_comparison_error_bars(data, metric_index, metric_name, ax):
    epochs = np.arange(1, EPOCH+1)
    colors = ['green', 'blue']  # Distinct colors for each CNN type
    for idx, (cnn_type, color) in enumerate(zip(['CrossProductCNN', 'CrossProductAsymmetricCNN'], colors)):
        means = data[cnn_type][1][metric_index]
        lower_bounds = data[cnn_type][2][metric_index]
        upper_bounds = data[cnn_type][3][metric_index]
        error = [means - lower_bounds, upper_bounds - means]  # Asymmetric error bars
        ax.errorbar(epochs, means, yerr=error, fmt='-o', color=color, ecolor=color, elinewidth=2, capsize=4, label=f'{cnn_type}')
    ax.set_xlabel('Epochs')
    ax.set_ylabel(f'{metric_name} with Variance')
    title = f'Comparison of {metric_name} with Variance Across CNN Types'
    ax.set_title("\n".join(wrap(title, 40)))  # Wrap title based on the plot width
    ax.legend(loc='center right')  # Move legend to the center right
    ax.grid(True)  # Enable grid lines


def plot_CrossProduct_fig1_loss_train_test(out_dir, middle_size, class_num, num_times=100):
    # Create a figure with subplots for each metric's variance comparison
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    # Iterate over each metric to create the plots with error bars
    metrics = ['Loss', 'Training Accuracy', 'Validation Accuracy']
    data_instance = evaluate_CrossProduct_history_contrast_loss_train_test(out_dir=out_dir, middle_size=middle_size, class_num=class_num, num_times=num_times)
    for i, metric_name in enumerate(metrics):
        plot_CrossProduct_variance_comparison_error_bars(data_instance, i, metric_name, axes[i])
    # Adjust the layout
    plt.tight_layout()
    # plt.show()
    return fig

In [None]:
out_dir = "/data0/user/gfhao/code_cpm/ServerCode/output/"
class_num = 100
num_times = 100
fig1 = plot_CrossProduct_fig1_loss_train_test(out_dir, 16, class_num, num_times)
fig2 = plot_CrossProduct_fig1_loss_train_test(out_dir, 32, class_num, num_times)
fig3 = plot_CrossProduct_fig1_loss_train_test(out_dir, 64, class_num, num_times)
fig4 = plot_CrossProduct_fig1_loss_train_test(out_dir, 128, class_num, num_times)
fig5 = plot_CrossProduct_fig1_loss_train_test(out_dir, 256, class_num, num_times)

In [None]:
def load_data_from_file(path):
    with open(path, "rb") as f:
        return pickle.load(f)

class_num = 100
num_times = 100
all_fid_trend = []
for middle_size in [16,32,64,128,256]:
    train_fid_epoch_trend = {"middle_size":round(middle_size, 0), "data":[]}
    for epoch in range(0,10):
        train_fid_epoch = {"epoch":round(epoch, 0)}
        file_name = f"./result/fid_crossproduct_contrast/fid_crossproduct_contrast_middle_size_{middle_size}_class_num_{class_num}_times{num_times}_epoch{epoch}.pkl"
        ensure_dir_exists("./result/fid_crossproduct_contrast/")
        train_fid_epoch["data"] = load_data_from_file(file_name)
        train_fid_epoch_trend["data"].append(train_fid_epoch)
    all_fid_trend.append(train_fid_epoch_trend)
# Format filename with parameters
file_name = f"./result/fid_crossproduct_contrast_all/all_fid_crossproduct_contrast_trend_class_num_{class_num}_times{num_times}_all_epoch.pkl"
ensure_dir_exists("./result/fid_crossproduct_contrast_all/")
with open(file_name, 'wb') as file:
    pickle.dump(all_fid_trend, file)
print(all_fid_trend)

In [112]:
import pandas as pd

def plot_crosspoduct_fid_train_all(class_num = 100):
    num_times = 100
    file_name = f"./result/fid_crossproduct_contrast_all/all_fid_crossproduct_contrast_trend_class_num_{class_num}_times{num_times}_all_epoch.pkl"
    data = pd.read_pickle(file_name)
    model_types = list(data[0]['data'][0]['data'].keys())
    epochs = range(10)
    middle_sizes = set(item['middle_size'] for item in data)
    plot_data_by_size = {size: {model: [] for model in model_types} for size in middle_sizes}
    for item in data:
        size = item['middle_size']
        for model in model_types:
            model_data = [epoch_data['data'][model] for epoch_data in item['data']]
            plot_data_by_size[size][model].append(model_data)
    sorted_middle_sizes = sorted(middle_sizes)
    color_map = {model: color for model, color in zip(model_types, ['green', 'blue'])}
    # Plotting with specified requirements
    fig, axs = plt.subplots(len(sorted_middle_sizes), 1, figsize=(8, 6*len(sorted_middle_sizes)))
    for i, size in enumerate(sorted_middle_sizes):
        for model in model_types:
            # Plotting each model type with its assigned color
            for config_data in plot_data_by_size[size][model]:
                sums = [subarray[1] + subarray[2] for subarray in config_data]
                # axs[i].plot(epochs, sums, label=model if i == 0 else "", color=color_map[model])
                axs[i].plot(epochs, sums, label=model, color=color_map[model])
        axs[i].set_title(f'Middle Size {size}')
        axs[i].set_xlabel('Epoch')
        axs[i].set_ylabel('FID percentage')
        axs[i].set_ylim([0, 100])  # Setting y-axis range
        axs[i].set_xticks(epochs)  # Setting x-axis ticks
        axs[i].set_xticklabels([r+1 for r in range(10)])
        axs[i].legend()  # Adding legend only to the first plot
    plt.tight_layout()
    return fig


In [None]:
fig = plot_crosspoduct_fid_train_all()

In [106]:
import pandas as pd
from matplotlib.patches import Patch

def plot_crosspoduct_isfixed_fid_train_all():
    class_num = 100
    num_times = 100
    file_name = f"./result/fid_crossproduct_contrast_all/all_fid_crossproduct_contrast_trend_class_num_{class_num}_times{num_times}_all_epoch.pkl"
    data = pd.read_pickle(file_name)
    model_types = list(data[0]['data'][0]['data'].keys())
    epochs = range(10)
    middle_sizes = set(item['middle_size'] for item in data)
    plot_data_by_size = {size: {model: [] for model in model_types} for size in middle_sizes}
    for item in data:
        size = item['middle_size']
        for model in model_types:
            model_data = [epoch_data['data'][model] for epoch_data in item['data']]
            plot_data_by_size[size][model].append(model_data)
    sorted_middle_sizes = sorted(middle_sizes)
    color_map = {model: color for model, color in zip(model_types, ['green', 'blue'])}
    tags = ['F1_dot > F1_number && F2_dot > F2_number', 
    'F1_dot > F1_number && F2_dot < F2_number', 
    'F1_dot < F1_number && F2_dot > F2_number', 
    'F1_dot < F1_number && F2_dot < F2_number']
    colors_blue=['#5789ba','#18305e','#3961a0','#8cb5d3']
    colors_green=['#9dcdc2','#335635','#4e8753','#c8e2de']
    colors = [colors_green, colors_blue]
    # Creating legend items
    # legend_handles = [Patch(color=color, label=tag) for color, tag in zip(colors_green + colors_blue, tags * 2)]
    # 创建图例项
    legend_handles = []
    for i in range(4):
        patch1 = Patch(color=colors_green[i], label=tags[i])
        patch2 = Patch(color=colors_blue[i], label=tags[i])
        legend_handles.append(patch1)
        legend_handles.append(patch2)

    bar_width = 0.25
    bottom_y = [0] * 10
    # Plotting with specified requirements
    fig, axs = plt.subplots(len(sorted_middle_sizes), 1, figsize=(8, 6*len(sorted_middle_sizes)))
    for i, size in enumerate(sorted_middle_sizes):
        for j, model in enumerate(model_types):
            # Plotting each model type with its assigned color
            for config_data in plot_data_by_size[size][model]:
                bar_positions = [k + j * bar_width for k in range(10)]
                transposed_data = list(map(list, zip(*config_data)))
                bottom_y = [0] * 10
                for ind in range(len(transposed_data)):
                    y = transposed_data[ind]
                    if j == 0:
                        axs[i].bar(bar_positions, y, width=bar_width, bottom=bottom_y,label=tags[ind], color=colors[j][ind])
                    else:
                        axs[i].bar(bar_positions, y, width=bar_width, bottom=bottom_y,label='', color=colors[j][ind])
                    bottom_y = [(a+b) for a, b in zip(y, bottom_y)]
                # axs[i].plot(epochs, sums, label=model if i == 0 else "", color=color_map[model])
                # axs[i].stackplot(bar_positions, transposed_data, color=color_map[model])
        axs[i].set_title(f'Middle Size {size}')
        axs[i].set_xlabel('Epoch')
        axs[i].set_ylabel('FID percentage')
        axs[i].set_ylim([0, 100])  # Setting y-axis range
        axs[i].set_xticks([r + 0.5 * bar_width for r in range(10)])  # Setting x-axis ticks
        axs[i].set_xticklabels([r+1 for r in range(10)])
        
        '''
        if i == 0:
            for leg_i in range(8):
                axs[i].legend(handles=legend_handles[leg_i])  # Adding legend only to the first plot
        '''
        '''
        axs[i].legend(handles=legend_handles, loc='upper left', bbox_to_anchor=(1, 1))
        '''
    plt.tight_layout()
    return fig


In [None]:
fig = plot_crosspoduct_isfixed_fid_train_all()