In [None]:

import torch
import torchvision.models as models
import urllib.request
import os
import json
import pickle


DATASET_PREFIX = os.environ.get('DATASET_PREFIX', '')
IMAGENET_LABELS_FILE = DATASET_PREFIX + "imagenet_class_index.json"
CIFAR100_LABELS_FILE = DATASET_PREFIX + "cifar100_labels.txt"
CIFAR10_LABELS_FILE = DATASET_PREFIX + "cifar10_labels.meta"
PASCAL_VOC_LABELS_FILE = DATASET_PREFIX + "pascal_voc_labels.txt"
PLACES365_LABELS_FILE = DATASET_PREFIX + "categories_places365.txt"
COCO_LABELS_FILE = DATASET_PREFIX + "coco_labels.txt"


def get_imagenet_labels():
    # Download the labels file from the internet
    
    if not os.path.exists(IMAGENET_LABELS_FILE):
        url = "https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json"
        urllib.request.urlretrieve(url, IMAGENET_LABELS_FILE)
    
    # Load the labels file
  
    with open(IMAGENET_LABELS_FILE) as f:
        class_idx = json.load(f)
    
    # Extract the labels
    labels = [class_idx[str(k)][1] for k in range(len(class_idx))]
    return labels

# Call the function to get the labels from ImageNet
imagenet_labels = get_imagenet_labels()

# Define the model factory function
def get_model(model_name):
    if model_name == 'resnet50':
        model = models.resnet50(pretrained=True)
    elif model_name == 'mobilenetv2':
        model = models.mobilenet_v2(pretrained=True)
    elif model_name == 'shufflenetv2':
        model = models.shufflenet_v2_x1_0(pretrained=True)
    else:
        raise ValueError('Invalid model name')
    # Use GPU if available
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval() # Set model to inference mode
    return model


# Define the classification function for multi-class classification
def classification(image,model):

    output = model(image)
    # Post-process the output
    post_processor(output)
    return output

# Define the post-processor function to convert class result to human read format
def post_processor(output):
    # Get the index of the predicted class
    _, index = torch.max(output, 1)
    # Convert the index to a human-readable label
    label = imagenet_labels[index[0]]
    return label









In [None]:
model = get_model("resnet50")
result = classification("/workspace/tests/pexels-pixabay-45201.jpg",model)
print(result)

In [None]:
model = get_model("mobilenetv2")
result = classification("/workspace/tests/pexels-pixabay-45201.jpg",model)
print(result)

In [None]:
model = get_model("shufflenetv2")
result = classification("/workspace/tests/pexels-pixabay-45201.jpg",model)
print(result)