In [None]:
# How to load a saved model and do predictions

In [None]:
### Loading a model

In [None]:
import torch
from src.models import MultiLabelResnet, MultiLabelCNN
from src import data_handling
from src.dataset import CustomImageDataset

In [None]:
if torch.cuda.is_available():
    print("Found cuda device")
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [None]:
#Which model to use
#model_name = "MultiLabelResnet"
model_name = "MultiLabelResnet"

#Load model class
if model_name=="MultiLabelResnet":
    model = MultiLabelResnet().to(device)
elif model_name=="MultiLabelCNN":
    model = MultiLabelCNN().to(device)

In [None]:
#Load model state to model
model_state_dict = torch.load(f"saved_models/{model_name}/model_state.pt")
model.load_state_dict(model_state_dict)

In [None]:
### Predictions

#Loading data
DATA_DIR = 'data/images'
train, test = data_handling.get_target_dfs(train=0.8, test=0.2)

# IF NEW DATA:
# If no ground truth, labels can be just set to 0 if you want to use the same dataset module

test_loader = torch.utils.data.DataLoader(dataset=CustomImageDataset(test, DATA_DIR, transform=None), batch_size=50, shuffle=False)

In [None]:
with torch.no_grad():
    model.eval() #Disables dropout layer
    test_accuracy = 0
    for inputs, labels in test_loader:
        outputs = model(inputs)
        predicted_labels = (outputs > 0.5).int()
        test_accuracy += (predicted_labels == labels).float().mean().item()

#Note: this "test" set is regenerated and might contain data from training set, so accuracy might differ
test_accuracy = test_accuracy / len(test_loader)
print(f"Test Accuracy: {test_accuracy*100:.2f}%")