In [1]:
from importlib import reload
from pathlib import Path
from argparse import ArgumentParser
import sys

import numpy as np
import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

sys.path.append('..')

from ball_detection.candidate_classifier import model
from ball_detection.candidate_classifier import data_preprocessing
from ball_detection.candidate_classifier import training
from ball_detection.candidate_classifier import augmentations

In [2]:
# reload(data_preprocessing)
# reload(model)
# reload(training)

In [3]:
candidate_dirs_motion = [
    Path('../data/sync/resources/' + d) 
    for d in ('009', '010', '011', '012', '013', '014', '016', '017', '018')
]
candidate_dirs_hough = [
    Path('../data/sync/resources/' + d) 
    for d in ('001', '003', '005', '006', '009', '010', '011', '012', '013', '014', '016', '017', '018')
] + [Path('../data/sync/resources/015/0151/')]
hough_indexes = [data_preprocessing.read_json_dataset_index(d, 'markup_h.json') for d in candidate_dirs_hough]
motion_indexes = [data_preprocessing.read_json_dataset_index(d, 'markup_motion.json') for d in candidate_dirs_motion]

In [4]:
folder_index = data_preprocessing.read_folder_dataset_index(
    Path('../data/sync/dataset_solid_striped_sep/'),
    Path('../data/sync/images_for_dataset/')
)

In [5]:
common_index = data_preprocessing.merge_dataset_indexes((*hough_indexes, *motion_indexes, folder_index))

In [6]:
train_index, val_index = train_test_split(list(common_index.items()), train_size=.8)
train_index, val_index = dict(train_index), dict(val_index)
train_balls_index, train_false_index = data_preprocessing.split_balls_false_detections(train_index)

In [7]:
train_balls_dataset = data_preprocessing.CandidatesDataset(train_balls_index, shuffle=True)
train_false_dataset = data_preprocessing.CandidatesDataset(train_false_index, move_prob=.2, shuffle=True)
train_candidate_dataset = data_preprocessing.MixDataset(train_balls_dataset, train_false_dataset)
augmentations_applier = augmentations.AugmentationApplier()
dataset_train = data_preprocessing.LabeledImageDataset(train_candidate_dataset, augmentations_applier)

val_candidate_dataset = data_preprocessing.CandidatesDataset(val_index)
dataset_val = data_preprocessing.LabeledImageDataset(val_candidate_dataset)

In [8]:
BATCH_SIZE = 1024
data_loader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, num_workers=0)
data_loader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, num_workers=0)

In [9]:
device = 'cuda'
model_saving_path = 'best_model.pt'

In [10]:
# hist = np.histogram(y_train.cpu().numpy(), bins=3)[0]
# class_weights = torch.tensor(hist.sum() / hist).float().to(device)

In [11]:
net = model.Net()
net = net.to(device)
opt = torch.optim.Adam(net.parameters(), lr=0.01)

In [12]:
def run_training(lr):
    opt.lr = lr
    training.train(100, net, data_loader_train, data_loader_val, weight_decay=.01, device=device, optimizer=opt, 
                   save_path=model_saving_path)

In [13]:
y_val = np.array([label for _, label in dataset_val])

def evaluate():
    best_net = model.Net()
    best_net.load_state_dict(torch.load(model_saving_path))

    val_pred = torch.cat([best_net(x_batch).detach() for x_batch, _ in data_loader_val])
    val_pred = val_pred.numpy().argmax(axis=1)
    
    accuracy = np.mean(val_pred == y_val)

    gt_pred_pairs = np.stack([y_val, val_pred], axis=1)
    confusion_matrix = np.zeros((5, 5), dtype=np.int32)
    for gt, pred in gt_pred_pairs:
        confusion_matrix[gt, pred] += 1

    print(confusion_matrix)
    print(accuracy)

In [14]:
run_training(.01)
evaluate()

loss: 1.3963/1.2498, accuracy: 0.6660350561141968
loss: 1.0923/1.1307, accuracy: 0.6660350561141968
loss: 0.9932/1.0115, accuracy: 0.6750355362892151
loss: 0.8506/0.7840, accuracy: 0.7446707487106323
loss: 0.7779/0.7689, accuracy: 0.7437233328819275
loss: 0.6850/0.7163, accuracy: 0.7555660605430603
loss: 0.6405/0.6719, accuracy: 0.7967787384986877
loss: 0.5848/0.5811, accuracy: 0.8185693621635437
loss: 0.5447/0.5724, accuracy: 0.8247275948524475
loss: 0.5150/0.5863, accuracy: 0.8072003722190857
loss: 0.5047/0.5098, accuracy: 0.8484130501747131
loss: 0.5051/0.5457, accuracy: 0.8389388918876648
loss: 0.5287/0.5465, accuracy: 0.83799147605896
loss: 0.5018/0.4883, accuracy: 0.8441497087478638
loss: 0.4636/0.5006, accuracy: 0.8422548174858093
loss: 0.4648/0.4571, accuracy: 0.8526764512062073
loss: 0.4488/0.4703, accuracy: 0.8488867878913879
loss: 0.4358/0.4151, accuracy: 0.8668877482414246
loss: 0.4106/0.4121, accuracy: 0.8754144906997681
loss: 0.4022/0.3951, accuracy: 0.8735196590423584
lo

In [15]:
run_training(.005)
evaluate()

loss: 0.1498/0.2759, accuracy: 0.9289436340332031
loss: 0.1490/0.3002, accuracy: 0.9279962182044983
loss: 0.1591/0.3121, accuracy: 0.9213642477989197
loss: 0.1564/0.2459, accuracy: 0.9327332973480225
loss: 0.1355/0.2874, accuracy: 0.9237328171730042
loss: 0.1340/0.2824, accuracy: 0.9317858815193176
loss: 0.1303/0.2495, accuracy: 0.930364727973938
loss: 0.1211/0.2977, accuracy: 0.9232590794563293
loss: 0.1277/0.2824, accuracy: 0.9218379855155945
loss: 0.1311/0.2843, accuracy: 0.91899573802948
loss: 0.1476/0.2829, accuracy: 0.930364727973938
loss: 0.1160/0.2864, accuracy: 0.924680233001709
loss: 0.1385/0.2764, accuracy: 0.9185220003128052
loss: 0.1316/0.3633, accuracy: 0.9109426736831665
loss: 0.1427/0.3409, accuracy: 0.9223116636276245
loss: 0.1426/0.2757, accuracy: 0.9317858815193176
loss: 0.1293/0.2847, accuracy: 0.9232590794563293
loss: 0.1234/0.2787, accuracy: 0.9171009063720703
loss: 0.1359/0.2794, accuracy: 0.9185220003128052
loss: 0.1265/0.2708, accuracy: 0.9294173121452332
loss:

In [16]:
run_training(.001)
evaluate()

loss: 0.0932/0.3237, accuracy: 0.9218379855155945
loss: 0.1021/0.3049, accuracy: 0.9227854013442993
loss: 0.0884/0.2823, accuracy: 0.9317858815193176
loss: 0.0880/0.3076, accuracy: 0.9294173121452332
loss: 0.0952/0.3134, accuracy: 0.9251539707183838
loss: 0.0939/0.2785, accuracy: 0.924680233001709
loss: 0.1073/0.2992, accuracy: 0.9270488023757935
loss: 0.0946/0.3212, accuracy: 0.9204168319702148
loss: 0.1064/0.3024, accuracy: 0.91946941614151
loss: 0.1339/0.2847, accuracy: 0.9270488023757935
loss: 0.1314/0.3114, accuracy: 0.9237328171730042
loss: 0.1248/0.2660, accuracy: 0.924206554889679
loss: 0.1049/0.2905, accuracy: 0.9284698963165283
loss: 0.0990/0.3143, accuracy: 0.9284698963165283
loss: 0.1048/0.2734, accuracy: 0.9275224804878235
loss: 0.1104/0.2771, accuracy: 0.9275224804878235
loss: 0.0968/0.2697, accuracy: 0.9261013865470886
loss: 0.0900/0.2736, accuracy: 0.9208905696868896
loss: 0.0879/0.2829, accuracy: 0.9332069754600525
loss: 0.0862/0.3049, accuracy: 0.9270488023757935
loss

In [17]:
run_training(.0003)
evaluate()

loss: 0.0732/0.3688, accuracy: 0.9351018071174622
loss: 0.0612/0.3783, accuracy: 0.9308384656906128
loss: 0.0611/0.4077, accuracy: 0.9294173121452332
loss: 0.0624/0.3761, accuracy: 0.9265750646591187
loss: 0.0705/0.3926, accuracy: 0.924206554889679
loss: 0.0711/0.4089, accuracy: 0.9289436340332031
loss: 0.0786/0.3873, accuracy: 0.929891049861908
loss: 0.0784/0.3687, accuracy: 0.9251539707183838
loss: 0.0872/0.3352, accuracy: 0.9327332973480225
loss: 0.0681/0.3543, accuracy: 0.9261013865470886
loss: 0.0662/0.3891, accuracy: 0.9294173121452332
loss: 0.0715/0.3364, accuracy: 0.930364727973938
loss: 0.0738/0.3625, accuracy: 0.930364727973938
loss: 0.0715/0.3185, accuracy: 0.9379441142082214
loss: 0.0624/0.4086, accuracy: 0.9261013865470886
loss: 0.0836/0.3941, accuracy: 0.9327332973480225
loss: 0.0685/0.4314, accuracy: 0.924206554889679
loss: 0.0783/0.4375, accuracy: 0.924680233001709
loss: 0.0874/0.3845, accuracy: 0.924680233001709
loss: 0.0813/0.3529, accuracy: 0.935575544834137
loss: 0.

In [18]:
run_training(.0001)
evaluate()

loss: 0.0523/0.3418, accuracy: 0.929891049861908
loss: 0.0526/0.2841, accuracy: 0.9327332973480225
loss: 0.0593/0.3207, accuracy: 0.924206554889679
loss: 0.0659/0.3269, accuracy: 0.930364727973938
loss: 0.0588/0.3029, accuracy: 0.9317858815193176
loss: 0.0704/0.2870, accuracy: 0.9308384656906128
loss: 0.0655/0.3244, accuracy: 0.9308384656906128
loss: 0.0974/0.3332, accuracy: 0.9204168319702148
loss: 0.0919/0.2927, accuracy: 0.9294173121452332
loss: 0.1024/0.3103, accuracy: 0.9308384656906128
loss: 0.0864/0.2799, accuracy: 0.9251539707183838
loss: 0.0714/0.2891, accuracy: 0.9317858815193176
loss: 0.0728/0.3267, accuracy: 0.9294173121452332
loss: 0.0546/0.3160, accuracy: 0.9313121438026428
loss: 0.0600/0.3655, accuracy: 0.924680233001709
loss: 0.0847/0.3685, accuracy: 0.9265750646591187
loss: 0.0792/0.3087, accuracy: 0.9332069754600525
loss: 0.0814/0.3181, accuracy: 0.9227854013442993
loss: 0.0847/0.3401, accuracy: 0.9256276488304138
loss: 0.0725/0.2966, accuracy: 0.9317858815193176
loss