In [None]:
import json
import numpy as np
from matplotlib import pyplot as plt

def draw(x, ys, legends = None, font_size = 12):
    """
        Draw non-locality plots
    """
    for i in range(len(ys)):
        plt.plot(x, ys[i], label = legends[i] if legends != None else None)
        
    plt.xlim(-1, 305)
    plt.ylim(3, 8)
    plt.title('Ours', fontsize = font_size+1)
    plt.xlabel('Epochs', fontsize= font_size)
    plt.ylabel('Non-locality', fontsize= font_size)
    
    plt.xticks(fontsize=font_size)
    plt.yticks(fontsize=font_size)
    
    plt.legend()
    plt.show()

"""
    nonlocality for each epoch example:
    {
        0: [a, b, c, ...]   # len: n_heads
        1: [a, b, c, ...]   # len: n_heads
        ...
        n_layer-1: [a, b, c, ...] # len: n_heads
    }
"""
def get_layer_nonlocalities(layer_data):
    """
        layer_data shape: (epochs, n_layer, n_heads)
        Return shape: (n_layer, epochs)
    """
    return np.mean(np.array(layer_data), axis=2).transpose()

log_path = '/data/ljc/convit_logs/' + '20240319105616.txt'
log_dics = [json.loads(s.strip()) for s in open(log_path, 'r').readlines()] # sorted by epoch
nonlocal_data = []
for ep_data in log_dics:
    all_layer_data = [ep_data[f'nonlocality_{l}'] for l in range(ep_data['nonlocality_len'])]
    nonlocal_data.append(all_layer_data)
    
nonlocal_data = get_layer_nonlocalities(nonlocal_data)  # shape: (n_layer, epochs)
x = list(range(1, nonlocal_data.shape[1] + 1))
legends = ['Layer ' + str(i) for i in range(1, nonlocal_data.shape[0] + 1)]

draw(x, nonlocal_data, legends, font_size=15)