In [4]:
import torch
import torchvision
from torch.utils.data import Dataset
from PIL import Image
import os
import requests
import json 
from tqdm import tqdm
import time



In [5]:
#TOTAL_POSTS = 7400200 # Total number of posts on Danbooru
IMG_PER_BATCH = 200 # Read limit from Danbooru
TAGS = ['action', 'looking_at_another', 'speech_bubble', 'romance', 'sad', 'crying', 'angry', 'scared', 'surprised', 'fighting', 'chase', 'talking'] # Tags for our categories
#STOP_ID = 7400200

# Creates directory
def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

# Returns the tags from TAGS that are in the tag_string of a post; returns False if none of our TAGS are in tag_string
def check_tags_in_tag_string(tag_string, tags):
    tag_list = []
    for tag in tags:
        if tag in tag_string:
            tag_list.append(tag)
    if len(tag_list) == 0:
        return False
    else:
        return tag_list
    
# Copied from custom_hymenoptera_dataset.py
# Checks for valid image files
def is_valid_image_file(filename):
  # Check file name extension
  valid_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff']
  if os.path.splitext(filename)[1].lower() not in valid_extensions:
    print(f"Invalid image file extension \"{filename}\". Skipping this file...")
  # Verify that image file is intact
  try:
    with Image.open(filename) as img:
      img.verify()  # Verify if it's an image
      return True
  except (IOError, SyntaxError) as e:
    print(f"Invalid image file {filename}: {e}")
    return False

# Create dir for our comic images
images_dir = 'comic_images'
create_dir(images_dir)

# Dict for image file name and list of tags
image_label_dict = {}

B = IMG_PER_BATCH

## TEMP ##
TOTAL_POSTS = 3727400

progress_bar = tqdm(total=TOTAL_POSTS)

# Loops over all post id's to download images from posts that are tagged "comic" and contain at least one of our TAGS
# Also creates dict of image file name and associated tags
while B <= TOTAL_POSTS:
    url = f'https://danbooru.donmai.us/posts.json?page=b{B}&page=a{B-IMG_PER_BATCH}&limit={IMG_PER_BATCH}'
    response_pages = requests.get(url)

    response_pages_json = response_pages.json()

    for page in response_pages_json:
        if 'file_url' in page and 'tag_string' in page:
            tag_string = page['tag_string']
            id = page['id']

            scene_tags = check_tags_in_tag_string(tag_string, TAGS)
            if 'comic' in page['tag_string'] and scene_tags:
                file_url = page['file_url']

                image_path = f'{id}.jpg'
        
                if not os.path.exists(os.path.join(images_dir, image_path)):
                    response_img = requests.get(file_url)

                    # If post contains relevant tags and has valid image file, save the image with id as name
                    if response_img.status_code == 200:
                        
                        with open(os.path.join(images_dir, image_path), 'wb') as file:
                            file.write(response_img.content)
                        
                # Write values to dict
                image_label_dict[image_path] = scene_tags

    B += IMG_PER_BATCH
    progress_bar.update(200)

progress_bar.close()

print(image_label_dict)

  0%|          | 6600/3727400 [00:34<5:20:23, 193.56it/s]
100%|██████████| 3727400/3727400 [2:25:33<00:00, 426.79it/s]  

{'142.jpg': ['speech_bubble'], '137.jpg': ['speech_bubble'], '136.jpg': ['speech_bubble'], '135.jpg': ['speech_bubble'], '134.jpg': ['speech_bubble', 'sad'], '133.jpg': ['speech_bubble', 'surprised'], '132.jpg': ['speech_bubble'], '131.jpg': ['speech_bubble'], '130.jpg': ['speech_bubble'], '100.jpg': ['speech_bubble'], '76.jpg': ['speech_bubble'], '49.jpg': ['speech_bubble'], '377.jpg': ['speech_bubble'], '292.jpg': ['speech_bubble'], '266.jpg': ['speech_bubble'], '264.jpg': ['speech_bubble'], '263.jpg': ['speech_bubble'], '262.jpg': ['speech_bubble'], '252.jpg': ['speech_bubble'], '251.jpg': ['speech_bubble'], '247.jpg': ['speech_bubble'], '246.jpg': ['speech_bubble'], '244.jpg': ['speech_bubble'], '243.jpg': ['speech_bubble'], '628.jpg': ['surprised'], '626.jpg': ['speech_bubble'], '625.jpg': ['speech_bubble', 'surprised'], '624.jpg': ['speech_bubble'], '623.jpg': ['speech_bubble'], '622.jpg': ['speech_bubble'], '621.jpg': ['speech_bubble', 'surprised'], '620.jpg': ['speech_bubble', 




In [6]:
# Save dict to .json file
with open('comic_labels.json', 'w') as f: 
     json.dump(image_label_dict, f)

In [None]:
class ComicDataset(Dataset):
    def __init__(self, images_dir, transform=None, target_transform=None):
        pass