In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.patches as patches

from PIL import Image
from PIL import ImageDraw

import numpy as np

import json

from os import listdir
from os.path import isfile, join

from tqdm import tqdm

In [None]:
datasets_root = '/datasets/'    #Root folder for all datasets

In [None]:
json_cornercases = datasets_root + 'CODA/val/annotations.json'

with open(json_cornercases, 'r') as f:
  data_cornercases = json.load(f)

annotations = data_cornercases["annotations"]
images = data_cornercases["images"]
categories = data_cornercases["categories"]

In [None]:
soda10m_counter = 0
once_counter = 0

soda10m_categories = []
once_categories = []

for image in tqdm(images):
  id = image['id']
  file_name = image['file_name']
  bbox_items = []

  # Get bounding box information
  for annotation in annotations:
    if (annotation['image_id']) == id:
      bbox_items.append(annotation)
      #bbox: Coordinate of boundingbox [x, y, w, h].


  # Load file
  img_path = datasets_root + '/CODA/val/images/' + file_name
  img = Image.open(img_path)

  #Get image size
  img_w, img_h = img.size

  #Dataset
  if (img_w == 1355): #(1355,720) is scaled-down ONCE, SODA10M is (1281,720) and (1280,720) and (958,720)?
    once_counter += 1
    mask_path_base = './data/CODA_masks/2022once/'
    dataset_name = '2022-ONCE'
  elif (img_w == 1281):
    soda10m_counter += 1
    mask_path_base = './data/CODA_masks/2022soda10m/'
    dataset_name = "2022-SODA10M"
    newsize = (1280, 720)
    img = img.resize(newsize)
    img_w, img_h = img.size

  elif (img_w == 958 or img_w == 1280):  #no resize
    soda10m_counter += 1
    mask_path_base = './data/CODA_masks/2022soda10m/'
    dataset_name = "2022-SODA10M"

  #Anomaly overlay
  img_overlay = Image.new('RGB',(img_w, img_h), (255,255,255))

  #Create mask
  mask = np.zeros((img_h, img_w), dtype=int)

  #List of anomaly rectangles
  rect_list = []

  draw = ImageDraw.Draw(img)
  draw_overlay = ImageDraw.Draw(img_overlay)


  # Create a Rectangle patch
  for bbox_item in bbox_items:
    bb_x = bbox_item['bbox'][0]
    bb_y = bbox_item['bbox'][1]
    bb_w = bbox_item['bbox'][2]
    bb_h = bbox_item['bbox'][3]

    shape = [(bb_x, bb_y), (bb_x + bb_w, bb_y + bb_h)]
    if bbox_item['corner_case'] == True:
      draw.rectangle(shape, outline =(0,255,0), width =3)
      draw_overlay.rectangle(shape, fill =(255, 102, 0), outline =(255,255,255), width =3)
      result = np.copy(np.array(img))
      result[~np.all(np.array(img_overlay) == 255*np.ones(3), axis=-1)] = 0.3*np.array(img)[~np.all(np.array(img_overlay) == 255*np.ones(3), axis=-1)] + 0.7*np.array(img_overlay)[~np.all(np.array(img_overlay) == 255*np.ones(3), axis=-1)]
      result_img = Image.fromarray(result.astype("uint8"))

    #Add +1 to every entry in the mask that lies in or on the rectangle
    if bbox_item['corner_case'] == True:
      x_min = int(bb_x)
      x_max = int(bb_x + bb_w)
      y_min = int(bb_y)
      y_max = int(bb_y + bb_h)
      mask[y_min:y_max,x_min:x_max] += 1

      #Collect category
      categorie = bbox_item["category_id"]
      if (dataset_name == "2022-ONCE" and categorie not in once_categories):
              once_categories.append(categorie)
      elif (dataset_name == "2022-SODA10M" and categorie not in soda10m_categories):
              soda10m_categories.append(categorie)

  #Save and show mask
  filename_np = file_name[:-4]
  mask_path = mask_path_base + filename_np
  np.save(mask_path, mask)

  #Save image with anomaly bounding boxes
  img_bb_path = '../output/coda2022/' + file_name
  result_img.save(img_bb_path)

  #Visualization
  total_count = soda10m_counter + once_counter
  if(total_count % 50 == 0):
    fig, ax = plt.subplots(1,2)
    ax[0].imshow(result_img)  #img
    ax[0].set_title(dataset_name + " " + file_name)
    for rect in rect_list:
      ax[0].add_patch(rect)
    ax[1].set_title('W,H: ' + str(img_w) + "," + str(img_h))
    mask_plot = ax[1].imshow(mask, interpolation='none')
    plt.colorbar(mask_plot, fraction=0.025, pad=0.04, ticks=np.linspace(0, 10, 11, endpoint=True))
    plt.show()

#Data completeness (https://coda-dataset.github.io/documentation.html#data_usage)
print("# SODA10M: " + str(soda10m_counter)  + " Categories: " + str(len(soda10m_categories)))
print("# ONCE: " + str(once_counter)  + " Categories: " + str(len(once_categories)))     