# Imports

In [3]:
import os
import random
import sys
sys.path.append('..')

from matplotlib import pyplot as plt
import numpy as np
import torch
import torchmetrics

from codeclr import DenseGraph
from codeclr.cass import CassConfig
from codeclr.model import ContrastiveLearner

# Configuration / Helper Functions

In [None]:
CONFIG = CassConfig(
    annot_mode=2,
    compound_mode=1,
    gfun_mode=1,
    gvar_mode=3,
    fsig_mode=1)
AUGMENTATIONS = ['identity', 'node_drop', 'node_mask', 'subtree_mask']
EPOCH = 0
DATA_DIR = os.path.join(
    'data',
    'preprocessed',
    'Project_CodeNet_C++1000',
    CONFIG.tag
)
VOCAB_FILE = os.path.join(DATA_DIR, 'vocab.pt')
VOCAB = torch.load(VOCAB_FILE)
NUM_ANALYSIS_DIRS = 10

directories = [d for d in os.listdir(DATA_DIR) if d.startswith('p')]
random.shuffle(directories)
ANALYSIS_DIRS = directories[:NUM_ANALYSIS_DIRS]
print(ANALYSIS_DIRS)

In [None]:
ENCODER_CACHE = {}
def get_encoder(augment_1, augment_2, mask_frac: float = 0.25):
    if (augment_1, augment_2) in ENCODER_CACHE:
        return ENCODER_CACHE[(augment_1, augment_2)]

    parameter_tag = parameter_tag = f'augment_1={augment_1}_augment_2={augment_2}_mask_frac={mask_frac}_batch_size=64_lr=0.001_{CONFIG.tag}'
    CHECKPOINT_FILE = os.path.join(
        'logs',
        parameter_tag,
        f'checkpoint_{EPOCH}.pt'
    )
    gcn_layers = [128, 128, 64, 32]
    model = ContrastiveLearner(gcn_layers, len(VOCAB) + 1)
    model.load_state_dict(torch.load(CHECKPOINT_FILE)['model_state_dict'])
    encoder = model.encoder
    ENCODER_CACHE[(augment_1, augment_2)] = encoder
    return encoder

In [None]:
ANALYSIS_GRAPH_EMBEDDINGS = {}
for augment_1 in AUGMENTATIONS:
    for augment_2 in AUGMENTATIONS:
        encoder = get_encoder(augment_1, augment_2)
        for problem_name in ANALYSIS_DIRS:
            problem_dir = os.path.join(DATA_DIR, problem_name)
            graphs = [DenseGraph(**torch.load(os.path.join(problem_dir, graph_file))) for graph_file in os.listdir(problem_dir)]
            graph_embeddings = encoder(graphs)
            ANALYSIS_GRAPH_EMBEDDINGS[(augment_1, augment_2, problem_name)] = graph_embeddings

# $L_1$ Distance Metrics

In [None]:
def l1_distance(anchor_embeddings, auxiliary_embeddings, num_bins: int = 100):
    anchor_distances = -torchmetrics.functional.pairwise_cosine_similarity(anchor_embeddings).flatten().detach().numpy()
    anchor_auxiliary_distances = -torchmetrics.functional.pairwise_cosine_similarity(anchor_embeddings, auxiliary_embeddings).flatten().detach().numpy()
    
    bins = np.linspace(-1, 1, num_bins)
    anchor_counts, _ = np.histogram(anchor_distances, bins=bins, density=True)
    anchor_auxiliary_counts, _ = np.histogram(anchor_auxiliary_distances, bins=bins, density=True)

    distance = np.mean(np.abs(anchor_counts - anchor_auxiliary_counts)) / 2
    return distance

In [None]:
def l1_distance_stats():
    average_l1_distances = {}
    max_l1_distances = {}
    for augment_1 in AUGMENTATIONS:
        for augment_2 in AUGMENTATIONS:
            l1_distances = []
            for anchor_problem_name in ANALYSIS_DIRS:
                for auxiliary_problem_name in ANALYSIS_DIRS:
                    if anchor_problem_name == auxiliary_problem_name:
                        continue
                    anchor_embeddings = ANALYSIS_GRAPH_EMBEDDINGS[(augment_1, augment_2, anchor_problem_name)]
                    auxiliary_embeddings = ANALYSIS_GRAPH_EMBEDDINGS[(augment_1, augment_2, auxiliary_problem_name)]
                    l1_distances.append(l1_distance(anchor_embeddings, auxiliary_embeddings))
            average_l1_distances[(augment_1, augment_2)] = np.mean(l1_distances)
            max_l1_distances[(augment_1, augment_2)] = np.max(l1_distances)
    return average_l1_distances, max_l1_distances

In [None]:
average_l1_distances, max_l1_distances = l1_distance_stats()

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(5, 8))
ax1, ax2 = axes

average_data_array = np.array([[average_l1_distances[(augment_1, augment_2)] for augment_2 in AUGMENTATIONS] for augment_1 in AUGMENTATIONS])
max_data_array = np.array([[max_l1_distances[(augment_1, augment_2)] for augment_2 in AUGMENTATIONS] for augment_1 in AUGMENTATIONS])

im = ax1.imshow(average_data_array, cmap='Blues')
ax1.set_xticks(np.arange(len(AUGMENTATIONS)), labels=AUGMENTATIONS)
ax1.set_yticks(np.arange(len(AUGMENTATIONS)), labels=AUGMENTATIONS)
plt.setp(ax1.get_xticklabels(), rotation=45, ha='right', rotation_mode='anchor')
cbar = ax1.figure.colorbar(im, ax=ax1)

im = ax2.imshow(max_data_array, cmap='Purples')
ax2.set_xticks(np.arange(len(AUGMENTATIONS)), labels=AUGMENTATIONS)
ax2.set_yticks(np.arange(len(AUGMENTATIONS)), labels=AUGMENTATIONS)
plt.setp(ax2.get_xticklabels(), rotation=45, ha='right', rotation_mode='anchor')
cbar = ax2.figure.colorbar(im, ax=ax2)

for i in range(len(AUGMENTATIONS)):
    for j in range(len(AUGMENTATIONS)):
        text = ax1.text(j, i, round(average_data_array[i, j], 2), ha='center', va='center')
        text = ax2.text(j, i, round(max_data_array[i, j], 2), ha='center', va='center')

ax1.set_title(r'Average $L_1$ Distance')
ax2.set_title(r'Max $L_1$ Distance')

fig.tight_layout()
fig.savefig('l1_distance.svg', format='svg', dpi=1200)
plt.show()