In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))

In [None]:
from pathlib import Path
import numpy as np
import copy
import torch
from tqdm.notebook import tqdm
import pickle
import matplotlib.pyplot as plt

In [None]:
import inspect
inspect.getfile(pickle)

In [None]:
## Directory containing cloned repos like GCaMP_ROI_classifier and basic_neural_processing_modules
dir_github = Path(r'/n/data1/hms/neurobio/sabatini/gyu/github_clone/Mothership_Zeta/MZ').resolve()

dir_analysisFiles = Path(r'D:\RH_local\data\BMI_cage_g2F\mouse_g2FB\20221111\analysis_data').resolve()

## Directory with F.npy, stat.npy etc.
dir_s2p         = Path(r'/n/data1/hms/neurobio/sabatini/gyu/analysis/suite2p_output/mouse_g2FB/20221113/scanimage_data/exp/jobNum_0/suite2p/plane0').resolve()

## Directory to save outputs from this notebook like iscell
dir_save = dir_analysisFiles


## NN fileNames
fileName_NN_pth = 'ResNet18_simCLR_model_202112078_EOD_transfmod=norm.pth' # name of pth file in dir_NNmodels directory
fileName_NN_py  = 'ResNet18_simCLR_model_202112078_EOD_transfmod=norm' # EXCLUDE THE .PY AT THE END. name of py file in dir_NNmodels directory.
fileName_classifier = 'logreg_model_0.01.pkl' # path to logististic classifier pickle file in dir_classifiers

In [None]:
dir_github

In [None]:
str(dir_github / 'classify/GCaMP_ROI_classifier')

In [None]:
## Directories of Classifier stuff
dir_classify = dir_github / 'classify'
dir_GRC_repo = dir_classify / 'GCaMP_ROI_classifier'
dir_GRC_EndUser = dir_GRC_repo / 'End_User'
dir_NNmodels = dir_GRC_EndUser / 'simclr-models'
dir_classifiers = dir_GRC_EndUser / 'classifier-models'
dir_GRC_util = dir_GRC_repo / 'new_stuff'

## Paths to NN and LR classifiers
path_NN_pth = dir_NNmodels / fileName_NN_pth
path_NN_py = dir_NNmodels / fileName_NN_py
path_classifier = dir_classifiers / fileName_classifier

path_statFile = dir_s2p / 'stat.npy'
path_opsFile = dir_s2p / 'ops.npy'

In [None]:
test

In [None]:
for session in test.rglob('*'):
    if "stat.npy" in str(session):
        print(str(session))

In [None]:
import sys
sys.path.append(str(dir_github))
sys.path.append(str(dir_classify))
# sys.path.append(str(dir_GRC_repo))
# sys.path.append(str(dir_GRC_util))

%load_ext autoreload
%autoreload 2
from GCaMP_ROI_classifier.new_stuff import util
# from basic_neural_processing_modules import *
from utils.basic_neural_processing_modules import torch_helpers, plotting_helpers, file_helpers

In [None]:
## Device to use for NN model
DEVICE = torch_helpers.set_device(use_GPU=True)

In [None]:
## TODO: Troubleshoot the runtime on this
# def drop_nan_imgs(rois):
#     ROIs_without_NaNs = torch.where(~torch.any(torch.any(torch.isnan(rois), dim=1), dim=1))[0]
#     return rois[ROIs_without_NaNs]

def dataloader_to_latents(dataloader, model, DEVICE='cpu'):
    def subset_to_latents(data):
        return model.get_head(model.base_model(data[0][0].to(DEVICE))).detach().cpu()
    return torch.cat([subset_to_latents(data) for data in tqdm(dataloader)], dim=0)

def load_classifier_model(classifier_name):
    with open(classifier_name, 'rb') as classifier_model_file:
        classifier = pickle.load(classifier_model_file)
    return classifier

In [None]:
spatial_footprints = torch.as_tensor(
    util.statFile_to_spatialFootprints(path_statFile, out_height_width=[36,36], max_footprint_width=455)
)

spatial_footprints = spatial_footprints / torch.sum(spatial_footprints, dim=(1,2), keepdim=True)

# spatial_footprints = drop_nan_imgs(spatial_footprints)
print(spatial_footprints.shape[0], 'ROIs loaded.')

In [None]:
# Instantiate Model
import importlib
# model_file = importlib.util.spec_from_file_location('path_NN_py')
sys.path.append(str(dir_NNmodels))
model_file = importlib.import_module(fileName_NN_py)
# model_file = importlib.import_module(tester)
model = model_file.get_model(path_NN_pth)
model.eval();

In [None]:
# Create Data Sets / Data Loaders
dataset, dataloader = model_file.get_dataset_dataloader(spatial_footprints, batch_size=64, device=DEVICE) ## TODO: Troubleshoot the runtime on this

In [None]:
model.to(DEVICE);

In [None]:
# Get Model Latents
latents = dataloader_to_latents(dataloader, model, DEVICE=DEVICE).numpy()

In [None]:
# Load Logistic Model
classifier_model = load_classifier_model(path_classifier)

In [None]:
# Predict ROIs — Save to File
proba = classifier_model.predict_proba(latents)
preds = np.argmax(proba, axis=-1)
uncertainty = util.loss_uncertainty(torch.as_tensor(proba), temperature=1, class_value=None).detach().cpu().numpy()
    
params = classifier_model.get_params()

In [None]:
ROI_classifier_outputs = {
    'latents': latents,
    'proba': proba,
    'preds': preds,
    'uncertainty': uncertainty,
    'LR_params': params
}

In [None]:
%matplotlib inline
plt.figure()
plt.hist(preds, 50);

In [None]:
%matplotlib inline
plt.figure()
plt.hist(preds, 50);

In [None]:
len(iscell_NN)

In [None]:
preds_toUse = [0,1]

iscell_NN = np.isin(preds, preds_toUse)
iscell_NN_idx = np.where(iscell_NN)[0]

print(f'number of included ROIs: {len(iscell_NN_idx)}')

In [None]:
%matplotlib notebook

grid_shape = (7,7)

print('including')
plotting_helpers.plot_image_grid(
    spatial_footprints[np.random.choice(iscell_NN_idx, np.prod(grid_shape))],
    grid_shape=grid_shape, 
    show_axis='off', 
);

print('excluding')
plotting_helpers.plot_image_grid(
    spatial_footprints[np.random.choice(np.where(~iscell_NN)[0], np.prod(grid_shape))],
    grid_shape=grid_shape, 
    show_axis='off', 
);

In [None]:
%load_ext autoreload
%autoreload 2
from Big_Ugly_ROI_Tracker.multiEps.multiEps_modules import *

sf_toShow = spatial_footprints[iscell_NN]

%matplotlib notebook
display_toggle_image_stack(sf_toShow, clim=None)

In [None]:
np.save(
    file=dir_save / 'iscell_NN.npy',
    arr=iscell_NN
)

# pickle_helpers.simple_save(
#     obj=ROI_classifier_outputs,
#     filename=dir_save / 'ROI_classifier_outputs.pkl'
# )
file_helpers.pickle_save(
    obj=ROI_classifier_outputs,
    path_save=dir_save / 'ROI_classifier_outputs.pkl'
)

In [None]:
sys.path.append('/n/data1/hms/neurobio/sabatini/gyu/github_clone')

In [None]:
import Mothership_Zeta.MZ.extract_process

In [None]:
sf = convert_stat_to_sparse_spatial_footprints(path_statFile, path_ops=path_opsFile, normalize='max')

sf_classes = [sf[preds==ii].sum(0).todense() for ii in np.unique(preds)]

%matplotlib notebook
display_toggle_image_stack(sf_classes, clim=None)

In [None]:
sf = convert_stat_to_sparse_spatial_footprints(path_statFile, path_ops=path_opsFile, normalize='max')

sf_classes = [sf[preds==ii].sum(0).todense() for ii in np.unique(preds)]

%matplotlib notebook
display_toggle_image_stack(sf_classes, clim=None)