In [1]:
import os
import json
import matplotlib.pyplot as plt
import numpy as np

def read_json_files(directory):
    data = {}
    for filename in os.listdir(directory):
        if filename.endswith('.json'):
            with open(os.path.join(directory, filename), 'r') as f:
                data[filename] = json.load(f)
    return data

def plot_data(data):
    num_files = len(data)
    fig, axes = plt.subplots(num_files, 1, figsize=(12, 6*num_files), squeeze=False)
    fig.tight_layout(pad=5.0)

    for idx, (filename, file_data) in enumerate(data.items()):
        ax = axes[idx, 0]
        segments = []
        our_method_means = []
        our_method_stds = []
        pq_means = []
        pq_stds = []

        for key in file_data.keys():
            if key.startswith('n_segments_') and key.endswith('_counters'):
                segment = int(key.split('_')[2])
                segments.append(segment)
                our_method_means.append(file_data[key]['our_method']['mean'])
                our_method_stds.append(file_data[key]['our_method']['std'])
                pq_means.append(file_data[key]['pq']['mean'])
                pq_stds.append(file_data[key]['pq']['std'])

        x = np.arange(len(segments))
        width = 0.35

        ax.bar(x - width/2, our_method_means, width, label='Our Method', yerr=our_method_stds, capsize=5)
        ax.bar(x + width/2, pq_means, width, label='PQ', yerr=pq_stds, capsize=5)

        ax.set_xlabel('Number of Segments')
        ax.set_ylabel('Counter Mean')
        ax.set_title(f'Comparison of Counters - {filename}')
        ax.set_xticks(x)
        ax.set_xticklabels(segments)
        ax.legend()

        # Add value labels on top of each bar
        for i, v in enumerate(our_method_means):
            ax.text(i - width/2, v + our_method_stds[i], f'{v:.0f}', ha='center', va='bottom')
        for i, v in enumerate(pq_means):
            ax.text(i + width/2, v + pq_stds[i], f'{v:.0f}', ha='center', va='bottom')

    plt.savefig('counter_comparison.png')
    plt.close()

# Usage
directory = '/home/jxu680/image-concept-compression/sweep_results_train2017_fixed_maskclip_mobilesam_20241008_115054'
data = read_json_files(directory)
plot_data(data)
print("Graph saved as 'counter_comparison.png'")

Graph saved as 'counter_comparison.png'
