In [1]:
from pycocotools.coco import COCO
from PIL import Image
import numpy as np
import os
import pandas as pd
import torchvision.transforms as tvt
from torchvision.io import read_image
import argparse
from tqdm import tqdm
import itertools
from pprint import pprint

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Global Variables
train_directory = r"/scratch/gilbreth/dfarache/ece60146/Nikita/hw06/train2014"
val_directory = r"/scratch/gilbreth/dfarache/ece60146/Nikita/hw06/val2014"
train_annotations = r"/scratch/gilbreth/dfarache/ece60146/Nikita/hw06/annotations/instances_train2014.json"
val_annotations = r"/scratch/gilbreth/dfarache/ece60146/Nikita/hw06/annotations/instances_val2014.json"


categories = ["bus", "cat", "pizza"]
catIds = [6, 17, 59]
categories = {"bus": 6, "cat": 17, "pizza": 59}
inverse_categories = {6: "bus", 17: "cat", 59: "pizza"}


columns = ["id", "path_to_image", "bus", "cat", "pizza"]
new_image_size = 256
resize_image = tvt.Compose([tvt.Resize((new_image_size,new_image_size))])


train = False
coco = COCO(train_annotations) if train else COCO(val_annotations)
path_to_dir = "/scratch/gilbreth/dfarache/ece60146/Nikita/hw06/Train" if train else "/scratch/gilbreth/dfarache/ece60146/Nikita/hw06/Val"

loading annotations into memory...
Done (t=5.22s)
creating index...
index created!


In [3]:
def get_all_image_ids():
    all_image_ids = []

    # Get all combinations of the categories
    for idx in range(1, len(categories) + 1):
        combination = [list(i) for i in itertools.combinations(categories, idx)]
        
        for class_labels in combination:
            category_ids = coco.getCatIds(catNms=class_labels)
            image_ids = coco.getImgIds(catIds=category_ids)

            for image_id in image_ids:
                if(image_id not in all_image_ids):
                    all_image_ids.append(image_id)
    
    print(f"Total number of images for all the combinations of the categories: {len(all_image_ids)}")
    return all_image_ids

all_image_ids = get_all_image_ids() 

Total number of images for all the combinations of the categories: 3940


In [4]:
def resize_bbox(bbox, original_width, original_height):
    # Resize bbox
    annotation_width, annotation_height = bbox[2], bbox[3]
    annotation_left_x, annotation_left_y = bbox[0], bbox[1]

    x_scale = new_image_size / original_width
    y_scale = new_image_size / original_height

    new_bbox_width = x_scale * annotation_width
    new_bbox_height = y_scale * annotation_height
    new_x1 = x_scale * annotation_left_x
    new_y1 = y_scale * annotation_left_y

    new_x2 = new_x1 + new_bbox_width
    new_y2 = new_y1 + new_bbox_height
    
    # Add everything to the list to add to the df
    ann_bbox = [new_x1, new_y1, new_x2, new_y2]
    return ann_bbox

In [5]:
def resize_image_bbox(image, path_to_image, Bboxs):
    directory = "/scratch/gilbreth/dfarache/ece60146/Nikita/hw06/train2014" if train else "/scratch/gilbreth/dfarache/ece60146/Nikita/hw06/val2014"
    new_bboxs = {"bus": [], "cat": [], "pizza": []}
    
    try:
        image_filename = image["file_name"]
        original_width, original_height = image["width"], image["height"]

        original_image = resize_image(read_image(os.path.join(directory, image_filename)))
        original_image = original_image.repeat(3,1,1) if original_image.size()[0] == 1 else original_image # Check if image has three channels
        original_image = tvt.functional.to_pil_image(original_image).convert("RGB") # convert to RGB image

        for label, bboxs in Bboxs.items():
            for idx in range(len(bboxs)):
                if(bboxs[idx] != -1):
                    resized_bbox = resize_bbox(bboxs[idx], original_width, original_height)

                    if(label == "bus"):
                        new_bboxs["bus"].append(resized_bbox)
                        new_bboxs["cat"].append(-1)
                        new_bboxs["pizza"].append(-1)

                    elif(label == "cat"):
                        new_bboxs["cat"].append(resized_bbox)
                        new_bboxs["bus"].append(-1)
                        new_bboxs["pizza"].append(-1)

                    elif(label == "pizza"):
                        new_bboxs["pizza"].append(resized_bbox)
                        new_bboxs["cat"].append(-1)
                        new_bboxs["bus"].append(-1)

        original_image.save(path_to_image)
        return new_bboxs

    except Exception as e:
        print(f"Exception: {e}")
        return False

In [6]:
dataset = {}
for image_id in tqdm(all_image_ids):
    image_info = {"Id": None,
                  "Path to Image": None,
                  "Bbox": {"bus": [], "cat": [], "pizza": []}}
    
    image = coco.loadImgs(ids=image_id)[0]
    annotation_ids = coco.getAnnIds(imgIds=image["id"], iscrowd=False)
    annotations = coco.loadAnns(annotation_ids)
    
    for annotation in annotations:
        if(annotation["category_id"] in catIds):
            if(annotation["area"] > (64 * 64)):
                path_to_image = os.path.join(path_to_dir, image["file_name"])
                image_info["Id"] = image_id
                image_info["Path to Image"] = path_to_image

                label_name = inverse_categories[annotation['category_id']]
                if(label_name == "bus"):
                    image_info['Bbox']["bus"].append(annotation['bbox'])
                    image_info['Bbox']["cat"].append(-1)
                    image_info['Bbox']["pizza"].append(-1)
                elif(label_name == "cat"):
                    image_info['Bbox']["cat"].append(annotation['bbox'])
                    image_info['Bbox']["bus"].append(-1)
                    image_info['Bbox']["pizza"].append(-1)
                elif(label_name == "pizza"):
                    image_info['Bbox']["pizza"].append(annotation['bbox'])
                    image_info['Bbox']["cat"].append(-1)
                    image_info['Bbox']["bus"].append(-1)
                    
        if(image_info["Path to Image"] != None):
            resized_bboxs = resize_image_bbox(image, image_info["Path to Image"], image_info["Bbox"])
            if(resized_bboxs):
                image_info["Bbox"] = resized_bboxs
                dataset[image["file_name"]] = image_info

print(f'Total Images downloaded {len(dataset)}')

100%|██████████| 3940/3940 [00:28<00:00, 136.38it/s]

Total Images downloaded 3491





In [7]:
def create_dataframe(dataset):
    df_filename = "train_data.csv" if train else "test_data.csv"
    df = pd.DataFrame(columns=columns)
    
    ids, paths_to_image, bus_bboxs, cat_bboxs, pizza_bboxs = [], [], [], [], []
    for image_filename, image_info in dataset.items():
        # image_info = {id:str, path_to_image:str, bboxs:dict}
        ids.append(image_info["Id"])
        paths_to_image.append(image_info["Path to Image"])
        bus_bboxs.append(image_info["Bbox"]["bus"])
        cat_bboxs.append(image_info["Bbox"]["cat"])
        pizza_bboxs.append(image_info["Bbox"]["pizza"])

    assert len(bus_bboxs) == len(cat_bboxs) == len(pizza_bboxs) == len(ids) == len(paths_to_image), f"Lengths of data not matched: {len(bus_bboxs)}, {len(cat_bboxs)}, {len(pizza_bboxs)}, {len(ids)}, {len(paths_to_image)}"
    df["id"] = ids
    df["path_to_image"] = paths_to_image
    df["bus"] = list(bus_bboxs)
    df["cat"] = list(cat_bboxs)
    df["pizza"] = list(pizza_bboxs)
    df.to_csv(df_filename)     

create_dataframe(dataset)