In [None]:
import os
from PIL import Image
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from pathlib import Path
import random
from models.vqgan import GumbelVQ
from NudeNet.nudenet import nudenet
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
from utils import *
from extraction import *

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}\n")
config_path = "checkpoints/config_openimages_gumbel.yaml"
ckpt_path = "checkpoints/check_openimages_gumbel.ckpt"

# Carica configurazione
config = OmegaConf.load(config_path)

# Crea il modello
model = GumbelVQ(**config['model']['params'])

# Carica pesi
checkpoint = torch.load(ckpt_path, map_location=device)
model.load_state_dict(checkpoint["state_dict"], strict=False)

model = model.to(device).eval()

print("\nVQGAN loaded successfully!")
detector = nudenet.NudeDetector()
print("NudeNet loaded successfully!")

Using device: cpu

Working with z of shape (1, 256, 32, 32) = 262144 dimensions.


  checkpoint = torch.load(ckpt_path, map_location=device)



VQGAN loaded successfully!
NudeNet loaded successfully!


# Retain Set on ImageNet

In [None]:
seed = 123

torch.manual_seed(seed)
random.seed(seed)

forget_size = 60000 # Slightly increased to be divisible by 1000
num_classes = 1000  # ImageNet-1k
samples_per_class = forget_size // num_classes

retain_paths = make_balanced_subset(
    root="/media/pinas/datasets/imagenet_zeus/train",
    samples_per_class=samples_per_class,
    seed=seed
)

retain_dataset_imagenet = ImageListDataset(retain_paths)

retain_loader_imagenet = DataLoader(
    retain_dataset_imagenet,
    batch_size=1,
    shuffle=False
)

print(f'Number of images in retain set: {len(retain_dataset_imagenet)}')
print(f'Number of images per class: {samples_per_class}')

Number of images in retain set: 60000
Number of images per class: 60


In [None]:
print(device)

index_codes_retain_imagenet = extract_codes(model, retain_loader_imagenet, device, structured=False)
torch.save(index_codes_retain_imagenet, "index_codes_retain_imagenet.pt")

print("\nFinished extracting and saving index codes for retain set!")

cuda

Finished extracting and saving index codes for retain set!


# Retain set using masked explicit images (only for breasts)

In [None]:
masked_explicit_dataset = ImageDataset("/media/pinas/datasets/nsfw_images_scraped/data/train/porn/")
masked_explicit_dataloader = DataLoader(masked_explicit_dataset, batch_size=1, shuffle=False)

In [None]:
target_classes = ['FEMALE_BREAST_EXPOSED']
index_codes_mask, _ = extract_codes_nudenet(model, detector, masked_explicit_dataloader, target_classes, expand=0.3, mask_breasts=True)
torch.save(index_codes_mask, "index_codes_masked_breasts.pt")


EXTRACTING CODES (Strategy: masked, Occlusion: True)...
  Processed 10 batches (4 items)...
  Processed 20 batches (10 items)...
  Processed 30 batches (13 items)...
  Processed 40 batches (17 items)...
  Processed 50 batches (22 items)...
  Processed 60 batches (30 items)...
  Processed 70 batches (34 items)...
  Processed 80 batches (42 items)...
  Processed 90 batches (45 items)...
  Processed 100 batches (48 items)...
  Processed 110 batches (50 items)...
  Processed 120 batches (55 items)...
  Processed 130 batches (59 items)...
  Processed 140 batches (61 items)...
  Processed 150 batches (65 items)...
  Processed 160 batches (70 items)...
  Processed 170 batches (72 items)...
  Processed 180 batches (78 items)...
  Processed 190 batches (83 items)...
  Processed 200 batches (88 items)...
  Processed 210 batches (90 items)...
  Processed 220 batches (94 items)...
  Processed 230 batches (95 items)...
  Processed 240 batches (97 items)...
  Processed 250 batches (100 items)...
  

# Retain set using masked explicit images (all explicit contents)

In [None]:
target_classes = ['BUTTOCKS_EXPOSED', 'FEMALE_BREAST_EXPOSED', 'FEMALE_GENITALIA_EXPOSED', 'ANUS_EXPOSED', 'MALE_GENITALIA_EXPOSED']
index_codes_mask_total, _ = extract_codes_nudenet(model, detector, masked_explicit_dataloader, target_classes, mask_breasts=True)
torch.save(index_codes_mask_total, "index_codes_masked_total.pt")


EXTRACTING CODES (Strategy: masked, Occlusion: True)...
  Processed 10 batches (7 items)...
  Processed 20 batches (14 items)...
  Processed 30 batches (21 items)...
  Processed 40 batches (30 items)...
  Processed 50 batches (39 items)...
  Processed 60 batches (49 items)...
  Processed 70 batches (54 items)...
  Processed 80 batches (62 items)...
  Processed 90 batches (67 items)...
  Processed 100 batches (72 items)...
  Processed 110 batches (80 items)...
  Processed 120 batches (88 items)...
  Processed 130 batches (92 items)...
  Processed 140 batches (98 items)...
  Processed 150 batches (105 items)...
  Processed 160 batches (112 items)...
  Processed 170 batches (119 items)...
  Processed 180 batches (126 items)...
  Processed 190 batches (136 items)...
  Processed 200 batches (142 items)...
  Processed 210 batches (146 items)...
  Processed 220 batches (153 items)...
  Processed 230 batches (159 items)...
  Processed 240 batches (165 items)...
  Processed 250 batches (172 it