In [80]:
import torch
import torch.nn as nn
from torchvision import models
from torchvision import transforms

class MyModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.base_model = models.resnet50(pretrained=True)
    self.feature_extractor = torch.nn.Sequential(
        *list(self.base_model.children())[:-1],
        nn.Flatten(),
        nn.AdaptiveAvgPool1d(64)
    )

  def forward(self, x):
    x = transforms.functional.resize(x,size=[256, 256])
    x = transforms.functional.center_crop(x, 224)

    x = x/255.0
    x = transforms.functional.normalize(x, 
                                            mean=[0.485, 0.456, 0.406], 
                                            std=[0.229, 0.224, 0.225])
    return self.feature_extractor(x)

model = MyModel()
model.eval()
pass

In [81]:
# model

In [82]:
print(model.base_model(torch.rand(2, 3, 500,500)).shape)
print(model(torch.rand(2, 3, 500,500)).shape)

torch.Size([2, 1000])
torch.Size([2, 64])


In [83]:
saved_model = torch.jit.script(model)
saved_model.save('../input/google-img-embed/resnet50_saved_model.pt')