In [1]:
#!pip install ftfy regex tqdm
#!pip install git+https://github.com/openai/CLIP.git
#!pip install --upgrade "nudenet>=3.4.2"

#https://github.com/ml-research/Q16
#https://github.com/notAI-tech/NudeNet

In [2]:
import numpy as np
import torch
import os
import PIL
import pickle
import clip
import pandas as pd

from tqdm import tqdm
from IPython.display import Image
from nudenet import NudeDetector

In [3]:
class ClipWrapper(torch.nn.Module):
    def __init__(self, device, model_name='ViT-L/14'):
        super(ClipWrapper, self).__init__()
        self.clip_model, self.preprocess = clip.load(
            model_name,
            device,
            jit=False
        )
        self.clip_model.eval()

    def forward(self, x):
        return self.clip_model.encode_image(x)


class SimClassifier(torch.nn.Module):
    def __init__(self, embeddings, device):
        super(SimClassifier, self).__init__()
        self.embeddings = torch.nn.parameter.Parameter(embeddings)

    def forward(self, x):
        embeddings_norm = self.embeddings / self.embeddings.norm(dim=-1,
                                                                 keepdim=True)
        # Pick the top 5 most similar labels for the image
        image_features_norm = x / x.norm(dim=-1, keepdim=True)

        similarity = (100.0 * image_features_norm @ embeddings_norm.T)
        # values, indices = similarity[0].topk(5)
        return similarity.squeeze()

def initialize_prompts(clip_model, text_prompts, device):
    text = clip.tokenize(text_prompts).to(device)
    return clip_model.encode_text(text)


def save_prompts(classifier, save_path):
    prompts = classifier.embeddings.detach().cpu().numpy()
    pickle.dump(prompts, open(save_path, 'wb'))


def load_prompts(file_path, device):
    return torch.HalfTensor(pickle.load(open(file_path, 'rb'))).to(device)

def compute_embeddings(image_paths):
    images = [clip.preprocess(PIL.Image.open(image_path)) for image_path in image_paths]
    images = torch.stack(images).to(device)
    return clip(images).half()

def classify_images_batches(image_files, batch_size=30):
    results = []
    detector = NudeDetector()
    for i in tqdm(range(0, len(image_files), batch_size), desc="Processing batches"):
        batch = image_files[i:i + batch_size]  # Get the current batch
        batch_embeddings = compute_embeddings(batch)  # Process the batch
        batch_embeddings = batch_embeddings.to(device)
        nudes = detector.detect_batch(batch)
        y = classifier(batch_embeddings)
        y = torch.argmax(y, dim=1)  # Get the predicted labels
        for file, q16, nude in zip(batch, y.tolist(), nudes):
            results.append({'file': file, 'q16': q16, 'nude': nude})
        torch.cuda.empty_cache()
    df = pd.DataFrame(results)
    return df

In [6]:
device='cuda'
prompt_path = '../data/q16/prompts.p'
trained_prompts = load_prompts(
    prompt_path, device=device
)

In [7]:
clip = ClipWrapper(device)
print('initialized clip model')

initialized clip model


In [8]:
classifier = SimClassifier(trained_prompts, device)
print('initialized classifier')

initialized classifier


In [9]:
image_dir = "../data/images/"
image_files = [
    os.path.join(image_dir, file) 
    for file in os.listdir(image_dir) if file.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif'))
]

In [None]:
df = classify_images_batches(
    image_files
)

[1;31m2024-12-04 17:43:13.734993195 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 3560405, index: 3, mask: {4, }, error code: 22 error msg: Invalid argument. Specify the number of threads explicitly so the affinity is not set.[m
[1;31m2024-12-04 17:43:13.735023281 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 3560404, index: 2, mask: {3, }, error code: 22 error msg: Invalid argument. Specify the number of threads explicitly so the affinity is not set.[m
[1;31m2024-12-04 17:43:13.734955883 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 3560403, index: 1, mask: {2, }, error code: 22 error msg: Invalid argument. Specify the number of threads explicitly so the affinity is not set.[m
[1;31m2024-12-04 17:43:13.734965640 [E:onnxruntime:Default, env.cc:234 ThreadMain] pthread_setaffinity_np failed for thread: 3560402, index: 0, mask: {1, }, error code: 22 error

In [None]:
df