In [309]:
import os, json, cv2, pickle, xmltodict, shutil, yaml, random
from pathlib import Path
import numpy as np
from chitra.image import Chitra
from matplotlib import pyplot as plt
from PIL import Image
from tqdm import tqdm
from copy import deepcopy
import pandas as pd

ROOT = "root_3"

# UAV wildfire dataset

## Для составления датасета были использованы следующие источники:

1. [Aerial Imagery dataset for fire detection: classification and segmentation using Unmanned Aerial Vehicle (UAV)](https://github.com/AlirezaShamsoshoara/Fire-Detection-UAV-Aerial-Image-Classification-Segmentation-UnmannedAerialVehicle)
2. [UAV Fire Detection](https://github.com/andre3racks/UAV-fire-detection)
3. [Forest Fire Detection through UAV imagery using CNNs](https://github.com/LeadingIndiaAI/Forest-Fire-Detection-through-UAV-imagery-using-CNNs)
4. [UAV Thermal Imaginary - Fire Dataset](https://www.kaggle.com/datasets/adiyeceran/uav-thermal-imaginary-fire-dataset)

Различные датасеты были предназначены для различных задач: классификации, локализации, сегментации.<br>
В датасете содержатся снимки полученные как с оптических, так и с тепловизионных камер.<br>
В этом ноутбуке показано каким образом собирался общий датасет.<br>
<p>Cтруктура датасета:</p>

`./uav_wildfire_dataset/`+<br>
&ensp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;|---`images/`+<br>
&ensp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;|&nbsp;&ensp;&emsp;&emsp;&emsp;&emsp;&emsp;|---`<hex_id>.jpg`<br>
&ensp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;|&nbsp;&ensp;&emsp;&emsp;&emsp;&emsp;&emsp;|---`...`<br>
&ensp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;|<br>
&ensp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;|<br>
&ensp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;|---`masks/`+<br>
&ensp;&ensp;&ensp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;|&ensp;&ensp;&nbsp;&emsp;&emsp;&emsp;&emsp;|---`<hex_id>.jpg`<br>
&ensp;&ensp;&ensp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;|&ensp;&ensp;&nbsp;&emsp;&emsp;&emsp;&emsp;|---`...`<br>
&ensp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;|<br>
&ensp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;|<br>
&ensp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;|---`bboxes/`+<br>
&ensp;&ensp;&ensp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;|&nbsp;&ensp;&emsp;&emsp;&emsp;&emsp;&emsp;|---`<hex_id>.csv`<br>
&ensp;&ensp;&ensp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;|&nbsp;&ensp;&emsp;&emsp;&emsp;&emsp;&emsp;|---`...`<br>
&ensp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;|<br>
&ensp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;|<br>
&ensp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;|---`labels.csv`<br>
&ensp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;|<br>
&ensp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;|<br>
&ensp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;&emsp;+---`idx.json`<br>

В папке `images` содержатся все изображения размера 640 x 640 со своими hex_id.<br>
В папке `masks` содержатся маски, соответствующие hex_id изображения.<br>
В папке `bboxes` содержатся координаты всех прямоугольников. ID описания соответствует hex_id изображения.<br>
Файл `labels.csv` содержит две колонки: hex_id и label (1 - есть огонь, 0 - нет огня)<br>
Файл `idx.json` содержит данные, необходимые для построения датасета под нужную задачу.

In [310]:
class ImageDataset:
    def __init__(self, root_posix_path: str, ext: tuple=(640, 640), restart=True, idx_b=None):
        self.root_path = Path(root_posix_path)
        self.root_path.mkdir(exist_ok=True, parents=True)
        for sub_path in ["images", "masks", "bboxes"]:
            path_ = self.root_path / sub_path
            path_.mkdir(exist_ok=True, parents=True)
        self.idx_keys = "dataset_name", "label", "bboxes", "mask", "thermal"
        self.dataset_name, self.label, self.bboxes, self.mask, self.thermal = [None] * 5
        self.ext = 640, 640
        self.idx = {}
        if idx_b:
            self.idx = idx_b
            self.current_id = int(list(idx_b.keys())[-1], 16) + 1
            self.hex_id = self.generate_next_id()
        if restart:
            self.current_id = 0
            self.hex_id = "0"
            for file in ["labels.csv", "idx.json"]:
                file_path = self.root_path / Path(file)
                open(file_path, mode="w").close()
        
    def generate_next_id(self):
        h = hex(self.current_id)[2:]
        while len(h) < 8:
            h = "0" + h
        self.current_id += 1
        return h
    
    def import_images(self,
                      export_dir: str,
                      dataset_name: str,
                      thermal: bool,
                      get_label_function=None,
                      get_bboxes_function=None,
                      get_mask_function=None):
        
        self.thermal = thermal
        self.mask = bool(get_mask_function or None)
        self.bboxes = bool(get_bboxes_function or None)
        self.label = bool(get_label_function or None)
        self.dataset_name = dataset_name
        
        label = None
        for image_name in tqdm(os.listdir(export_dir)):
            resized = False
            image_path = Path(export_dir) / Path(image_name)
            self.hex_id, self.exp_path = self.generate_next_id(), image_path.as_posix()
            
            if get_label_function:
                label = get_label_function(image_path)
                with open(self.root_path / Path("labels.csv"), "a") as labels:
                    labels.write(f"{self.hex_id};{str(label)};\n")
            if get_bboxes_function:
                get_bboxes_ = get_bboxes_function(self.exp_path)
                if get_bboxes_:
                    labels_ = [[label or "1"]] * len(get_bboxes_)
                    image, bboxes = Chitra(self.exp_path,
                                           get_bboxes_,
                                           labels_).resize_image_with_bbox(self.ext)
                    resized = True
                    with open(self.root_path / Path("bboxes") / Path(".".join([self.hex_id, "csv"])), "a") as bb_csv:
                        for bbox in bboxes:
                            bb_csv.write("\t".join([label[0] or str(0) if type(label) is list else label or str(0)] +  [str(int(i_)) for i_ in bbox.coords.flatten()]))
                            bb_csv.write("\n")

            if get_mask_function:
                mask = get_mask_function(self.exp_path)
                mask_arr = np.asarray(Image.open(mask))
                mask_arr[mask_arr > 0] = 255
                mask = Image.fromarray(mask_arr).resize(self.ext)
                mask.save(self.root_path / Path("masks") / Path(".".join([self.hex_id, "jpg"])))
            
            if not resized:
                image = Chitra(self.exp_path).resize(self.ext)
            image = image.convert('RGB')
            image.save(self.root_path / Path("images") / Path(".".join([self.hex_id, "jpg"])))

            self.idx[self.hex_id] = {k:v for k,v in zip(self.idx_keys, (self.dataset_name,
                                                                        self.label,
                                                                        self.bboxes,
                                                                        self.mask,
                                                                        self.thermal))}
            
# with open(f'./{ROOT}/imgdts.idx.part_4.pkl', 'rb') as f:
#     idx_backup = pickle.load(f)
            
# imgdts = ImageDataset(ROOT, restart=False, idx_b=idx_backup)
imgdts = ImageDataset(ROOT)

## 1. Aerial Imagery dataset for fire detection: classification and segmentation using Unmanned Aerial Vehicle (UAV)

В данном датасете содержатся фото и видеозаписи.<br>
Фото получены разложением видеозаписей на кадры.<br>
Не все видеозаписи переведены в кадры.<br>
Задачи: сегментация и классификация. <br>
Данные получены с камер: оптических и тепловизионных.<br>
Все изображения будут сводиться к размеру 640 x 640

In [311]:
print(*os.listdir("alireza_shamsoshoara/video/"), sep="\n") # Всего в датасете 6 видео

1-Zenmuse_X4S_1.mp4
2-Zenmuse_X4S_2.mp4
3-WhiteHot.mov
4-GreenHot.mov
5-Thermal_Fusion.mov
6-phantom.mov


In [312]:
DS_NAME = "alireza_shamsoshoara"

1. Сегментационные данные взяты из `6-phantom.mov`

In [None]:
EXP_DIR = 'alireza_shamsoshoara/frames/Segmentation/Data/Images'
THERM = False
get_label_function = lambda x: "1"
get_mask_function = lambda img_path: Path(img_path).parent.parent / Path("Masks") / ".".join(Path(img_path).name.split(".")[:-1] + ["png"])

imgdts.import_images(export_dir=EXP_DIR,
                     dataset_name=DS_NAME,
                     thermal=THERM,
                     get_label_function=get_label_function,
                     get_mask_function=get_mask_function)

2. Классификационные данные: 
* Тренировочные данные, лейбл "Fire": взяты из видео `1-Zenmuse_X4S_1.mp4`
* Тренировочные данные, лейбл "No_Fire": взяты из `неизвестного источника`
* Тестовые данные взяты из видео `6-phantom.mov`

In [None]:
EXP_DIR = 'alireza_shamsoshoara/frames/Training/Fire/'
THERM = False
get_label_function = lambda x: "1"

imgdts.import_images(export_dir=EXP_DIR,
                     dataset_name=DS_NAME,
                     thermal=THERM,
                     get_label_function=get_label_function)

In [None]:
EXP_DIR = 'alireza_shamsoshoara/frames/Training/No_Fire/'
THERM = False
get_label_function = lambda x: "0"

imgdts.import_images(export_dir=EXP_DIR,
                     dataset_name=DS_NAME,
                     thermal=THERM,
                     get_label_function=get_label_function)

EXP_DIR = 'alireza_shamsoshoara/frames/Test/Fire'
THERM = False
get_label_function = lambda x: "1"

imgdts.import_images(export_dir=EXP_DIR,
                     dataset_name=DS_NAME,
                     thermal=THERM,
                     get_label_function=get_label_function)

EXP_DIR = 'alireza_shamsoshoara/frames/Test/No_Fire'
THERM = False
get_label_function = lambda x: "0"

imgdts.import_images(export_dir=EXP_DIR,
                     dataset_name=DS_NAME,
                     thermal=THERM,
                     get_label_function=get_label_function)

3. В будущем, можно будет также получить данные из оставшихся видео:
* 2-Zenmuse_X4S_2.mp4
* 3-WhiteHot.mov
* 4-GreenHot.mov
* 5-Thermal_Fusion.mov

## 2. UAV Fire Detection

In [316]:
# Preparing dataset for import
DS_NAME = "andre3racks"
ann_list = os.listdir(DS_NAME + "/annotations/")


def f_bndbox(obj):
        ann_ob_l = obj if type(obj) == list else [obj]
        return [{x[0]:x[1] if x[0] != 'bndbox' else {x_[0]:int(x_[1]) for x_ in x[1].items()} for x in o.items()} for o in ann_ob_l]


total_ext = {}
for ann_list_index, ann_list_element in enumerate(ann_list):
    ann_path = DS_NAME + "/annotations/" + ann_list_element
    with open(ann_path, "r") as xml_file:
        xml_text = xml_file.read()
    ann = xmltodict.parse(xml_text)["annotation"]


    ectracted = {}
    for key in ann.keys():
        if key == 'object':
            ectracted[key] = f_bndbox(ann[key])
        elif key == 'size':
            ectracted[key] = tuple(int(ann["size"][a]) for a in ann["size"])
        else:
            ectracted[key] = ann[key]
    image_path = "/".join([ectracted['folder'], ectracted['filename']])
    ectracted["image_path"] = image_path
    total_ext[ectracted['filename']] = ectracted

In [317]:
# Creating temp dir
Path.mkdir(Path('tmp/'), exist_ok=True)

In [318]:
for image_path, data in total_ext.items():
    shutil.copy('./andre3racks/' + data["image_path"], 'tmp/')

In [None]:
# Exporting dataset
def prepare_bboxes(object_list):
    bboxes = []
    if type(object_list) is list:
        for bbox in object_list:
            xmin, ymin, xmax, ymax = bbox['bndbox']['xmin'],\
                                     bbox['bndbox']['ymin'],\
                                     bbox['bndbox']['xmax'],\
                                     bbox['bndbox']['ymax']
                        
            bboxes.append([xmin, ymin, xmax, ymax])
    else:
        xmin, ymin, xmax, ymax = object_list['bndbox']['xmin'],\
                                     object_list['bndbox']['ymin'],\
                                     object_list['bndbox']['xmax'],\
                                     object_list['bndbox']['ymax']
                        
        bboxes.append([xmin, ymin, xmax, ymax])
#  
    if len(bboxes) == 4:
        bboxes.append([0]*4)
    return bboxes


EXP_DIR = 'tmp'
THERM = False
get_label_function = lambda ext_path_: ["1"] if "object" in total_ext[Path(ext_path_).name].keys() else ["0"]
get_bboxes_function = lambda ext_path_: prepare_bboxes(total_ext[Path(ext_path_).name]["object"]) if "object" in total_ext[Path(ext_path_).name].keys() else None

imgdts.import_images(export_dir=EXP_DIR,
                     dataset_name=DS_NAME,
                     thermal=THERM,
                     get_label_function=get_label_function,
                     get_bboxes_function=get_bboxes_function)

In [320]:
# Delete tmp dir
[os.remove('tmp/' + f) for f in os.listdir('tmp')]
os.rmdir('tmp')

## 3. Forest Fire Detection through UAV imagery using CNNs

Простой датасет для классификации. Тепловизионных снимков нет, датасет разделён на папки.

In [321]:
DS_NAME = "leading_india_ai"
THERM = False

In [None]:
EXP_DIR = "leading_india_ai/data/train/Fire/"
get_label_function = lambda x: "1"
imgdts.import_images(export_dir=EXP_DIR,
                     dataset_name=DS_NAME,
                     thermal=THERM,
                     get_label_function=get_label_function)

EXP_DIR = "leading_india_ai/data/train/No Fire/"
get_label_function = lambda x: "0"
imgdts.import_images(export_dir=EXP_DIR,
                     dataset_name=DS_NAME,
                     thermal=THERM,
                     get_label_function=get_label_function)

EXP_DIR = "leading_india_ai/data/validation/Fire/"
get_label_function = lambda x: "1"
imgdts.import_images(export_dir=EXP_DIR,
                     dataset_name=DS_NAME,
                     thermal=THERM,
                     get_label_function=get_label_function)

EXP_DIR = "leading_india_ai/data/validation/No Fire/"
get_label_function = lambda x: "0"
imgdts.import_images(export_dir=EXP_DIR,
                     dataset_name=DS_NAME,
                     thermal=THERM,
                     get_label_function=get_label_function)

## 4. UAV Thermal Imaginary - Fire Dataset

Тепловизионный датасет

In [323]:
DS_NAME = "uav_thermal_imaginary_fire_dataset"
THERM = True

split_list = os.listdir(DS_NAME)
class_list = np.unique([os.listdir("/".join([DS_NAME, split])) for split in split_list]).tolist()

# Creating temp dir
Path.mkdir(Path('tmp'), exist_ok=True)

# Function for binarise thermal image
def from_color_to_binary(im_p_):
    im_ = np.asarray(Image.open(im_p_))
    gray_ = cv2.cvtColor(im_, cv2.COLOR_RGB2GRAY)
    (_, bw_) = cv2.threshold(gray_, 127, 255, cv2.THRESH_BINARY)
    Image.fromarray(bw_)
    return bw_

In [None]:
for split_ in split_list:
    for class_ in class_list:
        folder_ = '/'.join([DS_NAME, split_, class_])
        Path.mkdir(Path('/'.join(["tmp", split_, class_])), exist_ok=True, parents=True)
        for file_ in os.listdir(folder_):
            bw_ = from_color_to_binary('/'.join([folder_, file_]))
            Image.fromarray(bw_).save('/'.join(["tmp", split_, class_, file_]))
        EXP_DIR = "/".join(["tmp", split_, class_])
        get_label_function = lambda x: "0" if class_ == 'no_fire' else "1" 
        imgdts.import_images(export_dir=EXP_DIR,
                     dataset_name=DS_NAME,
                     thermal=THERM,
                     get_label_function=get_label_function)

In [325]:
# Delete tmp dir
shutil.rmtree('tmp')

## 5. Создание индекс-файла

Сначала нужно собрать воедино всю информацию о файлах

In [41]:
index = imgdts.idx

df = pd.DataFrame(pd.read_csv(F"./{ROOT}/labels.csv", sep=";").values[:,:2], columns=["hex", "label"])
df["label"] = df["label"].apply(lambda x: int(x) if len(x) == 1 else int(x[2]))
df.drop_duplicates(inplace=True)

t_ls = []
for key, values in sorted(index.items()):
    a = {"hex":key}
    if len(values) > 0:
        t_ls.append(a | values)
    else:
        print("err")


tot_df = pd.DataFrame(columns=df.columns)
for t in tqdm(t_ls):
    tot_df = pd.concat([tot_df, pd.DataFrame.from_records([t])])

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 57105/57105 [01:43<00:00, 551.82it/s]


In [42]:
for hex_id, val_ in tqdm(df.set_index("hex").iterrows()):
    val_ = val_.values[0]
    tot_df.loc[tot_df["hex"] == hex_id, "cls"] = val_

57104it [03:06, 306.95it/s]


In [48]:
# Dataframe check
if tot_df["bboxes"].sum() != len(os.listdir(f"{ROOT}/bboxes/")):
    print("проверить данные для локализации")
if tot_df.shape[0] != len(os.listdir(f"{ROOT}/images/")):
    print("проверить изображения")
if tot_df["mask"].sum() != len(os.listdir(f"{ROOT}/masks/")):
    print("проверить данные для сегментации")
if tot_df["label"].sum() != pd.read_csv(f"{ROOT}/labels.csv", delimiter=";", header=None).shape[0]:
    print("проверить данные для сегментации")

проверить данные для локализации


In [49]:
def masks_count():
    src_path = "./andre3racks/annotations/"
    filenames = os.listdir(src_path)
    anns = ["".join([a,b]) for a,b in zip([src_path] * len(filenames), [*filenames])]

    def read_path(path):
        with open(path, "r") as f:
            text = f.read()
        return text

    texts = [read_path(an) for an in anns]

    return sum(1 if "bndbox" in t else 0 for t in texts)

In [50]:
print("In dataframe {}; in folder {}; real {}".format(tot_df["bboxes"].sum(),
                                                           len(os.listdir(f"{ROOT}/bboxes/")),
                                                           masks_count()))

In dataframe 1034; in folder 725; real 725


In [51]:
tot_df.set_index("hex", inplace=True)
real_bboxes = [i.split(".")[0] for i in os.listdir(f"{ROOT}/bboxes/")]
bad_inxs = tot_df[~tot_df.index.isin(real_bboxes) & tot_df.bboxes].index
tot_df.loc[bad_inxs, "bboxes"] = tot_df.loc[bad_inxs, "bboxes"].replace(True, False)

assert tot_df["bboxes"].sum() == len(os.listdir(f"{ROOT}/bboxes/")), ("проверить данные для локализации")

In [61]:
tot_df = tot_df[["bboxes", "mask", "cls"]]
tot_df.loc[tot_df.index, "thermal"] = False
thermal_index = [i for i, v in index.items() if v["thermal"]]
tot_df.loc[thermal_index, "thermal"] = True
tot_df.loc[tot_df.index, "thermal"] = False
thermal_index = [i for i, v in index.items() if v["thermal"]]
tot_df.loc[thermal_index, "thermal"] = True
d = tot_df
d.to_csv("index.csv", index="hex")

# Интерфейс создания датасета:

In [162]:
def copy_resourses(id_, task_, split_, labels_index):
    
    src_img = f'{ROOT}/images/{id_}.jpg'

    if task_ == "cls":

        label = "fire" if int(labels_index.iloc[int(id_, 16),0]) == 1 else "no_fire"
        dst_img = f'{PARAMS["output_dir"]}/{split_}/images/{label}/{id_}.jpg'
        Path.mkdir(Path(dst_img).parent, parents=True, exist_ok=True)
        

    if task_ == "loc":
        dst_img = f'{PARAMS["output_dir"]}/{split_}/images/{id_}.jpg'
        
        src_bb = f'{ROOT}/bboxes/{id_}.jpg'
        dst_bb = f"{PARAMS['output_dir']}/{split_}/images/{id_}.txt"
        Path.mkdir(Path(dst_img).parent, parents=True, exist_ok=True)
        shutil.copyfile(src_bb, dst_bb)

    if task_ == "seg":
        dst_img = f'{PARAMS["output_dir"]}/{split_}/images/{id_}.jpg'
        
        src_mask = f'{ROOT}/masks/{id_}.jpg'
        dst_mask = f"{PARAMS['output_dir']}/{split_}/masks/{id_}.jpg"
        Path.mkdir(Path(dst_img).parent, parents=True, exist_ok=True)
        shutil.copyfile(src_mask, dst_mask)
        
    Path.mkdir(Path(dst_img).parent, parents=True, exist_ok=True)
    shutil.copyfile(src_img, dst_img)

def split_data(index_list, split_ration):
    
    assert 0.99 < sum((0.7, 0.2, 0.1)) < 1.01, (f"Неверные доли разбивки {split_ration}")
    
    n = len(index_list)
    tr_n = int(split_ration[0] * n)
    te_n = int(split_ration[1] * n)
    va_n = n - te_n - tr_n

    tr_l, te_l, va_l = [], [], []

    for _ in range(tr_n):
        ch_ = random.choice(index_list)
        tr_l.append(ch_)
        index_list.remove(ch_)
    
    for _ in range(te_n):
        ch_ = random.choice(index_list)
        te_l.append(ch_)
        index_list.remove(ch_)

    va_l = index_list
    
    return {key:val for key, val in zip(("train", "test", "val"), (tr_l, te_l, va_l))} 

# task_filter
def task_filter(task, labels_index):
    assert task in ["cls", "loc", "seg"], ("Wrong task param!")
    if task == "cls":
        return True
    if task == "loc":
        return d.bboxes
    return d["mask"]


cam_filter = lambda c_: d.thermal if c_ == "thermal" else ~d.thermal

def gen_dataset(task:str,
                cam:str = "optic",
                split_ration:tuple = (0.8, 0.15, 0.5),
                output_dir:str = "./default/",
               ):
    """
    Generates dataset with suitable parameters
    
    :param task: "cls" for classification,
                 "loc" for localisation,
                 "seg" for segmentation;
    :param cam:  "therm" for thermal,
                 "optic" for optic cam;
    :param split_ration: tuple of three floats: train_ratio, test_ratio, val_ratio
                         sum(split_ration) must be 1
    :param output_dir:   relative path for dataset creation
    
    """
    
    labels_index = pd.read_csv(f"{ROOT}/labels.csv", sep=";", header=None, index_col=0)
    
    # Labels fix
    t = labels_index.iloc[:,:2]
    t_v = labels_index[1].values
    t_v = [int(i) if len(i) == 1 else int(i[2]) for i in t_v]
    labels_index[1] = t_v
    
    # Filtering images
    df_filter = task_filter(task, labels_index) & cam_filter(cam)
    index_list = d[df_filter].index
    
    if len(index_list):
        print(f"Found {len(index_list)} images for current parameters.\nCreating dataset")
    else:
        print("No images found for current parameters")
        
    
    
    # Copying data
    for split, i_val in split_data(list(index_list), split_ration).items():
        for hex_id in i_val:
            copy_resourses(hex_id, PARAMS["task"], split, labels_index)
            
    print("Dataset created!")

In [307]:
d = pd.read_csv("index.csv", index_col="hex")
gen_dataset("cls", "optic", (0.7, 0.2, 0.1), "cls_test")

Found 53125 images for current parameters.
Creating dataset
Dataset created!


# Скрипт для создания датасета

Окончательные правки были добавлены в скрипт

Использование (запускать из папки проекта):

`python create_dataset.py "task" "cam" "split_ration" "output_dir"`

* task - тип задачи для которой нужно сделать датасет ("cls" for classification, "loc" for localisation, "seg" for segmentation)
* cam - тип камеры с котрой делали сники ("therm" for thermal, "optic" for optic cam)
* split_ration - доли тренировочной, тестовой и валидационной частей через пробел
`"0.75 0.1 0.05"`
* output_dir - относительный путь папки, в которой требуется создать датасет

Например:

`python create_dataset.py "loc" "optic" "0.7 0.2 0.1" "final_test"`