# Download images
---
### This script downloads images from https://battlecatsinfo.github.io saving to [cat](data/cat) an [background](data/background) respectively.

In [None]:
import os
import logging
from urllib.error import HTTPError
from concurrent.futures import ThreadPoolExecutor
import requests
from furl import furl
from itertools import count

cats_path = 'data\\cat'
bg_path = 'data\\background'
base_url = furl('https://battlecatsinfo.github.io')

logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [None]:
from itertools import takewhile


def find_max_index(url: str):
    def key(index: int):
        f = url.format(id=index)
        res = requests.head(f)
        logger.info(f'{f} returned {res.status_code}')
        return res.status_code == 200

    p = list(takewhile(key, (2 ** i for i in count())))
    lo = p[-1]
    hi = 2 * lo
    while lo < hi:
        mid = (lo + hi) // 2
        if key(mid):
            lo = mid + 1
        else:
            hi = mid
    return lo - 1


In [None]:
# Download cat images

_cats_url = f'{base_url.url}/img/u/{{id:03}}/{{fcs}}/uni{{id:03}}_{{fcs}}00.png'


def download_cats(start: int = 0, force_download=False):
    if not os.path.exists(cats_path):
        os.mkdir(cats_path)
    elif not force_download and any((file.endswith('.png') for file in os.listdir(cats_path))):
        return

    def download(cat_id: int):
        first = _cats_url.format(id=cat_id, fcs='f')
        if requests.head(first).status_code != 200:
            logging.warning(f'failed to download cat[{cat_id}]')
            return
        try:
            for fcs in ['f', 'c', 's']:
                src = _cats_url.format(id=cat_id, fcs=fcs)
                logging.debug(f'Downloading {src}')
                output_path = f'{cats_path}/{cat_id:003}_{fcs}.png'
                image = requests.get(src)
                with open(output_path, "wb") as f:
                    f.write(image.content)
        except HTTPError:
            pass
        logging.info(f'downloaded cat[{cat_id}]')

    if not os.path.exists(cats_path):
        os.mkdir(cats_path)
    elif not force_download and any((file.endswith('.png') for file in os.listdir(cats_path))):
        return

    f_url = f'{base_url.url}/img/u/{{id:03}}/f/uni{{id:03}}_f00.png'
    max_index = find_max_index(f_url)
    logging.info(f'maximum index is {max_index}')

    with ThreadPoolExecutor(max_workers=10) as executor:
        for i in range(start, max_index + 1):
            executor.submit(download, i)


download_cats()

In [None]:
# download backgrounds

def download_backgrounds(start=0, force_download=False):
    if not os.path.exists(bg_path):
        os.mkdir(bg_path)
    elif not force_download and any((file.endswith('.png') for file in os.listdir(bg_path))):
        return

    bg_url = f'{base_url}/img/bg/bg{{id:03}}.png'
    max_index = find_max_index(bg_url)
    logging.info(f'maximum index is {max_index}')

    def download(bg_id: int):
        src = base_url.copy().add(path=f'img/bg/bg{bg_id:03}.png').url
        output_path = f'{bg_path}/bg{bg_id:003}.png'
        image = requests.get(src)
        if image.status_code == 200:
            with open(output_path, "wb") as f:
                f.write(image.content)
            logging.info(f'downloaded bg[{bg_id}]')
        else:
            logging.warning(f'failed to download bg[{bg_id}]')

    with ThreadPoolExecutor(max_workers=10) as executor:
        for i in range(start, max_index + 1):
            executor.submit(download, i)


download_backgrounds()

In [None]:
import cv2
import matplotlib.pyplot as plt
import os
import numpy as np
import time

from random import sample,random, choice
from tqdm import tqdm

In [None]:
cats = [cat for cat in os.scandir(cats_path) if cat.name.endswith('.png')]
backgrounds = [bg for bg in os.scandir(bg_path) if bg.name.endswith('.png')]

def get_img_and_mask(img_path: str):
    img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    mask = cv2.imread('data/cat_mask.png')
    mask = (mask[:, :, 0] == 0).astype(np.uint8)
    return img, mask


def get_background(bg_path: str):
    bg = cv2.imread(bg_path, cv2.IMREAD_UNCHANGED)
    bg = cv2.cvtColor(bg, cv2.COLOR_BGR2RGB)
    return bg[:, :770, :]


ow, oh = 720, 360


def resize_bg(bg: np.ndarray):
    sc = 0.5 + random() / 2
    cy, cx = bg.shape[0] // 2, bg.shape[1] // 2
    half_w = int(cx * sc)
    half_h = int(half_w / ow * oh)
    bg = cv2.resize(bg[cy - half_h:cy + half_h, cx - half_w:cx + half_w, :], (ow, oh))
    return bg

In [None]:
# Let's look at a random object and its binary mask

img_path = cats[20].path
bg_path = backgrounds[20].path
img, mask = get_img_and_mask(img_path)
bg = resize_bg(get_background(bg_path))
print("Image file:", img_path)
print()
print("Shape of the image of the object:", img.shape)
print("Shape of the mask of the object:", mask.shape)
print("Shape of the background:", bg.shape)

plt.imshow(bg)
fig, ax = plt.subplots(1, 2, figsize=(16, 7))
ax[0].imshow(img)
ax[0].set_title('Object', fontsize=18)
ax[1].imshow(mask)
ax[1].set_title('Mask', fontsize=18)
plt.show()


In [None]:
def add_img_to_background(im: np.ndarray, mask: np.ndarray, bg: np.ndarray, x: int, y: int):
    bg = bg.copy()
    bg_h, bg_w = bg.shape[0], bg.shape[1]
    im_h, im_w = im.shape[0], im.shape[1]
    b_sx, b_sy = max(0, x), max(0, y)
    b_ex, b_ey = min(bg_w, x + im_w), min(bg_h, y + im_h)
    i_sx, i_sy = x - b_sx, y - b_sy
    i_ex, i_ey = b_ex - x, b_ey - y

    mask_boolean = mask == 0
    mask_rgb_boolean = np.stack([mask_boolean, mask_boolean, mask_boolean], axis=2)
    bg[b_sy:b_ey, b_sx:b_ex, :] *= mask_rgb_boolean[i_sy:i_ey, i_sx:i_ex, :]
    bg[b_sy:b_ey, b_sx:b_ex, :] += (im * ~mask_rgb_boolean)[i_sy:i_ey, i_sx:i_ex, :]
    return bg


In [None]:
def gen_data():
    cat_id = sample(range(len(cats)), 10)
    cat_id = np.reshape(cat_id, (5, 2))
    bg_path = choice(backgrounds).path
    bg = get_background(bg_path)
    im_w = 128
    sep_x = 3
    x = (bg.shape[1] - im_w * 5 - sep_x * 4) // 2
    y = 200
    dx = im_w + sep_x
    dy = 100
    cat_loc = {}
    for loc, cid in np.ndenumerate(cat_id):
        cat = cats[cid]
        loc = np.array(loc)
        loc = (x, y) + loc * (dx, dy)
        im, imm = get_img_and_mask(cat.path)
        bg = add_img_to_background(im, imm, bg, *loc)
        mid = loc + np.array([dx, dy]) // 2
        name = os.path.splitext(cat.name)[0]
        cat_loc[cid] = mid

    return bg, cat_loc

In [None]:
mix, cat_loc = gen_data()
print(cat_loc)
print(mix.shape)
plt.imshow(mix)
plt.show()

In [None]:
def create_yolo_annotations(comp_img, cat_loc: dict[str, np.array]):
    w = 128 - 9 * 2
    h = 128 - 21 * 2
    comp_h, comp_w = comp_img.shape[:2]
    annotations = []
    for cat, loc in cat_loc.items():
        loc = loc.astype(np.float64)
        loc /= (comp_w, comp_h)
        annotations.append([cat,
                            round(loc[0], 5),
                            round(loc[1], 5),
                            round(w / comp_w, 5),
                            round(h / comp_h, 5)])
    return annotations

In [None]:
def generate_dataset(number, folder, split='train'):
    time_start = time.time()
    for j in tqdm(range(number)):
        mix, cat_loc = gen_data()
        mix = cv2.cvtColor(mix, cv2.COLOR_RGB2BGR)
        img_path = os.path.join(folder, split, 'images', f'{j}.jpg')
        label_path = os.path.join(folder, split, 'labels', f'{j}.txt')
        cv2.imwrite(img_path, mix)
        annotations = create_yolo_annotations(mix, cat_loc)

        with open(label_path, "w") as f:
            for annotation in annotations:
                f.write(' '.join(map(str, annotation)) + '\n')

    time_end = time.time()
    time_total = round(time_end - time_start)
    time_per_img = round((time_end - time_start) / number, 1)

    print(
        f"Generation of {number} synthetic images is completed. It took {time_total} seconds, or {time_per_img} seconds per image")
    print(f"Images are stored in '{os.path.join(folder, split, 'images')}'")
    print(f"Annotations are stored in '{os.path.join(folder, split, 'labels')}'")

In [None]:
data_set_folder = 'datasets'
if not os.path.exists(data_set_folder):
    os.mkdir(data_set_folder)
train_folder = os.path.join(data_set_folder, 'train')
valid_folder = os.path.join(data_set_folder, 'valid')
for folder in [train_folder, valid_folder]:
    if not os.path.exists(folder):
        os.makedirs(folder)
    for child in ['images', 'labels']:
        if not os.path.exists(os.path.join(folder, child)):
            os.makedirs(os.path.join(folder, child))

generate_dataset(1000, folder=data_set_folder, split='train')
generate_dataset(200, folder=data_set_folder, split='valid')

Create YAML File

In [None]:
text = f"""
path: ''
train: 'train/images'
val: 'valid/images'

# class names
names: 
"""
names = [f"    {i}:'{f.name.split('.')[0]}'\n" for i, f in enumerate(cats)]

with open(os.path.join('cats_data.yaml'), 'w') as f:
    f.write(text)
    f.writelines(names)

Training

In [None]:
from ultralytics import YOLO

# Load the model.
model = YOLO('yolov8n.pt')

# Training.
results = model.train(
    data='cats_data.yaml',
    imgsz=720,
    epochs=10,
    batch=8,
    name='yolov8n_custom')