In [3]:
import torch
import torch.nn as nn
from torchvision import models
from torchvision import transforms

class MyModel(nn.Module):
  def __init__(self):
    super().__init__()
    inception_model = models.inception_v3(pretrained=True)
    inception_model.fc = nn.Linear(2048, 64)
    self.feature_extractor = inception_model

  def forward(self, x):
    x = transforms.functional.resize(x,size=[224, 224])
    x = x/255.0
    x = transforms.functional.normalize(x, 
                                            mean=[0.485, 0.456, 0.406], 
                                            std=[0.229, 0.224, 0.225])
    return self.feature_extractor(x).logits

model = MyModel()
model.eval()
saved_model = torch.jit.script(model)
saved_model.save('../input/google-img-embed/inception_v3_saved_model.pt')

##### [.logits] work only with torchscripts

In [2]:
from PIL import Image
import torch
from torchvision import transforms

# Model loading.
model = torch.jit.load('../models_pt/v7_swinv2_base_window16_256.pt')
model.eval()
embedding_fn = model

# Load image and extract its embedding.
input_image = Image.open('../input/images/dogs/collie-beach-bokeh.jpg').convert("RGB")
convert_to_tensor = transforms.Compose([transforms.PILToTensor()])
input_tensor = convert_to_tensor(input_image)
input_batch = input_tensor.unsqueeze(0)
with torch.no_grad():
  embedding = torch.flatten(embedding_fn(input_batch)[0]).cpu().data.numpy()
  print(embedding)

[ 0.04315272 -0.0235291   0.0312604  -0.00041633  0.02815877  0.12055283
 -0.03224912 -0.01205595 -0.00194191  0.04195309  0.00266817 -0.08809998
 -0.10464027  0.02160696  0.07408638  0.04939732  0.03640577  0.00602856
  0.01544393  0.04408798  0.19019389 -0.09772927 -0.0610893   0.00477735
 -0.10000484 -0.07538507  0.14449085 -0.04995918 -0.04938088  0.06414583
 -0.06592316 -0.07091008  0.10838479 -0.00762417 -0.09494048  0.10305584
  0.01002913 -0.0395183   0.12190846  0.01152168  0.00955782  0.01028253
  0.01631404 -0.00458292 -0.10308661  0.00719626 -0.00534671  0.07603404
 -0.02168461  0.02141615 -0.00448317  0.00851323 -0.02167099  0.06061072
  0.13131109  0.17078684  0.04667639  0.06882825 -0.05217126  0.01565712
  0.02451139 -0.00508972  0.02675567 -0.08019949]
