In [6]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import pickle

In [7]:
data = xr.open_dataset('data.nc')
data = data['ARPES'][:500].values
counts = data.sum(axis=1).sum(axis=1)
with open('positions_measured.pkl', 'rb') as f:
    positions_measured = pickle.load(f)
with open('cluster_label_history.pkl', 'rb') as f:
    cluster_label_history = pickle.load(f)
    

In [9]:
from PIL import Image

def fig_to_pil(fig):
    fig.canvas.draw()
    return Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb())

def order_cluster_labels_by_counts(positions_measured, labels, counts):
    positions_measured = np.asarray(positions_measured)
    labels = np.asarray(labels)
    counts = np.asarray(counts)[:len(labels)]

    num_pts = positions_measured.shape[0]
    counts = counts[:num_pts]
    median_counts = []
    labels_looped = []
    # print(labels.shape, counts.shape)
    for i, label in enumerate(set(labels)):
        median_counts.append(np.sum(np.where(labels == label, counts, 0)))
        labels_looped.append(label)
    
    ordered_labels = np.zeros(labels.shape)*np.nan
    indices = np.argsort(median_counts)
    # median_counts_ordered = np.array(median_counts)[indices]
    labels_ordered = np.array(labels_looped)[indices]
    # print(median_counts_ordered)
    # print(labels_looped)

    lookup_new_index = {before: after for after, before in zip(labels_looped, labels_ordered)}
    ordered_labels = [lookup_new_index.get(item,item) for item in labels]
    return ordered_labels

from scipy.interpolate import griddata
minx, maxx = np.min(positions_measured, axis=0)[0], np.max(positions_measured, axis=0)[0]
miny, maxy = np.min(positions_measured, axis=0)[1], np.max(positions_measured, axis=0)[1]
xgrid = np.linspace(minx, maxx, 200)
ygrid = np.linspace(miny, maxy, 200)
xgrid, ygrid = np.meshgrid(xgrid, ygrid)
print(np.min(xgrid.flatten()), np.max(xgrid.flatten()), np.min(ygrid.flatten()), np.max(ygrid.flatten()))


imgs = []
num = len(positions_measured)
for i in range(1,num+1):
# for i in range(1,40):
    fig, (ax1,ax3) = plt.subplots(1,2, figsize=(10,5))
    labels_ordered = order_cluster_labels_by_counts(positions_measured, cluster_label_history[i-1], counts)
    label_nearest_matrix = griddata(np.array(positions_measured[:i]), labels_ordered, (xgrid, ygrid), method='nearest')
    ax1.imshow(label_nearest_matrix, extent=[minx,maxx,miny,maxy], origin='lower', cmap='jet')
    ax1.scatter(*np.array(positions_measured[:i]).T, marker='.', c='k', s=1)
    # ax2.hist(cluster_label_history[i-1], bins=5)
    # ax2.bar([0,1,2,3,4], np.histogram(cluster_label_history[i-1], bins=5)[0], c=labels_ordered, cmap='tab10')
    ax3.imshow(data[i-1], origin='lower', cmap='gray_r')
    imgs.append(fig_to_pil(fig))
    plt.close()
    print(f'{i/num:.2%}%', end='\r')

imgs[0].save('sim_gr_expt.gif', format='GIF',
                append_images=imgs[1:],
                save_all=True, quality=300,
                duration=100, loop=0)
imgs[-1].save('sim_gr_expt_final.png', format='PNG')

-45.50715255737305 45.03228759765625 -42.68745040893555 45.02500915527344
100.00%%