## Helper Code to Load CelebA_resnet18.pth

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.datasets import CelebA
from torchvision import transforms
from torch.utils.data import DataLoader

### Run this code to load the model from the working directory

In [5]:
# define model architecture
class MultiLabelResNet(nn.Module):
    def __init__(self):
        super(MultiLabelResNet, self).__init__()
        self.model = models.resnet18(pretrained=False)
        self.model.fc = nn.Sequential(
            nn.Linear(self.model.fc.in_features, 40),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# instantiate
model = MultiLabelResNet()

# load WEIGHTS from saved model
model.load_state_dict(torch.load("CelebA_resnet18.pth", map_location=torch.device('cpu')))
model.eval()  # Set to evaluation mode 

print("model loaded successfully")



model loaded successfully


### The code below runs the test data on the model

**This is just to confirm it is loaded correctly and runs**

In [6]:
# define transformations
transform = transforms.Compose([
    transforms.Resize((128, 128)), 
    transforms.ToTensor(), 
])

# load the CelebA dataset
######### SET download=True if you have never downloaded the data this way before. Each time after that set download=False ##########
train_dataset = CelebA(root='data', split='train', target_type='attr', download=False, transform=transform)
val_dataset = CelebA(root='data', split='valid', target_type='attr', download=False, transform=transform)
test_dataset = CelebA(root='data', split='test', target_type='attr', download=False, transform=transform)

# data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [9]:
model.eval()

all_outputs = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images 
        labels = labels.float()

        outputs = model(images)
        outputs = outputs 

        all_outputs.append(outputs)
        all_labels.append(labels)

all_outputs = torch.cat(all_outputs).cpu()
all_labels = torch.cat(all_labels).cpu()

preds = (all_outputs >= 0.5).float()

mean_accuracy = (preds == all_labels).float().mean().item()
print(f"test set mean accuracy: {mean_accuracy:.4f}")

test set mean accuracy: 0.8988


**If the above output is 0.8988 the model should be ready to go**