In [7]:
import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
import td_load_data 
import td_run_model2

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
num_classes = 3  # Adjust this to match the number of classes (e.g., high, medium, low)
num_epochs = 20
batch_size = 16
learning_rate = 0.05

In [None]:
class Resnet50MultiClassClassifier(nn.Module):
    def __init__(self, num_classes):
        super(Resnet50MultiClassClassifier, self).__init__()
        self.base_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
        
        in_features = self.base_model.fc.in_features
        self.base_model.fc = nn.Sequential(
            nn.Linear(in_features, num_classes),
            nn.Softmax(dim=1)
        )
    
    
    def forward(self, x):
        x = self.base_model(x)
        return x

In [9]:

if __name__ == "__main__":
    # Model initialization
    model = Resnet50MultiClassClassifier(num_classes=num_classes).to(device)
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()  # CrossEntropyLoss for multi-class classification
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
    

    # Load data
    train_loader, validation_loader, test_loader, classes = td_load_data.create_data(batch_size=batch_size)
    
    # Train and validate the model
    td_run_model2.train(num_epochs, device, model, criterion, optimizer, train_loader, validation_loader, num_classes)
    
    # Test the model
    td_run_model2.test(device, model, test_loader, num_classes)

Epoch [1/20], Loss: 0.8744
Validation Accuracy: 83.41% Macro-averaged Precision: 0.8313, Macro-averaged Recall: 0.8154
ROC AUC scores for each class: {'Class 0': 0.9676, 'Class 1': 0.9715, 'Class 2': 0.9039}
Epoch [2/20], Loss: 0.7297
Validation Accuracy: 86.63% Macro-averaged Precision: 0.8625, Macro-averaged Recall: 0.8539
ROC AUC scores for each class: {'Class 0': 0.9762, 'Class 1': 0.9818, 'Class 2': 0.937}
Epoch [3/20], Loss: 0.6877
Validation Accuracy: 88.00% Macro-averaged Precision: 0.8764, Macro-averaged Recall: 0.8673
ROC AUC scores for each class: {'Class 0': 0.9808, 'Class 1': 0.9846, 'Class 2': 0.9473}
Epoch [4/20], Loss: 0.6618
Validation Accuracy: 89.45% Macro-averaged Precision: 0.8891, Macro-averaged Recall: 0.8861
ROC AUC scores for each class: {'Class 0': 0.9831, 'Class 1': 0.9861, 'Class 2': 0.9544}
Epoch [5/20], Loss: 0.6440
Validation Accuracy: 89.69% Macro-averaged Precision: 0.8931, Macro-averaged Recall: 0.8860
ROC AUC scores for each class: {'Class 0': 0.984, 

In [10]:
# Define the model architecture with the same number of classes used during training
model = Resnet50MultiClassClassifier(num_classes=num_classes).to(device)

# Load the state dict
# state_dict = torch.load('weights/38.pth')
state_dict = torch.load('weights/38.pth', map_location=torch.device('cpu'))     # Running locally

# Load the state dict into the model
model.load_state_dict(state_dict)

model.eval()  # Set the model to evaluation mode

# Load the test data using the function in td_load_data.py
_, _, test_loader, classes = td_load_data.create_data(batch_size=batch_size)

# Evaluate the model on the test set
td_run_model2.test_model_on_test_data(device, model, test_loader)

Test Accuracy: 87.97%

Class: high | Accuracy: 91.08% | Precision: 0.9060 | Recall: 0.9108 | ROC AUC: 0.9714
Class: low | Accuracy: 93.82% | Precision: 0.9033 | Recall: 0.9382 | ROC AUC: 0.9851
Class: medium | Accuracy: 75.86% | Precision: 0.8098 | Recall: 0.7586 | ROC AUC: 0.9429

Macro-averaged Precision: 0.8731, Macro-averaged Recall: 0.8692
