In [None]:
# pm.render()

In [None]:
import pyvista as pv
from seagullmesh import Mesh3
import numpy as np
import pymeshfix
from ga_population import *
import numpy as np
from sklearn.decomposition import PCA
from torchvision.utils import make_grid
import torchvision.transforms.functional as F
from scipy.stats import pearsonr
import warnings
from tqdm.autonotebook import tqdm

pm = ProbeMesh.load(
    mesh_file=r"D:\mesh_extract_work\robin\robin_animated_v4_mesh_extraction.ply",
    index='robin_animated_v4_orig_Main.png',
    img_file=r"C:\corpus2_temp\robin_animated_v4_orig_Main.png",
    cam_pos=(-31.297409057617188, -16.788650512695312, 8.073495864868164),
    tgt_pos=(-0.15069210529327393, -2.3357276916503906, 4.40774393081665),
    repair=False,
    fill_holes=False,
    cache=True,
    recache=False,
    wrap=True,
    relative_alpha=600,
    relative_offset=1000,
    simplify_n_faces=5000,
    verbose=True,
)

pm.render()

In [None]:
mesh_file = Path(r"D:\mesh_extract_work\robin\robin_animated_v4_mesh_extraction.ply")
pv_mesh = pv.read(mesh_file).triangulate().clean()
pv_mesh.plot(show_edges=True)

In [None]:
sm_mesh = Mesh3.from_alpha_wrapping(pv_mesh.points, faces=pv_mesh.regular_faces, relative_alpha=600, relative_offset=1000)
sm_mesh.to_pyvista().plot(show_edges=True)

In [None]:
sm_mesh2 = sm_mesh.copy()
sm_mesh2.edge_collapse('face', 5000)
sm_mesh2.to_pyvista().plot(show_edges=True)

In [None]:
# Can't fix triangle orientation, the repair throws out most of the mesh

pm = ProbeMesh.load(
    mesh_file=r"D:\mesh_extract_work\nespr4esso\Coffee_Machine_v3_mesh_extraction.ply",
    index='Coffee_Machine_v3_orig.png',
    img_file=r"C:\corpus2_temp\Coffee_Machine_v3_orig.png",
    cam_pos=(0.4897646903991699, 1.369809627532959, 1.192300796508789),
    tgt_pos=(0.039121244102716446, 0.007926076650619507, 0.2581562101840973),
    repair=False,
    fill_holes=False,
    cache=True,
    recache=True,
    wrap=True,
    relative_alpha=600,
    relative_offset=1000,
    simplify_n_faces=3000,
)
pm.render()

In [None]:
models = SourceModels.current_trained()

In [None]:
# models = self = SourceModels.current_trained()
# cp_vertex_weights = models.get_weights(cp_pms, expt_kwargs=dict(outputs_at='vertices'))
# torch.save(cp_vertex_weights, 'temp_weights.pt')

cp_vertex_weights = torch.load('temp_weights.pt')

In [None]:
# # PRINCIPLE COMPONENTS - PLOT SEPARATE

# for pm, weights in zip(cp_pms, cp_vertex_weights):
#     pm.plot_weights(weights=weights.pca(), shape=(1, 3), window_size=(1500, 500), scalar_bar=True, titles=['PC1', 'PC2', 'PC3'], render=False)

In [None]:
# # PRINCIPLE COMPONENTS - PLOT RGB

# plotter = pv.Plotter(shape=(1, 3))
# for i, (pm, weights) in enumerate(zip(cp_pms, cp_vertex_weights)):
#     plotter.subplot(0, i)
#     plotter.camera = pm.camera
#     plotter.add_mesh(pm.mesh, scalars=weights.pca(), rgb=True)
# plotter.show()

In [None]:
# cp_pm_preds = models.get_predictions(pms)
# torch.save(cp_pm_preds, 'temp_preds.pt')

cp_pm_preds = torch.load('temp_preds.pt')

In [None]:
# cp_responses = models.cp_responses.loc[preds.index]
# cp_responses.to_hdf('cp_responses.hdf', key='cp_responses')

cp_responses = pd.read_hdf('cp_responses.hdf', key='cp_responses')

In [None]:
rvals = np.zeros(cp_pm_preds.shape[1])
for i, (unit_idx, r_mesh) in enumerate(cp_pm_preds.items()):
    r_corpus = cp_responses.loc[:, unit_idx]
    with warnings.catch_warnings():
        warnings.filterwarnings('ignore')
        rvals[i] = pearsonr(r_mesh, r_corpus).statistic
rvals[np.isnan(rvals)] = -1
best_units = cp_responses.columns[np.argsort(rvals)[::-1]]

In [None]:
def load_unit_best(model_idx: int, channel_idx: int, n_best: int):
    model = models[model_idx]
    meta = model.reader.metadata
    scenes, responses, weights, fit_fns = meta.load_data(weights=meta.weight_error)
    # cp_weights = pm_weights[model_idx]
    channel = meta.channel[channel_idx]
    unit_responses = responses.iloc[:, channel_idx].sort_values(ascending=False)
    unit_best_scenes = unit_responses.index[:n_best]

    # Load the meshes
    ga_pms = list(ProbeMesh.load_ga_stim(meta.opts.data_dir, scenes.loc[unit_best_scenes]))

    # Present the meshes to the network to get vertex weights
    dataset = GaDataset(
        df=scenes.loc[unit_best_scenes],  # type: ignore
        responses=responses.loc[unit_best_scenes],  # type: ignore
        root_dir=meta.opts.data_file.parent,
        k_eig=meta.k_eig,
        op_cache_dir=meta.opts.data_dir / 'op_cache',
        file_mode=meta.opts.mesh_file_mode,
        weights=meta.weight_error,
        use_visible=meta.use_visible,
        use_color=meta.use_color,
        norm_verts=meta.norm_verts,
        features=meta.input_features,
        augment=None,
    )
    dataloader = DataLoader(dataset, shuffle=False, batch_size=None)
    expt = model.reader.experiment(outputs_at='vertices')
    assert expt.model.outputs_at == 'vertices'
    _ga_obs, ga_weights = expt.predict(dataloader, agg_fn=lambda x: x)
    ga_weights = [VertexWeights(weights=w) for w in ga_weights]

    return ga_pms, ga_weights

In [None]:
def plot_images(pms: list[ProbeMesh], vertex_weights: list[VertexWeights], channel_idx: int, img_sz=(5, 5), title=None):
    img_sz = np.array(img_sz)
    n = len(pms)
    fig, axs = plt.subplots(2, n, figsize=(img_sz[0] * n, img_sz[1] * 2), squeeze=False)
    for axs_i, pm, vw in zip(axs.T, pms, vertex_weights):
        rendered_weights = pm.render(
            weights=vw.weights[:, channel_idx], 
            ground=False, 
            show_scalar_bar=False, 
            window_size=(1024, 1024),
        )
        rendered_image = PIL.Image.open(pm.img_file)
        
        for ax, img in zip(axs_i, (rendered_weights, rendered_image)):
            ax.imshow(img)
            ax.axis('off')

        if title:
            axs_i[1].set_title(title.format(index=pm.index))
            
    fig.tight_layout()
    return fig

In [None]:
# from matplotlib.backends.backend_pdf import PdfPages
# from tqdm.autonotebook import tqdm

# img_sz = (3, 3)
# n_best_units = 10
# n_best_ga_stim = 5

# with PdfPages('synthetic_mesh_model_predictions2.pdf') as pdf:
#     for model_idx, channel_idx in tqdm(best_units[:n_best_units]):
#         model = models[model_idx]
#         run_name = Path(model.trained_file).parts[2]
#         channel = model.reader.metadata.channel[channel_idx]
#         model_title = f'{run_name} ch{channel}'
    
#         ga_pms, ga_vertex_weights = load_unit_best(model_idx=model_idx, channel_idx=channel_idx, n_best=n_best_ga_stim)
#         fig = plot_images(ga_pms, ga_vertex_weights, channel_idx=channel_idx, img_sz=img_sz, title='Scene {index}')
#         fig.suptitle(model_title)
#         pdf.savefig(fig)
#         plt.close(fig)

#         idx = cp_responses.columns == (model_idx, channel_idx)
#         this_unit_vws = [VertexWeights(vw.weights[:, idx]) for vw in cp_vertex_weights]
#         fig = plot_images(cp_pms, this_unit_vws, channel_idx=0, img_sz=(3, 3))
#         # fig.suptitle(model_title)
#         pdf.savefig(fig)
#         plt.close(fig)

In [None]:
from tqdm.autonotebook import tqdm
from zipfile import ZipFile

img_sz = (3, 3)
n_best_units = 10
n_best_ga_stim = 5

with ZipFile('synthetic_mesh_model_predictions.zip', 'w') as zip_file:
    for model_idx, channel_idx in tqdm(best_units[:n_best_units]):
        model = models[model_idx]
        run_name = Path(model.trained_file).parts[2]
        channel = model.reader.metadata.channel[channel_idx]
        model_title = f'{run_name} ch{channel}'
    
        ga_pms, ga_vertex_weights = load_unit_best(model_idx=model_idx, channel_idx=channel_idx, n_best=n_best_ga_stim)
        fig = plot_images(ga_pms, ga_vertex_weights, channel_idx=channel_idx, img_sz=img_sz, title='Scene {index}')
        fig.suptitle(model_title)
        filename = Path(f'{model_title} - ga.png')
        fig.savefig(filename)
        zip_file.write(filename)
        filename.unlink()

        idx = cp_responses.columns == (model_idx, channel_idx)
        this_unit_vws = [VertexWeights(vw.weights[:, idx]) for vw in cp_vertex_weights]
        fig = plot_images(cp_pms, this_unit_vws, channel_idx=0, img_sz=(3, 3))
        # fig.suptitle(model_title)
        filename = Path(f'{model_title} - corpus.png')
        fig.savefig(filename)
        zip_file.write(filename)
        filename.unlink()