In [1]:
from ga_regression import *
from matplotlib import pyplot as plt
from functools import partial
import PIL
import matplotlib as mpl
from matplotlib.cm import ScalarMappable
from pvutils import iter_subplots
from pyvista import PolyData
import pyvista as pv
from scipy.stats import pearsonr
from typing import cast

  from pandas.core.computation.check import NUMEXPR_INSTALLED


In [2]:
class Summarizer:
    def __init__(self, reader, cp_file: Path):
        self.reader = reader
        self.response_bar_width = 0.025
        self.cp_file = cp_file

        self.cp_scenes, self.cp_responses, *_ = GaDataset.load_data(
            data_file=cp_file,
            channel=reader.metadata.channel,
            file_mode='render',
            spike_window=reader.metadata.spike_window,
            weight_error=None,
            n_faces=None,
            features=None,
            n_min_reps=None,
        )

        self.expt = reader.experiment()
        self.dataset = reader.metadata.load_dataset(weights=None, augment=None)

    def channel(self, ch_idx: int) -> int:
        return self.reader.metadata.channel[ch_idx]

    def all_channels(self):
        fig, axs = plt.subplot_mosaic(
            [['loss', 'loss'],['scatter_train', 'scatter_test']], 
            constrained_layout=True,
            figsize=(10, 8),
        )
        self.reader.plot_training(ax=axs['loss'])
        axs['loss'].legend()
        self.reader.scatter_plot(channel=None, axs=[axs['scatter_train'], axs['scatter_test']])
        return fig

    def channel_scatter(self, ch_idx: int):
        fig, axs = plt.subplots(1, 2, figsize=(10, 5))
        ch = self.channel(ch_idx)
        self.reader.scatter_plot(channel=ch, axs=axs)
        fig.suptitle(f'Channel {ch}')
        return fig

    def img_with_response_bar(self, ax, img, norm_resp: float, bar_color: str):
        ax.imshow(img, extent=(0, 1, 0, 1))
        ax.axis('off')
        rect = mpl.patches.Rectangle(xy=(0, 0), width=self.response_bar_width, height=norm_resp, facecolor=bar_color)
        ax.add_patch(rect)

    def best_corpus_by_channel(self, ch_idx: int, grid_shape: tuple[int, int] = (2, 5), subfig_sz=4):
        ch = self.channel(ch_idx)
        ch_responses = self.cp_responses.iloc[:, ch_idx].sort_values(ascending=False)
        n = np.prod(grid_shape)
        figsize = np.array(grid_shape[::-1]) * subfig_sz
        fig, axs = plt.subplots(*grid_shape, figsize=figsize)
        
        for (stim_id, norm_resp), ax in zip(ch_responses[:n].items(), axs.reshape(-1)):
            img_file = self.cp_file.parent / self.cp_scenes.render.loc[stim_id]
            img = PIL.Image.open(img_file)
            self.img_with_response_bar(ax, img, norm_resp, bar_color='red')

        fig.tight_layout()
        return fig

    def weighted_meshes(self, ch_idx: int, n: int, subfig_sz=4):
        ch = self.channel(ch_idx)
        obs, preds = self.reader.scatter_data.loc(channel=ch, scene_ids=None)
        obs_preds = np.stack([obs, preds], axis=1)
        priority_idx = obs_preds.sum(axis=1).argsort()[::-1]
        figsize = np.array([n, 2]) * subfig_sz
        fig, axs = plt.subplots(2, n, figsize=figsize, squeeze=False)
    
        for stim_idx, axs_i in zip(priority_idx, axs.T):
            p, mesh, render_img, mesh_img = self.expt.load_mesh_img(
                dataset=self.dataset, stim_idx=stim_idx, ch_idx=ch_idx, upsample=False, background_color=None)
            
            for ax, img, norm_resp, bar_color in zip(axs_i, (render_img, mesh_img), obs_preds[stim_idx], ('red', 'blue')):
                self.img_with_response_bar(ax, img, norm_resp=norm_resp, bar_color=bar_color)

        fig.tight_layout()
        return fig

# Path(r"D:\resynth\run_20_21\run00020_resynth\2025-07-29-15-37-40\opts_and_metadata.pt")

data_file = Path(r"D:\resynth\run_38_39\run00038_resynth\2025-07-30-11-58-55\opts_and_metadata.pt")
cp_file = Path(r"D:\resynth\run_38_39\run00039_exported\run00039_exported.hdf")
idx = 1

data_file = Path(r"D:\resynth\run_48_49\run00048_resynth\2025-08-03-12-13-13\opts_and_metadata.pt")
cp_file = Path(r"D:\resynth\run_48_49\run00049_exported\run00049_exported.hdf")
idx = 0

readers = Readers.from_file(data_file)
s = Summarizer(readers[idx], cp_file=cp_file)

# TODO: 
- also dump hyperparameters log into pdf
- how to get the corpus image id? need that to find the population vector

In [3]:
from matplotlib.backends.backend_pdf import PdfPages

name = s.reader.metadata.opts.data_file.stem.split('_')[0]

with PdfPages(name + '.pdf') as pdf:
    fig = s.all_channels()
    fig.suptitle(f"{name} all channels training")
    pdf.savefig(fig)
    plt.close(fig)

    for ch_idx in tqdm(range(len(s.reader.metadata.channel))):
        ch = s.channel(ch_idx)
        fig = s.channel_scatter(ch_idx=ch_idx)
        fig.suptitle(f'Channel {ch}')
        pdf.savefig(fig)
        plt.close(fig)

        fig = s.best_corpus_by_channel(ch_idx=ch_idx, grid_shape=(2, 5))
        fig.suptitle(f'Channel {ch} best corpus stimuli')
        pdf.savefig(fig)
        plt.close(fig)
        
        fig = s.weighted_meshes(ch_idx=ch_idx, n=5)
        fig.suptitle(f'Channel {ch} best GA stimuli')
        pdf.savefig(fig)
        plt.close(fig)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 15/15 [05:03<00:00, 20.24s/it]
