<a href="https://colab.research.google.com/github/mkielo3/mammalian_brains/blob/main/quickstart.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import sys
if 'google.colab' in sys.modules:
	%pip install -q numpy pandas scipy scikit-learn matplotlib seaborn pillow ipykernel torch torchvision torchaudio altair datasets numba numbasom==0.0.5 tensorboard pysal vl-convert-python soundfile

	!git clone https://github.com/mkielo3/mammalian_brains.git
	%cd mammalian_brains

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.6/56.6 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m61.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m47.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m40.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m11.5 MB/s[0m eta [36m0

In [2]:
import torch
import pandas as pd

from tqdm import tqdm
from numbasom.core import lattice_closest_vectors
from analysis.som import SOM
from utils import save_output_to_pickle
from models.olfaction.olfaction import Olfaction
from models.vision.vision import Vision
from models.audio.audio import Audio
from models.touch.touch import Touch
from models.memory.memory import Memory
from utils import save_som_plot, save_rf_plot

from config import Args
args = Args()


fast = True
args.som_size = (5,5) if fast else (25, 25)
args.experiment_name = "main_results_fast" if fast else "main_results"
args.fast = fast
args.som_epochs = 10000 if fast else 100000
modality_list = [Olfaction(args), Vision(args), Audio(args), Touch(args), Memory(args)]

In [3]:
for modality in modality_list:
	print ("\n", modality.modality, pd.Timestamp.now())

	# 1. Train/Download Model
	modality.setup_model()
	modality.setup_som()

	# 2. Get activations for each patch
	patches = modality.get_patches()
	activation_list = []
	for p in tqdm(patches):
		p, static = modality.generate_static(p)
		activation = modality.calculate_activations(static)
		activation_list.append([p, activation])

	# 3. Fit SOM
	x_mat = torch.stack([x[1] for x in activation_list]).numpy()
	som = modality.initialize_som(SOM)
	lattice = som.train(x_mat, num_iterations=args.som_epochs, initialize=args.som_init, normalize=False, start_lrate=args.som_lr)

	# 4. Get coordinates for each BMU
	coordinate_list = [x[0] for x in activation_list]
	closest = lattice_closest_vectors(x_mat, lattice, additional_list=coordinate_list)

	# 5. Save
	output = {"closest": closest,
			"coord_map": coordinate_list,
			"x_range": (0, max([x[0][0] for x in activation_list])),
			"y_range": (0, max([x[0][1] for x in activation_list])),
			"lattice": lattice,
			"som": None,
			"samples": modality.sample_data,
			"modality": modality.modality,
			"args": args,
			"activations": activation_list}

	save_output_to_pickle(output, args.experiment_name)



  model.load_state_dict(torch.load(f"{self.project_path}/models/olfaction/saved_data/model.pt", map_location=torch.device('cpu')))
  data = torch.load(f"{self.project_path}/models/olfaction/saved_data/val_dataset.pt", map_location=torch.device('cpu'))



 olfaction 2025-02-12 17:07:52.488785


100%|██████████| 4096/4096 [00:05<00:00, 696.86it/s]


Initializing SOM with Random
Done Init
SOM training took: 3.899162 seconds.
Finding closest data points took: 0.208529 seconds.
Saved output to: output/main_results_fast/olfaction.pkl

 vision 2025-02-12 17:08:02.884584


Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to /root/.cache/torch/hub/v0.10.0.zip
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 155MB/s]
  model.load_state_dict(torch.load(f"{self.project_path}/models/vision/saved_data/resnet_model.pt"))
  self.sample_data = torch.load(f"{self.project_path}/models/vision/saved_data/processed_images.pt")
100%|██████████| 225/225 [00:21<00:00, 10.42it/s]


Initializing SOM with Random
Done Init
SOM training took: 11.256604 seconds.
Finding closest data points took: 0.231657 seconds.
Saved output to: output/main_results_fast/vision.pkl

 audio 2025-02-12 17:08:37.662170


100%|██████████| 106k/106k [00:00<00:00, 21.6MB/s]
Downloading: "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_base_ls960_asr_ls960.pth" to /root/.cache/torch/hub/checkpoints/wav2vec2_fairseq_base_ls960_asr_ls960.pth
100%|██████████| 360M/360M [00:01<00:00, 281MB/s]
  model.load_state_dict(torch.load(f"{self.project_path}/models/audio/saved_data/wav2vec2_model.pt"))
  data = torch.load(f"{self.project_path}/models/audio/saved_data/waveform.pt").to('cpu')
100%|██████████| 205/205 [01:41<00:00,  2.02it/s]


Initializing SOM with Random
Done Init
SOM training took: 9.764035 seconds.
Finding closest data points took: 0.151955 seconds.
Saved output to: output/main_results_fast/audio.pkl

 touch 2025-02-12 17:10:34.603244
[TouchDataset] Refreshing tuples...


  model.model.load_state_dict(torch.load(f"{self.project_path}/models/touch/saved_data/model.pt", map_location=torch.device('cpu')))
  data = torch.load(f"{self.project_path}/models/touch/saved_data/val_dataset.pt", map_location=torch.device('cpu'))


Loaded "ObjectClusterDataset" - split "test" with 16119 records...
[TouchDataset] Refreshing tuples...
Base LR = 1.000000e-03
tensor(0.5800)


100%|██████████| 548/548 [00:03<00:00, 167.47it/s]


Initializing SOM with Random
Done Init
SOM training took: 3.444031 seconds.
Finding closest data points took: 0.181586 seconds.
Saved output to: output/main_results_fast/touch.pkl

 memory 2025-02-12 17:10:46.784998


100%|██████████| 4455/4455 [00:00<00:00, 77794.53it/s]


Initializing SOM with Random
Done Init
SOM training took: 0.199127 seconds.
Finding closest data points took: 0.081490 seconds.
Saved output to: output/main_results_fast/memory.pkl


In [4]:
save_som_plot(args.experiment_name, modality_list, args)

In [5]:
save_rf_plot(args.experiment_name, modality_list)

range(0, 5)
