# Training

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Subset
import random
import os

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Constants
IMG_X_SIZE = 218
IMG_Y_SIZE = 178
NUM_CLASSES = 10  # For demo
MODEL_CHECKPOINT_PATH = './face_recognition_model.pth'

# Transform: grayscale + tensor
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
])

# Load CelebA (aligned & cropped)
dataset = datasets.CelebA(
    root='./data',
    split='train',
    target_type='identity',
    download=True,
    transform=transform
)

# Create a subset with 10% of the dataset
subset_size = int(0.1 * len(dataset))  # 10% of dataset
subset_indices = random.sample(range(len(dataset)), subset_size)  # Random sampling
subset_dataset = Subset(dataset, subset_indices)

# DataLoader
dataloader = DataLoader(subset_dataset, batch_size=40, shuffle=True)
example_batch, example_labels = next(iter(dataloader))
example_batch = example_batch.to(device)
example_labels = example_labels.to(device)

# Limit to 10 samples for inversion demo
example_batch = example_batch[:10]
example_labels = example_labels[:10] % NUM_CLASSES

# Model definition
class FaceRecognitionCNN(nn.Module):
    def __init__(self):
        super(FaceRecognitionCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 36, kernel_size=7)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(36 * (IMG_X_SIZE - 6) * (IMG_Y_SIZE - 6), NUM_CLASSES)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.flatten(x)
        x = self.fc1(x)
        return x

model = FaceRecognitionCNN().to(device)

# Training setup
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Early stopping setup
best_loss = float('inf')
best_accuracy = 0.0
patience = 5
epochs_without_improvement = 0

# Training loop with early stopping
EPOCHS = 20
print("Training model...")

for epoch in range(1, EPOCHS + 1):
    model.train()
    total_loss = 0.0
    total_acc = 0.0
    count = 0

    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        y = y % NUM_CLASSES

        optimizer.zero_grad()
        output = model(X)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()

        preds = output.argmax(dim=1)
        acc = (preds == y).float().mean().item()

        total_loss += loss.item()
        total_acc += acc
        count += 1

    avg_loss = total_loss / count
    avg_acc = total_acc / count
    print(f"Epoch {epoch}/{EPOCHS} — Loss: {avg_loss:.4f} — Accuracy: {avg_acc:.4f}")

    if avg_loss < best_loss and avg_acc > best_accuracy:
        best_loss = avg_loss
        best_accuracy = avg_acc
        epochs_without_improvement = 0
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement >= patience:
            print(f"Early stopping triggered after {epoch} epochs.")
            break

# Save model checkpoint
print("Saving model checkpoint...")
torch.save(model.state_dict(), MODEL_CHECKPOINT_PATH)


Using device: cuda


FileURLRetrievalError: Failed to retrieve file url:

	Too many users have viewed or downloaded this file recently. Please
	try accessing the file again later. If the file you are trying to
	access is particularly large or is shared with many people, it may
	take up to 24 hours to be able to view or download the file. If you
	still can't access a file after 24 hours, contact your domain
	administrator.

You may still be able to access the file from the browser:

	https://drive.google.com/uc?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM

but Gdown can't. Please check connections and permissions.

# Reconstruction

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os

# Constants
IMG_X_SIZE = 218
IMG_Y_SIZE = 178
NUM_CLASSES = 8192
MODEL_CHECKPOINT_PATH = './face_recognition_model.pth'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Model definition
class FaceRecognitionCNN(nn.Module):
    def __init__(self):
        super(FaceRecognitionCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 36, kernel_size=7)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(36 * (IMG_X_SIZE - 6) * (IMG_Y_SIZE - 6), NUM_CLASSES)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.flatten(x)
        x = self.fc1(x)
        return x

# Load model
model = FaceRecognitionCNN().to(device)
model.load_state_dict(torch.load(MODEL_CHECKPOINT_PATH))
model.eval()

# Inversion function (corrected leaf tensor handling)
def sarahs_inversion(model, label_idx, steps=100, lr=0.01, noise_strength=0.001):
    model.eval()
    label = torch.tensor([label_idx], device=device)
    loss_fn = nn.CrossEntropyLoss()

    # Create a leaf tensor with requires_grad=True
    img_data = torch.randn((1, 1, IMG_X_SIZE, IMG_Y_SIZE), device=device) * 0.1
    img_data = img_data.clamp(0, 1)
    img_data.requires_grad_()

    for step in range(steps):
        if img_data.grad is not None:
            img_data.grad.zero_()

        output = model(img_data)
        loss = loss_fn(output, label)
        loss.backward()

        with torch.no_grad():
            img_data -= lr * img_data.grad
            # img_data += torch.randn_like(img_data) * noise_strength
            img_data.clamp_(0, 1)

        if step % 10 == 0:
            print(f"Step {step}/{steps} — Loss: {loss.item():.4f}")

    return img_data.detach()

# Side-by-side plot
def plot_comparison(original, reconstructed, index):
    fig, axs = plt.subplots(1, 2, figsize=(6, 3))
    axs[0].imshow(original.squeeze().cpu(), cmap='gray')
    axs[0].set_title("Original")
    axs[0].axis("off")

    axs[1].imshow(reconstructed.squeeze().cpu(), cmap='gray')
    axs[1].set_title("Reconstructed")
    axs[1].axis("off")

    plt.suptitle(f"Example {index}")
    plt.tight_layout()
    plt.show()

# Reconstruction
print("Reconstructing images...")
for idx in range(10):
    label = example_labels[idx].item()
    target_img = example_batch[idx:idx+1]
    recon_img = sarahs_inversion(model, label_idx=label, steps=100, lr=0.01, noise_strength=0.001)
    plot_comparison(target_img, recon_img, idx)


Using device: cuda
