In [None]:
import torch
import pandas as pd
from scripts.no_training import (
    sample_images,
    predict_images,
    process_points_results, 
    log_results, 
    get_data, 
    process_shapes_results,
    predict_images_without_deciding
)

%matplotlib inline
import seaborn as sns
blue_yellow = ["#FFD700", "#0000FF"]
# blue_yellow.reverse()
sns.set_palette(blue_yellow)

device = torch.device("cpu")
dtype = torch.float32

num_passes = 4
pairs_num = 100

In [None]:
connections, shuffled_connections, all_neurons, neuron_data, all_coords = get_data()
neurons_in_coords = all_neurons.merge(all_coords, on="root_id", how="right")[
    ["root_id", "cell_type"]
].fillna("Unknown")

# Set all cell_types with less than "20" samples to "others"
n = 20

counts = neurons_in_coords["cell_type"].value_counts()

small_categories = counts[counts < n].index
neurons_in_coords["cell_type"] = neurons_in_coords["cell_type"].apply(
    lambda x: "others" if x in small_categories else x
)

In [None]:
base_dir = "images/zero_to_five/train"
sub_dirs = ["yellow", "blue"]
sampled_images = sample_images(base_dir, sub_dirs, pairs_num)

In [None]:
predictions = predict_images_without_deciding(
    sampled_images, neuron_data, connections, all_coords, num_passes
)
predictions["cell_type"] = neurons_in_coords["cell_type"]

In [None]:
means = pd.DataFrame(predictions.groupby("cell_type").mean()).T
means["num_points"] = [int(a.split("_")[1]) + int(a.split("_")[2]) for a in means.index]

In [None]:
# take the correlation of each column with "num_points"
correlations = means.corr()["num_points"].sort_values(ascending=False)

In [None]:
correlations

In [None]:
# predictions = predictions.drop(columns=["cell_type"])
df = predictions.T
df["num_points"] = [int(a.split("_")[1]) + int(a.split("_")[2]) for a in df.index]

In [None]:
means = df.groupby("num_points").mean()
# normalize all columns to 0-1
means = (means - means.min()) / (means.max() - means.min())
# remove columns with missing data
means = means.dropna(axis=1)

In [None]:
means["one_tuning_curve"] = [1, 0.4, 0.3, 0.2, 0.1]

In [None]:
# get only columns where the first row is the biggest
temp = means.T
temp[temp[1] == temp.max(axis=1)]

In [None]:
correlations = means.corr()["one_tuning_curve"].sort_values(ascending=False)
correlations

## Shuffling

In [None]:
predictions = predict_images(
    sampled_images,
    neuron_data,
    shuffled_connections,
    all_coords,
    all_neurons,
    num_passes,
)

In [None]:
results = process_points_results(predictions)
fig, acc = log_results(results, "points", shuffled=True)
fig

In [None]:
print(f"accuracy = {acc}")

# With shapes

In [None]:
base_dir = "images/black_blue_80_110_jitter/train/two_shapes"
sub_dirs = ["circle", "triangle"]

sampled_images = sample_images(base_dir, sub_dirs, pairs_num)
predictions = predict_images(
    sampled_images, neuron_data, connections, all_coords, all_neurons, num_passes
)

In [None]:
results = process_shapes_results(predictions, sampled_images)
fig, acc = log_results(results, "shapes")

In [None]:
fig

In [None]:
print(f"accuracy = {acc}")

# Reshuffle de pesos

In [None]:
predictions = predict_images(
    sampled_images, neuron_data, shuffled_connections, all_coords, all_neurons, num_passes
)

In [None]:
results = process_shapes_results(predictions, sampled_images)
fig, acc = log_results(results, "shapes", shuffled=True)
fig

In [None]:
print(f"accuracy = {acc}")

In [None]:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 2, figsize=(10, 6))

In [None]:
"_hola".replace("_", " ")