In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import numpy as np

class MyClassifier():
    
    def __init__(self):
        self.class_labels = ['flower_1', 'flower_2', 'flower_3', 'flower_4', 'flower_5',
                             'weed_1', 'weed_2', 'weed_3', 'weed_4', 'weed_5']
        
        self.full_name = 'Naman Khosla'
        self.student_id = 'N11507721'
        
        # Initialize the model
        self.load_model()
        
    def load_model(self):
        ''' This function will initialise your model. 
            You will need to load the model architecture and load the saved weights file you wish to submit as your best model.
        '''
        self.model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
        num_ftrs = self.model.fc.in_features
        
        self.model.fc = torch.nn.Linear(num_ftrs, len(self.class_labels))
        self.model.load_state_dict(torch.load('Project1_ResNet_student_best.pth', map_location=torch.device('cpu')))
        self.model.eval()
        
        self.m = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') #this line checks if we have a GPU available
        self.model.to(self.m)
        
    def test_image(self, file_name):
        ''' This function will be given the file name of an image, and should return the predicted class label for that image. 
        '''
        
        # Define the image transformations
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        
        # Load and transform the image
        image = Image.open(file_name).convert('RGB')
        image = transform(image)
        image = image.unsqueeze(0)
        
        if torch.cuda.is_available():
            image = image.cuda()
        
        # Make prediction
        with torch.no_grad():
            output = self.model(image)
            _, predicted = torch.max(output, 1)
            predicted_cls = self.class_labels[predicted.item()]
        
        return predicted_cls
