# Scraper

In [None]:
!pip install gallery-dl tqdm -U
!apt update -yqq
!apt install aria2 -yqq

## Grab JSON files by tags

In [None]:
import os
import urllib.parse
import threading
from queue import Queue
from tqdm import tqdm

def scrape_tags_images(tags_name, queue, pbar):
    encoded_tags_name = urllib.parse.quote_plus(tags_name)
    url = f"https://danbooru.donmai.us/posts?tags={encoded_tags_name}+&z=5"
    dir_path = f"/workspace/train_data_json/animagine-xl-3.1/{encoded_tags_name}"

    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

    command = f"gallery-dl '{url}' -D '{dir_path}' --write-metadata --no-download --user-agent 'gdl/1.26.7' --username 'Isaacvincrad' --password 'u83VTMCYQRQdD1Q4xf3FFqDe'"

    os.system(command)
    queue.task_done()
    pbar.update(1)

def worker(queue, pbar):
    while True:
        tags_name = queue.get()
        if tags_name is None:
            break  # Exit signal
        scrape_tags_images(tags_name, queue, pbar)

def main():
    tags_names = []
        
    queue = Queue()
    threads = []
    num_worker_threads = 4

    with tqdm(total=len(tags_names), desc="Scraping Images") as pbar:
        for _ in range(num_worker_threads):
            t = threading.Thread(target=worker, args=(queue, pbar))
            t.start()
            threads.append(t)

        for tags_name in tags_names:
            queue.put(tags_name)

        queue.join()

        for _ in range(num_worker_threads):
            queue.put(None)
        for t in threads:
            t.join()

    print("Scraping complete.")

if __name__ == "__main__":
    main()


In [None]:
import os
import json
import requests
import threading
import shutil
import subprocess
import dateutil.parser
from queue import Queue
from tqdm import tqdm

total_files = 0
downloaded_files = 0
download_lock = threading.Lock()
user_agent = 'gdl/1.25.8'

meta_keywords_black_list = [
    "(medium)", "commentary",
    "bad", "translat", "request", 
    "mismatch", "revision", "audio",
    "video", 
]

special_tags = [
    "1girl", "2girls", "3girls", "4girls", "5girls", "6+girls", "multiple girls",
    "1boy", "2boys", "3boys", "4boys", "5boys", "6+boys", "multiple boys", "male focus",
    "1other", "2others", "3others", "4others", "5others", "6+others", "multiple others", "other focus",
]

score_percentile_full = {
    "g": {
        5: 0,
        10: 1,
        15: 2,
        20: 3,
        25: 3,
        30: 4,
        35: 5,
        40: 5,
        45: 6,
        50: 7,
        55: 8,
        60: 9,
        65: 11,
        70: 12,
        75: 14,
        80: 16,
        85: 19,
        90: 24,
        95: 33,
    },
    "s": {
        5: 0,
        10: 1,
        15: 2,
        20: 3,
        25: 4,
        30: 5,
        35: 6,
        40: 7,
        45: 8,
        50: 9,
        55: 11,
        60: 12,
        65: 15,
        70: 17,
        75: 20,
        80: 25,
        85: 31,
        90: 41,
        95: 62,
    },
    "q": {
        5: 2,
        10: 4,
        15: 5,
        20: 7,
        25: 9,
        30: 11,
        35: 14,
        40: 16,
        45: 19,
        50: 23,
        55: 26,
        60: 31,
        65: 36,
        70: 42,
        75: 49,
        80: 59,
        85: 73,
        90: 93,
        95: 134,
    },
    "e": {
        5: 2,
        10: 4,
        15: 7,
        20: 10,
        25: 13,
        30: 17,
        35: 20,
        40: 25,
        45: 29,
        50: 35,
        55: 41,
        60: 48,
        65: 56,
        70: 66,
        75: 78,
        80: 94,
        85: 115,
        90: 148,
        95: 211,
    },
}

def aria2_download(dir, filename, url):

    aria2_config = {
        "console-log-level"         : "error",
        "summary-interval"          : 10,
        "continue"                  : True,
        "max-connection-per-server" : 16,
        "min-split-size"            : "1M",
        "split"                     : 16,
        "dir"                       : str(dir),
        "out"                       : filename,
        "user-agent"                : user_agent,
        "_url"                      : url,
    }
    aria2_args = parse_args(aria2_config, aria=True)
    subprocess.Popen(["aria2c", *aria2_args])
    
def parse_args(config, aria=False):
    args = []

    for k, v in config.items():
        if k.startswith("_"):
            args.append(f"{v}")
        elif isinstance(v, str) and v is not None:
            if aria:
                args.append(f"--{k}={v}")
            else:
                args.append(f"--{k}='{v}'")
        elif isinstance(v, bool) and v:
            args.append(f"--{k}")
        elif isinstance(v, float) and not isinstance(v, bool):
            args.append(f"--{k}={v}")
        elif isinstance(v, int) and not isinstance(v, bool):
            args.append(f"--{k}={v}")

    return args
    
def download_image(folder, url, filename, extension):
    global downloaded_files
    full_path = f"{folder}/{filename}.{extension}"
    if os.path.exists(full_path):
        return False
    try:
        aria2_download(folder, f"{filename}.{extension}", url)
        
        if os.path.exists(full_path) and os.path.getsize(full_path) > 0:
            with download_lock:
                downloaded_files += 1
            return True
        else:
            return False
        
    except Exception as e:
        print(f"Error downloading {url}: {e}")
        return False
        
def write_metadata(folder, filename, tags):
    with open(f"{folder}/{filename}.txt", "w") as f:
        f.write(tags)

def filter_blacklisted_tags(tags, blacklist):
    tag_list = tags.split(", ")
    filtered_tags = [tag for tag in tag_list if not any(blacklist_keyword in tag for blacklist_keyword in blacklist)]
    return ", ".join(filtered_tags)
    
def generate_tags(data):
    def process_tags(tag_str):
        processed_tags = []
        for tag_name in tag_str.split(" "):
            if len(tag_name) > 3:
                tag_name = tag_name.replace("_", " ")
            processed_tags.append(tag_name)
        return ", ".join(processed_tags)

    created_at = data.get("media_asset", {}).get("created_at", "")
    year = 0
    
    try:
        parsed_date = dateutil.parser.isoparse(created_at)
        year = parsed_date.year
    except:
        pass
        
    if 2005 <= year <= 2010:
        year_tag = "oldest"
    elif year <= 2014:
        year_tag = "early"
    elif year <= 2017:
        year_tag = "mid"
    elif year <= 2020:
        year_tag = "recent"
    elif year <= 2024:
        year_tag = "newest"
    else:
        year_tag = ''

    rating = data.get("rating")

    score = data.get("score")


    tags_general = process_tags(data.get("tag_string_general", ""))
    tags_character = process_tags(data.get("tag_string_character", ""))
    tags_copyright = process_tags(data.get("tag_string_copyright", ""))
    tags_artist =  process_tags(data.get("tag_string_artist", ""))
    
    tags_meta_raw = data.get("tag_string_meta", "")
    tags_meta = filter_blacklisted_tags(tags_meta_raw, meta_keywords_black_list)
    tags_meta = process_tags(tags_meta)
    
    quality_tag = ""
    percentile = score_percentile_full[rating]
    
    if score > percentile[95]:
        quality_tag = "masterpiece"
    elif score > percentile[85]:
        quality_tag = "best quality"
    elif score > percentile[75]:
        quality_tag = "great quality"
    elif score > percentile[50]:
        quality_tag = "good quality"
    elif score > percentile[25]:
        quality_tag = "normal quality"
    elif score > percentile[10]:
        quality_tag = "low quality"
    else:
        quality_tag = "worst quality"
    
    if rating in "e":
        nsfw_tags = "explicit, nsfw"
    elif rating in "q":
        nsfw_tags = "nsfw"
    elif rating in "s":
        nsfw_tags = "sensitive"
    else:
        nsfw_tags = "safe"

    tags_general_list = tags_general.split(', ')

    found_special_tags = [tag for tag in tags_general_list if tag in special_tags]

    for tag in found_special_tags:
        tags_general_list.remove(tag)

    first_general_tag = ', '.join(found_special_tags)
    rest_general_tags = ', '.join(tags_general_list)

    tags_separator = "|||"
    
    pre_separator_tags = []
    post_separator_tags = []

    if first_general_tag:
        pre_separator_tags.append(first_general_tag)
    if tags_character:
        pre_separator_tags.append(tags_character)
    if tags_copyright:
        pre_separator_tags.append(tags_copyright)
    if tags_artist:
        pre_separator_tags.append(tags_artist)

    if nsfw_tags:
        post_separator_tags.append(nsfw_tags)
    if rest_general_tags:
        post_separator_tags.append(rest_general_tags)
    if year_tag:
        post_separator_tags.append(year_tag)
    if tags_meta:
        post_separator_tags.append(tags_meta)
    if quality_tag:
        post_separator_tags.append(quality_tag)

    pre_separator_str = ', '.join(pre_separator_tags)
    post_separator_str = ', '.join(post_separator_tags)

    caption = f"{pre_separator_str}, {tags_separator} {post_separator_str}"
    
    # print(caption)
    # print()
    return caption

def process_file(json_folder, json_file):
    with open(f"{json_folder}/{json_file}", "r") as f:
        data = json.load(f)

    extension = data.get("file_ext")
    rating_map = {'g': 'general', 's': 'sensitive', 'q': 'questionable', 'e': 'explicit'}
    rating = rating_map.get(data.get("rating"), "")
    file_url = data.get("file_url")

    if extension not in ["png", "jpg", "jpeg", "webp", "bmp"]:
        return

    tags = generate_tags(data)
    
    if download_image(rating, file_url, json_file.split(".")[0], extension):
        processed_folder = f"{json_folder}_processed"
        os.makedirs(processed_folder, exist_ok=True)
        shutil.move(f"{json_folder}/{json_file}", f"{processed_folder}/{json_file}")
        
    write_metadata(rating, json_file.split(".")[0], tags)

def worker(queue, pbar):
    while True:
        item = queue.get()
        if item is None:
            # Signal to terminate this worker thread
            queue.task_done()
            break
        json_file_path, json_folder = item  # Now safe to unpack
        process_file(json_folder, os.path.basename(json_file_path))
        queue.task_done()
        pbar.update(1)  

def main(json_folder):
    global total_files
    files_to_process = [(os.path.join(root, file), root) for root, dirs, files in os.walk(json_folder) for file in files if file.endswith('.json')]
    total_files = len(files_to_process)

    rating_folders = ['general', 'sensitive', 'questionable', 'explicit']
    for folder in rating_folders:
        if not os.path.exists(folder):
            os.mkdir(folder)

    queue = Queue()
    threads = []
    num_worker_threads = 8

    with tqdm(total=total_files, desc="Processing Files") as pbar:
        for _ in range(num_worker_threads):
            t = threading.Thread(target=worker, args=(queue, pbar))
            t.start()
            threads.append(t)

        for file_path, folder in files_to_process:
            queue.put((file_path, folder))

        queue.join()

        for _ in range(num_worker_threads):
            queue.put(None)
        for t in threads:
            t.join()

    print("Scraping complete.")

if __name__ == "__main__":
    project_path = "/workspace/train_data/animagine-xl-3.1"
    os.makedirs(project_path, exist_ok=True)
    os.chdir(project_path)
    main("/workspace/train_data_json/animagine-xl-3.1")

In [None]:
import os
import shutil

def remove_ipynb_checkpoints(root_dir=''):
    for dirpath, dirnames, filenames in os.walk(root_dir, topdown=False):  # Walk through the directory tree, from bottom to top
        for dirname in dirnames:
            if dirname == '__pycache__':
                full_path = os.path.join(dirpath, dirname)
                shutil.rmtree(full_path)  # Remove the directory and all its contents
                print(f"Removed: {full_path}")

# Call the function to start the removal process
remove_ipynb_checkpoints()