# TrashNet Inference

## Download image

In [1]:
!wget https://live.staticflickr.com/1/898231_ca259fd6e0_b.jpg

'wget' is not recognized as an internal or external command,
operable program or batch file.


In [1]:
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import os
import argparse
import matplotlib.pyplot as plt

def get_model(model_name, num_classes):
    """Function to return the specified model."""
    model = getattr(models, model_name)(pretrained=True)

    if hasattr(model, 'classifier') and isinstance(model.classifier, nn.Sequential):
        num_ftrs = model.classifier[-1].in_features
        model.classifier[-1] = nn.Linear(num_ftrs, num_classes)
    elif hasattr(model, 'fc'):
        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, num_classes)
    elif hasattr(model, 'classifier') and isinstance(model.classifier, nn.Conv2d):
        model.classifier = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
    else:
        raise ValueError(f"Model {model_name} is not supported or needs custom handling.")

    return model


def load_image(image_path, transform=None):
    """Load an image and apply transformations."""
    image = Image.open(image_path).convert('RGB')
    if transform is not None:
        image = transform(image)
    return image


def predict(image_path, model, transform, class_names, device):
    """Run inference on an image and return the predicted class."""
    model.eval()
    image = load_image(image_path, transform)
    image = image.unsqueeze(0)  # Add batch dimension
    image = image.to(device)  # Move the input to the same device as the model

    with torch.no_grad():
        outputs = model(image)
        probabilities = torch.softmax(outputs, dim=1).squeeze()
        _, predicted = torch.max(outputs, 1)
        predicted_class = class_names[predicted.item()]
        # print(probabilities.shape, probabilities[predicted.item()], predicted.item())
    return predicted_class, probabilities[predicted.item()]


def get_class_names(data_path):
    """Get class names from the subdirectories of the given path."""
    class_names = [d.name for d in os.scandir(data_path) if d.is_dir()]
    class_names.sort()
    return class_names


def main_predict(data_path, model, image_path, model_path):
    # Define transformations for the input image
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Get class names from the training data path
    class_names = get_class_names(data_path)
    # print(class_names)
    # Load the specified model
    model = get_model(model, len(class_names))
    model.load_state_dict(torch.load(model_path, map_location=device))

    # Use GPU if available

    model = model.to(device)

    # Perform inference
    predicted_class, predicted_probability = predict(image_path, model, transform, class_names, device)
    print(f'Predicted class: {predicted_class}')
    print(f'Prediction probability: {predicted_probability*100:0.2f}%')

    img = Image.open(image_path).convert('RGB')
    plt.imshow(img)
    plt.title(f'Predicted: {predicted_class}\nProbability: {predicted_probability*100:.2f}%')
    plt.axis('off')
    plt.show()

ModuleNotFoundError: No module named 'torch'

In [2]:
# test dataset location
data_path='dataset/test_resized'
# model name eg. mobilenet_v2, mobilenet_v3_large, mobilenet_v2_small, resnet152
model='resnet152'
# Test image localtion
image_path='/content/898231_ca259fd6e0_b.jpg'
# model weight localtion. Note: change the model weights too according to model name
model_path='log/resnet152_data_aug_imagenet_pretrained/model-best.pth'
main_predict(data_path, model, image_path, model_path)

NameError: name 'main_predict' is not defined