In [1]:
import os
import numpy as np
import torch.distributed as dist
import pandas as pd
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import matplotlib.pyplot as plt
from typing import List
import cv2
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import accuracy_score, f1_score
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
import pandas as pd
import json
from PIL.Image import Image
import PIL


In [2]:
def parse_json_file(label_path: str) -> pd.DataFrame:
    with open(label_path, "r") as f:
        data = json.load(f)
        
    attributes = pd.Series(data["attributes"])
    boxes = pd.DataFrame(list(map(lambda d: d["box2d"], data["labels"])))
    categories = pd.Series(list(map(lambda d: d["category"], data["labels"])))

    df = pd.DataFrame(
        data={
            "category": categories,
            "x1": boxes.x1,
            "x2": boxes.x2,
            "y1": boxes.y1,
            "y2": boxes.y2,    
        },
    )
    for (key, val) in attributes.items():
        df[key] = val
    df["image_id"] = label_path.split("/")[-1].split(".")[0]
    return df

In [3]:
from PIL.Image import open as pil_open

def label_str_to_num(label: str) -> int:
    return int(label[0])

class CustomDataset(Dataset):
    def __init__(self, data: List[str], labels: List[str] = None, transform: torchvision.transforms.Compose = None, has_label: bool = False) -> None:
        super().__init__()
        self.data = data
        self.labels = labels
        self.transform = transform
        self.has_label = has_label
        
    def __len__(self):
        return len(self.data)
        
    def __getitem__(self, index: int) -> torch.Tensor:
        image_path = self.data[index]
        img = pil_open(image_path).convert("RGB")
        # width_ratio = config.input_size[1] / img.size[1]
        # height_ratio = config.input_size[0] / img.size[0]
        if self.transform is not None:
            img = self.transform(img)
        else:
            img = transforms.ToTensor()(img)
        if not self.has_label:
            return img, {}
        label_info = parse_json_file(self.labels[index])
        label = torch.tensor(np.array(list(map(label_str_to_num, label_info.category.values)))).squeeze(dim=0)
        boxes = [label_info.x1.values, label_info.y1.values, label_info.x2.values, label_info.y2.values]
        # for i in range(len(boxes)):
        #     ratio = width_ratio if i % 2 == 0 else height_ratio
        #     boxes[i] = list(map(lambda d: d * ratio, boxes[i]))
        boxes = torch.tensor(np.array(boxes)).squeeze(dim=0).T        
        image_id = label_info["image_id"].values[0]
        return img, {
            "boxes": boxes,
            "labels": label
        }, image_id

In [4]:
data = [os.path.join(os.getcwd(), "train_images/train_2570037_0045.jpg")]
labels = [os.path.join(os.getcwd(), "train_annotations/train_2570037_0045.json")]

dataset = CustomDataset(data=data, labels=labels, has_label=True)
dataset[0]

(tensor([[[0.6431, 0.6471, 0.6510,  ..., 0.6275, 0.6314, 0.6196],
          [0.6471, 0.6510, 0.6549,  ..., 0.6235, 0.6275, 0.6157],
          [0.6392, 0.6392, 0.6431,  ..., 0.6314, 0.6314, 0.6235],
          ...,
          [0.6353, 0.6353, 0.6275,  ..., 0.6392, 0.6431, 0.6392],
          [0.6392, 0.6314, 0.6353,  ..., 0.6314, 0.6353, 0.6353],
          [0.6353, 0.6275, 0.6392,  ..., 0.6353, 0.6431, 0.6471]],
 
         [[0.6431, 0.6471, 0.6510,  ..., 0.6314, 0.6353, 0.6235],
          [0.6471, 0.6510, 0.6549,  ..., 0.6275, 0.6314, 0.6196],
          [0.6392, 0.6392, 0.6431,  ..., 0.6353, 0.6353, 0.6275],
          ...,
          [0.6431, 0.6431, 0.6353,  ..., 0.6392, 0.6431, 0.6392],
          [0.6471, 0.6392, 0.6431,  ..., 0.6314, 0.6353, 0.6353],
          [0.6431, 0.6353, 0.6471,  ..., 0.6353, 0.6431, 0.6471]],
 
         [[0.6353, 0.6392, 0.6431,  ..., 0.6118, 0.6157, 0.6039],
          [0.6392, 0.6431, 0.6471,  ..., 0.6078, 0.6118, 0.6000],
          [0.6314, 0.6314, 0.6353,  ...,