In [None]:
# Standard library imports
import os
import sys

# Third-party imports
import numpy as np  # NumPy for numerical operations
from sklearn.metrics import average_precision_score  # Calculate average precision score
import torch
import torch.nn as nn
import torch.nn.functional as F  # PyTorch functions
import tqdm  # Progress bar library
import torchvision.transforms as transforms  # Image transformation methods
from torchvision.transforms.functional import to_pil_image  # Convert tensor to PIL image
from torchvision.utils import make_grid  # Create grid of images

# Insert custom paths into system path
sys.path.insert(0, '/root/autodl-tmp/SL-KD')
sys.path.insert(0, '/root/autodl-tmp/SL-KD/models/stylegan2')

# Local application imports
from models.dataset import dataset_dict  # Dataset dictionary containing different datasets
from models.decoder import StyleGANDecoder  # StyleGAN decoder for image generation
from models.e4e.psp_encoders import Encoder4Editing  # Encoder model for generating w
from models.ops import load_network, age2group  # Network loading and age-group conversion
from models.ops.loggerx import LoggerX  # Logging tool
from models.modules import Classifier  # Classifier for attributes and age classification

# Set the CUDA device to GPU 0
torch.cuda.set_device(0)

# Disable gradient calculation
torch.autograd.set_grad_enabled(False)

# Auto-reload modules in Jupyter Notebook
%load_ext autoreload
%autoreload 2

# Set precision for printing tensors and disable scientific notation
torch.set_printoptions(precision=3, sci_mode=False)

# Enable inline plotting for Jupyter notebooks
%matplotlib inline

In [None]:
# Configuration
img_size = 256  # Image size
bs = 64  # Batch size

# Normalization for image preprocessing
normalize = transforms.Normalize(
    mean=[0.5, 0.5, 0.5],
    std=[0.5, 0.5, 0.5],
    inplace=True
)

# Compose transformations for training and testing
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # Random horizontal flip
    transforms.Resize(img_size),
    transforms.ToTensor(),
    normalize
])

test_transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    normalize
])

# Dataset paths
data_root_celeba = '/root/autodl-tmp/DATASET/celeba/img_align_celeba_ffhq'
data_root_ffhq = '/root/autodl-tmp/DATASET/FFHQ/images256x256'

# CelebA dataset
dataset_type_celeba = dataset_dict['CelebA']
test_celeba_loader = torch.utils.data.DataLoader(
    dataset_type_celeba(data_root_celeba, img_size=img_size, split='test', transform=test_transform),
    batch_size=bs,
    shuffle=False,  # No shuffling for test set
    num_workers=8
)

train_celeba_loader = torch.utils.data.DataLoader(
    dataset_type_celeba(data_root_celeba, img_size=img_size, split='train', transform=train_transform),
    batch_size=bs,
    shuffle=True,  # Shuffle for better generalization
    drop_last=True,  # Drop last incomplete batch
    num_workers=8
)

# FFHQ dataset
dataset_type_ffhq = dataset_dict['FFHQAge']
train_ffhq_loader = torch.utils.data.DataLoader(
    dataset_type_ffhq(data_root_ffhq, img_size=img_size, split='train', transform=train_transform),
    batch_size=bs,
    shuffle=True,  # Shuffle for better generalization
    drop_last=True,  # Drop last incomplete batch
    num_workers=8
)

test_ffhq_loader = torch.utils.data.DataLoader(
    dataset_type_ffhq(data_root_ffhq, img_size=img_size, split='test', transform=test_transform),
    batch_size=bs,
    shuffle=False,  # No shuffling for test set
    num_workers=8
)

mem_ffhq_loader = torch.utils.data.DataLoader(
    dataset_type_ffhq(data_root_ffhq, img_size=img_size, split='train', transform=test_transform),
    batch_size=bs,
    num_workers=8
)

# Training the Classifier

In [None]:
# Focal Loss definition
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets, weights):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)  # prevents nans when probability is 0
        F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss

        # Apply sample weights
        F_loss = F_loss * weights

        if self.reduction == 'mean':
            return F_loss.mean()
        elif self.reduction == 'sum':
            return F_loss.sum()
        else:
            return F_loss

In [None]:
num_attributes = 40
P_T = torch.tensor([0.5] * num_attributes)

In [None]:
def compute_weights(labels, P_T):
    batch_size, num_attributes = labels.shape
    weights = torch.zeros_like(labels)

    for a in range(num_attributes):
        P_B_a = labels[:, a].mean().item()

        if P_B_a > P_T[a]:  # Over-represented class
            positive_indices = (labels[:, a] == 1).nonzero(as_tuple=True)[0]
            num_to_keep = int(P_T[a] * batch_size)
            keep_indices = positive_indices[torch.randperm(len(positive_indices))[:num_to_keep]]
            weights[keep_indices, a] = 1
            weights[labels[:, a] == 0, a] = P_T[a] / (1 - P_B_a)

        else:  # Under-represented class
            negative_indices = (labels[:, a] == 0).nonzero(as_tuple=True)[0]
            num_to_keep = int((1 - P_T[a]) * batch_size)
            keep_indices = negative_indices[torch.randperm(len(negative_indices))[:num_to_keep]]
            weights[keep_indices, a] = 1
            weights[labels[:, a] == 1, a] = P_T[a] / P_B_a

    return weights

In [None]:
# Binary Classifier
backbone = 'r50'  # `backbone = 'r34'` can be used for ResNet-34
image_classifier = Classifier(backbone).cuda()

# Logger settings
logger = LoggerX(save_root=None, print_freq=100, enable_wandb=False)

# Optimizer settings with a learning rate of 0.0001 and no weight decay
optimizer = torch.optim.Adam(image_classifier.parameters(), lr=1e-4, weight_decay=0.00)

# Training settings
num_epochs = 10
criterion = FocalLoss()
n_iter = 0
best_mAP = 0

for epoch in range(num_epochs):
    print(f'Epoch {epoch + 1}/{num_epochs}:')

    for images, labels in train_celeba_loader:
        # Move images and labels to GPU
        images, labels = images.cuda(), labels.cuda()

        # Set the classifier to training mode
        image_classifier.train()

        # Enable gradient calculation
        with torch.autograd.set_grad_enabled(True):
            # Forward pass
            outs = image_classifier.forward_attr(images)[0]

            # Compute sample weights
            weights = compute_weights(labels, P_T)

            # Calculate loss
            loss = criterion(outs, labels, weights)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Log the loss
        logger.msg([loss], n_iter)
        n_iter += 1

        if n_iter % 1000 == 0:
            print('Evaluating model...')

            # Set the classifier to evaluation mode
            image_classifier.eval()

            # Initialize lists to store predictions and labels
            all_preds, all_labels = [], []

            # Disable gradient calculation
            with torch.no_grad():
                # Iterate over the test dataset
                for images, labels in tqdm.tqdm(test_celeba_loader):
                    # Move images and labels to GPU
                    images, labels = images.cuda(), labels.cuda()

                    # Forward pass and apply sigmoid to predictions
                    preds = torch.sigmoid(image_classifier.forward_attr(images)[0])

                    # Store predictions and labels
                    all_preds.append(preds)
                    all_labels.append(labels)

                # Concatenate all predictions and labels and move to CPU
                all_preds = torch.cat(all_preds, dim=0).cpu().numpy()
                all_labels = torch.cat(all_labels, dim=0).cpu().numpy()

                # Calculate average precision scores for each label and mean average precision (mAP)
                average_precisions = [average_precision_score(all_labels[:, i], all_preds[:, i]) for i in range(40)]
                mAP = np.mean(average_precisions)

                # Print average precision scores and mAP
                print('Average Precision Scores:', average_precisions)
                print('Mean Average Precision (mAP):', mAP)

                # Update the best mAP score
                best_mAP = max(mAP, best_mAP)

                # Save the model if the current mAP is the best
                if best_mAP == mAP:
                    print('Saving model...')
                    torch.save(
                        image_classifier.state_dict(),
                        f'/root/autodl-tmp/SL-KD/data/focal_loss_{backbone}_{n_iter}.pth'
                    )

print('Training completed.')

In [None]:
# Initialize Classifier
backbone = 'r50'
classifier_checkpoint = f'/root/autodl-tmp/SL-KD/data/focal_loss_{backbone}_19000.pth'

# Initialize the classifier and ensure it is on GPU
image_classifier = Classifier(backbone=backbone).cuda()

# Load the classifier state dict from the checkpoint
state_dict = torch.load(classifier_checkpoint, map_location='cpu')
image_classifier.load_state_dict(load_network(state_dict))

# Double-check the model is on GPU
image_classifier = image_classifier.cuda()

In [None]:
# Age classifier
# Logger settings
logger = LoggerX(save_root=None, print_freq=100)

# Optimizer settings with a learning rate of 0.0001 and no weight decay
optimizer = torch.optim.Adam(image_classifier.parameters(), lr=1e-4, weight_decay=0.00)

# Training settings
num_epochs = 10
n_iter = 0

for epoch in range(num_epochs):
    print(f'Epoch {epoch + 1}/{num_epochs}:')

    for images, labels in train_ffhq_loader:
        # Move images and labels to GPU
        images, labels = images.cuda(), labels.cuda()

        # Set the classifier to training mode
        image_classifier.train()

        # Enable gradient calculation
        with torch.autograd.set_grad_enabled(True):
            # Forward pass
            outs = image_classifier.forward_age(images)[0]
            ordinal_labels = age2group(ordinal=True, groups=labels, age_group=7).float()
            loss = F.binary_cross_entropy_with_logits(outs, ordinal_labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Log the loss
        logger.msg([loss, ], n_iter)
        n_iter += 1

        if n_iter % 1000 == 0:
            print('Evaluating model...')

            # Set the classifier to evaluation mode
            image_classifier.eval()

            # Initialize lists to store predictions and labels
            all_preds, all_labels = [], []

            # Disable gradient calculation for evaluation
            with torch.no_grad():
                # Iterate over the test dataset
                for images, labels in tqdm.tqdm(test_ffhq_loader):
                    # Move images and labels to GPU
                    images, labels = images.cuda(), labels.cuda()

                    # Forward pass and get predictions
                    preds = image_classifier(images)[1]

                    # Store predictions and labels
                    all_preds.append(preds[:, -1])
                    all_labels.append(labels)

                # Concatenate all predictions and labels
                all_preds = torch.cat(all_preds, dim=0)
                all_labels = torch.cat(all_labels, dim=0)

                # Compute masks for each class
                masks = [(all_labels == i) for i in range(7)]

                # Calculate the accuracy for each class
                accuracies = [
                    (all_preds == all_labels)[mask].float().mean(dim=0) for mask in masks
                ]

                # Stack all accuracies into a single tensor
                accuracies_tensor = torch.stack(accuracies)

                # Print the accuracy tensor
                print(accuracies_tensor)

print('Training completed.')

# Define the save path
save_path = os.path.join(
    '/root/autodl-tmp/SL-KD/data',
    f'focal_loss_{backbone}_age_{n_iter}.pth'
)

# Save the model state dictionary
torch.save(image_classifier.state_dict(), save_path)

In [None]:
# Define checkpoint paths
stylegan2_checkpoint = '/root/autodl-tmp/SL-KD/data/ffhq.pkl'
e4e_checkpoint = '/root/autodl-tmp/SL-KD/data/e4e_ffhq_encode.pt'
classifier_checkpoint = '/root/autodl-tmp/SL-KD/data/focal_loss_r34_age_8410.pth'

# Define output size
output_size = 256

# Initialize StyleGANDecoder
G = StyleGANDecoder(
    stylegan2_checkpoint,
    start_from_latent_avg=False,
    output_size=output_size
)
G = G.cuda().eval()

# Initialize Encoder4Editing
encoder = Encoder4Editing(
    num_layers=50,
    mode='ir_se',
    stylegan_size=1024,
    checkpoint_path=e4e_checkpoint
)
encoder = encoder.cuda().eval()

# Initialize Classifier
backbone = 'r34'
image_classifier = Classifier(backbone=backbone).cuda()
image_classifier.load_state_dict(
    load_network(torch.load(classifier_checkpoint, map_location='cpu'))
)
image_classifier = image_classifier.cuda().eval()

# Generate latent vectors and corresponding labels

In [None]:
# Initialize lists to store all latents and predictions
all_latents = []
all_preds = []

# Set models to evaluation mode
encoder.eval()
image_classifier.eval()

# Obtain latents
for images, _ in tqdm.tqdm(mem_ffhq_loader):
    images = images.cuda()
    with torch.no_grad():
        all_latents.append(encoder(images))

# Concatenate all latents into a single tensor
all_latents = torch.cat(all_latents, dim=0)

# Obtain predictions
for images, _ in tqdm.tqdm(mem_ffhq_loader):
    images = images.cuda()
    with torch.no_grad():
        preds = image_classifier(images)[0].sigmoid()
        all_preds.append(preds)

# Concatenate all predictions into a single tensor
all_preds = torch.cat(all_preds, dim=0)

# Save latents and predictions with updated filenames
torch.save(all_latents, '/root/autodl-tmp/SL-KD/data/ffhq_train_latents.pth')
torch.save(all_preds, '/root/autodl-tmp/SL-KD/data/focal_loss_ffhq_train_preds.pth')

In [None]:
from IPython.display import display

# Load latents and predictions
all_latents = torch.load(
    '/root/autodl-tmp/SL-KD/data/ffhq_train_latents.pth',
    map_location='cpu'
).cuda() + G.latent_avg

all_preds = torch.load(
    '/root/autodl-tmp/SL-KD/data/focal_loss_ffhq_train_preds.pth',
    map_location='cpu'
).cuda()

# Select the indices and display the images and predictions
indices = torch.arange(16)

# Display the generated images
generated_images = to_pil_image(make_grid(G(all_latents[indices])).clamp(-1., 1.) * 0.5 + 0.5)
display(generated_images)

# Display the selected predictions
display(all_preds[indices[:, None], [15, 20, 31]])

In [None]:
from models.modules import FLOW

# Initialize the model and optimizer
realnvp = FLOW(style_dim=G.style_dim, n_styles=G.n_styles, n_layer=10).cuda().eval()
optimizer = torch.optim.AdamW(list(realnvp.parameters()), lr=0.0001, weight_decay=0.00)

# Initialize logger
logger = LoggerX(save_root=None, print_freq=10)

# Set training parameters
max_iter = 10000
bs = 256

# Training loop
for n_iter in range(1, max_iter + 1):
    # Sample random latents
    indices = torch.randint(0, len(all_latents), (bs,))
    latents = all_latents[indices]

    realnvp.train()

    # Enable gradient tracking
    with torch.autograd.set_grad_enabled(True):
        # Compute loss, log determinant of Jacobian, and other outputs
        loss, logz, log_det_jacobian, _ = realnvp(latents)

        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Log the results
    logger.msg([logz, log_det_jacobian, loss], n_iter)