In [7]:
import torch
from torchvision.models.resnet import Bottleneck, ResNet

from torchvision import models


In [8]:
class ResNetTrunk(ResNet):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        del self.fc  # remove FC layer (...)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x

def get_pretrained_url(key):
    URL_PREFIX = "https://github.com/lunit-io/benchmark-ssl-pathology/releases/download/pretrained-weights"
    model_zoo_registry = {
        "BT": "bt_rn50_ep200.torch",
        "MoCoV2": "mocov2_rn50_ep200.torch",
        "SwAV": "swav_rn50_ep200.torch",
    }
    pretrained_url = f"{URL_PREFIX}/{model_zoo_registry.get(key)}"
    return pretrained_url

def resnet50(pretrained, progress, key, **kwargs):
    model = ResNetTrunk(Bottleneck, [3, 4, 6, 3], **kwargs)
    if pretrained:
        pretrained_url = get_pretrained_url(key)
        verbose = model.load_state_dict(
            torch.hub.load_state_dict_from_url(pretrained_url, progress=progress)
        )
        print(verbose)
    return model


In [9]:
model_tcga = resnet50(pretrained=True, progress=False, key="SwAV")

<All keys matched successfully>


In [17]:
dummy_input = torch.randn(1, 3, 256, 256) 
out = model_tcga(dummy_input)
out.shape #

torch.Size([1, 2048, 16, 16])

In [14]:

model = models.resnet50(pretrained=False)
model = torch.nn.Sequential(*list(model.children())[:-2]) #this matches! 

output = model(dummy_input)
output.shape

torch.Size([1, 2048, 8, 8])

In [16]:
print(8*8*2048)
print(3*256*256)

131072
196608
