In [None]:
!pip install -q jupyter_black

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.2/79.2 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m16.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
%load_ext jupyter_black

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
import cv2
import json
import numpy as np

import torch
from torch.utils.data import DataLoader, Dataset
from torch import nn, optim
from torchvision import models
import torchvision.transforms as tfs

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
"""
file_set = set(os.listdir(path_to_data + 'images'))
with open(path_to_data + '/data_val.json', "r") as file:
    helper = json.load(file)

print(len(helper))

for item in helper:
    filename = item['id'] + '.png'
    if filename not in file_set:
        helper.remove(item)

print(len(helper))

with open(path_to_data + '/data_val.json', "w") as file:
    json.dump(helper, file)
"""
path_to_data = '/content/drive/MyDrive/Tennis_Ball_Tracker/data/'

In [None]:
class KeypointsDataset(Dataset):
    def __init__(self, image_dir, data_file, ignore_string=file_string):
        self.image_dir = image_dir
        self.ignore_string = ignore_string
        with open(data_file, "r") as file:
            self.data = json.load(file)

        self.transforms = tfs.Compose(
            [
                tfs.ToPILImage(),
                tfs.Resize((224, 224)),
                tfs.ToTensor(),
                tfs.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        item = self.data[index]
        image = cv2.imread(f"{self.image_dir}/{item['id']}.png")
        height, width = image.shape[:2]
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.transforms(image)
        labels = np.array(item["kps"]).flatten()
        labels = labels.astype(np.float32)
        # After resizing to (224, 224) we need to match the keypoints to the same scale
        labels[::2] *= 224.0 / width  # for x
        labels[1::2] *= 224.0 / height  # for y

        return image, labels

In [None]:
train_dataset = KeypointsDataset(path_to_data + "images", path_to_data + 'data_train.json')
valid_dataset = KeypointsDataset(path_to_data + "images", path_to_data + 'data_val.json')

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=8, shuffle=True)

In [None]:
model = models.resnet50(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 14 * 2)
model = model.to(device)



In [None]:
lr = 1e-4
num_epochs = 20

In [None]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if i % 20 == 0:
            print(f"Epoch: {epoch + 1}/{num_epochs} \t Iteration: {i + 1}/{len(train_loader)} \t Loss: {loss.item()}")

Epoch: 1/20 	 Iteration: 1/625 	 Loss: 14525.5078125
Epoch: 1/20 	 Iteration: 21/625 	 Loss: 13444.9482421875
Epoch: 1/20 	 Iteration: 41/625 	 Loss: 13632.123046875
Epoch: 1/20 	 Iteration: 61/625 	 Loss: 12419.525390625
Epoch: 1/20 	 Iteration: 81/625 	 Loss: 11954.6259765625
Epoch: 1/20 	 Iteration: 101/625 	 Loss: 11118.2490234375
Epoch: 1/20 	 Iteration: 121/625 	 Loss: 10034.4912109375
Epoch: 1/20 	 Iteration: 141/625 	 Loss: 9881.201171875
Epoch: 1/20 	 Iteration: 161/625 	 Loss: 9071.607421875
Epoch: 1/20 	 Iteration: 181/625 	 Loss: 8588.1640625
Epoch: 1/20 	 Iteration: 201/625 	 Loss: 7370.486328125
Epoch: 1/20 	 Iteration: 221/625 	 Loss: 7581.9052734375
Epoch: 1/20 	 Iteration: 241/625 	 Loss: 6855.38330078125
Epoch: 1/20 	 Iteration: 261/625 	 Loss: 5475.25048828125
Epoch: 1/20 	 Iteration: 281/625 	 Loss: 5568.65771484375
Epoch: 1/20 	 Iteration: 301/625 	 Loss: 5148.66552734375
Epoch: 1/20 	 Iteration: 321/625 	 Loss: 4505.23876953125
Epoch: 1/20 	 Iteration: 341/625 	 L

KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), path_to_data + 'keypoints_weights.pth')