Task: For each soccer player, return cropped images of 10 players and the jersey numbers of those 10 players

In [None]:
from google.colab import drive

drive.mount('/content/gdrive')
path = '/content/gdrive/My Drive/MyCode/Real-time-football-match-tracking/'
%cd {path}

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
/content/gdrive/My Drive/MyCode/Real-time-football-match-tracking


In [None]:
import os
import json
import cv2
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, ToTensor, Resize

In [None]:
class FootballDataset(Dataset):
    def __init__(self, root, transform = None):
        self.images = []
        self.labels = []
        self.file_names = []
        self.num_frames = []

        matches = os.listdir(root)
        for match in matches:
            folder_path = os.path.join(root, match)
            json_path, video_path = sorted(os.listdir(folder_path))
            self.file_names.append(os.path.join(folder_path, json_path.replace(".json", "")))
            with open(os.path.join(folder_path, json_path), "r") as json_file:
                json_data = json.load(json_file)

            # count number of frame
            self.num_frames.append(len(json_data["images"]))

        self.transform = transform

    def __len__(self):
        # Returns the total number of frames
        return sum(self.num_frames)

    def __getitem__(self, index):
        # index belongs to video
        if index < self.num_frames[0]:
            frame_id = index
            video_id = 0
        elif self.num_frames[0] <= index < self.num_frames[0] + self.num_frames[1]:
            frame_id = index - self.num_frames[0]
            video_id = 1
        else:
            frame_id = index - self.num_frames[0] - self.num_frames[1]
            video_id = 2

        video_path = "{}.mp4".format(self.file_names[video_id])
        json_path = "{}.json".format(self.file_names[video_id])

        # Read video
        cap = cv2.VideoCapture(video_path)
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
        flag, image = cap.read()
        # cv2.imwrite("sample.jpg", image)

        with open(json_path, "r") as json_file:
            json_data = json.load(json_file)
        # print(json_data["annotations"])
        bboxes = [anno["bbox"] for anno in json_data["annotations"] \
                              if anno["image_id"] - 1 == frame_id \
                                and anno["category_id"] == 4] # category_id is human
        jerseys = [int(anno["attributes"]["jersey_number"]) for anno in json_data["annotations"] \
                              if anno["image_id"] -1 == frame_id \
                                and anno["category_id"] == 4]
        colors = [anno["attributes"]["team_jersey_color"] for anno in json_data["annotations"] \
                              if anno["image_id"] -1 == frame_id \
                                and anno["category_id"] == 4]
        colors = [0 if color == "black" else 1 for color in colors]
        cropped_images = [image[int(y_min): int(y_min + height), int(x_min): int(x_min + width), :] \
                       for (x_min, y_min, width, height) in bboxes]

        # for i, cropped_image in enumerate(cropped_images):
        #     cv2.imwrite("{}.jpg".format(i), cropped_image)

        if self.transform:
            cropped_images = [self.transform(image) for image in cropped_images]
        # visualize
        # for ann in current_annotation:
        #     x_min, y_min, width, height = ann
        #     x_min = int(x_min)
        #     y_min = int(y_min)
        #     x_max = int(x_min + width)
        #     y_max = int(y_min + height)
        #     cv2.rectangle(image, (x_min, y_min), (x_max, y_max), (0, 0, 255), 2)
        # cv2.imwrite("sample.jpg", image)
        return cropped_images, jerseys, colors

In [None]:
def collate_fn(batch):
    images, labels, colors = zip(*batch)

    final_images = []
    for image in images:
        final_images.extend(image)
    final_images = torch.stack(final_images)

    final_labels = []
    for label in labels:
        final_labels.extend(label)
    final_labels = torch.IntTensor(final_labels)

    final_colors = []
    for color in colors:
        final_colors.extend(color)
    final_colors = torch.IntTensor(final_colors)

    return final_images, final_labels, final_colors
    # print(len(batch))
    # print(batch)
    # exit()

In [None]:
class ResNet_two_header2(nn.Module):
    def __init__(self, num_jerseys = 10, num_colors = 2):
        super().__init__()
        self.model = models.resnet50(pretrained = True)
        self.model.fc1 = nn.Linear(in_features = 2048, out_features = num_jerseys)
        self.model.fc2 = nn.Linear(in_features = 2048, out_features = num_colors)

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        x = self.model.avgpool(x)
        x = torch.flatten(x, 1)
        x1 = self.model.fc1(x)
        x2 = self.model.fc2(x)

        return x1, x2

In [1]:
if __name__ == '__main__':
    transform = Compose([
        ToTensor(),
        Resize((224, 224))
    ])
    path = "/content/gdrive/My Drive/MyCode/Real-time-football-match-tracking/Data/football_train"
    dataset = FootballDataset(root = path, transform = transform)
    params = {
        "batch_size" : 2,
        "shuffle" : True,
        "drop_last" : True,
        "num_workers" : 6,
        "collate_fn" : collate_fn
    }
    dataloader = DataLoader(dataset, **params)
    # model = models.resnet50(pretrained = True)
    # model.fc = nn.Linear(in_features = 2048, out_features = 10)
    model = ResNet_two_header2(10, 2)
    for images, labels, colors in dataloader:
        jersey_prediction, color_prediction = model(images)
        print(jersey_prediction.shape, color_prediction.shape)
        # print(output.shape)
        # print(labels)
        # print(colors)