#### scraping dataset from web

In [4]:
import os
import time
import requests
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.common.keys import Keys


def fetch_image_urls(query, num_images):
    chrome_options = Options()
    chrome_options.add_argument("--headless")
    chrome_options.add_argument("--disable-gpu")
    chrome_options.add_argument("--no-sandbox")
    chrome_options.add_argument("--disable-dev-shm-usage")
    print("getting chrome web driver...")
    driver = webdriver.Chrome(options=chrome_options)

    print("getting search results...")
    search_url = f"https://www.google.com/search?tbm=isch&q={query}"
    driver.get(search_url)

    image_urls = set()
    image_count = 0
    results_start = 0

    print("starting loop...")
    while image_count < num_images:
        thumbnail_results = driver.find_elements(
            By.CSS_SELECTOR, "img.YQ4gaf"
        )  # IMPORTANT: Class name for img tag may change in future
        number_results = len(thumbnail_results)

        for img in thumbnail_results[results_start:number_results]:
            try:
                img.click()
                time.sleep(1)
            except Exception as e:
                # print("skipping...")
                continue

            actual_images = driver.find_elements(
                By.CSS_SELECTOR, "img.sFlh5c"
            )  # IMPORTANT: Class name for img tag may change in future
            for actual_image in actual_images:
                if actual_image.get_attribute("src") and "http" in actual_image.get_attribute("src"):
                    image_urls.add(actual_image.get_attribute("src"))
                    print(f"url found! {len(image_urls)} urls")
                else:
                    print("url not found")

            image_count = len(image_urls)
            if len(image_urls) > num_images:
                print(f"Found: {len(image_urls)} image links, done!")
                break
        else:
            print("Found:", len(image_urls), "looking for more ...")
            driver.find_element(By.TAG_NAME, "body").send_keys(Keys.END)
            time.sleep(5)

        results_start = len(thumbnail_results)

    driver.quit()
    return list(image_urls)


def download_images(img_urls, download_path):
    if not os.path.exists(download_path):
        os.makedirs(download_path)

    for i, img_url in enumerate(img_urls):
        try:
            img_data = requests.get(img_url).content
            with open(os.path.join(download_path, f"image_{i + 1}.jpg"), "wb") as handler:
                handler.write(img_data)
        except Exception as e:
            print(f"Could not download {img_url}: {e}")


if __name__ == "__main__":
    num_images = 120
    prefix = "one piece anime png"
    characters = ["luffy", "zoro", "sanji", "nami", "robin", "shanks", "usopp", "chopper"]
    # characters = ["shanks", "robin"]
    download_path = "./images"

    img_urls = []
    for char in characters:
        query = prefix + " " + char
        img_urls.extend(fetch_image_urls(query, num_images))
        print(f"url search for {char} complete! {len(img_urls)} urls total")
    # print(img_urls)
    if img_urls:
        print("Downloading images from url...")
        download_images(img_urls, download_path)
        print(f"Downloaded {len(img_urls)} images to {download_path}")
    else:
        print("No images found.")

getting chrome web driver...


SessionNotCreatedException: Message: session not created: Chrome failed to start: exited normally.
  (session not created: DevToolsActivePort file doesn't exist)
  (The process started from chrome location /home/choi/.cache/selenium/chrome/linux64/127.0.6533.72/chrome is no longer running, so ChromeDriver is assuming that Chrome has crashed.)
Stacktrace:
#0 0x55bbe2bb16ba <unknown>
#1 0x55bbe2881730 <unknown>
#2 0x55bbe28b9615 <unknown>
#3 0x55bbe28b5488 <unknown>
#4 0x55bbe28ffe88 <unknown>
#5 0x55bbe28f37f3 <unknown>
#6 0x55bbe28c3ec9 <unknown>
#7 0x55bbe28c491e <unknown>
#8 0x55bbe2b779eb <unknown>
#9 0x55bbe2b7b972 <unknown>
#10 0x55bbe2b64e15 <unknown>
#11 0x55bbe2b7c502 <unknown>
#12 0x55bbe2b49d2f <unknown>
#13 0x55bbe2ba0578 <unknown>
#14 0x55bbe2ba0750 <unknown>
#15 0x55bbe2bb048c <unknown>
#16 0x7fc7e8c0bac3 <unknown>


In [None]:
from datasets import load_dataset

dataset = load_dataset("imagefolder", data_dir="./images")
dataset.push_to_hub("picasso")

In [None]:
for item in dataset["train"]:
    print(item["image"])

#### Loading dataset from hf

In [2]:
from datasets import load_dataset

# dataset = load_dataset("kusnim1121/picasso")
dataset = load_dataset("kusnim1121/one-piece-shanks-robin")

  from .autonotebook import tqdm as notebook_tqdm
Downloading readme: 100%|██████████| 280/280 [00:00<00:00, 829kB/s]
Downloading data: 100%|██████████| 40.1M/40.1M [00:03<00:00, 11.9MB/s]
Generating train split: 100%|██████████| 593/593 [00:00<00:00, 3047.79 examples/s]


In [3]:
dataset

DatasetDict({
    train: Dataset({
        features: ['image'],
        num_rows: 593
    })
})

In [4]:
dataset["train"]

Dataset({
    features: ['image'],
    num_rows: 593
})

In [5]:
dataset.save_to_disk("./images")

Saving the dataset (1/1 shards): 100%|██████████| 365/365 [00:00<00:00, 2343.98 examples/s]


In [5]:
for sample in dataset["train"]:
    print(sample)
    break

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=299x168 at 0x7F80A5586780>}


#### Filter dataset with CLIP

In [1]:
import torch
from transformers import CLIPProcessor, CLIPModel
import datasets
from datasets import load_dataset
from tqdm.auto import trange

device = "cuda" if torch.cuda.is_available() else "cpu"
dataset = load_dataset("kusnim1121/one-piece")["train"]

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset

In [3]:
# IMPORTANT: cannot process all images at once due to limited resource.
# Split dataset into batches
batch_size = 128
num_batches = (len(dataset) + batch_size - 1) // batch_size

# Prepare to collect scores
all_scores = []
prompt = "an anime character from the anime one piece"
model.to(device)

# Process each batch
for i in trange(num_batches):
    start_idx = i * batch_size
    end_idx = min((i + 1) * batch_size, len(dataset))

    batch_images = dataset["image"][start_idx:end_idx]

    inputs = processor(text=[prompt], images=batch_images, return_tensors="pt", padding=True).to(device)

    # Get the logits
    outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
    scores = logits_per_image.squeeze().detach().cpu()

    all_scores.append(scores)

# Concatenate all scores
all_scores = torch.cat(all_scores)

# Get indices of top 80% images
indices = torch.topk(all_scores, int(len(dataset) * 0.8))[1]

# Filter datasets based on indices
filtered_dataset = dataset[indices]["image"]
discarded_dataset = [item for index, item in enumerate(dataset) if index not in indices]

In [7]:
# Example of discarded image
import matplotlib.pyplot as plt
import numpy as np

fig = plt.figure(figsize=(8, 8))
columns = 4
rows = 4
for i in range(1, columns * rows + 1):
    fig.add_subplot(rows, columns, i)
    plt.imshow(discarded_dataset[i]["image"])
    plt.axis("off")
plt.suptitle("Discarded images")
plt.show()

100%|██████████| 8/8 [00:50<00:00,  6.27s/it]


#### Generate captions with BLIP

In [20]:
from transformers import BlipProcessor, BlipForConditionalGeneration

# Load the BLIP processor and model
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)

In [None]:
from tqdm.auto import trange

# IMPORTANT: similarly, not enough resource to compute all captions at once.
postfix = ", in one piece style"

# Split dataset into batches
batch_size = 128
num_batches = (len(filtered_dataset) + batch_size - 1) // batch_size

# Prepare to collect scores
all_captions = []

# Process each batch
for i in trange(num_batches):
    start_idx = i * batch_size
    end_idx = min((i + 1) * batch_size, len(dataset))

    batch_images = filtered_dataset[start_idx:end_idx]
    inputs = processor(batch_images, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model.generate(**inputs)

    captions = processor.batch_decode(outputs, skip_special_tokens=True)

    all_captions.extend(captions)

# Add postfix to caption
captions = [caption + postfix for caption in all_captions]

In [15]:
train_dataset = datasets.Dataset.from_dict({"image": filtered_dataset, "caption": captions})
datasets.DatasetDict({"train": train_dataset}).push_to_hub("kusnim1121/filtered-one-piece-with-caption")