In [35]:
import os
import math
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from utilities import import_data_files
import torch
import torch.nn as nn
import torch.optim as optim
import awkward as ak


class KernelSVM(nn.Module):
    def __init__(self, input_size, output_size=1):
        super(KernelSVM, self).__init__()
        self.linear = nn.Linear(1, output_size)  # Linear layer expects a scalar value as input for each pair

    def forward(self, x1, x2):
        # Compute the squared Euclidean distance (quadratic kernel) between x1 and x2
        diff = x1 - x2  # Difference between the two input vectors
        kernel = torch.sum(diff ** 2)  # Sum of squared differences

        # Return the kernel output as a scalar
        return self.linear(kernel.view(1, 1))  # Ensure that kernel is a 2D tensor (1, 1) for linear layer


# Data preparation
def prepare_data(DFs, train_ratio=0.7, val_ratio=0.15, seed=42):
    torch.manual_seed(seed)  # Ensure reproducibility

    print("Data loaded from files")
    print("Converting data to numpy array...")
    
    # Extracting data from awkward arrays
    accepted_numpy = ak.to_numpy(DFs[0]['SuperCell_ET'])
    rejected_numpy = ak.to_numpy(DFs[1]['SuperCell_ET'])
    
    print("Converting data to torch tensors...")
    accepted_tensor = torch.from_numpy(accepted_numpy).float()
    rejected_tensor = torch.from_numpy(rejected_numpy).float()
    
    print("Generating labels...")
    accepted_labels = torch.zeros(len(accepted_tensor), dtype=torch.float32)
    rejected_labels = torch.ones(len(rejected_tensor), dtype=torch.float32)
    
    # Concatenating data and labels
    print("Concatenating datasets...")
    data = torch.cat([accepted_tensor, rejected_tensor], dim=0)
    labels = torch.cat([accepted_labels, rejected_labels], dim=0)
    
    print("Shuffling data...")
    indices = torch.randperm(len(data))
    data = data[indices]
    labels = labels[indices]
    
    # Splitting data
    print("Splitting data...")
    dataset = torch.utils.data.TensorDataset(data, labels)
    train_size = int(train_ratio * len(dataset))
    val_size = int(val_ratio * len(dataset))
    test_size = len(dataset) - train_size - val_size
    
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size, test_size])
    
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
    
    return train_loader, val_loader, test_loader


# Hinge loss function
def hinge_loss(output, target):
    loss = torch.clamp(1 - output * target.view_as(output), min=0)
    return loss.mean()


#Training model
def train_model(train_loader, val_loader, input_size, learning_rate=0.01, num_epochs=10, pair_sample_size=100):
    model = KernelSVM(input_size=input_size, output_size=1)
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)
    
    model.train()
    epoch_losses = []
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for x_batch, y_batch in train_loader:
            batch_size = x_batch.size(0)
            
            # Subsample pairs randomly for efficiency
            pair_indices = torch.randint(0, batch_size, (pair_sample_size, 2))  # Sample k random pairs
            
            for idx1, idx2 in pair_indices:
                x1, y1 = x_batch[idx1], y_batch[idx1]
                x2, y2 = x_batch[idx2], y_batch[idx2]

                optimizer.zero_grad()
                output = model(x1, x2)  # Forward pass for pair
                target = 2 * (y1 == y2).float() - 1  # +1 for same class, -1 for different class

                loss = hinge_loss(output, target)
                loss.backward()
                optimizer.step()
                
                epoch_loss += loss.item()

        val_loss = validate_model(model, val_loader)
        epoch_losses.append((epoch_loss / len(train_loader), val_loss))
        
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_losses[-1][0]:.4f}, Val Loss: {val_loss:.4f}')
    
    model_save_path = f'svm_model_test_3_in{input_size}_lr{learning_rate}_ep{num_epochs}.pth'
    torch.save(model.state_dict(), model_save_path)
    print(f'Model saved to {model_save_path}')
    return model, epoch_losses


# Validation function
def validate_model(model, val_loader):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for x_batch, y_batch in val_loader:
            batch_size = x_batch.size(0)
            
            # Vectorized pairwise comparison
            for i in range(batch_size):
                for j in range(i + 1, batch_size):  # Pairwise comparison
                    x1, y1 = x_batch[i], y_batch[i]
                    x2, y2 = x_batch[j], y_batch[j]

                    output = model(x1, x2)
                    target = 2 * (y1 == y2).float() - 1  # +1 for same class, -1 for different class
                    val_loss += hinge_loss(output, target).item()

    return val_loss / len(val_loader)


def test_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        z=0
        for x_batch, y_batch in test_loader:
            batch_size = x_batch.size(0)
            z+=1
            # Initialize lists to store the pairwise outputs
            pairwise_outputs = []
            pairwise_targets = []

            # Generate pairwise comparisons
            for i in range(batch_size):
                for j in range(i + 1, batch_size):  # Ensure unique pairs (i, j)
                    x1, y1 = x_batch[i], y_batch[i]
                    x2, y2 = x_batch[j], y_batch[j]

                    # Forward pass on the pair
                    output = model(x1, x2)  # Get the model output for the pair
                    
                    # Define the target as +1 if the labels are the same, -1 if different
                    target = 2 * (y1 == y2).float() - 1
                    
                    # Make sure output and target are 1D tensors
                    pairwise_outputs.append(output.unsqueeze(0))  # Unsqueeze to make them 1D
                    pairwise_targets.append(target.unsqueeze(0))  # Unsqueeze to make them 1D
            # Convert lists to tensors
            pairwise_outputs = torch.cat(pairwise_outputs)
            pairwise_targets = torch.cat(pairwise_targets)

            # Make predictions (output > 0 is predicted as 1, else -1)
            predicted = (pairwise_outputs > 0).float()

            # Compute the number of correct predictions
            correct += (predicted == (pairwise_targets > 0)).sum().item()
            total += len(pairwise_targets)
            if z%10==0:
                print(f'Batch number [{z}/{len(test_loader)}]')
    accuracy = correct / total * 100
    print(f'Accuracy: {accuracy:.2f}%')



# Plot loss
def plot_loss(epoch_losses):
    train_losses, val_losses = zip(*epoch_losses)
    plt.plot(range(1, len(train_losses) + 1), train_losses, label='Train Loss', marker='o')
    plt.plot(range(1, len(val_losses) + 1), val_losses, label='Val Loss', marker='x')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss Over Epochs')
    plt.legend()
    plt.grid()
    plt.show()

In [9]:
DFs = import_data_files(["l1calo_hist_ZMUMU_extended.root", "l1calo_hist_EGZ_extended.root"])

In [None]:
train_loader, val_loader, test_loader = prepare_data(DFs)

In [16]:
input_size = train_loader.dataset[0][0].numel()  # Auto-determine input size from dataset

In [None]:
model, epoch_losses = train_model(train_loader, val_loader, input_size=input_size)

In [None]:
plot_loss(epoch_losses)

In [36]:
test_model(model, test_loader)

Batch number [10/542]
Batch number [20/542]
Batch number [30/542]
Batch number [40/542]
Batch number [50/542]
Batch number [60/542]
Batch number [70/542]
Batch number [80/542]
Batch number [90/542]
Batch number [100/542]
Batch number [110/542]
Batch number [120/542]
Batch number [130/542]
Batch number [140/542]
Batch number [150/542]
Batch number [160/542]
Batch number [170/542]
Batch number [180/542]
Batch number [190/542]
Batch number [200/542]
Batch number [210/542]
Batch number [220/542]
Batch number [230/542]
Batch number [240/542]
Batch number [250/542]
Batch number [260/542]
Batch number [270/542]
Batch number [280/542]
Batch number [290/542]
Batch number [300/542]
Batch number [310/542]
Batch number [320/542]
Batch number [330/542]
Batch number [340/542]
Batch number [350/542]
Batch number [360/542]
Batch number [370/542]
Batch number [380/542]
Batch number [390/542]
Batch number [400/542]
Batch number [410/542]
Batch number [420/542]
Batch number [430/542]
Batch number [440/54