# Using Coco Dataset with Rikai

In [None]:
from pyspark.sql.functions import udf, size, col
from pyspark.sql.types import FloatType, StructField, StructType, IntegerType, ArrayType, StringType
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
import numpy as np
from rikai.spark.utils import get_default_jar_version

version = get_default_jar_version(use_snapshot=True)
spark = (
    SparkSession
    .builder
    .appName('rikai-quickstart')
    .config('spark.jars.packages', 
            "ai.eto:rikai_2.12:{}".format(version))
    .master('local[*]')
    .getOrCreate()
)

# Preparing Coco Dataset

It will download [Fast.ai subset of Coco dataset](https://course.fast.ai/datasets#coco). It might take sometime.

In [None]:
# Download Coco Sample Dataset from Fast.ai datasets
import os
import subprocess

if not os.path.exists("coco_sample"):
    subprocess.check_call("wget https://s3.amazonaws.com/fast-ai-coco/coco_sample.tgz -O - | tar -xz", shell=True)
else:
    print("Coco sample already downloaded...")

In [None]:
# Convert coco dataset into Rikai format
import json
from rikai.spark.functions import box2d_from_top_left

with open("coco_sample/annotations/train_sample.json") as fobj:
    coco = json.load(fobj)
    
# print(coco.keys())
# print(coco["categories"])
# print(coco["annotations"][:10])

In [None]:
categories_df = spark.createDataFrame(coco["categories"])

# Make sure that all bbox coordinates are float
anno_array = [{
    "image_id": a["image_id"],
    "bbox": [float(x) for x in a["bbox"]],
    "category_id": a["category_id"]
} for a in coco["annotations"]]

anno_df = (
    spark
    .createDataFrame(anno_array)
    .withColumn("box2d", box2d_from_top_left("bbox"))
)

# We could use JOIN to replace pycocotools.COCO
annotations_df = (
    anno_df.join(categories_df, anno_df.category_id == categories_df.id)
    .withColumn("anno", F.struct([col("box2d"), col("name"), col("category_id")]))
    .drop("box", "name", "id", "category_id")
    .groupBy(anno_df.image_id)
    .agg(F.collect_list("anno").alias("annotations"))
)

annotations_df.printSchema()
annotations_df.show(5)


## Build Coco dataset with image and annotations in Rikai format.

In [None]:
from pyspark.sql.functions import col, lit, concat, udf
from rikai.types.vision import Image
from rikai.types.geometry import Box2d
from rikai.spark.functions import to_image, box2d
from rikai.spark.types import ImageType, Box2dType

images_df = spark \
    .createDataFrame(spark.sparkContext.parallelize(coco["images"])) \
    .withColumn(
        "image", 
        to_image(concat(lit("coco_sample/train_sample/"), col("file_name")))
    )
images_df = images_df.join(annotations_df, images_df.id == annotations_df.image_id) \
    .drop("annotations_df.image_id", "file_name", "id")
images_df.show(5)
images_df.printSchema()

In [None]:
# Inspect Bounding Boxes on an Image

row = images_df.where("id = 32954").first()

row.image | [anno.box2d for anno in row.annotations]


In [None]:
# Write Spark DataFrame into the rikai format.
(
    images_df
    .repartition(4)  # Control the number of files
    .write
    .format("rikai")
    .mode("overwrite")
    .save("/tmp/rikaicoco/out")
)

# This dataset can be directly loaded into Pytorch

In [None]:
from rikai.pytorch.data import Dataset
from torch.utils.data import DataLoader

data_loader = DataLoader(
    Dataset("/tmp/rikaicoco/out", columns=["image_id", "image"]),
    batch_size=1,
)

In [None]:
batch = next(iter(data_loader))
len(batch)

# Data is appropriately converted into pytorch.Torch

In [None]:
batch