In [3]:
import torch
import torch.nn as nn
from torchvision import models
from torchvision import transforms

class MyModel(nn.Module):
  def __init__(self):
    super().__init__()
    inception_model = models.inception_v3(pretrained=True)
    inception_model.fc = nn.Linear(2048, 64)
    self.feature_extractor = inception_model

  def forward(self, x):
    x = transforms.functional.resize(x,size=[224, 224])
    x = x/255.0
    x = transforms.functional.normalize(x, 
                                            mean=[0.485, 0.456, 0.406], 
                                            std=[0.229, 0.224, 0.225])
    return self.feature_extractor(x).logits

model = MyModel()
model.eval()
saved_model = torch.jit.script(model)
saved_model.save('../input/google-img-embed/inception_v3_saved_model.pt')

##### [.logits] work only with torchscripts

In [2]:
from PIL import Image
import torch
from torchvision import transforms

# Model loading.
model = torch.jit.load('../models_pt/v11_swin_base_patch4_window7_224_in22k.pt')
model.eval()
embedding_fn = model

# Load image and extract its embedding.
input_image = Image.open('../input/images/dogs/collie-beach-bokeh.jpg').convert("RGB")
convert_to_tensor = transforms.Compose([transforms.PILToTensor()])
input_tensor = convert_to_tensor(input_image)
input_batch = input_tensor.unsqueeze(0)
with torch.no_grad():
  embedding = torch.flatten(embedding_fn(input_batch)[0]).cpu().data.numpy()
  print(embedding)

[ 0.05671266  0.02757249 -0.1023538   0.024033   -0.17680858  0.3449868
 -0.4970118  -0.21332997 -0.25042892  0.08485723  0.315588    0.30897635
  0.25909576 -0.07186266  0.09054497 -0.1933078   0.29890564  0.18681553
  0.16552402 -0.34627232  0.13632624  0.01161262 -0.07296698  0.11399545
 -0.08872433 -0.25143853  0.0376635  -0.16649503 -0.06944361  0.13528076
  0.29189467  0.04448745 -0.22423002 -0.30426562 -0.1743249  -0.1865472
 -0.15636072 -0.13407327 -0.29968047 -0.24464166  0.290393   -0.00605993
 -0.03487606  0.2089283   0.12893851  0.10404649  0.17819619 -0.06499936
  0.01899411 -0.09067152 -0.08642957  0.04255173  0.12876575 -0.07403735
  0.08794733  0.15112586 -0.09704171  0.06054981  0.1371095   0.179655
  0.02621843  0.09447531  0.1027277  -0.15676223]
