In [None]:
import torch
from tqdm import tqdm
from numbasom.core import lattice_closest_vectors
from analysis.som import SOM
from utils import save_output_to_pickle

from models.olfaction.olfaction import Olfaction
from models.vision.vision import Vision
from models.audio.audio import Audio
from models.touch.touch import Touch
from models.memory.memory import Memory

from config import Args

In [None]:
for lr in [0.01, 0.1, 0.2]:
    args = Args()
    args.setup_models = False # Set to True if first ever time running
    args.experiment_name = "main_results_lr_sweep"
    args.som_lr = lr

    for modality in ([Olfaction(args), Vision(args), Audio(args), Touch(args), Memory(args)]):
        print ("\n", modality.modality, pd.Timestamp.now())
        
        # 1. Train/Download Model
        modality.setup_model()
        modality.setup_som()

        # 2. Get activations for each patch
        patches = modality.get_patches()
        activation_list = []
        for p in tqdm(patches):
            p, static = modality.generate_static(p)
            activation = modality.calculate_activations(static)
            activation_list.append([p, activation])

        # 3. Fit SOM
        x_mat = torch.stack([x[1] for x in activation_list]).numpy()
        som = modality.initialize_som(SOM)
        lattice = som.train(x_mat, num_iterations=args.som_epochs, initialize=args.som_init, normalize=False, start_lrate=args.som_lr)
        
        # 4. Get coordinates for each BMU
        coordinate_list = [x[0] for x in activation_list]
        closest = lattice_closest_vectors(x_mat, lattice, additional_list=coordinate_list)

        # 5. Save
        output = {"closest": closest, 
                "coord_map": coordinate_list,
                "x_range": (0, max([x[0][0] for x in activation_list])),
                "y_range": (0, max([x[0][1] for x in activation_list])),
                "lattice": lattice,
                "som": None,
                "samples": modality.sample_data,
                "modality": modality.modality,
                "args": args,
                "activations": activation_list
                }

        save_output_to_pickle(output, args.experiment_name)

In [2]:
import torch

imgs = torch.load(f"models/vision/saved_data/processed_images.pt")

torch.Size([100, 3, 224, 224])

In [8]:
imgs_subset = imgs[:50].clone().contiguous()
torch.save(imgs_subset, "models/vision/saved_data/processed_images.pt")

In [7]:

imgs_subset.shape

torch.Size([50, 3, 224, 224])

In [18]:
data = torch.load(f"models/olfaction/saved_data/val_dataset.pt", map_location='cpu')

In [17]:
imgs_subset = data.clone().contiguous()
torch.save(imgs_subset, "models/olfaction/saved_data/val_dataset.pt")


In [16]:
!ls -l

total 3280
-rw-r--r--  1 matthewkielo  staff     2745 Feb  7 00:53 README.md
drwxr-xr-x  6 matthewkielo  staff      192 Feb  7 00:53 [34manalysis[m[m
-rw-r--r--  1 matthewkielo  staff     2110 Feb  7 00:57 config.py
-rw-r--r--  1 matthewkielo  staff  1639511 Feb  7 01:02 foobar.pt
-rw-r--r--  1 matthewkielo  staff     2157 Feb  7 00:53 main.py
drwxr-xr-x  9 matthewkielo  staff      288 Feb  7 00:53 [34mmodels[m[m
-rw-r--r--  1 matthewkielo  staff     4260 Feb  7 00:57 quickstart.ipynb
-rw-r--r--  1 matthewkielo  staff     7420 Feb  7 00:53 requirements.txt
-rw-r--r--  1 matthewkielo  staff      689 Feb  7 00:53 setup.py
-rw-r--r--  1 matthewkielo  staff      507 Feb  7 00:53 utils.py


In [None]:
1639511
163841200