In [3]:
import torch
import torch.nn as nn
from torchvision import models
from torchvision import transforms

class MyModel(nn.Module):
  def __init__(self):
    super().__init__()
    inception_model = models.inception_v3(pretrained=True)
    inception_model.fc = nn.Linear(2048, 64)
    self.feature_extractor = inception_model

  def forward(self, x):
    x = transforms.functional.resize(x,size=[224, 224])
    x = x/255.0
    x = transforms.functional.normalize(x, 
                                            mean=[0.485, 0.456, 0.406], 
                                            std=[0.229, 0.224, 0.225])
    return self.feature_extractor(x).logits

model = MyModel()
model.eval()
saved_model = torch.jit.script(model)
saved_model.save('../input/google-img-embed/inception_v3_saved_model.pt')

##### [.logits] work only with torchscripts

In [1]:
from PIL import Image
import torch
from torchvision import transforms

# Model loading.
model = torch.jit.load('../input/google-img-embed/v4_resnet50_saved_model.pt')
model.eval()
embedding_fn = model

# Load image and extract its embedding.
input_image = Image.open('../input/images/dogs/collie-beach-bokeh.jpg').convert("RGB")
convert_to_tensor = transforms.Compose([transforms.PILToTensor()])
input_tensor = convert_to_tensor(input_image)
input_batch = input_tensor.unsqueeze(0)
with torch.no_grad():
  embedding = torch.flatten(embedding_fn(input_batch)[0]).cpu().data.numpy()
  print(embedding)

[0.49236295 0.43916446 0.4789883  0.38781932 0.4920384  0.36499846
 0.38398114 0.401796   0.5004146  0.5484579  0.3844925  0.50782156
 0.627752   0.39378378 0.4953413  0.37244242 0.49768174 0.5064805
 0.5457502  0.47064242 0.44930282 0.50932884 0.57070726 0.5515556
 0.40232548 0.44075444 0.43162206 0.5554637  0.37993324 0.64674383
 0.46709436 0.49590132 0.4667424  0.43552575 0.47120625 0.58384526
 0.3633921  0.41693127 0.5534935  0.36043447 0.43745026 0.52090955
 0.36178485 0.41308272 0.49650058 0.41861808 0.60746115 0.44994262
 0.54319316 0.3947586  0.61874783 0.4871928  0.5517022  0.5635879
 0.4268709  0.44035915 0.49573123 0.3042859  0.45980042 0.33129784
 0.44358158 0.34073335 0.50487447 0.40105692]
