# WL phi feature map distributions

In [None]:
from notebook_prelude import *

In [None]:
' '.join(dataset_helper.get_dataset_names_with_concept_map())

In [None]:
SAVE_FOLDER='tmp/wl_phi_distributions'

In [None]:
FIGSIZE = (15, 12)

def print_phi_distribution(title, X_phi, Y, sort_y = True, add_class_label = True, add_class_line = True, figsize = FIGSIZE, flip_axis = False, draw_vertice_count_line = True, h_lines = [], sns_color_palette = "bright", class_v_line_kwargs=dict(linestyle = 'solid', color = 'black', linewidth = 1, alpha = 0.1), class_h_line_kwargs = dict(linestyle ='solid', linewidth = 1, alpha = 0.6)):
    if sort_y:
        # Use a stable sort algorithm
        Y_sorted_indices = np.argsort(Y, kind = 'mergesort')
        Y = np.array(Y)[Y_sorted_indices]
        X_phi = [phi[Y_sorted_indices] for phi in X_phi]

    cmap_ = sns.color_palette(sns_color_palette, max(len(set(Y)), len(h_lines[0])))
    clazz_2_color_map = dict()
    
    counter = 0
    for clazz in Y:
        if clazz not in clazz_2_color_map:
            clazz_2_color_map[clazz] = counter
            counter += 1
    colors = [clazz_2_color_map[y] for y in Y]
    occurences = {clazz: -1 for clazz in set(Y)}
    
    for idx, y in enumerate(Y):
        if occurences[y] == -1: occurences[y] = idx

    ax_labels = ['phi index', 'graph']
    ax_line = 'v' if flip_axis else 'x'
    
    if flip_axis:
        ax_labels = list(reversed(ax_labels))
    
    fig, axes = plt.subplots(ncols = len(X_phi), figsize = figsize, sharey = True)
    
    axes[0].set_ylabel(ax_labels[1])
    
    num_graphs, num_vertices = X_phi[0].shape
    
    for h, (phi, ax) in enumerate(zip(X_phi, axes)):
        non_zero = phi.nonzero()

        if flip_axis:
            non_zero = reversed(non_zero)

        y, x = non_zero

        # Plot hlines (eg. max phi index of class)
        if len(h_lines):
            for class_idx, hline in enumerate(h_lines[h]):
                if isinstance(hline, tuple):
                    hline, line_color = hline
                else:
                    line_color = cmap_[class_idx]
                ax.axhline(hline, color = line_color, **class_h_line_kwargs)
        
        colors_ = [cmap_[colors[x_]] for x_ in x]
        ax.scatter(x = x, y = y, c = colors_, cmap=cmap_, s = 1)
        ax.set_title('Iteration: {}'.format(h))

    for ax in axes:
        ax.set_xlabel(ax_labels[0])
        ax.grid('off')
        ax.set_xticks([])
        ax.set_yticks([])
        
        # Add class occurrence lines
        for clazz, occurence_idx in occurences.items():
            if add_class_line:
                getattr(ax, 'ax{}line'.format(ax_line))(occurence_idx, **class_v_line_kwargs)
            if add_class_label:
                x, y = 1, occurence_idx
                if flip_axis:
                    y, x = x, y
                ax.text(x = 1, y = occurence_idx, s = clazz, color = 'red')

        if draw_vertice_count_line:
            ax.set_ylim(0, num_vertices * 1.03)
            ax.axhline(num_vertices - (num_vertices / 100), color = 'red', linewidth=1)
        
    
    fig.suptitle(title)
    fig.tight_layout()
    fig.subplots_adjust(top = .9)
    
    return fig, ax

def filter(cache_file):
    is_split = 'splitted' in cache_file
    is_ling_spam = 'ling-spam' in cache_file
    is_dataset = lambda dataset: filename_utils.get_dataset_from_filename(cache_file) == dataset
    is_concept_graph = 'concept-map' in cache_file
    is_same_label = 'same-label' in cache_file
    return is_split# and is_same_label #and is_dataset('webkb')# and is_concept_graph

for cache_file in dataset_helper.get_all_cached_graph_phi_datasets():
    if not filter(cache_file): continue
    print(cache_file)
    def plot(X_phi, Y, suffix = '', h_lines = [], plot_kwargs = dict()):
        filename = '{}{}.png'.format(cache_file.split('/')[-1], suffix)
        assert np.array_equal(np.sort(Y), Y)
        print(filename)
        kwargs = dict(dict(sort_y = False, add_class_label = False, flip_axis = True, add_class_line = True, h_lines = h_lines), **plot_kwargs)
        fig, ax = print_phi_distribution('File: {} {} (#vertices: {})'.format(cache_file.split('/')[-1], suffix, X_phi[0].shape[1]), X_phi, Y, **kwargs)
        save_fig(fig, filename, folder=SAVE_FOLDER)
        plt.close(fig)

    def get_hlines(phis, Y):
        highest_per_class = []
        for phi in phis:
            non_zero = phi.nonzero()
            non_zero_y, non_zero_x = non_zero 
            highest_per_class_ = collections.defaultdict(lambda: -1)
            for non_zero_y, non_zero_x in zip(non_zero_y, non_zero_x):
                clazz = Y[non_zero_y]
                highest_per_class_[clazz] = max(highest_per_class_[clazz], non_zero_x)
            highest_per_class.append(sorted(list(highest_per_class_.values())))
        return highest_per_class
        
    phi_res = dataset_helper.get_dataset_cached(cache_file, check_validity=False)
    if len(phi_res) == 2:
        phi_train, Y_train = phi_res
        h_lines = get_hlines(phi_train, Y_train)
        plot(*phi_res, h_lines = h_lines)
    elif len(phi_res) == 6:
        phi_train, phi_test, X_train, X_test, Y_train, Y_test = phi_res
        h_lines = get_hlines(phi_train, Y_train)
        kwargs = dict()
        if 'same-label' in cache_file:
            kwargs['draw_vertice_count_line'] = False
        if len(set(Y_train)) > 20:
            kwargs['add_class_line'] = False
        for res in [(phi_train, Y_train, '_train'), [phi_test, Y_test, '_test']]:
            plot(*res, h_lines = h_lines, plot_kwargs=kwargs)
    else:
        assert False
print("Finished")

## Copy saved figures into categorized folder

In [None]:
FOLDER_PHI_DIST = 'tmp/phi-distributions'

data = collections.defaultdict(lambda: [])
for file in glob('{}/*.png'.format(FOLDER_PHI_DIST)):
    data['file'].append(file)
df = pd.DataFrame(data)
df['dataset'] = df.file.apply(filename_utils.get_dataset_from_filename)
df['same_label'] = df.file.str.contains('same-label') | df.file.str.contains('same_label')
df['type'] = df.file.str.extract(r'dataset_graph_(.+?)_')
df['window_size'] = df.file.str.extract(r'dataset_graph_cooccurrence_(.+?)_')
df['split'] = df.file.str.contains('splitted')
df['test_train'] = df.file.str.extract(r'phi.npy_(.+?).png$')
df['original_file'] = df.file.str.extract(r'/([^/]+?\.npy)').str.replace('.splitted', '')

for idx, item in df.iterrows():
    old_folder = item.file.rsplit("/", 1)[0]
    old_filename = item.file.rsplit("/", 1)[1]
    new_filename = '{old_folder}/_/{t.dataset}/{t.type}/{t.same_label}/{old_filename}'.format(t = item, old_folder = old_folder, old_filename = old_filename)
    folder = new_filename.rsplit('/', 1)[0]
    os.makedirs(folder, exist_ok = True)
    shutil.copyfile(src = item.file, dst=new_filename)