# Pairwise MLP approach

Get tracksters from a certain neighbourhood.

Train a NN to decide whether two tracksters should be joined.

Neighbourhood:
- get links from ticlNtuplizer/graph
    - figure out how these links are formed
- convert the tracksters into some latent space and predict a link between them
- later extend this using edgeconv or sageconf to add information from the neighbourhood

Graph:
- linked_inners
    - nodes linked to the given tracksters within its cone


## MLP

In [1]:
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import random_split, DataLoader


from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score, precision_score, recall_score

from reco.dataset import TracksterPairs

In [11]:
# Apple silicon setup
# this ensures that the current MacOS version is at least 12.3+
print(torch.backends.mps.is_available())

True


In [12]:
# device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
device = "cpu"    # torch mps implementation sucks
print(f"Using device: {device}")

Using device: cpu


In [13]:
ds = TracksterPairs("data", N_FILES=10)
ds[0]

Processing: /Users/ecuba/data/multiparticle_complet/new_ntuples_14992862_2834.root
Processing: /Users/ecuba/data/multiparticle_complet/new_ntuples_14992862_542.root
Processing: /Users/ecuba/data/multiparticle_complet/new_ntuples_14992862_112.root
Processing: /Users/ecuba/data/multiparticle_complet/new_ntuples_14992862_2137.root
Processing: /Users/ecuba/data/multiparticle_complet/new_ntuples_14992862_10.root
Processing: /Users/ecuba/data/multiparticle_complet/new_ntuples_14992862_2567.root
Processing: /Users/ecuba/data/multiparticle_complet/new_ntuples_14992862_2588.root
Processing: /Users/ecuba/data/multiparticle_complet/new_ntuples_14992862_2072.root
Processing: /Users/ecuba/data/multiparticle_complet/new_ntuples_14992862_954.root
Processing: /Users/ecuba/data/multiparticle_complet/new_ntuples_14992862_2422.root


(tensor([ 4.8522e+01, -8.2125e+01,  3.6825e+02,  1.5845e+00,  3.9181e-01,
          3.6443e+01,  2.6332e+00,  2.1716e-01,  1.0894e-01, -1.3928e-01,
          9.8424e-01,  4.2209e+01, -8.8453e+01,  3.6038e+02,  2.7677e+00,
          2.7677e+00,  1.3071e+01,  1.9458e+00,  1.0452e+00,  3.4449e-01,
          8.5402e-02,  9.3490e-01,  9.4797e+00]),
 tensor(1.))

In [14]:
ds.x = torch.nn.functional.normalize(ds.x, p=torch.inf, dim=0)
ds.x

tensor([[ 2.2209e-01, -4.2084e-01,  7.1810e-01,  ...,  8.6727e-02,
          9.3490e-01,  9.4800e-01],
        [ 1.8108e-01, -4.2039e-01,  7.5715e-01,  ..., -2.0122e-01,
          9.8001e-01,  3.8386e-01],
        [ 2.5283e-01, -5.1010e-01,  7.9072e-01,  ..., -1.0279e-01,
          9.7088e-01,  8.0178e-01],
        ...,
        [ 1.7773e-02,  4.1718e-01,  8.2109e-01,  ...,  1.4606e-01,
          9.8944e-01,  9.7887e-01],
        [-9.7574e-03,  4.2095e-01,  7.4703e-01,  ..., -6.9531e-02,
          9.4099e-01,  3.7493e-01],
        [-6.8441e-04,  3.8074e-01,  6.9604e-01,  ...,  3.1176e-01,
          8.6479e-01,  5.7305e-01]])

In [15]:
print("dataset balance:", float(sum(ds.y) / len(ds.y))) 

dataset balance: 0.5


In [16]:
class PairWiseMLP(torch.nn.Module):
    def __init__(self, num_inputs, num_hidden=10):
        super(PairWiseMLP, self).__init__()

        self.W1 = nn.Linear(num_inputs, num_hidden)
        self.activation = nn.ReLU()
        self.W2 = nn.Linear(num_hidden, num_hidden)
        self.W3 = nn.Linear(num_hidden, 1)
        self.output = nn.Sigmoid()

    def forward(self, data):
        x = self.W1(data)
        x = self.activation(x)
        x = self.W2(x)
        x = self.activation(x)
        x = self.W3(x)
        return self.output(x)

In [17]:
loss_obj = torch.nn.BCELoss()

def train(model, opt, loader):
    epoch_loss = 0
    for batch, labels in loader:
        model.train()
        batch = batch.to(device)
        labels = labels.to(device)
        opt.zero_grad()
        z = model(batch).reshape(-1)
        loss = loss_obj(z, labels)
        epoch_loss += loss
        loss.backward()
        opt.step()
    return float(epoch_loss)

@torch.no_grad()
def test(model, data):
    total = 0
    correct = 0
    for batch, labels in data:
        model.eval()
        batch = batch.to(device)
        labels = labels.to(device)
        z = model(batch).reshape(-1)
        prediction = (z > 0.5).type(torch.int)
        total += len(prediction) 
        correct += sum(prediction == labels.type(torch.int))
    return (correct / total)

In [18]:
ds_size = len(ds)
test_set_size = ds_size // 10
train_set_size = ds_size - test_set_size
train_set, test_set = random_split(ds, [train_set_size, test_set_size])
print(f"Train samples: {len(train_set)}, Test samples: {len(test_set)}")

train_dl = DataLoader(train_set, batch_size=32, shuffle=True)
test_dl = DataLoader(test_set, batch_size=32, shuffle=True)

Train samples: 18845, Test samples: 2093


In [21]:
model = PairWiseMLP(ds.x.shape[1], 10)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=50, gamma=0.5)
test_acc = test(model, test_dl)
print(f"Initial acc: {test_acc:.4f}")

for epoch in range(1, 1001):
    loss = train(model, optimizer, train_dl)
    # scheduler.step()
    if epoch % 50 == 0:
        train_acc = test(model, train_dl)
        test_acc = test(model, test_dl)
        print(f'Epoch: {epoch}, loss: {loss:.4f}, train acc: {train_acc:.4f}, test acc: {test_acc:.4f}')

Initial acc: 0.5088
Epoch: 50, loss: 407.5300, train acc: 0.5147, test acc: 0.4864
Epoch: 100, loss: 406.4890, train acc: 0.5219, test acc: 0.4849
Epoch: 150, loss: 406.0963, train acc: 0.5296, test acc: 0.5021
Epoch: 200, loss: 405.3450, train acc: 0.5301, test acc: 0.4921
Epoch: 250, loss: 404.9729, train acc: 0.5356, test acc: 0.4950


In [51]:
pred = []
lab = []
for b, l in test_dl:
    pred += (model(b) > 0.5).type(torch.int).tolist()
    lab += l.tolist()

tn, fp, fn, tp = confusion_matrix(lab, pred).ravel()
print(f"TP: {tp}, TN: {tn}, FP: {fp}, FN: {fn}")
print(f'Accuracy: {accuracy_score(lab, pred):.4f}')
print(f'Precision: {precision_score(lab, pred):.4f}')
print(f'Recall: {recall_score(lab, pred):.4f}')

TP: 2, TN: 512, FP: 2, FN: 536
Accuracy: 0.4886
Precision: 0.5000
Recall: 0.0037
