In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import torchvision.utils as vutils
import matplotlib.pyplot as plt 
import clip 

# --- 1. Poisoner Class --- #
class MultiModelEmbeddingPoisoner(nn.Module):
    def __init__(self, models, epsilon=0.03, targeted=False, target_embeddings=None):
        super(MultiModelEmbeddingPoisoner, self).__init__()
        self.models = [m.eval() for m in models]
        self.epsilon = epsilon
        self.targeted = targeted
        self.target_embeddings = target_embeddings

    def forward(self, x):
        x = x.clone().detach().requires_grad_(True)
        total_grad = torch.zeros_like(x)

        for i, model in enumerate(self.models):
            emb = model(x)

            if self.targeted:
                assert self.target_embeddings is not None and len(self.target_embeddings) == len(self.models), \
                    "Provide one target embedding per model."
                target_emb = self.target_embeddings[i]
                loss = -F.cosine_similarity(emb, target_emb).mean()
            else:
                loss = torch.norm(emb, p=2)

            grad = torch.autograd.grad(loss, x, retain_graph=True, create_graph=False)[0]
            total_grad += grad

        avg_grad = total_grad / len(self.models)
        x_adv = x + self.epsilon * avg_grad.sign()
        x_adv = torch.clamp(x_adv, 0, 1)

        return x_adv.detach()


# --- 2. Load Real Pretrained Face Models --- #
def load_face_models():
    from facenet_pytorch import InceptionResnetV1
    import timm

    model1 = InceptionResnetV1(pretrained='vggface2').eval() 
    # model1, preprocess = clip.load("ViT-B/32", device="cuda", jit=False)
    model2 = timm.create_model('resnet18', pretrained=True)
    model2.fc = nn.Identity()
    model3 = timm.create_model('tf_efficientnet_b0', pretrained=True)
    model3.classifier = nn.Identity()

    return [model1, model2, model3]



def cloak_single_image(image_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    models_list = load_face_models()
    models_list = [m.to(device) for m in models_list]

    poisoner = MultiModelEmbeddingPoisoner(models=models_list, epsilon=0.03).to(device)

    transform = transforms.Compose([
        transforms.Resize((160, 160)),
        transforms.ToTensor(),
    ])

    image = Image.open(image_path).convert('RGB')
    x = transform(image).unsqueeze(0).to(device)

    # Cloak the image
    x_adv = poisoner(x)

    # Get embeddings for cosine similarity calculation
    original_embeddings = [model(x) for model in models_list]
    cloaked_embeddings = [model(x_adv) for model in models_list]

    # Calculate cosine similarity for each model
    cosine_similarities = []
    for orig_emb, cloaked_emb in zip(original_embeddings, cloaked_embeddings):
        similarity = F.cosine_similarity(orig_emb, cloaked_emb).mean().item()
        cosine_similarities.append(similarity)

    # Convert tensors to CPU for visualization
    x = x.cpu().detach()
    x_adv = x_adv.cpu().detach()

    # Convert to numpy arrays for plotting
    x = x.squeeze(0).permute(1, 2, 0).numpy()
    x_adv = x_adv.squeeze(0).permute(1, 2, 0).numpy()

    # Plot the original and cloaked images side by side
    plt.figure(figsize=(12, 6))

    # Display original image
    plt.subplot(1, 2, 1)
    plt.imshow(x)
    plt.title("Original Image")
    plt.axis('off')

    # Display cloaked image
    plt.subplot(1, 2, 2)
    plt.imshow(x_adv)
    plt.title("Cloaked Image")
    plt.axis('off')

    # Show the plot
    plt.show()

    # Print Cosine Similarities for each model
    for i, similarity in enumerate(cosine_similarities):
        print(f"Cosine Similarity for model {i+1}: {similarity:.4f}")

if __name__ == '__main__':
    image_path = "models/cloaked_image.png" 
    cloak_single_image(image_path)

  from .autonotebook import tqdm as notebook_tqdm


URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:997)>

In [6]:
from facenet_pytorch import InceptionResnetV1
import timm
model=InceptionResnetV1(pretrained='vggface2').eval() 
torch.save(model, "model.pt")