In [None]:
import sys
import os

sys.path.append(os.path.abspath('src'))

In [None]:
from model.labels_manager import LabelsManager

labels_manager = LabelsManager()

In [None]:
%%capture
# labels_manager.save_labels_to_csv()

In [None]:
from labels_printer import LabelsPrinter

labels_printer = LabelsPrinter(labels_manager)

In [None]:
labels_printer.print_labels_distribution(fig_size=(10,4))

In [None]:
labels_printer.print_labels_distribution_over_game_intervals(fig_size=(10,5))

In [None]:
from model.data_loading import DataLoading

batch_size = 32
fps = 2
chunk_length = 60

train_dataloading = DataLoading(labels_manager, "SoccerNet", fps, chunk_length, batch_size, split_type="train", context_aware=False)
val_dataloading = DataLoading(labels_manager, "SoccerNet", fps, chunk_length, batch_size, split_type="valid", context_aware=False)

train_loader = train_dataloading.get_dataloader()
val_loader = val_dataloading.get_dataloader()


In [None]:
from torch.utils.data import DataLoader
from model.neuron_network import NeuronNetwork

classifier = NeuronNetwork(input_dim=512, num_classes=17)

In [None]:
from model.training import Trainer

trainer = Trainer(classifier, train_loader, val_loader, epochs=200, batch_size=batch_size, prediction_threshold=0.5, context_aware=False)
# trainer.load_checkpoint("weights/model_0_1.pth")

In [None]:
trainer.train()

In [None]:
trainer.plot_training_loss()

In [None]:
trainer.save_checkpoint("weights/model_0_1.pth")

In [None]:
from model.event_spotting import EventSpotter

test_dataloading = DataLoading(labels_manager, "SoccerNet", 1, chunk_length=1, batch_size=1, split_type="test")

In [None]:
video_name = test_dataloading.video_names[0]

features, labels = test_dataloading.load_features_labels(video_name, half=1)

print(features.shape)
print(labels.shape)

print(features[:5])
print(labels[:5])

event_spotter = EventSpotter(labels_manager, model=classifier, fps=fps, detection_threshold=0.8, nms_window=60, delta=360)

event_spotter.detect_events(features)

raw_predictions = event_spotter.get_predictions()
final_events = event_spotter.get_events()

ground_truth = {i: labels[i] for i in range(len(labels))}  # Format ground truth
print(f"Ground truth: {ground_truth}")
evaluation = event_spotter.evaluate_predictions(ground_truth)

print("Raw Predictions:", raw_predictions)
print("Final Detected Events:", final_events)
print("Evaluation Metrics:", evaluation)

event_spotter.show_predictions_summary(ground_truth, True)