In [3]:
import torch
import torch.nn as nn
from torchvision import models
from torchvision import transforms

class MyModel(nn.Module):
  def __init__(self):
    super().__init__()
    inception_model = models.inception_v3(pretrained=True)
    inception_model.fc = nn.Linear(2048, 64)
    self.feature_extractor = inception_model

  def forward(self, x):
    x = transforms.functional.resize(x,size=[224, 224])
    x = x/255.0
    x = transforms.functional.normalize(x, 
                                            mean=[0.485, 0.456, 0.406], 
                                            std=[0.229, 0.224, 0.225])
    return self.feature_extractor(x).logits

model = MyModel()
model.eval()
saved_model = torch.jit.script(model)
saved_model.save('../input/google-img-embed/inception_v3_saved_model.pt')

##### [.logits] work only with torchscripts

In [1]:
from PIL import Image
import torch
from torchvision import transforms

# Model loading.
model = torch.jit.load('../models_pt/v8_efficientnet_b0.pt')
model.eval()
embedding_fn = model

# Load image and extract its embedding.
input_image = Image.open('../input/images/dogs/collie-beach-bokeh.jpg').convert("RGB")
convert_to_tensor = transforms.Compose([transforms.PILToTensor()])
input_tensor = convert_to_tensor(input_image)
input_batch = input_tensor.unsqueeze(0)
with torch.no_grad():
  embedding = torch.flatten(embedding_fn(input_batch)[0]).cpu().data.numpy()
  print(embedding)

[ 0.15418938  0.13132873  0.10568013 -0.02319951  0.0513504   0.28369874
  0.09418299  0.04545451  0.1120355   0.01482786  0.10111406  0.13878836
  0.05167865  0.0767236   0.12491049  0.13515668  0.0036693   0.09988344
  0.02310264  0.01448677  0.19196111  0.02595637  0.31589824  0.13172874
  0.07139672 -0.00108883  0.05925114  0.12693705  0.34643418  0.05538007
  0.03742011  0.03647575  0.05398829  0.07095765  0.12663479  0.01357992
  0.03738907  0.01871396  0.12489943  0.12197356  0.13598627  0.02653401
  0.08901893  0.05628708  0.16498378  0.0433558   0.08133106  0.09273782
  0.1087992   0.088459   -0.00665544  0.21822235 -0.01935812  0.17107223
  0.02439914  0.05979742  0.05605851  0.01209688  0.01066151  0.0687295
  0.19390936  0.02068388 -0.02530537  0.0815108 ]
