In [1]:
import torch
from torchvision import models

class ResNet(torch.nn.Module):
    def __init__(self, net_name, pretrained=False, use_fc=False):
        super().__init__()
        base_model = models.__dict__[net_name](pretrained=pretrained)
        self.encoder = torch.nn.Sequential(*list(base_model.children())[:-1])

        self.use_fc = use_fc
        if self.use_fc:
            self.fc = torch.nn.Linear(2048, 512)

    def forward(self, x):
        x = self.encoder(x)
        x = torch.flatten(x, 1)
        if self.use_fc:
            x = self.fc(x)
        return x

device = torch.device('cpu')
model = ResNet('resnet50', pretrained=False, use_fc=False).to(device)

# load encoder
model_path = 'resnet50_byol_imagenet2012.pth.tar'
checkpoint = torch.load(model_path, map_location=device)['online_backbone']
state_dict = {}
length = len(model.encoder.state_dict())
for name, param in zip(model.encoder.state_dict(), list(checkpoint.values())[:length]):
    state_dict[name] = param
model.encoder.load_state_dict(state_dict, strict=True)
model.eval()

example = torch.ones(1, 3, 224, 224)

# convert to torch.jit.ScriptModule via tracing
traced_script_module = torch.jit.trace(model, example)
for p in traced_script_module.parameters():
    p.requires_grad = False

print(traced_script_module)
traced_script_module.save('resnet50_byol_imagenet2012.pt')

assert (model(example) == traced_script_module(example)).all()



ResNet(
  original_name=ResNet
  (encoder): Sequential(
    original_name=Sequential
    (0): Conv2d(original_name=Conv2d)
    (1): BatchNorm2d(original_name=BatchNorm2d)
    (2): ReLU(original_name=ReLU)
    (3): MaxPool2d(original_name=MaxPool2d)
    (4): Sequential(
      original_name=Sequential
      (0): Bottleneck(
        original_name=Bottleneck
        (conv1): Conv2d(original_name=Conv2d)
        (bn1): BatchNorm2d(original_name=BatchNorm2d)
        (conv2): Conv2d(original_name=Conv2d)
        (bn2): BatchNorm2d(original_name=BatchNorm2d)
        (conv3): Conv2d(original_name=Conv2d)
        (bn3): BatchNorm2d(original_name=BatchNorm2d)
        (relu): ReLU(original_name=ReLU)
        (downsample): Sequential(
          original_name=Sequential
          (0): Conv2d(original_name=Conv2d)
          (1): BatchNorm2d(original_name=BatchNorm2d)
        )
      )
      (1): Bottleneck(
        original_name=Bottleneck
        (conv1): Conv2d(original_name=Conv2d)
        (bn1): 

In [4]:

class BYOL(torch.nn.Module):
    def __init__(self, device = 'cuda'):
        super().__init__()
        model = ResNet('resnet50', pretrained=False, use_fc=True)
        model_path = 'models/resnet50_byol_imagenet2012.pth.tar'
        checkpoint = torch.load(model_path, map_location=device)['online_backbone']
        state_dict = {}
        length = len(model.encoder.state_dict())
        for name, param in zip(model.encoder.state_dict(), list(checkpoint.values())[:length]):
            state_dict[name] = param
        model.encoder.load_state_dict(state_dict, strict=True)
        model.eval()

        example = torch.ones(1, 3, 224, 224)

        # convert to torch.jit.ScriptModule via tracing
        traced_script_module = torch.jit.trace(model, example)
        for p in traced_script_module.parameters():
            p.requires_grad = False
        self.model = model
        self.linear1 = torch.nn.Linear(2048, 312)
        self.classifier1 = torch.nn.Sequential(torch.nn.Linear(2048, 312), torch.nn.ReLU(), torch.nn.BatchNorm1d(312), torch.nn.Linear(312, 312))
    
    def forward(self, x):
        x = self.model.encoder(x)
        x = torch.flatten(x, 1)
        classifier_out = self.classifier1(x)
        linear_out = self.linear1(x)
        return (linear_out, classifier_out)
model = BYOL()



In [34]:
model

ResNet(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(64, 2

In [31]:
import os
os.listdir()

['__pycache__',
 'attributes.txt',
 'cub_dataset.py',
 'models',
 'test.ipynb',
 'training_scripts',
 'utils',
 'port_tf_pt.ipynb',
 'resnet50_byol_imagenet2012.pth.tar.gz']