In [None]:
import torch
import torch.nn as nn
import numpy as np
import fiftyone as fo
import fiftyone.zoo as foz
import fiftyone.brain as fob
from tqdm.notebook import tqdm, tnrange
import torchvision.transforms as transforms
from facenet_pytorch import MTCNN

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
DIM = (224, 224)

In [None]:
class CustomNormalize:
    def __call__(self, img):
        # Convert PIL image to tensor
        img = transforms.ToTensor()(img)
        # Subtract 128 and divide by 128
        img = (img * 255.0 - 128.0) / 128.0
        return img

preprocess = transforms.Compose([
    transforms.Resize(DIM, interpolation=transforms.InterpolationMode.LANCZOS),  # Resize the image to the desired dimensions
    transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
    # transforms.RandomRotation(45),  # Randomly rotate the image by up to 10 degrees
    # transforms.RandomVerticalFlip(),  # Randomly flip the image vertically
    # transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),  # Randomly change brightness, contrast, saturation, and hue
    CustomNormalize(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(x)
        out = self.relu(out)
        return out

class ResNet18(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet18, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.layer1 = self._make_layer(BasicBlock, 64, 2, stride=1)
        self.layer2 = self._make_layer(BasicBlock, 128, 2, stride=2)
        self.layer3 = self._make_layer(BasicBlock, 256, 2, stride=2)
        self.layer4 = self._make_layer(BasicBlock, 512, 2, stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.maxpool(out)
        
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

In [None]:
def load_model(model, path, device):
    state_dict = torch.load(path, map_location=device)
    # Create new OrderedDict without 'module.' prefix
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if k.startswith('module.'):
            name = k[7:] # remove 'module.' prefix
        else:
            name = k
        new_state_dict[name] = v
    
    # Load the weights
    model.load_state_dict(new_state_dict)
    # Set to evaluation mode
    model.eval()
    print(f"Model loaded from {path}")
    return model

In [None]:
torch.cuda.empty_cache()
embedding_model = ResNet18(128)
embedding_model = embedding_model.to(device)
embedding_model = load_model(embedding_model, '224x224_ResNet18_AMSoftmax_validation_20250409-172656.pt', device)

In [None]:
# Initialize the MTCNN face detector
mtcnn = MTCNN(keep_all=False, device=device, image_size=112, margin=0)

# Function to perform face detection and crop face from image
def detect_and_crop_face(image, mtcnn, target_size=(224, 224)):
    """Detects face, crops using bounding box, makes it square, and resizes to target_size."""
    
    # Convert to PIL
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    elif isinstance(image, torch.Tensor):
        image = transforms.ToPILImage()(image)
    elif isinstance(image, str):
        image = Image.open(image)
    elif not isinstance(image, Image.Image):
        raise ValueError("Input image must be a numpy array, torch tensor, or PIL Image.")

    if image.mode != 'RGB':
        image = image.convert('RGB')

    # Detect face
    boxes, _ = mtcnn.detect(image)

    if boxes is not None:
        # Crop the face using the bounding box
        x1, y1, x2, y2 = boxes[0].astype(int)
        face = image.crop((x1, y1, x2, y2))
    else:
        # If no face is detected, return the original image
        print("No face detected.")
        return image

    # Step 2: Resize to final target size
    face = face.resize(target_size, Image.Resampling.LANCZOS)
    return face


In [None]:
# get images from test dataset by iterating over it
def get_image_from_path(image_path):
    return Image.open(image_path).convert("RGB")


def get_test_images(dataset):
    test_images = []
    for i, sample in zip(tnrange(len(dataset), desc="Samples"), dataset):
        image_path = sample["filepath"]
        # image = get_image_from_path(image_path)
        test_images.append(image_path)
    return test_images
def get_embeddings_from_images(images, model):
    embeddings = []
    model.eval()
    with torch.no_grad():
        for i, image in zip(tnrange(len(images), desc="Images"), images):
            image = detect_and_crop_face(image, mtcnn, target_size=(224, 224))
            image = preprocess(image)
            image = image.unsqueeze(0).to(device)
            embedding = model(image)
            embedding = embedding.squeeze()
            embedding = embedding.cpu()
            embeddings.append(embedding)
    
    return embeddings

In [None]:
lfw_test = foz.load_zoo_dataset("lfw", split="test")

In [None]:
test_images = get_test_images(lfw_test)
test_embed = get_embeddings_from_images(test_images, embedding_model)
print(test_embed[0].shape)
test_embeddings = test_embed

In [None]:
results = fob.compute_visualization(
    lfw_test,      # samples
    None,                  # patches_field (set to None if not applicable)
    test_embeddings,       # embeddings
    label_field="ground_truth.label",
    classes=lfw_test.values("ground_truth.label"),
    brain_key="ResNet_face_detection_embeddings_tsne",
    output_dir="ResNet_face_detection_embeddings",
    overwrite=True, 
    method="tsne",
)

results = fob.compute_visualization(
    lfw_test,
    None,
    test_embeddings,    
    label_field="ground_truth.label",
    classes=lfw_test.values("ground_truth.label"),
    brain_key="ResNet_face_detection_embeddings_pca",
    output_dir="ResNet_face_detection_embeddings",
    overwrite=True,
    method = "pca",
)

results = fob.compute_visualization(
    lfw_test,
    None,
    test_embeddings,
    label_field="ground_truth.label",
    classes=lfw_test.values("ground_truth.label"),
    brain_key="ResNet_face_detection_embeddings_umap",
    output_dir="ResNet_face_detection_embeddings",
    overwrite=True,
    method = "umap",
)

In [None]:
sess = fo.launch_app(lfw_test)
sess.open_tab()