# Pairwise MLP approach

Get tracksters from a certain neighbourhood.

Train a NN to decide whether two tracksters should be joined.

Neighbourhood:
- get links from ticlNtuplizer/graph
    - figure out how these links are formed
- convert the tracksters into some latent space and predict a link between them
- later extend this using edgeconv or sageconf to add information from the neighbourhood

Graph:
- linked_inners
    - nodes linked to the given tracksters within its cone


## MLP

In [1]:
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import random_split, DataLoader


from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score, precision_score, recall_score

from reco.dataset import TracksterPairs

In [2]:
# Apple silicon setup
# this ensures that the current MacOS version is at least 12.3+
print(torch.backends.mps.is_available())

True


In [3]:
# device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
device = "cpu"    # torch mps implementation sucks
print(f"Using device: {device}")

Using device: cpu


In [27]:
ds = TracksterPairs("data", N_FILES=5)
ds[0]

(tensor([ 3.4465e+01, -8.5131e+01,  3.6889e+02,  5.3535e+00,  1.3882e+00,
          1.4811e+02,  9.5811e-01,  1.4442e-01,  1.0263e-01, -1.9299e-01,
          9.7582e-01,  4.1557e+01, -7.9505e+01,  3.5487e+02,  2.9358e+00,
          2.9358e+00,  3.3207e+01,  2.7202e+00,  5.3345e-01,  1.2542e-01,
         -2.0203e-01,  9.7132e-01]),
 tensor(1.))

In [45]:
ds.x = torch.nn.functional.normalize(ds.x, p=torch.inf, dim=0)
ds.x

tensor([[ 0.1578, -0.4362,  0.7193,  ...,  0.1288, -0.2052,  0.9713],
        [ 0.1449, -0.4608,  0.8003,  ...,  0.1619, -0.1848,  0.9706],
        [ 0.4275, -0.5957,  0.8044,  ...,  0.2237, -0.1173,  0.9691],
        ...,
        [-0.1268, -0.2916,  0.7328,  ...,  0.3227, -0.0717,  0.9467],
        [-0.1150, -0.3350,  0.8049,  ...,  0.0397, -0.2202,  0.9754],
        [-0.6361, -0.0812,  0.7855,  ..., -0.3802,  0.0235,  0.9286]])

In [46]:
print("dataset balance:", float(sum(ds.y) / len(ds.y))) 

dataset balance: 0.5


In [47]:
class PairWiseMLP(torch.nn.Module):
    def __init__(self, num_inputs, num_hidden=10):
        super(PairWiseMLP, self).__init__()

        self.W1 = nn.Linear(num_inputs, num_hidden)
        self.activation = nn.ReLU()
        self.W2 = nn.Linear(num_hidden, num_hidden)
        self.W3 = nn.Linear(num_hidden, 1)
        self.output = nn.Sigmoid()

    def forward(self, data):
        x = self.W1(data)
        x = self.activation(x)
        x = self.W2(x)
        x = self.activation(x)
        x = self.W3(x)
        return self.output(x)

In [48]:
loss_obj = torch.nn.BCELoss()

def train(model, opt, loader):
    epoch_loss = 0
    for batch, labels in loader:
        model.train()
        batch = batch.to(device)
        labels = labels.to(device)
        opt.zero_grad()
        z = model(batch).reshape(-1)
        loss = loss_obj(z, labels)
        epoch_loss += loss
        loss.backward()
        opt.step()
    return float(epoch_loss)

@torch.no_grad()
def test(model, data):
    total = 0
    correct = 0
    for batch, labels in data:
        model.eval()
        batch = batch.to(device)
        labels = labels.to(device)
        z = model(batch).reshape(-1)
        prediction = (z > 0.5).type(torch.int)
        total += len(prediction) 
        correct += sum(prediction == labels.type(torch.int))
    return (correct / total)

In [49]:
ds_size = len(ds)
test_set_size = ds_size // 10
train_set_size = ds_size - test_set_size
train_set, test_set = random_split(ds, [train_set_size, test_set_size])
print(f"Train samples: {len(train_set)}, Test samples: {len(test_set)}")

train_dl = DataLoader(train_set, batch_size=32, shuffle=True)
test_dl = DataLoader(test_set, batch_size=32, shuffle=True)

Train samples: 9470, Test samples: 1052


In [50]:
model = PairWiseMLP(ds.x.shape[1], 10)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = StepLR(optimizer, step_size=50, gamma=0.5)
test_acc = test(model, test_dl)
print(f"Initial acc: {test_acc:.4f}")

for epoch in range(1, 201):
    loss = train(model, optimizer, train_dl)
    # scheduler.step()
    if epoch % 50 == 0:
        train_acc = test(model, train_dl)
        test_acc = test(model, test_dl)
        print(f'Epoch: {epoch}, loss: {loss:.4f}, train acc: {train_acc:.4f}, test acc: {test_acc:.4f}')

Initial acc: 0.5019
Epoch: 50, loss: 205.2773, train acc: 0.5020, test acc: 0.4876
Epoch: 100, loss: 204.9308, train acc: 0.5051, test acc: 0.5095
Epoch: 150, loss: 204.6276, train acc: 0.5045, test acc: 0.5067
Epoch: 200, loss: 204.3797, train acc: 0.5026, test acc: 0.4886


In [51]:
pred = []
lab = []
for b, l in test_dl:
    pred += (model(b) > 0.5).type(torch.int).tolist()
    lab += l.tolist()

tn, fp, fn, tp = confusion_matrix(lab, pred).ravel()
print(f"TP: {tp}, TN: {tn}, FP: {fp}, FN: {fn}")
print(f'Accuracy: {accuracy_score(lab, pred):.4f}')
print(f'Precision: {precision_score(lab, pred):.4f}')
print(f'Recall: {recall_score(lab, pred):.4f}')

TP: 2, TN: 512, FP: 2, FN: 536
Accuracy: 0.4886
Precision: 0.5000
Recall: 0.0037
