## Demo
This file provides an example on how to load the model and use it to classify a sample image.

In [11]:
import torch
from torchvision import transforms, models
from PIL import Image

#### Define a global variable that maps class index to class name

In [12]:
dct = {
    0: 'cardboard',
    1: 'glass',
    2: 'metal',
    3: 'paper',
    4: 'plastic',
    5: 'trash'
}

#### Import model

ATTENTION: Do not forget to modify the last layer to the number of classes

In [None]:
# Load MobileNet model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.mobilenet_v3_large(pretrained=False)

# Modify the last layer to the number of classes
model.classifier[3] = torch.nn.Linear(in_features=1280, out_features=6)

model.load_state_dict(torch.load("classifier.pth", map_location=device))
model.eval()

  model.load_state_dict(torch.load("classifier.pth", map_location=device))


MobileNetV3(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
    (1): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
        )
      )
    )
    (2): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bi

#### Run inference
ATTENTION: Transform the image before passing it to the model!

In [9]:
# Define preprocessing transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
# Load and preprocess the image
# Modify the image path to see how it makes inferences on other images
image = Image.open("data/Garbage classification/Garbage classification/metal/metal1.jpg").convert("RGB")
input_tensor = transform(image).unsqueeze(0)

# Perform inference
with torch.no_grad():
    output = model(input_tensor)

# Get the predicted class
predicted_class = torch.argmax(output[0]).item()

Predicted Class: 2


In [15]:
# Visualize the example
from PIL import ImageDraw, ImageFont

# Create a drawing context
draw = ImageDraw.Draw(image)

# Define text and font
text = f"Predicted Class: {dct[predicted_class]}"
font = ImageFont.load_default()

# Specify position and color
text_position = (10, 10)
text_color = (0, 0, 0)

# Add text to the image
draw.text(text_position, text, fill=text_color, font=font)

# Display the image
image.show()
