# Pairwise MLP approach

Get tracksters from a certain neighbourhood.

Train a NN to decide whether two tracksters should be joined.

Neighbourhood:
- get links from ticlNtuplizer/graph
    - figure out how these links are formed
- convert the tracksters into some latent space and predict a link between them
- later extend this using edgeconv or sageconf to add information from the neighbourhood

Graph:
- linked_inners
    - nodes linked to the given tracksters within its cone


## MLP

In [1]:
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import random_split, DataLoader


from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score, precision_score, recall_score

from reco.data_utils import TracksterPairs

In [29]:
# Apple silicon setup
# this ensures that the current MacOS version is at least 12.3+
print(torch.backends.mps.is_available())

True


In [38]:
# device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
device = "cpu"    # torch mps implementation sucks
print(f"Using device: {device}")

Using device: cpu


In [39]:
ds = TracksterPairs("data", N_FILES=5)
ds[0]

(tensor([  80.8828, -108.8683,  380.0638,   60.2691,   75.7790, -109.8341,
          337.5883,  353.5663]),
 tensor(1.))

In [40]:
ds.x = torch.nn.functional.normalize(ds.x, dim=0)

In [41]:
ds.x[0]

tensor([ 0.0099, -0.0140,  0.0093,  0.0140,  0.0097, -0.0147,  0.0086,  0.0460])

In [42]:
class PairWiseMLP(torch.nn.Module):
    def __init__(self, num_inputs, num_hidden=10):
        super(PairWiseMLP, self).__init__()

        self.W1 = nn.Linear(num_inputs, num_hidden)
        self.activation = nn.Sigmoid()
        self.W2 = nn.Linear(num_hidden, num_hidden)
        self.W3 = nn.Linear(num_hidden, 1)
        self.output = nn.Sigmoid()

    def forward(self, data):
        x = self.W1(data)
        x = self.activation(x)
        x = self.W2(x)
        x = self.activation(x)
        x = self.W3(x)
        return self.output(x)

In [43]:
loss_obj = torch.nn.BCELoss()

def train(model, opt, loader):
    epoch_loss = 0
    for batch, labels in loader:
        model.train()
        batch = batch.to(device)
        labels = labels.to(device)
        opt.zero_grad()
        z = model(batch).reshape(-1)
        loss = loss_obj(z, labels)
        epoch_loss += loss
        loss.backward()
        opt.step()
    return float(epoch_loss)

@torch.no_grad()
def test(model, data):
    total = 0
    correct = 0
    for batch, labels in data:
        model.eval()
        batch = batch.to(device)
        labels = labels.to(device)
        z = model(batch).reshape(-1)
        prediction = (z > 0.5).type(torch.int)
        total += len(prediction) 
        correct += sum(prediction == labels.type(torch.int))
    return (correct / total)

In [44]:
ds_size = len(ds)
test_set_size = ds_size // 10
train_set_size = ds_size - test_set_size
train_set, test_set = random_split(ds, [train_set_size, test_set_size])
print(f"Train samples: {len(train_set)}, Test samples: {len(test_set)}")

train_dl = DataLoader(train_set, batch_size=32, shuffle=True)
test_dl = DataLoader(test_set, batch_size=32, shuffle=True)

Train samples: 9470, Test samples: 1052


In [45]:
model = PairWiseMLP(ds.x.shape[1], 128)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
scheduler = StepLR(optimizer, step_size=50, gamma=0.5)
test_acc = test(model, test_dl)
print(f"Initial acc: {test_acc:.4f}")

for epoch in range(1, 1000):
    loss = train(model, optimizer, train_dl)
    scheduler.step()
    if epoch % 10 == 0:
        train_acc = test(model, train_dl)
        test_acc = test(model, test_dl)
        print(f'Epoch: {epoch}, loss: {loss:.4f}, train acc: {train_acc:.4f}, test acc: {test_acc:.4f}')

Initial acc: 0.5048
Epoch: 10, loss: 219.9350, train acc: 0.4995, test acc: 0.5048
Epoch: 20, loss: 229.2296, train acc: 0.5005, test acc: 0.4952
Epoch: 30, loss: 231.3980, train acc: 0.4995, test acc: 0.5048
Epoch: 40, loss: 230.3350, train acc: 0.5005, test acc: 0.4952
Epoch: 50, loss: 234.3919, train acc: 0.4995, test acc: 0.5048
Epoch: 60, loss: 219.2107, train acc: 0.4995, test acc: 0.5048
Epoch: 70, loss: 220.5267, train acc: 0.4995, test acc: 0.5048
Epoch: 80, loss: 216.2666, train acc: 0.5005, test acc: 0.4952
Epoch: 90, loss: 217.2358, train acc: 0.4995, test acc: 0.5048
Epoch: 100, loss: 213.6661, train acc: 0.5005, test acc: 0.4952
Epoch: 110, loss: 214.4244, train acc: 0.5005, test acc: 0.4952
Epoch: 120, loss: 212.5400, train acc: 0.5005, test acc: 0.4952
Epoch: 130, loss: 212.4735, train acc: 0.4995, test acc: 0.5048
Epoch: 140, loss: 214.9870, train acc: 0.4995, test acc: 0.5048
Epoch: 150, loss: 210.6201, train acc: 0.4995, test acc: 0.5048
Epoch: 160, loss: 208.4183, t

In [24]:
pred = []
lab = []
for b in test_dl:
    pred += (model(b) > 0.5).type(torch.int).tolist()
    lab += b.y.tolist()

tn, fp, fn, tp = confusion_matrix(lab, pred).ravel()
print(f"TP: {tp}, TN: {tn}, FP: {fp}, FN: {fn}")
print(f'Accuracy: {accuracy_score(lab, pred):.4f}')
print(f'Precision: {precision_score(lab, pred):.4f}')
print(f'Recall: {recall_score(lab, pred):.4f}')

TP: 44, TN: 81, FP: 31, FN: 68
Accuracy: 0.5580
Precision: 0.5867
Recall: 0.3929
