## Loading pretrained model weights
If the loading of parameters was successful, a message should be printed out saying `Loaded parameters from /PATH/TO/WEIGHTS`

Note that you will have to change the directories: `BASE_DIR` and `TORCH_HOME`

In [1]:
%load_ext autoreload
%autoreload 2

import os

from mouse_vision.core.model_loader_utils import load_model
from mouse_vision.models.model_paths import MODEL_PATHS

In [2]:
def load_pretrained_model(model_name):
    model_path = MODEL_PATHS[model_name]
    assert os.path.isfile(model_path)

    model, layers = load_model(
        model_name, 
        trained=True, 
        model_path=model_path, 
        model_family="imagenet",
        state_dict_key="model_state_dict",  # make sure `model_state_dict` is in the *.pt file
    )
    
    return model, layers

### AlexNet (instance recognition)

In [3]:
name = "alexnet_bn_ir_64x64_input_pool_6"
model, model_layers = load_pretrained_model(name)
print("======= Model architecture =======\n", model)
print(f"======= Model layers =======\n{model_layers}")

Loading alexnet_bn_ir_64x64_input_pool_6. Pretrained: True. Model Family: imagenet.
Loaded parameters from /home/nclkong/plos_mouse_vision/mouse-vision/model_ckpts/alexnet_bn_ir.pt
 AlexNetBN(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (5): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): Conv2d(384, 256, ker

### MouseNet of Shi et al. (instance recognition)

#### a) Original architecture

In [4]:
name = "shi_mousenet_ir"
model, model_layers = load_pretrained_model(name)

Loading shi_mousenet_ir. Pretrained: True. Model Family: imagenet.
Loaded parameters from /home/nclkong/plos_mouse_vision/mouse-vision/model_ckpts/shi_mousenet_ir.pt


#### b) Our variant

In [5]:
name = "shi_mousenet_vispor5_ir"
model, model_layers = load_pretrained_model(name)

Loading shi_mousenet_vispor5_ir. Pretrained: True. Model Family: imagenet.
Loaded parameters from /home/nclkong/plos_mouse_vision/mouse-vision/model_ckpts/shi_mousenet_vispor5_ir.pt


### Dual stream (instance recognition)

In [6]:
name = "simplified_mousenet_dual_stream_visp_3x3_ir"
model, model_layers = load_pretrained_model(name)

Loading simplified_mousenet_dual_stream_visp_3x3_ir. Pretrained: True. Model Family: imagenet.
Single stream set to False
Using {'type': 'BN'} normalization
Loaded parameters from /home/nclkong/plos_mouse_vision/mouse-vision/model_ckpts/dual_stream_ir.pt


### Six stream (SimCLR)

In [7]:
name = "simplified_mousenet_six_stream_visp_3x3_simclr"
model, model_layers = load_pretrained_model(name)

Loading simplified_mousenet_six_stream_visp_3x3_simclr. Pretrained: True. Model Family: imagenet.
Single stream set to False
Using {'type': 'SyncBN'} normalization
Loaded parameters from /home/nclkong/plos_mouse_vision/mouse-vision/model_ckpts/six_stream_simclr.pt
