In [1]:
import cv2
from matplotlib import pyplot as plt
import torch
from torchvision import transforms
from torch import nn
import numpy as np
from torchvision.models import resnet152, ResNet152_Weights, resnet50, ResNet50_Weights, resnet18, ResNet18_Weights, resnet34, ResNet34_Weights, resnet101, ResNet101_Weights

In [2]:
def preprocess(image):
    if type(image) is str:
        return transforms.Compose((
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
            )))(cv2.resize(cv2.cvtColor(cv2.imread(image), cv2.COLOR_BGR2RGB), (224, 224)) / 255)
    elif type(image) is np.ndarray:
        return transforms.Compose((
            transforms.ToTensor(),
            transforms.Resize((224, 224)),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
            )))(image)
    elif type(image) is torch.Tensor:
        return transforms.Compose((
            transforms.Resize((224, 224)),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
            )))(image)

In [7]:
class ResNet18(nn.Module):
    def __init__(self):
        super().__init__()

        self.backbone = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        self.backbone.fc = nn.Identity()

    def forward(self, x):
        return self.backbone(x)

class ResNet34(nn.Module):
    def __init__(self):
        super().__init__()

        self.backbone = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
        self.backbone.fc = nn.Identity()

    def forward(self, x):
        return self.backbone(x)

class ResNet50(nn.Module):
    def __init__(self):
        super().__init__()

        self.backbone = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        self.backbone.fc = nn.Identity()

    def forward(self, x):
        return self.backbone(x)

class ResNet101(nn.Module):
    def __init__(self):
        super().__init__()

        self.backbone = resnet101(weights=ResNet101_Weights.IMAGENET1K_V2)
        self.backbone.fc = nn.Identity()

    def forward(self, x):
        return self.backbone(x)

class ResNet152(nn.Module):
    def __init__(self):
        super().__init__()

        self.backbone = resnet152(weights=ResNet152_Weights.IMAGENET1K_V2)
        self.backbone.fc = nn.Identity()

    def forward(self, x):
        return self.backbone(x)

In [15]:
print("ResNet18:", ResNet18()(preprocess('/content/unknown.png').float()[None]).shape)
print("ResNet34:", ResNet34()(preprocess('/content/unknown.png').float()[None]).shape)
print("ResNet50:", ResNet50()(preprocess('/content/unknown.png').float()[None]).shape)
print("ResNet101:", ResNet101()(preprocess('/content/unknown.png').float()[None]).shape)
print("ResNet152:", ResNet152()(preprocess('/content/unknown.png').float()[None]).shape)

ResNet18: torch.Size([1, 512])
ResNet34: torch.Size([1, 512])
ResNet50: torch.Size([1, 2048])
ResNet101: torch.Size([1, 2048])
ResNet152: torch.Size([1, 2048])
