In [1]:
import torch
import torch.nn as nn
from torchvision.models import vgg16, VGG16_Weights
import td_load_data
import td_run_model2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
num_classes = 3  # Adjust this to match the number of classes (e.g., high, medium, low)
num_epochs = 23
batch_size = 16
learning_rate = 0.005

In [4]:
class VGG16MultiClassClassifier(nn.Module):
    def __init__(self, num_classes):
        super(VGG16MultiClassClassifier, self).__init__()
        self.base_model = vgg16(weights=VGG16_Weights.IMAGENET1K_V1)
        
        in_features = self.base_model.classifier[6].in_features
        
        self.base_model.classifier[6] = nn.Sequential(
            nn.Linear(in_features, num_classes)
        )

    def forward(self, x):
        x = self.base_model(x)
        return x

In [5]:
# Define the model architecture with the same number of classes used during training
model = VGG16MultiClassClassifier(num_classes=num_classes).to(device)

# Load the state dict
# state_dict = torch.load('weights/43.pth')
# state_dict = torch.load('weights/42.pth')
state_dict = torch.load('../weights/43.pth', map_location=torch.device('cpu'))     # Running locally

# Load the state dict into the model
model.load_state_dict(state_dict)

model.eval()  # Set the model to evaluation mode

# Load the test data using the function in td_load_data.py
# _, _, test_loader, classes = td_load_data.create_data(batch_size=batch_size)
test_loader, classes = td_load_data.create_data(batch_size=batch_size)


# Evaluate the model on the test set
td_run_model2.test_model_on_test_data(device, model, test_loader)

Filename: 1711_19-06-2024_08-05-01.jpg, True class: high, Predicted class: high
Filename: 2701_22-05-2024_11-45-02.jpg, True class: high, Predicted class: high
Filename: 2702_04-08-2024_06-15-01.jpg, True class: high, Predicted class: high
Filename: 3704_05-08-2024_19-15-02.jpg, True class: high, Predicted class: high
Filename: 4701_20-06-2024_18-25-01.jpg, True class: high, Predicted class: high
Filename: 4703_18-06-2024_19-55-01.jpg, True class: high, Predicted class: high
Filename: 4708_23-05-2024_09-55-01.jpg, True class: high, Predicted class: high
Filename: 6703_06-08-2024_19-20-02.jpg, True class: high, Predicted class: high
Filename: 7793_20-06-2024_08-35-02.jpg, True class: high, Predicted class: high
Filename: 9702_08-08-2024_09-20-02.jpg, True class: high, Predicted class: high
Filename: 1006_15-07-2024_20-50-02.jpg, True class: low, Predicted class: low
Filename: 1705_14-07-2024_02-30-01.jpg, True class: low, Predicted class: low
Filename: 1709_06-06-2024_05-00-02.jpg, True

###### 