In [1]:
import torch
import torchvision.transforms as transforms
from PIL import Image
import pickle

from model_training_ck.models import MobileNetV3

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class InferencePipeline:
    def __init__(self, model_path, label_encoder_path, image_size=(224, 224)):
        self.model = self.load_model(model_path)
        self.model.eval()
        self.label_encoder = self.load_label_encoder(label_encoder_path)

        self.transforms = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    def load_model(self, model_path):
        model = MobileNetV3(num_labels=7)
        model.load_state_dict(torch.load(model_path))
        return model

    def load_label_encoder(self, label_encoder_path):
        with open(label_encoder_path, 'rb') as f:
            label_encoder = pickle.load(f)
        return label_encoder

    def load_image(self, image_path):
        image = Image.open(image_path)
        image = image.convert('RGB')
        return image

    def preprocess_image(self, image_path):
        image = self.load_image(image_path)
        if self.transforms:
            image = self.transforms(image)
        return image.unsqueeze(0)

    def predict(self, image_path):
        image = self.preprocess_image(image_path)

        with torch.no_grad():
            output = self.model(image)
            _, predicted_idx = torch.max(output, 1)

        predicted_label = self.label_encoder.inverse_transform([predicted_idx.item()])
        return predicted_label[0]

In [3]:
best_model_path = "model_training_ck/models/distilled_lottery_ticket_10_590669.pt"
label_encoder_path = "model_training_ck/models/label_encoder.pkl"

In [4]:
pipeline = InferencePipeline(best_model_path,
                            label_encoder_path)

  model.load_state_dict(torch.load(model_path))


In [5]:
pipeline.predict("/home/diogoalves/thesis/ck_data/contempt.png")

'contempt'