In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
import os

sys.path.append(os.path.abspath(".."))  # Ajoute event_detection_project au path

In [4]:
from data_utils import get_data
from vae import VAE
import numpy as np
import torch

In [None]:
# Extraire les images et les labels
images, labels = [], []

events = [
          #'Cards',
          #'Center'
    
          'Free-Kick',
          'To-Subtitue',
          'Corner',
          'Penalty',
          'Red-Cards',
          'Tackle',
          'Yellow-Cards'
         ]

train_images, train_labels = get_data(folder=f"{os.path.abspath("..")}/dataset/train", events=events)

In [None]:
train_images.shape

In [None]:
import torch.optim as optim
import torch.utils.data as data
import torch
import matplotlib.pyplot as plt
from vae_model.vae import VAE, recon_loss, vae_loss

# Chargement des données
train_dataset = train_images
train_loader = data.DataLoader(train_dataset, batch_size=32, shuffle=True)

# Initialisation du modèle et de l'optimiseur
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae_instance = VAE(latent_dim=400).to(device)
optimizer = optim.Adam(vae_instance.parameters(), lr=1e-3)

# Variables pour suivre les pertes
num_epochs = 100
loss_history = []
recon_loss_history = []

# Entraînement
vae_instance.train()
for epoch in range(num_epochs):
    total_loss = 0
    total_recon_loss = 0

    for imgs in train_loader:
        imgs = imgs.to(device)
        optimizer.zero_grad()

        # Forward pass
        recon_imgs, mu, logvar = vae_instance(imgs)

        # Calcul des pertes
        loss = vae_loss(recon_imgs, imgs, mu, logvar)
        r_loss = recon_loss(recon_imgs, imgs)

        # Backpropagation
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_recon_loss += r_loss.item()

    # Moyenne des pertes par batch
    avg_loss = total_loss / len(train_dataset)
    avg_recon_loss = total_recon_loss / len(train_dataset)

    loss_history.append(avg_loss)
    recon_loss_history.append(avg_recon_loss)

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Recon Loss: {avg_recon_loss:.4f}")

model_path = f"vae_model.pth"
torch.save(vae_instance.state_dict(), model_path)
print(f"Modèle sauvegardé sous {model_path}")

# Affichage du graphique
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs + 1), loss_history, label="Reconstruction + KL Loss", marker="o")
plt.plot(range(1, num_epochs + 1), recon_loss_history, label="Reconstruction Loss", marker="s")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("VAE Training Loss Evolution")
plt.legend()
plt.grid()
plt.show()

In [None]:
# threshold VAE loss: the value of 328 as the threshold for the loss gives the best distinction between categories

# the images of seven events(corner ok, penalty ok, free kick ok, red card, yellow card, tackle ok, substitute) defined from 
# the SEV dataset are selected and given to the VAE
# network as training data