In [10]:
import torch
import torch.nn as nn
from torchvision.models import vgg16, VGG16_Weights
import td_load_data
import td_run_model2

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [12]:
num_classes = 3  # Adjust this to match the number of classes (e.g., high, medium, low)
num_epochs = 23
batch_size = 16
learning_rate = 0.005

In [None]:
class VGG16MultiClassClassifier(nn.Module):
    def __init__(self, num_classes):
        super(VGG16MultiClassClassifier, self).__init__()
        self.base_model = vgg16(weights=VGG16_Weights.IMAGENET1K_V1)
        
        in_features = self.base_model.classifier[6].in_features
        
        self.base_model.classifier[6] = nn.Sequential(
            nn.Linear(in_features, num_classes)
        )

    def forward(self, x):
        x = self.base_model(x)
        return x

In [None]:
if __name__ == "__main__":
    # Hyperparameters

    # Model initialization
    model = VGG16MultiClassClassifier(num_classes=num_classes).to(device)
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

    # Load data
    train_loader, validation_loader, test_loader, classes = td_load_data.create_data(batch_size=batch_size)
    
    # Test the model's forward pass with a sample input
    sample_input = torch.randn(1, 3, 224, 224).to(device)
    print("Output shape:", model(sample_input).shape)  # This should now be [1, num_classes]
    
    # Train and validate the model
    td_run_model2.train(num_epochs, device, model, criterion, optimizer, train_loader, validation_loader, num_classes)
    
    
    # Test the model
    td_run_model2.test(device, model, test_loader, num_classes)


In [None]:
# Define the model architecture with the same number of classes used during training
model = VGG16MultiClassClassifier(num_classes=num_classes).to(device)

# Load the state dict
# state_dict = torch.load('weights/43.pth')
# state_dict = torch.load('weights/42.pth')
state_dict = torch.load('weights/43.pth', map_location=torch.device('cpu'))     # Running locally

# Load the state dict into the model
model.load_state_dict(state_dict)

model.eval()  # Set the model to evaluation mode

# Load the test data using the function in td_load_data.py
_, _, test_loader, classes = td_load_data.create_data(batch_size=batch_size)

# Evaluate the model on the test set
td_run_model2.test_model_on_test_data(device, model, test_loader)

###### 