In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import time
import fiftyone as fo
import fiftyone.zoo as foz
import fiftyone.brain as fob
from tqdm.notebook import tqdm, tnrange
from PIL import Image, ImageOps
import matplotlib.pyplot as plt
import torchsummary
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import os
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
# import resnet18 model from pytorch
from torchvision.models import resnet18
from facenet_pytorch import MTCNN
import cv2
from torch.utils.tensorboard import SummaryWriter

In [2]:
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
DIM = (224, 224)

In [3]:
class CustomNormalize:
    def __call__(self, img):
        # Convert PIL image to tensor
        img = transforms.ToTensor()(img)
        # Subtract 128 and divide by 128
        img = (img * 255.0 - 128.0) / 128.0
        return img

preprocess = transforms.Compose([
    transforms.Resize(DIM, interpolation=transforms.InterpolationMode.LANCZOS),  # Resize the image to the desired dimensions
    transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
    # transforms.RandomRotation(45),  # Randomly rotate the image by up to 10 degrees
    # transforms.RandomVerticalFlip(),  # Randomly flip the image vertically
    # transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),  # Randomly change brightness, contrast, saturation, and hue
    CustomNormalize(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [4]:
# How can I modify the model to output embeddings of size 128?
# 1. Create a new model that outputs embeddings
# 2. Modify the last layer of the model to output embeddings
# 3. Use a hook to extract embeddings from the model
# 4. Use a custom loss function to train the model

class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(x)
        out = self.relu(out)
        return out

class ResNet18(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet18, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.layer1 = self._make_layer(BasicBlock, 64, 2, stride=1)
        self.layer2 = self._make_layer(BasicBlock, 128, 2, stride=2)
        self.layer3 = self._make_layer(BasicBlock, 256, 2, stride=2)
        self.layer4 = self._make_layer(BasicBlock, 512, 2, stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.maxpool(out)
        
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

In [5]:
def load_model(model, path, device):
    state_dict = torch.load(path, map_location=device)
    # Create new OrderedDict without 'module.' prefix
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if k.startswith('module.'):
            name = k[7:] # remove 'module.' prefix
        else:
            name = k
        new_state_dict[name] = v
    
    # Load the weights
    model.load_state_dict(new_state_dict)
    # Set to evaluation mode
    model.eval()
    print(f"Model loaded from {path}")
    return model

In [6]:
torch.cuda.empty_cache()
embedding_model = ResNet18(128)
embedding_model = embedding_model.to(device)
embedding_model = load_model(embedding_model, './runs/224x224_ResNet18_AMSoftmax_20250409-172656/224x224_ResNet18_AMSoftmax_validation_20250409-172656.pt', device)

Model loaded from ./runs/224x224_ResNet18_AMSoftmax_20250409-172656/224x224_ResNet18_AMSoftmax_validation_20250409-172656.pt


In [7]:
# torchsummary.summary(embedding_model, (3, 224, 224), device='cuda')
# torchsummary.summary(embedding_model, (3, 112, 96), device='cuda')
# torch.cuda.empty_cache()

In [8]:
# Initialize the MTCNN face detector
mtcnn = MTCNN(keep_all=False, device='cuda:3', image_size=112, margin=0)

# Function to perform face detection and crop face from image
def detect_and_crop_face(image, mtcnn, target_size=(224, 224)):
    """Detects face, crops using bounding box, makes it square, and resizes to target_size."""
    
    # Convert to PIL
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    elif isinstance(image, torch.Tensor):
        image = transforms.ToPILImage()(image)
    elif isinstance(image, str):
        image = Image.open(image)
    elif not isinstance(image, Image.Image):
        raise ValueError("Input image must be a numpy array, torch tensor, or PIL Image.")

    if image.mode != 'RGB':
        image = image.convert('RGB')

    # Detect face
    boxes, _ = mtcnn.detect(image)

    if boxes is not None:
        # Crop the face using the bounding box
        x1, y1, x2, y2 = boxes[0].astype(int)
        face = image.crop((x1, y1, x2, y2))
    else:
        # If no face is detected, return the original image
        print("No face detected.")
        return image

    # Step 2: Resize to final target size
    face = face.resize(target_size, Image.Resampling.LANCZOS)
    return face


In [9]:
# get images from test dataset by iterating over it
def get_image_from_path(image_path):
    return Image.open(image_path).convert("RGB")


def get_test_images(dataset):
    test_images = []
    for i, sample in zip(tnrange(len(dataset), desc="Samples"), dataset):
        image_path = sample["filepath"]
        # image = get_image_from_path(image_path)
        test_images.append(image_path)
    return test_images
def get_embeddings_from_images(images, model):
    embeddings = []
    model.eval()
    with torch.no_grad():
        for i, image in zip(tnrange(len(images), desc="Images"), images):
            image = detect_and_crop_face(image, mtcnn, target_size=(224, 224))
            image = preprocess(image)
            image = image.unsqueeze(0).to(device)
            embedding = model(image)
            embedding = embedding.squeeze()
            embedding = embedding.cpu()
            embeddings.append(embedding)
    
    return embeddings

In [10]:
lfw_test = foz.load_zoo_dataset("lfw")

Split 'train' already downloaded
Split 'test' already downloaded
Loading 'lfw' split 'train'
 100% |███████████████| 9525/9525 [2.6s elapsed, 0s remaining, 3.6K samples/s]      
Loading 'lfw' split 'test'
 100% |███████████████| 3708/3708 [994.6ms elapsed, 0s remaining, 3.7K samples/s]      
Dataset 'lfw' created


In [11]:
test_images = get_test_images(lfw_test)
test_embed = get_embeddings_from_images(test_images, embedding_model)
print(test_embed[0].shape)
test_embeddings = test_embed

Samples:   0%|          | 0/13233 [00:00<?, ?it/s]

Images:   0%|          | 0/13233 [00:00<?, ?it/s]

torch.Size([128])


In [12]:


results = fob.compute_visualization(
    lfw_test,      # samples
    None,                  # patches_field (set to None if not applicable)
    test_embeddings,       # embeddings
    label_field="ground_truth.label",
    classes=lfw_test.values("ground_truth.label"),
    brain_key="ResNet_face_detection_embeddings_tsne",
    output_dir="ResNet_face_detection_embeddings",
    overwrite=True, 
    method="tsne",
)

results = fob.compute_visualization(
    lfw_test,
    None,
    test_embeddings,    
    label_field="ground_truth.label",
    classes=lfw_test.values("ground_truth.label"),
    brain_key="ResNet_face_detection_embeddings_pca",
    output_dir="ResNet_face_detection_embeddings",
    overwrite=True,
    method = "pca",
)

results = fob.compute_visualization(
    lfw_test,
    None,
    test_embeddings,
    label_field="ground_truth.label",
    classes=lfw_test.values("ground_truth.label"),
    brain_key="ResNet_face_detection_embeddings_umap",
    output_dir="ResNet_face_detection_embeddings",
    overwrite=True,
    method = "umap",
)

Generating visualization...




[t-SNE] Computing 91 nearest neighbors...
[t-SNE] Indexed 13233 samples in 0.001s...
[t-SNE] Computed neighbors for 13233 samples in 0.306s...
[t-SNE] Computed conditional probabilities for sample 1000 / 13233
[t-SNE] Computed conditional probabilities for sample 2000 / 13233
[t-SNE] Computed conditional probabilities for sample 3000 / 13233
[t-SNE] Computed conditional probabilities for sample 4000 / 13233
[t-SNE] Computed conditional probabilities for sample 5000 / 13233
[t-SNE] Computed conditional probabilities for sample 6000 / 13233
[t-SNE] Computed conditional probabilities for sample 7000 / 13233
[t-SNE] Computed conditional probabilities for sample 8000 / 13233
[t-SNE] Computed conditional probabilities for sample 9000 / 13233
[t-SNE] Computed conditional probabilities for sample 10000 / 13233
[t-SNE] Computed conditional probabilities for sample 11000 / 13233
[t-SNE] Computed conditional probabilities for sample 12000 / 13233
[t-SNE] Computed conditional probabilities for sam



UMAP( verbose=True)
Tue Apr 29 11:35:50 2025 Construct fuzzy simplicial set
Tue Apr 29 11:35:50 2025 Finding Nearest Neighbors
Tue Apr 29 11:35:50 2025 Building RP forest with 11 trees
Tue Apr 29 11:35:56 2025 NN descent for 14 iterations
	 1  /  14
	 2  /  14
	 3  /  14
	 4  /  14
	 5  /  14
	 6  /  14
	 7  /  14
	Stopping threshold met -- exiting after 7 iterations
Tue Apr 29 11:36:05 2025 Finished Nearest Neighbor Search
Tue Apr 29 11:36:07 2025 Construct embedding


Epochs completed:   0%|            0/200 [00:00]

	completed  0  /  200 epochs
	completed  20  /  200 epochs
	completed  40  /  200 epochs
	completed  60  /  200 epochs
	completed  80  /  200 epochs
	completed  100  /  200 epochs
	completed  120  /  200 epochs
	completed  140  /  200 epochs
	completed  160  /  200 epochs
	completed  180  /  200 epochs
Tue Apr 29 11:36:14 2025 Finished embedding


In [13]:
sess = fo.launch_app(lfw_test)
sess.open_tab()

<IPython.core.display.Javascript object>