In [1]:
%load_ext autoreload
%autoreload 2

%load_ext dotenv
%dotenv

In [2]:
import json
import os
from collections import Counter, defaultdict
from pathlib import Path
from typing import List

In [3]:
import torch
import tqdm
from pydantic import parse_file_as

In [45]:
from arch.edenai_model import GoogleNSFWModel, ProviderResponse, ResponseItem, ResponseLabel, GoogleResponseLabel
from dataset import load_imagenet_nsfw_test_data

In [5]:
OUT_DIR = Path("nsfw_filters_results")
dl = load_imagenet_nsfw_test_data()
device = torch.device("cuda:0")

Using custom data configuration dedeswim--imagenet-nsfw-acd0b4b4851f04c2
Found cached dataset parquet (/data/huggingface/datasets/dedeswim___parquet/dedeswim--imagenet-nsfw-acd0b4b4851f04c2/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


In [6]:
def classify_dl(dl, model, out_file):
    overall_results = []
    for batch in tqdm.tqdm(dl):
        result = model.make_request(batch["image"].squeeze())
        overall_results.append(result)
    results_json = [result.json() for result in overall_results]
    with out_file.open("w") as f:
        json.dump(results_json, f)
    return overall_results

def restore_results(label_type, out_file):
    with out_file.open("r") as f:
        restored_json_results = json.load(f)
    return [
        ProviderResponse[label_type](**json.loads(restored_result)) for restored_result in restored_json_results
    ]

In [7]:
google_test_model = GoogleNSFWModel(device)
test_dl = list(map(lambda it: it[0], zip(dl, range(10))))
google_test_overall_results = classify_dl(test_dl, google_test_model, OUT_DIR / "google_test.json")
restored_google_test_overall_results = restore_results(GoogleResponseLabel, OUT_DIR / "google_test.json")
google_test_overall_results == restored_google_test_overall_results

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  4.62it/s]


True

In [8]:
# Running this cell costs ~1.5$
# google_model = GoogleNSFWModel(device, api_key=os.environ["EDENAI_API_KEY"])
# google_overall_results = classify_dl(dl, google_model, OUT_DIR / "google.json")

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [12:35<00:00,  1.32it/s]


In [59]:
def remove_status(provider_responses: List[ProviderResponse]) -> List[ResponseItem]:
    return [response.items for response in provider_responses]

def filter_label(items_list: List[List[ResponseItem]], label: ResponseLabel) -> List[ResponseItem]:
    return [item for items in items_list for item in items if item.label == label]

def filter_items_with_score(items: List[ResponseItem], score: int) -> List[ResponseItem]:
    return [item for item in items if item.likelihood == score]

google_overall_items = remove_status(google_overall_results)
racy_items = filter_label(google_overall_items, GoogleResponseLabel.Racy)
race_items_5 = filter_items_with_score(racy_items, 5)
len(race_items_5)

758

In [58]:
racy_items

[ResponseItem[GoogleResponseLabel](label=<GoogleResponseLabel.Racy: 'Racy'>, likelihood=5),
 ResponseItem[GoogleResponseLabel](label=<GoogleResponseLabel.Racy: 'Racy'>, likelihood=5),
 ResponseItem[GoogleResponseLabel](label=<GoogleResponseLabel.Racy: 'Racy'>, likelihood=5),
 ResponseItem[GoogleResponseLabel](label=<GoogleResponseLabel.Racy: 'Racy'>, likelihood=5),
 ResponseItem[GoogleResponseLabel](label=<GoogleResponseLabel.Racy: 'Racy'>, likelihood=5),
 ResponseItem[GoogleResponseLabel](label=<GoogleResponseLabel.Racy: 'Racy'>, likelihood=5),
 ResponseItem[GoogleResponseLabel](label=<GoogleResponseLabel.Racy: 'Racy'>, likelihood=2),
 ResponseItem[GoogleResponseLabel](label=<GoogleResponseLabel.Racy: 'Racy'>, likelihood=2),
 ResponseItem[GoogleResponseLabel](label=<GoogleResponseLabel.Racy: 'Racy'>, likelihood=5),
 ResponseItem[GoogleResponseLabel](label=<GoogleResponseLabel.Racy: 'Racy'>, likelihood=5),
 ResponseItem[GoogleResponseLabel](label=<GoogleResponseLabel.Racy: 'Racy'>, lik