In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.datasets import CelebA
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
import numpy as np

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [8]:
# Define transformations for the dataset
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize images to 128x128
    transforms.ToTensor(),          # Convert images to tensors
])

# Load the CelebA dataset
train_dataset = CelebA(root='data', split='train', target_type='attr', download=False, transform=transform)
val_dataset = CelebA(root='data', split='valid', target_type='attr', download=False, transform=transform)
test_dataset = CelebA(root='data', split='test', target_type='attr', download=False, transform=transform)

In [4]:
# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Define the model
class MultiLabelResNet(nn.Module):
    def __init__(self):
        super(MultiLabelResNet, self).__init__()
        self.model = models.resnet18(pretrained=True)
        self.model.fc = nn.Sequential(
            nn.Linear(self.model.fc.in_features, 40),
            nn.Sigmoid()  # Sigmoid activation for multi-label classification
        )

    def forward(self, x):
        return self.model(x)

model = MultiLabelResNet().to(device)

# Define loss function and optimizer
criterion = nn.BCELoss()  # Binary Cross Entropy Loss
optimizer = optim.Adam(model.parameters(), lr=0.001)



In [5]:
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    print(f"Epoch {epoch+1}\n-------------------------------")
    
    for batch_idx, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device).float()  # For BCELoss

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        if (batch_idx + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

    avg_train_loss = running_loss / len(train_loader)
    print(f'Epoch [{epoch+1}/{num_epochs}]-------- Average Training Loss: {avg_train_loss:.4f}')

    # Validation loop
    model.eval()
    val_loss = 0.0
    all_outputs = []
    all_labels = []

    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device).float()

            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            all_outputs.append(outputs.cpu())
            all_labels.append(labels.cpu())

    avg_val_loss = val_loss / len(val_loader)

    # Compute accuracy
    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)
    
    preds = (all_outputs >= 0.5).float()  # Threshold at 0.5
    correct = (preds == all_labels).float().mean().item()  # Mean accuracy over all attributes

    print(f'Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {correct:.4f}\n')

Epoch 1
-------------------------------
Epoch [1/10], Batch [100/2544], Loss: 0.2475
Epoch [1/10], Batch [200/2544], Loss: 0.2497
Epoch [1/10], Batch [300/2544], Loss: 0.2252
Epoch [1/10], Batch [400/2544], Loss: 0.2395
Epoch [1/10], Batch [500/2544], Loss: 0.2283
Epoch [1/10], Batch [600/2544], Loss: 0.2414
Epoch [1/10], Batch [700/2544], Loss: 0.2217
Epoch [1/10], Batch [800/2544], Loss: 0.2240
Epoch [1/10], Batch [900/2544], Loss: 0.2064
Epoch [1/10], Batch [1000/2544], Loss: 0.2075
Epoch [1/10], Batch [1100/2544], Loss: 0.1963
Epoch [1/10], Batch [1200/2544], Loss: 0.2152
Epoch [1/10], Batch [1300/2544], Loss: 0.2081
Epoch [1/10], Batch [1400/2544], Loss: 0.2310
Epoch [1/10], Batch [1500/2544], Loss: 0.2085
Epoch [1/10], Batch [1600/2544], Loss: 0.2213
Epoch [1/10], Batch [1700/2544], Loss: 0.2156
Epoch [1/10], Batch [1800/2544], Loss: 0.2195
Epoch [1/10], Batch [1900/2544], Loss: 0.2139
Epoch [1/10], Batch [2000/2544], Loss: 0.2033
Epoch [1/10], Batch [2100/2544], Loss: 0.2156
Epo

In [7]:
# Evaluation function
def evaluate(model, data_loader):
    model.eval()
    all_labels = []
    all_preds = []
    with torch.no_grad():
        for images, labels in data_loader:
            images = images.to(device)
            labels = labels.to(device).float()
            outputs = model(images)
            preds = (outputs > 0.5).float()  # Apply threshold
            all_labels.append(labels.cpu().numpy())
            all_preds.append(preds.cpu().numpy())

    all_labels = np.vstack(all_labels)
    all_preds = np.vstack(all_preds)
    accuracies = []
    for i in range(40):
        acc = accuracy_score(all_labels[:, i], all_preds[:, i])
        accuracies.append(acc)
        print(f'Attribute {i+1} Accuracy: {acc:.4f}')
    print(f'Average Accuracy: {np.mean(accuracies):.4f}')

# Evaluate on test set

print("Test Set Evaluation:")
evaluate(model, test_loader)

Test Set Evaluation:
Attribute 1 Accuracy: 0.9378
Attribute 2 Accuracy: 0.8232
Attribute 3 Accuracy: 0.8112
Attribute 4 Accuracy: 0.8104
Attribute 5 Accuracy: 0.9885
Attribute 6 Accuracy: 0.9537
Attribute 7 Accuracy: 0.6895
Attribute 8 Accuracy: 0.8077
Attribute 9 Accuracy: 0.8868
Attribute 10 Accuracy: 0.9535
Attribute 11 Accuracy: 0.9584
Attribute 12 Accuracy: 0.8577
Attribute 13 Accuracy: 0.8992
Attribute 14 Accuracy: 0.9521
Attribute 15 Accuracy: 0.9602
Attribute 16 Accuracy: 0.9952
Attribute 17 Accuracy: 0.9720
Attribute 18 Accuracy: 0.9814
Attribute 19 Accuracy: 0.9036
Attribute 20 Accuracy: 0.8507
Attribute 21 Accuracy: 0.9802
Attribute 22 Accuracy: 0.9314
Attribute 23 Accuracy: 0.9668
Attribute 24 Accuracy: 0.8609
Attribute 25 Accuracy: 0.9553
Attribute 26 Accuracy: 0.7449
Attribute 27 Accuracy: 0.9684
Attribute 28 Accuracy: 0.7481
Attribute 29 Accuracy: 0.9301
Attribute 30 Accuracy: 0.9456
Attribute 31 Accuracy: 0.9754
Attribute 32 Accuracy: 0.9189
Attribute 33 Accuracy: 0.818