In [8]:
from transformers import ResNetForImageClassification
import torch
from torchvision.transforms import Compose, Normalize, ToTensor, Resize
from PIL import Image
import requests

In [2]:
# Load the ResNet-50 model
model = ResNetForImageClassification.from_pretrained('microsoft/resnet-50')
model.eval()  # Set the model to evaluation mode

Downloading config.json: 100%|██████████| 69.6k/69.6k [00:00<00:00, 80.7MB/s]
Downloading pytorch_model.bin: 100%|██████████| 103M/103M [00:02<00:00, 39.1MB/s] 


ResNetForImageClassification(
  (resnet): ResNetModel(
    (embedder): ResNetEmbeddings(
      (embedder): ResNetConvLayer(
        (convolution): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (normalization): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): ReLU()
      )
      (pooler): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    )
    (encoder): ResNetEncoder(
      (stages): ModuleList(
        (0): ResNetStage(
          (layers): Sequential(
            (0): ResNetBottleNeckLayer(
              (shortcut): ResNetShortCut(
                (convolution): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
                (normalization): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              )
              (layer): Sequential(
                (0): ResNetConvLayer(
                  (convolution): Conv2d(64

In [3]:
def process_image(image_url):
    # Load the image
    response = requests.get(image_url)
    image = Image.open(requests.get(image_url, stream=True).raw)

    # Define the image transformation
    transform = Compose([
        Resize((224, 224)),
        ToTensor(),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Apply the transformation to the image
    return transform(image)


In [11]:
# function to extract image features
def extract_image_features(image_urls):
    features_dict = {}
    
    for url in image_urls:
        processed_image = process_image(url)
        
        with torch.no_grad():
            features = model(processed_image.unsqueeze(0))
        
        features_dict[url] = features
    
    return features_dict


In [13]:
# example
image_urls = ['https://images.wsj.net/im-694446?width=700&height=467', 'https://images.wsj.net/im-694460?width=700&height=466', 'https://images.wsj.net/im-694449?width=700&height=466', 'https://images.wsj.net/im-694454?width=700&height=466', 'https://images.wsj.net/im-694457?width=700&height=466']
dict = extract_image_features(image_urls)