In [1]:
%matplotlib qt5
import numpy as np
import h5py
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button
import sunpy.io.fits
import matplotlib.pyplot as plt
from pathlib import Path
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import MultipleLocator

  '{0}.{1}.{2}'.format(*version.hdf5_built_version_tuple)


In [2]:
base_path = Path('/home/harsh/SpinorNagaraju/maps_1/stic/processed_inputs/')

In [3]:
filename = 'alignedspectra_scan1_map01_Ca.fits_stic_profiles.nc'

In [4]:
f = h5py.File(base_path / filename, 'r')
ind = np.where(f['profiles'][0, 0, 0, :, 0] != 0)[0]
wave = f['wav'][ind]
data = f['profiles'][()][0, :, :, ind, 0]
f.close()

In [5]:
data = data.reshape(306, 19 * 60).T

In [68]:
profile_ind = 0
pixel_cat_arr = np.ones((19 * 60), dtype=np.int64) * -1
cat_data = np.zeros((30, 306), dtype=np.float64)
suggestion = -1
write_path = Path('/home/harsh/SpinorNagaraju/maps_1/stic/manual_categories/')
filename = 'manual_labels.txt'
index_prev_color = None

fig = plt.figure(figsize=(19, 19 * 9/16))
gs1 = GridSpec(5, 6)
gs1.update(left=0.05, bottom=0.4, right=0.99, top=0.95)
gs2 = GridSpec(1, 1)
gs2.update(left=0.05, bottom=0.05, right=0.5, top=0.35)

axs = list()
for k in range(30):
    axs.append(fig.add_subplot(gs1[k]))
axs.append(fig.add_subplot(gs2[0]))


btn_axs = list()

count = 0
start_x = 0.51
start_y = 0.3
while count < 36:
    k = 0
    while k < 8:
        addition = k * 0.06
        if count == 35:
            addition = (k + 1) * 0.06
        btn_axs.append(
            plt.axes(
                [start_x + addition, start_y, 0.05, 0.05]
            )
        )
        k += 1
        count += 1
        if count >= 36:
            break
    start_y -= 0.07

btn_text_list = list()
for i in range(30):
    btn_text_list.append('Cat {}'.format(i + 1))
btn_text_list.append('Load KMeans')
btn_text_list.append('Reset')
btn_text_list.append('Load')
btn_text_list.append('Previous')
btn_text_list.append('Next')
btn_text_list.append('Suggest')

btn_list = list()
for i in range(36):
    btn_list.append(Button(btn_axs[i], btn_text_list[i], color='gray'))

im = list()
for i in range(30):
    im.append(axs[i].plot(np.zeros(306), color='blue')[0])
    axs[i].set_xticklabels([])
im.append(axs[30].plot(data[profile_ind], color='blue')[0])

title = axs[-2].text(
    -1.23, -3.77,
    'Profile: {}'.format(profile_ind),
    transform=axs[-2].transAxes
)

suggest_text = axs[-2].text(
    -0.35, -3.77,
    '',
    transform=axs[-2].transAxes
)

for i in range(30):
    axs[i].text(
        0.45, 0.85,
        'Cat {}'.format(i + 1),
        transform=axs[i].transAxes
    )

def recalculate_categories():
    global pixel_cat_arr, cat_data
    for i in range(30):
        a = np.where(pixel_cat_arr == i)[0]
        if a.size <= 0:
            cat_data[i] = np.zeros(306)
            return
        cat_data[i] = np.mean(data[a], 0)

def prepare_update_cat_func(cat_num):
    def update_cat(*args, **kwargs):
        global pixel_cat_arr, cat_data
        pixel_cat_arr[profile_ind] = cat_num
#         recalculate_categories()
#         update_cat_plots()
        set_color()
        save_data()
        fig.canvas.draw_idle()
    return update_cat

def update_cat_plots():
    global cat_data
    for i in range(30):
        im[i].set_ydata(cat_data[i])
        axs[i].set_ylim(cat_data[i].min() * 0.95, cat_data[i].max() * 1.05)

def set_color():
    global index_prev_color
    if index_prev_color is not None:
        im[index_prev_color].set_color('blue')
        btn_list[index_prev_color].color = 'gray'
    im[pixel_cat_arr[profile_ind]].set_color('darkgreen')
    btn_list[pixel_cat_arr[profile_ind]].color = 'green'
    index_prev_color = pixel_cat_arr[profile_ind]
    fig.canvas.draw_idle()

def load_cat_from_file(*args, **kwargs):
    global pixel_cat_arr, cat_data
    pixel_cat_arr = np.loadtxt(write_path / filename).astype(np.int64)
    recalculate_categories()
    update_cat_plots()
    set_color()
    fig.canvas.draw_idle()

def update_prev(*args, **kwargs):
    global profile_ind
    if profile_ind == 0:
        return
    profile_ind -= 1
    im[-1].set_ydata(data[profile_ind])
    axs[-1].set_ylim(data[profile_ind].min() * 0.95, data[profile_ind].max() * 1.05)
    title.set_text('Profile: {}'.format(profile_ind))
    set_color()
    suggest_text.set_text('')
    fig.canvas.draw_idle()

def update_next(*args, **kwargs):
    global profile_ind
    total = 19 * 60 - 1
    if profile_ind == total:
        return
    profile_ind += 1
    im[-1].set_ydata(data[profile_ind])
    axs[-1].set_ylim(data[profile_ind].min() * 0.95, data[profile_ind].max() * 1.05)
    title.set_text('Profile: {}'.format(profile_ind))
    set_color()
    suggest_text.set_text('')
    fig.canvas.draw_idle()

def save_data(*args, **kwargs):
    global pixel_cat_arr
    np.savetxt(write_path / filename, pixel_cat_arr)

def reset_data(*args, **kwargs):
    global pixel_cat_arr, cat_data
    load_kmeans()
    save_data()

def suggest(*args, **kwargs):
    global suggestion
    suggestion = np.argmin(
        np.sqrt(
            np.sum(
                np.square(
                    np.subtract(
                        cat_data[:, 171:224], data[profile_ind][np.newaxis, 171:224]
                    )
                ),
                1
            )
        )
    )
    suggest_text.set_text('Cat {}'.format(suggestion + 1))

def load_kmeans(*args, **kwargs):
    global pixel_cat_arr
    f = h5py.File('/home/harsh/SpinorNagaraju/maps_1/stic/chosen_out_30.h5')
    pixel_cat_arr = f['final_labels'][()].reshape(19 * 60)
    recalculate_categories()
    update_cat_plots()
    set_color()
    f.close()

for i in range(30):
    btn_list[i].on_clicked(prepare_update_cat_func(i))
btn_list[30].on_clicked(load_kmeans)
btn_list[31].on_clicked(load_cat_from_file)
btn_list[32].on_clicked(reset_data)
btn_list[33].on_clicked(update_prev)
btn_list[34].on_clicked(update_next)
btn_list[35].on_clicked(suggest)

0