#### TODO:
1. Work on Assembly in "# get activations" part instead of the StimulusSet
2. Translate to ``behavior.py`` and the test.

In [16]:
import functools

import numpy as np

import brainscore_vision
from brainio.assemblies import DataAssembly, BehavioralAssembly, walk_coords
from brainscore_vision.benchmark_helpers.screen import place_on_screen
from brainscore_vision.model_helpers.activations import PytorchWrapper
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment

In [17]:
def pytorch_custom():
    import torch
    from torch import nn
    from brainscore_vision.model_helpers.activations.pytorch import load_preprocess_images

    class MyModel(nn.Module):
        def __init__(self):
            super(MyModel, self).__init__()
            np.random.seed(0)
            torch.random.manual_seed(0)
            self.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=2, kernel_size=3)
            self.relu1 = torch.nn.ReLU()
            linear_input_size = np.power((224 - 3 + 2 * 0) / 1 + 1, 2) * 2
            self.linear = torch.nn.Linear(int(linear_input_size), 1000)
            self.relu2 = torch.nn.ReLU()

        def forward(self, x):
            x = self.conv1(x)
            x = self.relu1(x)
            x = x.view(x.size(0), -1)
            x = self.linear(x)
            x = self.relu2(x)
            return x

    preprocessing = functools.partial(load_preprocess_images, image_size=224)
    return PytorchWrapper(model=MyModel(), preprocessing=preprocessing)

In [81]:
def calculate_similarity_matrix(features, similarity_measure='dot'):
   print(features, type(features))
   features = features.transpose('presentation', 'neuroid')
   values = features.values
   if similarity_measure == 'dot':
      similarity_matrix = np.dot(values, np.transpose(values))
   elif similarity_measure == 'cosine':
      row_norms = np.linalg.norm(values, axis=1).reshape(-1, 1)
      norm_product = np.dot(row_norms, row_norms.T)
      dot_product = np.dot(values, np.transpose(values))
      similarity_matrix = dot_product / norm_product
   else:
      raise ValueError(
      f"Unknown similarity_measure {similarity_measure} -- expected one of 'dot' or 'cosine'")

   similarity_matrix = DataAssembly(similarity_matrix, coords={
        **{f"{coord}_left": ('presentation_left', values) for coord, _, values in
           walk_coords(features['presentation'])},
        **{f"{coord}_right": ('presentation_right', values) for coord, _, values in
           walk_coords(features['presentation'])}
   }, dims=['presentation_left', 'presentation_right'])
   return similarity_matrix

In [82]:
def calculate_choices(similarity_matrix, triplets):
    triplets = np.array(triplets).reshape(-1, 3)
    choice_predictions = []
    for triplet in triplets:
        i, j, k = triplet
        sims = similarity_matrix[i, j], similarity_matrix[i, k],  similarity_matrix[j, k]
        idx = triplet[2 - np.argmax(sims)]
        choice_predictions.append(idx)
    # TODO return as DataAssembly
    return choice_predictions

In [20]:
import numpy as np
from brainscore_vision import load_stimulus_set, load_dataset 

assembly = load_dataset('Hebart2023')
stimulus_set = load_stimulus_set("Hebart2023")
triplets = np.array([
    assembly.coords["image_1"].values,
    assembly.coords["image_2"].values,
    assembly.coords["image_3"].values
]).T.reshape(-1, 1)

triplets = np.array([f"{triplet[0]}.jpg" for triplet in triplets])

In [21]:
# create model
activations_model = pytorch_custom()
layers = ["relu2"]

# create brain model
brain_model = ModelCommitment(
    identifier=activations_model.identifier, 
    activations_model=activations_model, 
    layers=[None], 
    behavioral_readout_layer='relu2')

# get activations
assy = brainscore_vision.load_dataset(f'Hebart2023')
stimuli = place_on_screen(
    stimulus_set=assy.stimulus_set,
    target_visual_degrees=brain_model.visual_degrees(),
    source_visual_degrees=8)

In [23]:
stimuli

Unnamed: 0,stimulus_id,top_down_1,rank,Wordnet_ID4,unique_id,example_image,top_down_2,filename,Wordnet_ID2,dispersion,bottom_up,word_freq,dominant_part,freq_1,WordNet_synonyms,freq_2,WordNet_ID,Wordnet_ID3,word_freq_online
0,0,animal,51507.0,aardvark.n.01,aardvark,https://imgur.com/LAJGlN0,animal,0.jpg,aardvark%1:05:00::,0.78,animal,28.0,Noun,,"aardvark, ant_bear, anteater, Orycteropus_afer",21.0,n02082791,aardvark#1,53
1,1,,34578.0,abacus.n.02,abacus,https://imgur.com/peZeM0l,home decor,1.jpg,abacus%1:06:00::,0.86,,97.0,Noun,,abacus,12.0,n02666196,abacus#2,188
2,2,musical instrument,15132.0,accordion.n.01,accordion,https://imgur.com/GgGvdZR,musical instrument,2.jpg,accordion%1:06:00::,0.90,musical instrument,735.0,Noun,,"accordion, piano_accordion, squeeze_box",67.0,n02672831,accordion#1,816
3,3,fruit,16007.0,acorn.n.01,acorn,https://imgur.com/YfIB5lM,,3.jpg,acorn%1:20:00::,0.85,,692.0,Noun,238.0,acorn,37.0,n12267677,acorn#1,1289
4,4,,,air_conditioner.n.01,air_conditioner,https://imgur.com/KqYNwWH,electronic device,4.jpg,air_conditioner%1:06:00::,,,,,,"air_conditioner, air_conditioning",0.0,n02686379,air_conditioner#1,943
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1849,1849,,16792.0,yoke.n.07,yoke,https://imgur.com/nOt3K3f,,1849.jpg,yoke%1:06:00::,0.90,,597.0,Noun,143.0,yoke,22.0,n04612840,yoke#7,692
1850,1850,food,11647.0,yolk.n.01,yolk,https://imgur.com/gWY0jPO,food,1850.jpg,yolk%1:13:00::,0.89,,1224.0,Noun,108.0,"egg_yolk, yolk",21.0,n07841345,yolk#1,881
1851,1851,animal,14397.0,zebra.n.01,zebra,https://imgur.com/xg5AAHb,animal,1851.jpg,zebra%1:05:00::,0.87,animal,839.0,Noun,224.0,zebra,128.0,n02391049,zebra#1,1066
1852,1852,,10687.0,zipper.n.01,zipper,https://imgur.com/T2RLBxe,,1852.jpg,zipper%1:06:00::,0.88,,1452.0,Noun,62.0,"slide_fastener, zip, zipper, zip_fastener",144.0,n04238321,zipper#1,1478


In [24]:
# determine unique stimuli with numpy
unique_stimuli = np.unique(triplets)
features = activations_model(unique_stimuli, layers=layers)
features = features.transpose('presentation', 'neuroid')

activations:   0%|          | 0/1856 [00:00<?, ?it/s]

FileNotFoundError: [Errno 2] No such file or directory: '1146.jpg'

In [None]:
# TODO
assy = brainscore_vision.load_dataset(f'Hebart2023')
triplets = np.array([assy['image_1'], assy['image_2'], assy['image_3']]).T
triplets = triplets.reshape(-1)
sample = triplets[:10*3]
sim = calculate_similarity_matrix(features, similarity_measure='cosine')
choices = calculate_choices(similarity_matrix=sim, triplets=sample)