# MAKE  TFRECORD FILES
Mathias Ramm Haugland // 28.01.22 // Master thesis // NTNU & OUH

# For Colab

In [None]:
#For Colab
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import pathlib
# Clone the tensorflow models repository if it doesn't already exist
if "models" in pathlib.Path.cwd().parts:
    while "models" in pathlib.Path.cwd().parts:
      os.chdir('..')
elif not pathlib.Path('models').exists():
    !git clone --depth 1 https://github.com/tensorflow/models

In [None]:
# Install the Object Detection API
%%bash
cd models/research/
protoc object_detection/protos/*.proto --python_out=.
cp object_detection/packages/tf2/setup.py .
python -m pip install .

In [None]:
# For colab train/test images imported from drive
!unzip 'drive/MyDrive/master/NBI_WLI.zip'

In [None]:
!unzip 'drive/MyDrive/master/test_zip_json.zip'

# Imports

In [None]:
import json
import os
import glob
import pandas as pd
import argparse
import xml.etree.ElementTree as ET
from tqdm import tqdm

# From Segmentation mask to JSON

In [None]:
#Insert script here or make independently
#additional script for doing this: mask_to_bbox.py

# From JSON to CSV

In [None]:
#inspired by https://github.com/nazililham11/detection_util_scripts/blob/master/generate_csv.py
def __list_to_csv(annotations, output_file):
    column_name = ['filename', 'width', 'height', 'class', 'xmin', 'ymin', 'xmax', 'ymax']
    xml_df = pd.DataFrame(annotations, columns=column_name)
    xml_df.to_csv(output_file, index=None)

def json_to_csv(input_json, output_file ):
    """Reads a JSON file, generated by the VGG Image Annotator, and generates a single CSV file"""
    with open(input_json) as f:
        images = json.load(f)
    
    annotations = []
    for entry in images:
        filename = entry
        #filename = entry.split(".")[0]+"_SNBI.png" # FOR SNBI 
        #HERE
        width = images[entry]["width"]
        height = images[entry]["height"]
        for bbox in images[entry]["bbox"]:
            c = bbox["label"]
            xmin = bbox["xmin"]
            ymin = bbox["ymin"]
            xmax = bbox["xmax"]
            ymax = bbox["ymax"]

            value = (filename, width, height, c, xmin, ymin, xmax, ymax)
            annotations.append(value)
    
    __list_to_csv(annotations, output_file)

#run
json_to_csv('set3_det.json', 'set3_det.csv')
json_to_csv('set3_class.json', 'set3_class.csv')

# From CSV to TFRecord

In [None]:
#Inspired by https://github.com/douglasrizzo/detection_util_scripts/blob/master/generate_tfrecord.py
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import

import io
import tensorflow as tf

from PIL import Image
from object_detection.utils import dataset_util
from collections import namedtuple, OrderedDict


In [None]:
def __split(df, group):
    data = namedtuple('data', ['filename', 'object'])
    gb = df.groupby(group)
    return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]


In [None]:

def create_tf_example(group, path, class_dict):

    with tf.io.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:
        encoded_jpg = fid.read()
    encoded_jpg_io = io.BytesIO(encoded_jpg)
    image = Image.open(encoded_jpg_io)
    width, height = image.size
    
    filename = group.filename.encode('utf8')
    image_format = b'png' #IMPORTANT
    xmins = []
    xmaxs = []
    ymins = []
    ymaxs = []
    classes_text = []
    classes = []
    weights = []
    
    for index, row in group.object.iterrows():
        if set(['xmin_rel', 'xmax_rel', 'ymin_rel', 'ymax_rel']).issubset(set(row.index)):
            xmin = row['xmin_rel']
            xmax = row['xmax_rel']
            ymin = row['ymin_rel']
            ymax = row['ymax_rel']
        
        elif set(['xmin', 'xmax', 'ymin', 'ymax']).issubset(set(row.index)):
            xmin = row['xmin'] / width
            xmax = row['xmax'] / width
            ymin = row['ymin'] / height
            ymax = row['ymax'] / height
        
        xmins.append(xmin)
        xmaxs.append(xmax)
        ymins.append(ymin)
        ymaxs.append(ymax)

        if str(row['class']) == 'Hyper':
          class_nam = 'Hyperplasia'
        elif str(row['class']) == 'Adenoma':
          class_nam = 'Adenoma'
        else: print("WRONG Class NAME!")

        classes_text.append(class_nam.encode('utf8'))
        classes.append(class_dict[class_nam])
        
        """
        #This didn't work
        
        if str(row['class']) == 'Adenoma':
          weights.append(1.0)
        else:
          weights.append(1.0) #Hyperplasia or one-class polyp
        """
    tf_example = tf.train.Example(features=tf.train.Features(
        feature={
            'image/height': dataset_util.int64_feature(height),
            'image/width': dataset_util.int64_feature(width),
            'image/filename': dataset_util.bytes_feature(filename),
            'image/source_id': dataset_util.bytes_feature(filename),
            'image/encoded': dataset_util.bytes_feature(encoded_jpg),
            'image/format': dataset_util.bytes_feature(image_format),
            'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
            'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
            'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
            'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
            'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
            'image/object/class/label': dataset_util.int64_list_feature(classes), 
#            'image/object/weight': dataset_util.float_list_feature(weights) # Important line
}))
    return tf_example


In [None]:
def class_dict_from_pbtxt(pbtxt_path):
    # open file, strip \n, trim lines and keep only
    # lines beginning with id or display_name
    
    with open(pbtxt_path, 'r', encoding='utf-8-sig') as f:
        data = f.readlines()
    
    name_key = None
    if any('display_name:' in s for s in data):
        name_key = 'display_name:'
    elif any('name:' in s for s in data):
        name_key = 'name:'
    
    if name_key is None:
        raise ValueError(
            "label map does not have class names, provided by values with the 'display_name' or 'name' keys in the contents of the file"
        )
    data = [l.rstrip('\n').strip() for l in data if 'id:' in l or name_key in l]
    
    ids = [int(l.replace('id:', '')) for l in data if l.startswith('id')]
    names = [
        l.replace(name_key, '').replace('"', '').replace("'", '').strip() for l in data
        if l.startswith(name_key)]
    
    # join ids and display_names into a single dictionary
    class_dict = {}
    for i in range(len(ids)):
        class_dict[names[i]] = ids[i]
    return class_dict


In [None]:
#Fix names here
#first detection, then class
def rec(viddd,mode):
  path2 = '/home/hemin/mathiasrammhaugland/master/final/' #SHIFT

  pbtxt_input = path2+'label_map.pbtxt' #SHIFT
  class_dict = class_dict_from_pbtxt(pbtxt_input)

  output_path =  path2+viddd+'_piccolo_'+mode+'_4.record'
  writer = tf.compat.v1.python_io.TFRecordWriter(output_path)

  image_dir = path2+mode
  path = os.path.join(image_dir)

  csv_input = path2+'bbox.csv' #SHIFT
  examples = pd.read_csv(csv_input)
  grouped = __split(examples, 'filename')

  for group in tqdm(grouped, desc='groups'):
      tf_example = create_tf_example(group, path, class_dict)
      writer.write(tf_example.SerializeToString())

  writer.close()
  output_path = os.path.join(os.getcwd(), output_path)
  print('Successfully created the TFRecords: {}'.format(output_path))


#run
for mode in ["snbi4"]:
  for viddd in ["class"]:
    rec(viddd, mode)

In [None]:
#compress for download
!zip -r /content/recfiles.zip /content/recfiles

In [None]:
!rm -rf masks

In [None]:
shutil.move("/content/recfiles.zip", "/content/drive/MyDrive/master/recfiles.zip") 


# Check your TFRecord file

In [None]:
raw_dataset = tf.data.TFRecordDataset("train_piccolo_wlix.record")
a = []
for raw_record in raw_dataset.take(1000):
    example = tf.train.Example()
    example.ParseFromString(raw_record.numpy())
    a.append(example)
print(a[21])

In [None]:
print(a[18
        ])