In [15]:
from azure.cognitiveservices.vision.customvision.training import CustomVisionTrainingClient
from azure.cognitiveservices.vision.customvision.prediction import CustomVisionPredictionClient
from azure.cognitiveservices.vision.customvision.training.models import ImageFileCreateBatch, ImageFileCreateEntry, Region
from msrest.authentication import ApiKeyCredentials
import os, time, uuid

In [16]:
import yaml
with open('../config.yaml') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

In [17]:
ENDPOINT = f"https://{config['custom_vision']['endpoint']}/"
training_key = config["custom_vision"]["training_key"]
prediction_key = config["custom_vision"]["prediction_key"]
prediction_resource_id = config["custom_vision"]["prediction_resource_id"]


In [18]:
publish_iteration_name = "FireNet"

credentials = ApiKeyCredentials(in_headers={"Training-key": training_key})
trainer = CustomVisionTrainingClient(ENDPOINT, credentials)

In [19]:
if not config["project_id"]:
    # Create a new project
    print("Creating project...")
    obj_detection_domain = next(
        domain
        for domain in trainer.get_domains()
        if domain.type == "ObjectDetection" and domain.name == "General"
    )

    # Use uuid to avoid project name collisions.
    project = trainer.create_project(
        str(uuid.uuid4()),
        domain_id=obj_detection_domain.id,
        classification_type="Multiclass",
    )
else:
    project = trainer.get_project(config["project_id"])


## Upload and tag images

In [20]:
import pandas as pd
from pathlib import Path

local_path = Path("../data")
datasets = [e for e in local_path.iterdir() if e.is_dir()]

### Add tags to the project

In [21]:
from tqdm import tqdm

for data_path in datasets:
    images_path = data_path / "Images"
    df = pd.read_pickle(data_path / "object_detection.pkl")

    tags = {}
    for az_tag in trainer.get_tags(project.id):
        tags[az_tag.name] = az_tag
    for tag_name in df.tag.unique():
        if tag_name not in tags:
            tags[tag_name] = trainer.create_tag(project.id, tag_name)
    len(tags)

    batch_size = config["upload_batch_size"]
    tagged_images_with_regions = []

    print("Loading images and labels...")

    images = df.filename.unique()
    i = 0
    for file_name in tqdm(images):
        i += 1

        rows = df.where(df.filename == file_name).dropna()
        regions = []
        for _, row in rows.iterrows():
            regions.append(
                Region(
                    tag_id=tags[row["tag"]].id,
                    left=row["left"],
                    top=row["top"],
                    width=row["width"],
                    height=row["height"],
                )
            )

        with open(images_path / file_name, mode="rb") as image_contents:
            tagged_images_with_regions.append(
                ImageFileCreateEntry(
                    name=file_name, contents=image_contents.read(), regions=regions
                )
            )

        if i % batch_size == 0:
            print(f"Uploading image batch {i} ...")
            upload_result = trainer.create_images_from_files(
                project.id, ImageFileCreateBatch(images=tagged_images_with_regions)
            )
            tagged_images_with_regions = []
            if not upload_result.is_batch_successful:
                errors = False
                for image in upload_result.images:
                    if image.status != "OK" and image.status != "OKDuplicate":
                        errors = True
                        print(f"Image {image.source_url}: {image.status}")
                if errors:
                    print(f"Image batch {i} upload failed.")
                    break

    if len(tagged_images_with_regions) > 0:
        print(f"Uploading latest batch {i}...")
        upload_result = trainer.create_images_from_files(
            project.id, ImageFileCreateBatch(images=tagged_images_with_regions)
        )
        tagged_images_with_regions = []
        if not upload_result.is_batch_successful:
            errors = False
            for image in upload_result.images:
                if image.status != "OK" and image.status != "OKDuplicate":
                    errors = True
                    print(f"Image {image.source_url}: {image.status}")
            if errors:
                print(f"Latest Image batch upload failed.")

Loading images and labels...


100%|██████████| 22/22 [00:00<00:00, 334.43it/s]

Uploading latest batch 22...





### Adding images