In [32]:
import json

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import v2
from torchvision.transforms import functional

import torchvision.utils as utils
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from PIL import Image

%matplotlib inline
import matplotlib.pyplot as plt

In [None]:
import os
import json
import natsort

bboxes = []
# load json
with open('team_classification_data\\bboxes.json') as bboxes_json:
    bboxes = json.load(bboxes_json)

# load images 
data_dir = 'team_classification_data\\frames\\'
filenames = [name for name in os.listdir(data_dir) if os.path.splitext(name)[-1] == '.jpeg']
sorted_filenames = natsort.natsorted(filenames)

image_tensors = [torchvision.io.read_image(data_dir + name) for name in sorted_filenames]
batch = torch.stack(image_tensors)

for bb, fn, tens in zip(bboxes.items(), sorted_filenames, image_tensors):
    print(f'{bb[0]} - {fn}')

    team_list = []
    bboxes_list = []
    for player_id, bbox in bb[1].items():
        height = tens.shape[1]
        width = tens.shape[2]

        x = bbox['box'][0] * width
        y = bbox['box'][1] * height
        w = bbox['box'][2] * width
        h = bbox['box'][3] * height

        bboxes_list.append([x, y, w, h])
        team_list.append(int(player_id))

    bboxes_tensor = torch.tensor(bboxes_list)
    print(bboxes_tensor.shape)
    print(bboxes_tensor)
    bboxes_tensor = torchvision.ops.box_convert(bboxes_tensor[1], 'xywh', 'xyxy')

    if bb[0] > '2000': 
        img = utils.draw_bounding_boxes(tens, bboxes_tensor[1], width=2, colors='green')
        img = transforms.ToPILImage()(img)
        img.show()
        break



In [81]:
class TeamClassificationDataset(Dataset):
    def __init__(self, img_folder_path, bboxes_file_path, transform = None, debug_img_show = False, debug_teams_train = False):
        super(TeamClassificationDataset, self).__init__()
        self.img_folder_path = img_folder_path
        self.transform = transform
        self.debug_img_show = debug_img_show
        self.debug_teams_train = debug_teams_train

        self.bboxes = {}
        with open(bboxes_file_path) as bboxes_json:
            self.bboxes = json.load(bboxes_json)
            self.bboxes = list(self.bboxes.items())
            # print ('type(self.bboxes) = ', type(self.bboxes))
            # print(self.bboxes[0])

    def __len__(self):
        return len(self.bboxes)

    def load_image(self, imgage_path):
        image = Image.open(imgage_path).convert('RGB')
        return image

    def __getitem__(self, index):
        img_path = f'{self.img_folder_path}\\{self.bboxes[index][0]}.jpeg'
        #print(f'Loading {img_path}...')
        img = self.load_image(img_path)
        height = img.height
        width = img.width

        orig_img = img
        if self.transform:
            img = self.transform(img)

        # process bboxes
        team_list = []
        lable_list = []
        bboxes_list = []
        colors_list = []
        for id, player in self.bboxes[index][1].items():
            box = player['box']
            x = box[0] * width
            y = box[1] * height
            w = box[2] * width
            h = box[3] * height
            bboxes_list.append([x, y, w, h])

            if self.debug_teams_train is True:
                lable_list.append(id)

                team = player['team']
                team_list.append(int(team))
                if int(team):
                    colors_list.append('green')
                else:
                    colors_list.append('red')
            else:
                colors_list.append('blue')

        bboxes_tensor = torch.tensor(bboxes_list)
        bboxes_tensor_xyxy = torchvision.ops.box_convert(bboxes_tensor, 'xywh', 'xyxy')

        if self.debug_img_show:
            debug_img = functional.pil_to_tensor(orig_img)
            debug_img = utils.draw_bounding_boxes(debug_img, 
                                                  bboxes_tensor_xyxy,
                                                  width=2,
                                                  colors=colors_list,
                                                  labels=lable_list,
                                                  font='verdana.ttf',
                                                  font_size=20)
            debug_img = transforms.ToPILImage()(debug_img)
            debug_img.show()

        return {'image': img, 
                'bboxes_tensor': bboxes_tensor,
                'teams_list': team_list}

transform = v2.Compose([
                v2.PILToTensor(), # uint8 [0, 255]
                v2.ToDtype(dtype=torch.float32, scale=True)# , # float32, [0, 1]
                #v2.Normalize((0.5,), (0.5,)), # (img - mean) / std [-1, 1]
           ])

train_dataset = TeamClassificationDataset('team_classification_data\\frames',
                                          'team_classification_data\\bboxes.json',
                                           transform,
                                           debug_img_show = False,
                                           debug_teams_train = True)

train_dataloader = DataLoader(train_dataset)

In [98]:
from torchvision.transforms.functional import crop, center_crop
import numpy as np

i = 0
for data in train_dataloader:
    i = i + 1
    if (i > 2):
        break

    image = data['image']
    bboxes = data['bboxes_tensor']

    print(image.shape)
    image = image.squeeze(0)
    print(image.shape)
    r, g, b = torch.mean(image, dim=[1, 2])
    print(r, g, b)

    for bb in bboxes[0]:
        cropped = crop(image,
                    left   = int(bb[0]),
                    top    = int(bb[1]),
                    width  = int(bb[2]),
                    height = int(bb[3]))

        cropped_im = transforms.ToPILImage()(cropped)
        cropped_im.show()

    break


torch.Size([1, 3, 720, 1280])
torch.Size([3, 720, 1280])
tensor(0.5117) tensor(0.4237) tensor(0.3785)
