In [None]:
from cortexlib.mouse import CortexlabMouse

mouse = CortexlabMouse()

null_srv_all_neurons = mouse.compute_null_all_neurons(n_shuffles=100)
real_srv_all_neurons = mouse.compute_real_srv_all_neurons()
reliable_neuron_indices = mouse.get_reliable_neuron_indices(
            null_srv_all_neurons, real_srv_all_neurons, percentile_threshold=99)
neural_responses_mean, neural_responses, _ = mouse.get_responses_for_reliable_neurons(reliable_neuron_indices, real_srv_all_neurons, num_neurons=500)

mouse.plot_null_distribution_for_neuron(null_srv_all_neurons, neuron_index=0)
mouse.plot_real_srv_distribution(real_srv_all_neurons, reliable_neuron_indices)

In [None]:
from cortexlib.images import CortexlabImages

images = CortexlabImages()
images.plot_raw_image(int(mouse.stimulus_ids[0]))
image_dataset = images.load_images_shown_to_mouse(mouse.stimulus_ids)
images.show_sample(image_dataset, n=5)

In [None]:
from cortexlib.simclr import PreTrainedSimCLRModel

simclr = PreTrainedSimCLRModel()
simclr_features, _ = simclr.extract_features(image_dataset)

In [None]:
from cortexlib.predictor import NeuralResponsePredictor
import matplotlib.pyplot as plt
import pandas as pd

results = []

for layer, feats in simclr_features.items():
    for n_pcs in [None, 10, 20, 50, 100, 200, 300, 400, 500]:
        predictor = NeuralResponsePredictor(reduce_image_representation_to_n_pcs=n_pcs, neural_data_pc_index=0)
        r_squared = predictor.compute_r_squared(feats, neural_responses)
        fev = predictor.compute_fev(feats, neural_responses)

        results.append({
            'layer': layer,
            'n_pcs': n_pcs,
            'test_r2': r_squared['test_r2'],
            'mean_fev': fev['mean_fev']
        })
        
        print(f"SimCLR Layer: {layer}, Feats PCs: {n_pcs}, R^2: {r_squared['test_r2']:.4f} FEV: {fev['mean_fev']:.4f}")

# For each player, plot a line of FEV against number of PCs
results_df = pd.DataFrame(results)
for layer in results_df['layer'].unique():
    layer_results = results_df[results_df['layer'] == layer]
    plt.plot(layer_results['n_pcs'], layer_results['mean_fev'], label=layer)
    plt.xlabel('Number of PCs')
    plt.ylabel('Mean FEV')
    plt.title(f'FEV vs Number of PCs for {layer}')
    plt.legend()
    plt.show()