In [2]:
import pickle
import torch

from ldm.util import instantiate_from_config

from torchvision.transforms import ToTensor, ToPILImage
import torchvision
from PIL import Image

from FSSAAD.pytorch_msssim import msssim
from torch.nn import functional as F

from torcheval.metrics import BinaryAUROC

In [3]:
def load_model_from_config(config, ckpt, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    sd = pl_sd["state_dict"]
    config.model.params.ckpt_path = ckpt
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)
    
    return model

In [4]:
device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")

with open('config.pkl', 'rb') as f:
    config = pickle.load(f)
model = load_model_from_config(config, "models/ldm/text2img-large/model.ckpt").first_stage_model.to(device)

for param in model.parameters():
    param.requires_grad = False

Loading model from models/ldm/text2img-large/model.ckpt
LatentDiffusion: Running in eps-prediction mode


: 

In [4]:
# load model to PyTorch Tensor
def load_image_toTensor(path):
    image_path = path
    image = Image.open(image_path)  # Load the image using PIL
    image = image.convert("RGB")


    transform = ToTensor()  # Create a ToTensor transform
    input_image_tensor = transform(image)  # Convert image to tensor

    # Reshape the tensor to match the expected input shape of the encoder
    input_image_tensor = input_image_tensor.unsqueeze(0)  # Add batch dimension
    return input_image_tensor

In [5]:
def perception_loss(a, b):
    msssim_noise_loss = (1 - msssim(a, b)) / 2
    f1_noise_loss = F.l1_loss(a, b)
    return msssim_noise_loss + f1_noise_loss
    # return torch.norm(a-b, p=2)

def dists(base_pt, target_pt, poisoned_pt):
    decoded_base, _ = model(base_pt)
    decoded_target, _ = model(target_pt)
    decoded_poisoned, _ = model(poisoned_pt)
    
    base_decoded_base = perception_loss(base_pt, decoded_base)
    target_decoded_target = perception_loss(target_pt, decoded_target)
    poisoned_decoded_poisoned = perception_loss(poisoned_pt, decoded_poisoned)

    base_poisoned = perception_loss(base_pt, poisoned_pt)
    decoded_base_decoded_poisoned = perception_loss(decoded_base, decoded_poisoned)

    target_poisoned = perception_loss(target_pt, poisoned_pt)
    decoded_target_decoded_poisoned = perception_loss(decoded_target, decoded_poisoned)

    # print(f"Distance between base and its decoded output: {base_decoded_base}")
    # print(f"Distance between target and its decoded output: {target_decoded_target}")
    # print(f"Distance between poison and its decoded output: {poisoned_decoded_poisoned}")

    # print(f"Distance between base and poison: {base_poisoned}")
    # print(f"Distance between base and poison decoded outputs: {decoded_base_decoded_poisoned}")

    # print(f"Distance between target and poison: {target_poisoned}")
    # print(f"Distance between target and poison decoded outputs: {decoded_target_decoded_poisoned}")

    return base_decoded_base, target_decoded_target, poisoned_decoded_poisoned, base_poisoned, decoded_base_decoded_poisoned, target_poisoned, decoded_target_decoded_poisoned

def all_dists(bases, targets, poisons):
    bdbs, tdts, pdps = [], [], []

    for (b, t, p) in zip(bases, targets, poisons):
        bdb, tdt, pdp, bp, dbdp, tp, dtdp = dists(b, t, p)
        bdbs.append(bdb)
        tdts.append(tdt)
        pdps.append(pdp)
    
    return torch.tensor(bdbs), torch.tensor(tdts), torch.tensor(pdps)



In [42]:
def statistics(bdbs, pdps):
    print("x_b mean:", bdbs.mean(), "x_d mean:", pdps.mean())
    print("threshold:", pdps.min())
    print("False positive rate:", sum(bdbs > pdps.min()), "/", len(bdbs))
    metric = BinaryAUROC()
    metric.update(bdbs, torch.zeros(len(bdbs)))
    metric.update(pdps, torch.ones(len(pdps)))
    print("AUROC:", metric.compute())
    metric.reset()

In [27]:
# bottle base
bottle_bases = [load_image_toTensor(f'poison/bottle_watermark_clipped/{i+1}.png').to(device) for i in range(4)]

# bottle target
bottle_targets = [load_image_toTensor(f'poison/bottle_watermark_clipped/{i+1}o.png').to(device) for i in range(4)]

# bottle poison
bottle_poisons = [torch.load(f'poison/bottle_watermark_clipped/img_train_{i+1}/poison_100000.pt').to(device) for i in range(4)]

In [28]:
bottle_bdbs, bottle_tdts, bottle_pdps = all_dists(bottle_bases, bottle_targets, bottle_poisons)

In [43]:
statistics(bottle_bdbs, bottle_pdps)

x_b mean: tensor(0.0502) x_d mean: tensor(0.0856)
threshold: tensor(0.0703)
False positive rate: tensor(1) / 4
AUROC: tensor(0.8750, dtype=torch.float64)


In [30]:
# sunflower base
sunflower_bases = [load_image_toTensor(f'poison/sunflowers_clipped/{i}o.jpg').to(device) for i in [2, 5, 6]]

# sunflower target
sunflower_targets = [load_image_toTensor(f'poison/sunflowers_clipped/{i}.jpg').to(device) for i in [2, 5, 6]]

# sunflower poison
sunflower_poisons = [torch.load(f'poison/sunflowers_clipped/img_train_{i}/poison_100000.pt').to(device) for i in [2, 5, 6]]

In [31]:
sunflower_bdbs, sunflower_tdts, sunflower_pdps = all_dists(sunflower_bases, sunflower_targets, sunflower_poisons)

In [32]:
statistics(sunflower_bdbs, sunflower_tdts, sunflower_pdps)

x_b mean: tensor(0.0256) x_c mean: tensor(0.0890) x_d mean: tensor(0.1632)
threshold: tensor(0.1412)
tensor([False, False, False])
AUROC: tensor(1., dtype=torch.float64)


In [33]:
# style base
style_base_1 = load_image_toTensor('../textual_inversion_2/poison/style_with_cons/water_building.jpg').to(device)
style_base_2 = load_image_toTensor('../textual_inversion_2/poison/style_with_cons/water_image.jpg').to(device)
style_base_3 = load_image_toTensor('../textual_inversion_2/poison/style_with_cons/water_van_gogh_ht.jpg').to(device)
style_bases = [style_base_1, style_base_2, style_base_3]

# style target
style_target_1 = load_image_toTensor('../textual_inversion_2/poison/style_with_cons/building.jpg').to(device)
style_target_2 = load_image_toTensor('../textual_inversion_2/poison/style_with_cons/target_image.jpg').to(device)
style_target_3 = load_image_toTensor('../textual_inversion_2/poison/style_with_cons/van_gogh_ht.jpg').to(device)
style_targets = [style_target_1, style_target_2, style_target_3]

# style poison
style_poison_1 = torch.load('../textual_inversion_2/poison/img_train_style_building_clip_2000/poison_100000.pt').to(device)
style_poison_2 = torch.load('../textual_inversion_2/poison/img_train_style_river_clip_2000/poison_100000.pt').to(device)
style_poison_3 = torch.load('../textual_inversion_2/poison/img_train_style_ht_clip_2000/poison_100000.pt').to(device)
style_poisons = [style_poison_1, style_poison_2, style_poison_3]


In [34]:
style_bdbs, style_tdts, style_pdps = all_dists(style_bases, style_targets, style_poisons)

In [35]:
statistics(style_bdbs, style_tdts, style_pdps)

x_b mean: tensor(0.0551) x_c mean: tensor(0.1046) x_d mean: tensor(0.2667)
threshold: tensor(0.2506)
tensor([False, False, False])
AUROC: tensor(1., dtype=torch.float64)


# Imagenet

In [36]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch

In [37]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize(256),          # Resize to 256x256
    transforms.CenterCrop(256),      # Crop the image to 256x256
    transforms.ToTensor(),           # Convert image to tensor
])

# Assuming you have the ImageNet dataset at 'path/to/imagenet'
# Note: You will need to replace 'path/to/imagenet' with the actual path to your ImageNet dataset.
imagenet_data = datasets.Imagenette(root='.', split='val', transform=transform)

# Since we want to load only 100 images, let's use a Subset
indices = torch.arange(100)
subset_data = torch.utils.data.Subset(imagenet_data, indices)

# DataLoader to load the images
data_loader = DataLoader(subset_data, batch_size=1, shuffle=False)

losses = torch.zeros(100)

# Now, you can iterate over data_loader to access your images.
for i, (images, labels) in enumerate(data_loader):
    # Here, images are your transformed images, and labels are the class labels
    image = images.to(device)
    decoded, _ = model(image)
    losses[i] = perception_loss(image, decoded)

In [44]:
statistics(losses, bottle_pdps)

x_b mean: tensor(0.0655) x_d mean: tensor(0.0856)
threshold: tensor(0.0703)
False positive rate: tensor(42) / 100
AUROC: tensor(0.7550, dtype=torch.float64)


In [45]:
statistics(losses, sunflower_pdps)

x_b mean: tensor(0.0655) x_d mean: tensor(0.1632)
threshold: tensor(0.1412)
False positive rate: tensor(2) / 100
AUROC: tensor(0.9933, dtype=torch.float64)


In [46]:
statistics(losses, style_pdps)

x_b mean: tensor(0.0655) x_d mean: tensor(0.2667)
threshold: tensor(0.2506)
False positive rate: tensor(0) / 100
AUROC: tensor(1., dtype=torch.float64)
