In [None]:
import open_clip
import torch.utils.data
import torchvision
import tqdm


In [None]:
import dotenv
import os
import sys
import pathlib

# Load environment variables
dotenv.load_dotenv()

# Enable loading of the project module
MODULE_DIR = os.path.join(os.path.abspath(os.path.join(os.path.curdir, os.path.pardir)), 'src')
sys.path.append(MODULE_DIR)


In [None]:
import data

In [None]:
model, _, preprocess = open_clip.create_model_and_transforms(
    "ViT-SO400M-14-SigLIP-384",
    pretrained='webli',
    device='cuda',
)

In [None]:
data_generator = data.DatasetGenerator(
    num_shadow=64,
    num_canaries=500,
    canary_type=data.CanaryType.LABEL_NOISE,
    num_poison=0,
    poison_type=data.PoisonType.CANARY_DUPLICATES,
    data_dir=pathlib.Path(os.environ.get("DATA_ROOT")),
    seed=0,
    download=False,
)

# shadow model index only matters for membership, hence can use any
full_data, membership_mask_any, canary_mask, poison_mask = data_generator.build_train_data_full_with_poison(shadow_model_idx=0)
canary_indices = data_generator.get_canary_indices()

In [None]:
embeddings = torch.zeros((len(full_data), 1152), dtype=torch.float32)
assert not poison_mask.any()
for sample_idx in tqdm.notebook.trange(len(full_data), unit="image", desc="Encoding dataset"):
    image = torchvision.transforms.functional.to_pil_image(full_data[sample_idx][0])
    with torch.no_grad():
        embedding = model.encode_image(preprocess(image).unsqueeze(0).to("cuda")).cpu()
    embeddings[sample_idx] = embedding
torch.save(embeddings, "embeddings.pt")
