In [None]:
!pip3 install lpips

In [None]:
from transformers import AutoProcessor, CLIPVisionModel
import torch
import torchvision
import PIL.Image
import lpips
import torch.optim
from tqdm.auto import tqdm

In [None]:
device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")

device

device(type='cuda')

In [None]:
model = CLIPVisionModel.from_pretrained('openai/clip-vit-base-patch32').to(device)
processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [32]:
def load_image(path: str) -> torch.Tensor:
    image = PIL.Image.open(path).convert("RGB")
    transform = torchvision.transforms.Compose([
        torchvision.transforms.Resize(512),
        torchvision.transforms.ToTensor(),
    ])
    return transform(image).to(device)

def get_image_features(image) -> torch.Tensor:
    inputs = processor(images=image, return_tensors="pt", do_rescale=False, do_normalize=True).to(device)
    return model(**inputs).pooler_output.to(device)

In [None]:
tests = [
    (
        "../examples/airplane-train/k63368-01.jpg",
        "../examples/train-gift/img_presse_retro-tgv1_30062017.jpg"
    ),
    (
        "../examples/airplane-train/k63368-01.jpg",
        "../examples/train-gift/img_presse_retro-tgv1_30062017.jpg"
    ),
    (
        "../examples/airplane-train/560458.jpg",
        "../examples/train-gift/img_presse_retro-tgv1_30062017.jpg"
    )
]

test_data = [(load_image(test[0]), load_image(test[1])) for test in tests]

In [None]:
lpips_criterion = lpips.LPIPS(net='vgg').to(device) # TODO: set cache dir

In [None]:
def clamp_inplace(t: torch.Tensor):
    requires_grad = t.requires_grad
    t.requires_grad = False
    offset =  -(torch.maximum(t - 1, torch.tensor(0.0)) + torch.minimum(t, torch.tensor(0.0)))
    t.add_(offset)
    t.requires_grad = requires_grad

def nightshade(data: list[tuple], lr: float = 0.01, num_epochs: int = 50, alpha: float = 1000.0, p: float = 0.07):
    model.requires_grad_(True)

    poisoned_images = []

    for original_image, anchor_image in data:
        poisoned_image = processor(images=original_image, return_tensors="pt", do_rescale=False, do_normalize=True).to(device)['pixel_values']
        poisoned_image.requires_grad = True

        original_image = processor(images=original_image, return_tensors="pt", do_rescale=False, do_normalize=True).to(device)['pixel_values']
        original_embedding = model(pixel_values=original_image).pooler_output.detach()

        anchor_image = processor(images=anchor_image, return_tensors="pt", do_rescale=False, do_normalize=True).to(device)['pixel_values']
        anchor_embedding = model(pixel_values=anchor_image).pooler_output.detach()

        optimizer = torch.optim.Adam([poisoned_image], lr=lr)

        pbar = tqdm(range(num_epochs))

        for epoch in pbar:

            optimizer.zero_grad()

            poisoned_embedding = model(pixel_values=poisoned_image).pooler_output

            original_loss = torch.linalg.norm(poisoned_embedding - original_embedding)
            embedding_loss = torch.linalg.norm(poisoned_embedding - anchor_embedding)
            
            lpips_loss = lpips_criterion(poisoned_image, original_image, normalize=True)[0][0][0][0]
            
            loss = embedding_loss + alpha * torch.maximum(lpips_loss - p, torch.tensor(0.0))

            pbar.set_description(f'embed {embedding_loss.item():.3f} [to orig {original_loss.item():.3f}] | lpips {lpips_loss.item():.3f} | total {loss.item():.3f}')

            loss.backward()
            optimizer.step()

        del optimizer

        poisoned_image = torch.clamp(poisoned_image, min=0.0, max=1.0)
        poisoned_images.append(poisoned_image.squeeze())

    return poisoned_images



# generate_anchor_image("train")
torch.cuda.empty_cache()
poisoned_images = nightshade(test_data[-1:], lr=0.003, num_epochs=100, alpha=10, p=0.07)
torch.cuda.empty_cache()
torchvision.transforms.ToPILImage()(poisoned_images[0])