### Warning: This notebook uses the full `spacetorch` installation instead of the src/ demo code!

In [1]:
from pathlib import Path

import torch
import torch.nn as nn

from vissl.config import AttrDict
from vissl.models.heads.linear_eval_mlp import LinearEvalMLP

from spacetorch.datasets.imagenet import imagenet_validation_performance
from spacetorch.models.trunks.resnet import VisslResNet
from spacetorch.paths import CHECKPOINT_DIR

** fvcore version of PathManager will be deprecated soon. **
** Please migrate to the version in iopath repo. **
https://github.com/facebookresearch/iopath 



## construct path to linear eval checkpoint

In [2]:
LIN_EVAL_DIR = CHECKPOINT_DIR / "linear_eval"
weight_dir = Path(
    LIN_EVAL_DIR / "relu_rescue__simclr_spatial_resnet18_swappedon_SineGrating2019_isoswap_3_linear_eval_checkpoints"
)
weight_path = weight_dir / "model_final_checkpoint_phase27.torch"
assert weight_path.is_file()

In [3]:
ckpt = torch.load(weight_path, map_location='cpu')['classy_state_dict']['base_model']['model']

## Reconstruct trunk (self-supervised) and head (linear readout)

In [4]:
# specify the model config for trunk and head
model_config = AttrDict(
    {
        "TRUNK": {
            "NAME": "custom_resnet",
            "TRUNK_PARAMS": {
                "VisslResNet": {
                    "ARCH": "resnet18"
                }
            }
        },
        "HEAD": {
            "BATCHNORM_EPS": 1e-5,
            "BATCHNORM_MOMENTUM": 0.1,
            "PARAMS_MULTIPLIER": 1.0,
        }
    }
)

In [5]:
trunk = VisslResNet(model_config, "resnet18")
load_status = trunk.load_state_dict(ckpt['trunk'])
print(load_status)

<All keys matched successfully>


In [6]:
head = LinearEvalMLP(
    model_config=model_config, 
    in_channels=512, 
    dims=[512, 1000], 
    use_bn=False, 
    use_relu=False
)

# remove the leading "0." in the checkpoint state dict:
def remove_prefix(key, prefix: str = "0."):
    if key.startswith(prefix):
        return key[len(prefix):]
    return key 

modified_head_params = {remove_prefix(k, prefix="0."): v for k, v in ckpt['heads'].items()}
load_status = head.load_state_dict(modified_head_params)
print(load_status)

<All keys matched successfully>


# Create a combined model by fusing trunk and head

In [7]:
class CombinedModel(nn.Module):
    def __init__(self, trunk: nn.Module, head: nn.Module):
        super(CombinedModel, self).__init__()
        self.trunk = trunk
        self.head = head

    def forward(self, x, **trunk_kwargs):
        x = self.trunk(x, **trunk_kwargs)
        if isinstance(x, (tuple, list)):
            x = x[0]
        x = self.head(x)
        return x

In [8]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [9]:
trunk = trunk.eval().requires_grad_(False).to(DEVICE)
head = head.eval().requires_grad_(False).to(DEVICE)

In [10]:
combined_model = CombinedModel(trunk=trunk, head=head)

## Get logits for some fake inputs

In [11]:
batch_size, in_channels, height, width = 5, 3, 224, 224
inputs = torch.rand(batch_size, in_channels, height, width).to(DEVICE)
outputs = combined_model(inputs)

print(f"{outputs.shape=}")

outputs.shape=torch.Size([5, 1000])


## Test on real imagenet images

In [12]:
top1 = imagenet_validation_performance(
    model=combined_model,
    output_layer="head",
    batch_size=64,
    n_batches=100
)

batch: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [01:03<00:00,  1.57it/s]


In [13]:
print(f"{top1=:.2f}")

top1=0.42
