# Loading the state of the art v4 model

In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
from torchvision import transforms 

In [2]:
# suppress datajoint asking for credentials

os.environ["DJ_USER"] = ""
os.environ["DJ_HOST"] = ""
os.environ["DJ_PASS"] = ""

In [3]:
# get model
from nnvision.models.trained_models.v4_data_driven import v4_multihead_attention_model as model

Connecting @:3306
datajoint connection not established, skipping model imports from nnfabrik tables


In [4]:
model

Encoder(
  (core): SE2dCore(
    (_input_weights_regularizer): LaplaceL2norm(
      (laplace): Laplace()
    )
    (features): Sequential(
      (layer0): Sequential(
        (conv): Conv2d(1, 96, kernel_size=(9, 9), stride=(1, 1), bias=False)
        (norm): BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
        (nonlin): ELU(alpha=1.0, inplace=True)
      )
      (layer1): Sequential(
        (ds_conv): DepthSeparableConv2d(
          (in_depth_conv): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (spatial_conv): Conv2d(96, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=96, bias=False)
          (out_depth_conv): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (norm): BatchNorm2d(96, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
        (nonlin): ELU(alpha=1.0, inplace=True)
      )
      (layer2): Sequential(
        (ds_conv): DepthSeparableConv2d(
          (i

## Model predictions

In [5]:
# set model evaluation mode - as default, it is being instantiated in "train" mode
model.eval();

In [6]:
# the model was trained with images of size 100,100, and can only deal with these input sizes
example_input = torch.zeros(1,1,100,100)

In [7]:
with torch.no_grad():
    responses = model(example_input, data_key="all_sessions")

In [8]:
print(responses)

tensor([[0.5138, 0.1225, 0.4294,  ..., 0.3131, 0.1410, 0.5676]])


In [9]:
# 1244 neurons
responses.shape

torch.Size([1, 1244])

## Input transformations

In [10]:
tform = transforms.Compose([
    transforms.ToPILImage(), # Must convert to PIL image for subsequent operations to run
    transforms.ToTensor(), # Must convert to pytorch tensor for subsequent operations to run
    transforms.Normalize([0.4876], [0.2756,]), # images are grayscale, mean and std are taken from our actual monkey data
])

image = np.random.randn(100, 100).astype(np.uint8)
input_tensor = tform(image)

In [11]:
input_tensor.shape

torch.Size([1, 100, 100])

In [12]:
with torch.no_grad():
    responses = model(input_tensor.unsqueeze(0), data_key="all_sessions") # the model always expects dimensions (batch_size, 1, 100, 100)

In [13]:
print(responses.shape)

torch.Size([1, 1244])


## Using the GPU

In [14]:
if torch.cuda.is_available:
    model = model.cuda()

In [15]:
# also adding the .cuda() flag to the input
with torch.no_grad():
    responses = model(input_tensor.unsqueeze(0).cuda(), data_key="all_sessions") # the model always expects dimensions (batch_size, 1, 100, 100)