In [3]:
import torch
import torch.nn as nn
from torchvision import models
from torchvision import transforms

class MyModel(nn.Module):
  def __init__(self):
    super().__init__()
    inception_model = models.inception_v3(pretrained=True)
    inception_model.fc = nn.Linear(2048, 64)
    self.feature_extractor = inception_model

  def forward(self, x):
    x = transforms.functional.resize(x,size=[224, 224])
    x = x/255.0
    x = transforms.functional.normalize(x, 
                                            mean=[0.485, 0.456, 0.406], 
                                            std=[0.229, 0.224, 0.225])
    return self.feature_extractor(x).logits

model = MyModel()
model.eval()
saved_model = torch.jit.script(model)
saved_model.save('../input/google-img-embed/inception_v3_saved_model.pt')

##### [.logits] work only with torchscripts

In [1]:
from PIL import Image
import torch
from torchvision import transforms

# Model loading.
model = torch.jit.load('../input/google-img-embed/v6_efficientnet_b7.pt')
model.eval()
embedding_fn = model

# Load image and extract its embedding.
input_image = Image.open('../input/images/dogs/collie-beach-bokeh.jpg').convert("RGB")
convert_to_tensor = transforms.Compose([transforms.PILToTensor()])
input_tensor = convert_to_tensor(input_image)
input_batch = input_tensor.unsqueeze(0)
with torch.no_grad():
  embedding = torch.flatten(embedding_fn(input_batch)[0]).cpu().data.numpy()
  print(embedding)

[-0.00422927  0.07208674  0.0107396   0.02217594  0.00476109  0.01163914
 -0.00700159  0.11872683  0.01967308 -0.01705721  0.03650533  0.01849472
 -0.02746117 -0.02818933 -0.00114972 -0.07334472  0.08396803 -0.03251896
 -0.01345569 -0.04495557  0.01206339 -0.01713189  0.07153784  0.09358376
  0.00809748  0.01467024 -0.05876682  0.04816409  0.0459237   0.12587154
 -0.0076489   0.03158112 -0.0370841   0.05056956  0.01578933  0.06055729
 -0.04742613  0.05835805  0.05572148  0.02805956  0.11293344  0.05028643
  0.04443515 -0.0168138   0.07156266 -0.00231754  0.00723598 -0.00807834
  0.01296688  0.07115047 -0.02048157 -0.02985774  0.03019468  0.01321307
 -0.0046351   0.01219096  0.0584572   0.04004542  0.03315777  0.02869853
  0.03391372 -0.01115438  0.03653382  0.01278363]
