In [4]:
import torch
from torchvision.models import resnet50, ResNet50_Weights

In [13]:
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
model.fc = torch.nn.Linear(model.fc.in_features, 1)
torch.nn.init.constant_(model.fc.weight, 1e-4)
torch.nn.init.constant_(model.fc.bias, 0)

Parameter containing:
tensor([0.], requires_grad=True)

In [14]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [38]:
test_inp = torch.randn(1, 3, 224, 224)

In [23]:
test_out = model(test_inp)

In [24]:
test_out.shape

torch.Size([1, 1])

In [45]:
from torchvision.models import vit_b_16, ViT_B_16_Weights

features = {}

def get_features(name):
    def hook(model, x, output):
        features[name] = output
    return hook

In [46]:
vit = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
vit.heads.head = torch.nn.Linear(vit.heads.head.in_features, 1)
torch.nn.init.constant_(vit.heads.head.weight, 1e-4)
torch.nn.init.constant_(vit.heads.head.weight, 0)
vit.encoder.register_forward_hook(get_features('encoder'))

<torch.utils.hooks.RemovableHandle at 0x7587ea853350>

In [47]:
print(features)

{}


In [48]:
test_out = vit(test_inp)

In [50]:
print(features['encoder'].shape)

torch.Size([1, 197, 768])
