In [None]:
import os
from functools import partial
from PIL import Image, ImageDraw

import tensorflow as tf

from src.model.destr_model import build_model
from src.utils.bbox_utils import from_cxcyhw_to_xyxy

In [None]:
model = build_model()
checkpoint = tf.train.Checkpoint(model)
status = checkpoint.restore(tf.train.latest_checkpoint("/workspace/models/checkpoints_6"))

In [None]:
def _parse(proto, ft_desc):
    parsed_ft = tf.io.parse_single_example(proto, ft_desc)
    coord = parsed_ft["coord"]
    label = parsed_ft["label"]
    oh_label = parsed_ft["oh_label"]
    image = parsed_ft["logit"]

    return image, coord, label, oh_label

def load_data_tfrecord(
    path_to_tfrecord="/workspace/data/tfrecords", class_num: int = 8
):
    tfrecord_files = [
        os.path.join(path_to_tfrecord, f)
        for f in os.listdir(path_to_tfrecord)
        if f.endswith(".tfrecord")
    ]

    feature_description = {
        "logit": tf.io.FixedLenFeature([], tf.string),
        "label": tf.io.FixedLenFeature([], tf.float32),
        "oh_label": tf.io.FixedLenFeature([class_num], tf.float32),
        "coord": tf.io.FixedLenFeature([4], tf.float32),  # min_x, max_x, min_y, max_y
    }
    parse_fn = partial(_parse, ft_desc=feature_description)

    raw_dataset = tf.data.TFRecordDataset(tfrecord_files)
    parsed_dataset = raw_dataset.map(parse_fn)

    return parsed_dataset

In [None]:
ds = load_data_tfrecord().skip(200000).batch(8)

In [None]:
for batch in ds:
    logits, coord, *_ = batch
    
    images = tf.reshape(tf.cast(tf.io.decode_raw(logits, tf.uint8), tf.float32), shape=(-1, 224, 224, 3))
    pred_cls, pred_boxes, *_ = model(images)

    break

In [None]:
idx = 0

img = Image.fromarray(tf.cast(images[idx], tf.int8).numpy(), 'RGB')
coord = tf.cast(from_cxcyhw_to_xyxy(pred_boxes[idx]) * 224, tf.int32).numpy()

In [None]:
img_draw = ImageDraw.Draw(img)
img_draw.rectangle(list(coord), outline='green')

img.show()