Week 8 Notebook: Extending the Model
===============================================================

This week, we will look at graph neural networks using the PyTorch Geometric library: <https://pytorch-geometric.readthedocs.io/>. See {cite}`PyTorchGeometric` for more details.

In [1]:
import torch_geometric

In [2]:
# 48 track-level features
features = ['track_pt',
            'track_ptrel',
            'trackBTag_Eta',
            'trackBTag_DeltaR',
            'trackBTag_EtaRel',
            'trackBTag_JetDistVal',
            'trackBTag_Momentum',
            'trackBTag_PPar',
            'trackBTag_PParRatio',
            'trackBTag_PtRatio',
            'trackBTag_PtRel',
            'trackBTag_Sip2dSig',
            'trackBTag_Sip2dVal',
            'trackBTag_Sip3dSig',
            'trackBTag_Sip3dVal',
            'track_VTX_ass',
            'track_charge',
            'track_deltaR',
            'track_detadeta',
            'track_dlambdadz',
            'track_dlambdadz',
            'track_dphidphi',
            'track_dphidxy',
            'track_dptdpt',
            'track_drminsv',
            'track_drsubjet1',
            'track_drsubjet2',
            'track_dxy',
            'track_dxydxy',
            'track_dxydz',
            'track_dxysig',
            'track_dz',
            'track_dzdz',        
            'track_dzsig',
            'track_erel',
            'track_etarel',
            'track_fromPV',
            'track_isChargedHad',
            'track_isEl',
            'track_isMu',
            'track_lostInnerHits',
            'track_mass',
            'track_normchi2',            
            'track_phirel',
            'track_pt',
            'track_ptrel',
            'track_puppiw',
            'track_quality']

# spectators to define mass/pT window
spectators = ['fj_sdmass',
              'fj_pt']

# 2 labels: QCD or Hbb (we'll reduce the following labels)
labels =  ['label_QCD_b',
           'label_QCD_bb',
           'label_QCD_c', 
           'label_QCD_cc', 
           'label_QCD_others',
           'sample_isQCD',
           'label_H_bb']

nfeatures = len(features)
nspectators = len(spectators)

In [3]:
from GraphDataset import GraphDataset

In [4]:
graph_dataset = GraphDataset('data', features, labels, spectators, n_events=10000, n_events_merge=1000)

Processing...


  0%|          | 15/10000 [00:00<01:09, 144.56it/s]

21
35
24
32
28
24
25
32
49
31
40
42
16
34
25
42
25
21
32
39
12
31
31
29
15
18
37
16
17
15
16
32
32
26
19


  1%|          | 54/10000 [00:00<01:00, 164.06it/s]

10
19
10
27
30
27
26
37
11
29
11
14
43
39
27
42
41
31
27
12
21
31
40
16
15
39
25
24
33
25
29
31
36
12
40
23
42


  1%|          | 72/10000 [00:00<00:59, 166.51it/s]

15
30
22
25
18
20
45
36
18
84


  1%|          | 102/10000 [00:00<01:26, 114.17it/s]

9
9
23
32
25
28
21
27
21
27
22
49
31
35
29
24
21
24
57
16
62
21
37
17
31
35
22
55
42
35


  1%|▏         | 129/10000 [00:00<01:22, 119.12it/s]

25
51
9
26
32
36
22
33
21
13
75
33
15
26
12
27
41
13
31
14
32
24
15
21
24
23
19
26
42
21
32
27
45


  2%|▏         | 162/10000 [00:01<01:17, 126.20it/s]

12
15
70
49
35
18
31
19
28
18
29
73
22
33
55
35
19
11
54
17
22
39
48
28


  2%|▏         | 195/10000 [00:01<01:11, 137.02it/s]

29
23
32
16
18
12
20
32
21
26
28
15
11
66
32
21
15
49
38
38
30
22
33
26
27
38
17
22
12
25
26
21
25
38
22
29


  2%|▏         | 229/10000 [00:01<01:04, 151.47it/s]

54
20
30
60
19
22
35
12
9
26
34
22
29
31
7
30
34
32
38
32
18
16
32
13
38
29
18
37
27
16
21
81


  3%|▎         | 264/10000 [00:01<01:02, 155.83it/s]

26
39
25
27
23
18
26
39
22
23
21
52
16
9
25
24
19
30
54
22
22
12
29
23
18
23
23
33
23
29
11
23
20
63
28
41


  3%|▎         | 297/10000 [00:02<01:03, 153.17it/s]

17
55
19
32
25
30
23
19
43
18
30
22
22
35
26
29
28
20
48
17
4
51
33
26
20
11
36
51
23
21
26
33


  3%|▎         | 316/10000 [00:02<01:09, 139.32it/s]

23
9
21
16
42
19
5
23
14
33
118
26
14
16
19
27
26
37
23
40
72
17
30


  3%|▎         | 347/10000 [00:02<01:06, 144.25it/s]

34
21
9
39
31
35
24
26
16
32
22
17
37
4
56
42
13
17
20
39
57
21
43
35
7
68
24
25
28


  4%|▍         | 377/10000 [00:02<01:12, 132.86it/s]

19
25
21
34
27
18
26
13
46
28
11
7
24
25
43
26
41
6
17
91
46
27
19
41
25
16
15
33
11
27


  4%|▍         | 417/10000 [00:02<01:02, 153.49it/s]

30
25
32
36
26
20
12
73
18
34
30
20
43
18
28
20
5
25
31
20
24
39
25
11
25
24
36
34
15
26
15
23
53
31
17
21
43


  5%|▍         | 458/10000 [00:03<00:55, 173.26it/s]

23
18
33
26
36
25
28
16
18
30
27
27
31
9
37
32
27
16
22
75
26
10
28
32
21
18
28
17
15
7
27
15
14
28
15
17
37
22
31
43
15
12
24


  5%|▌         | 507/10000 [00:03<00:47, 200.74it/s]

10
17
19
17
26
24
17
14
32
29
21
21
24
20
27
25
37
22
8
26
33
10
31
26
21
37
17
28
20
18
20
10
25
36
30
37
30
18
39
15
13
6
32
27
35
28
30
23
18


  5%|▌         | 548/10000 [00:03<00:50, 188.50it/s]

56
34
19
32
21
30
43
32
27
24
36
36
11
7
43
23
24
35
22
23
23
18
33
21
19
17
69
38
7
20
16
46
22
26
35


  6%|▌         | 594/10000 [00:03<00:47, 196.30it/s]

12
22
17
20
29
7
26
23
27
11
23
19
38
15
39
28
12
28
32
14
20
33
25
12
23
41
57
37
29
15
25
23
17
46
23
33
26
23
17
27
20
33
23


  6%|▌         | 615/10000 [00:03<00:51, 181.29it/s]

19
16
24
32
33
18
57
29
22
18
32
30
39
14
30
47
14
38
21
52
46
26
50
27
25
19
17
30
61
20
55


  7%|▋         | 652/10000 [00:04<00:56, 165.07it/s]

25
18
46
20
38
17
14
30
16
15
25
12
49
28
25
46
46
32
23
35
27
20
38
19
30
19
56
66
27
23
22
26


  7%|▋         | 669/10000 [00:04<01:06, 139.94it/s]

23
22
8
32
58
23
26
14
25
43
27
104
23
10
32
18
41
31
26
38
32
29
20
22
25
31
46


  7%|▋         | 708/10000 [00:04<00:57, 161.51it/s]

21
27
14
16
35
28
24
34
24
27
10
45
23
8
45
28
37
10
18
10
35
11
17
56
14
25
21
60
17
51
25
19
25
19
28
30
28
25
16


  7%|▋         | 748/10000 [00:04<00:52, 177.56it/s]

24
44
13
21
22
16
26
47
30
22
24
38
44
25
25
30
16
19
28
21
32
16
27
9
28
23
23
80
32
28
21
16
19
14
39
26
17
12


  8%|▊         | 787/10000 [00:04<00:51, 180.53it/s]

19
8
17
27
14
46
13
32
52
15
34
12
14
21
29
20
16
9
25
29
45
13
26
14
44
35
22
37
20
29
44
28
23
15
21
22
24
29
19
29
32


  8%|▊         | 827/10000 [00:05<00:50, 181.40it/s]

13
22
29
8
50
41
32
32
32
41
19
10
37
30
31
26
33
14
24
37
26
24
54
30
9
10
36
17
22
27
23
23
11
21
15
22
23
40
16
31


  9%|▊         | 872/10000 [00:05<00:47, 191.75it/s]

32
30
11
30
16
30
31
24
19
36
46
25
24
33
31
28
20
16
17
19
46
16
25
22
13
14
70
17
23
6
28
20
13
14
25
15
61
26
12
33


  9%|▉         | 911/10000 [00:05<00:53, 170.81it/s]

26
19
19
23
20
40
50
60
12
12
38
55
44
27
48
22
6
10
34
22
19
45
17
19
32
34
29
54
16
26
15
37
44
13
55
8
14
37
21
47
16
11
11
13
16
34


  9%|▉         | 948/10000 [00:05<01:05, 138.19it/s]

25
21
35
15
42
27
17
37
18
28
15
25
14
15
19
36
8
67
26
31
19
58
24
21
22
15
14
10
33
29
66
19
26
23
21
21
38


 10%|▉         | 986/10000 [00:06<00:57, 156.81it/s]

48
35
12
30
34
24
26
35
22
25
18
19
40
30
34
24
14
25
38
35
23
20
16
22
29
43
11
22
50
21
17
28
30
24
22


 10%|█         | 1024/10000 [00:06<01:11, 125.11it/s]

26
18
13
20
27
27
14
18
26
48
30
19
32
25
14
37
37
24
23
19
8
50
14
50
41
17
16
24
15
31
36
27
17
33
22
16
77
56


 11%|█         | 1058/10000 [00:06<01:03, 140.39it/s]

11
39
22
25
23
31
31
21
18
46
37
28
12
37
39
34
30
30
26
34
15
35
39
31
15
17
19
25
14
19
12
26
27
21
35
34
36
33
39


 11%|█         | 1100/10000 [00:06<00:52, 168.52it/s]

30
24
13
43
19
23
26
22
25
26
16
20
26
38
22
63
21
12
6
10
19
30
36
19
67
12
31
76
26
32
37
19
34


 11%|█▏        | 1143/10000 [00:07<00:50, 176.82it/s]

19
28
9
19
18
31
17
36
4
16
13
50
23
7
23
36
20
23
25
20
9
19
23
33
21
20
15
26
26
29
21
29
57
38
34
36
26
16
15
12
30
50
19
16
22


 12%|█▏        | 1189/10000 [00:07<00:44, 199.81it/s]

32
21
12
16
18
48
15
24
21
44
41
11
15
18
23
28
39
12
19
12
21
20
35
14
17
20
18
30
21
16
31
27
12
18
14
28
41
35
28
37
37
35
19
22
29
33


 12%|█▏        | 1230/10000 [00:07<00:45, 190.97it/s]

14
24
20
24
48
32
26
25
50
22
19
21
21
60
27
40
14
25
22
43
14
28
35
23
22
30
20
17
24
26
26
28
37
21
23
22
40
27


 13%|█▎        | 1271/10000 [00:07<00:48, 178.27it/s]

19
25
13
52
32
15
24
19
24
12
37
42
18
21
27
40
42
27
62
22
17
25
17
25
12
13
50
26
18
33
27
31
23
51


 13%|█▎        | 1290/10000 [00:08<00:54, 159.99it/s]

25
33
35
23
37
45
31
19
45
37
35
21
31
11
57
10
33
18
56
20
52
45
24
25
12
24
30
51


 13%|█▎        | 1325/10000 [00:08<00:58, 148.18it/s]

30
21
66
21
5
41
15
24
24
25
29
30
11
40
20
34
48
25
17
11
10
34
6
62
24
33
20
32
21
19
18
25


 14%|█▎        | 1357/10000 [00:08<00:57, 149.11it/s]

37
45
25
49
25
21
32
14
36
22
27
32
39
8
23
27
24
22
26
15
17
39
16
69
35
37
17
19
18
21
19
38


 14%|█▍        | 1398/10000 [00:08<00:50, 171.88it/s]

19
17
22
24
32
24
12
12
23
29
12
40
24
14
37
23
53
11
32
13
40
23
7
21
33
19
7
15
26
24
10
24
29
6
14
15
13
15
15
36
27
23
22
51


 14%|█▍        | 1437/10000 [00:08<00:52, 162.12it/s]

29
23
16
23
22
17
34
28
38
52
15
26
20
33
27
26
18
29
10
23
31
33
30
25
67
48
14
37
44
25


 15%|█▍        | 1454/10000 [00:09<00:56, 151.74it/s]

17
32
32
31
49
55
16
35
35
38
23
27
23
19
31
17
13
32
10
39
6
13
41
14
34
44
29
18
41
18
64


 15%|█▍        | 1489/10000 [00:09<00:55, 154.52it/s]

21
10
39
14
30
22
50
30
23
9
30
25
24
29
14
20
43
16
51
24
23
46
17
31
25
23
11
18
30
30
22
48
46


 15%|█▌        | 1522/10000 [00:09<00:54, 156.11it/s]

12
19
20
17
23
24
9
21
25
30
33
22
18
25
22
42
33
72
24
16
32
44
23
30
22
21
11
16
20
11
30
25
18
16
52
29


 16%|█▌        | 1562/10000 [00:09<00:50, 167.15it/s]

28
30
25
23
8
23
8
38
19
29
25
15
19
28
15
41
30
12
26
18
15
30
69
40
23
7
36
30
34
32
27
41
18
8
24
26


 16%|█▌        | 1596/10000 [00:09<00:51, 161.89it/s]

15
43
21
27
17
22
47
22
54
36
30
22
19
33
28
13
21
30
13
32
36
15
40
8
26
26
38
13
35
22
24
24
27
10


 16%|█▋        | 1634/10000 [00:10<00:52, 158.89it/s]

35
43
6
14
18
8
38
23
15
16
27
31
16
14
22
33
28
50
25
66
17
60
22
29
36
28
28
30
32
19
22
62
23
93
22
17
30
17
73


 17%|█▋        | 1670/10000 [00:10<01:14, 111.27it/s]

10
32
27
18
27
36
35
29
33
29
19
17
12
24
21
20
12
25
20
15
53
32
21
27
20
39
21
28
28
21
38
25
19
26
28
42


 17%|█▋        | 1703/10000 [00:10<01:05, 126.33it/s]

24
23
21
34
28
19
29
35
28
16
16
74
41
16
25
34
22
39
12
28
28
23
34
17
39
27
25
42
23
8
31


 17%|█▋        | 1741/10000 [00:11<00:56, 145.60it/s]

14
38
16
39
24
25
16
26
17
25
8
26
7
30
18
13
18
37
25
31
12
29
33
6
18
77
38
19
23
32
26
35
43


 18%|█▊        | 1772/10000 [00:11<01:00, 135.07it/s]

26
45
17
33
14
21
25
91
17
33
25
15
33
40
36
22
12
41
17
26
42
26
42
41
24
21


 18%|█▊        | 1791/10000 [00:11<00:57, 142.33it/s]

23
27
44
15
35
24
33
21
11
42
4
10
37
25
21
21
49
37
43
38
26
26
32
12
27
24
23
20
21
13
41
27


 18%|█▊        | 1824/10000 [00:11<00:55, 146.14it/s]

60
15
11
28
33
36
30
30
53
22
33
18
25
40
11
14
10
37
33
29
43
5
32
30
42
14
28
20
39
27
33


 19%|█▊        | 1857/10000 [00:11<00:57, 140.76it/s]

7
20
29
28
31
25
21
14
15
13
65
28
24
30
40
59
43
24
33
35
57
22
14
31
31
21
13
29
31


 19%|█▉        | 1887/10000 [00:12<01:02, 129.38it/s]

17
55
21
39
18
25
60
47
21
53
19
28
23
26
5
23
58
39
42
15
33
24
32


 19%|█▉        | 1915/10000 [00:12<01:02, 128.71it/s]

69
13
32
26
71
34
23
25
13
33
31
26
43
21
24
12
12
19
15
50
45
21
59
28
57
21


 20%|█▉        | 1952/10000 [00:12<00:53, 149.33it/s]

18
35
52
17
19
27
15
29
35
16
37
31
31
21
11
8
19
26
55
53
34
31
8
22
23
23
19
35
34
35
29
12
32
21
32
23
33


 20%|█▉        | 1986/10000 [00:12<00:50, 158.92it/s]

21
18
30
9
21
38
23
16
33
30
60
52
27
7
22
48
34
40
34
20
36
17
10
17
21
32
17
29
26
29
53
24
24
36
4
65
21
17
40
12
53
43
31
54
25
19
14
35


 20%|██        | 2026/10000 [00:13<01:05, 122.38it/s]

17
16
5
24
35
18
49
36
21
27
28
18
20
29
28
21
23
28
26
18
12
20
26
23
27
24
27
23
20
14
26
23
16
14
17
9
51
26
14
46
33
15
15
48
31
40


 21%|██        | 2064/10000 [00:13<00:55, 143.79it/s]

36
39
11
25
22
24
39
63
11
32
23
16
37
29
50
32
23
26
25
27
29
27
31
27
43
37
4
59
23
24
13
42
21
10
27


 21%|██        | 2099/10000 [00:13<00:53, 148.59it/s]

29
25
52
18
21
29
76
54
41
14
18
30
36
19
34
17
25
12
33
11
19
20
48
20
25
35
16
37
28
21
34
14
68


 21%|██▏       | 2135/10000 [00:13<00:49, 159.46it/s]

22
34
23
12
24
66
16
15
30
19
25
22
43
43
38
29
11
20
23
24
38
36
21
40
21
18
18
29
52
34
7
38
24
30
25
19
22


 22%|██▏       | 2175/10000 [00:14<00:44, 175.63it/s]

16
31
35
12
18
15
35
42
16
20
28
15
31
13
42
26
11
31
30
30
46
27
33
27
13
16
12
41
34
27
31
32
92


 22%|██▏       | 2217/10000 [00:14<00:42, 181.06it/s]

19
23
21
26
20
20
28
32
17
38
29
11
33
29
28
19
34
27
36
33
9
18
22
24
32
33
10
15
27
30
15
22
39
13
23
30
17
24
40
19
28
29
29
11
31
38


 23%|██▎       | 2261/10000 [00:14<00:39, 198.23it/s]

34
34
48
14
19
17
33
29
16
23
31
26
9
23
26
28
8
26
22
20
16
24
12
16
24
29
56
14
35
22
14
27
17
23
16
16
23
14
17
32
13
34
39
32
14
26
25


 23%|██▎       | 2304/10000 [00:14<00:38, 197.48it/s]

44
29
35
13
38
26
27
34
25
22
22
17
9
31
32
17
11
23
22
16
30
28
43
13
27
63
40
35
29
26
22
28
19
21
22
23
28
52
35


 23%|██▎       | 2345/10000 [00:14<00:42, 182.04it/s]

27
34
16
16
16
32
12
54
30
31
52
19
40
31
35
19
37
20
22
28
40
24
24
13
36
40
14
24
31
26
35
20
10
42
25


 24%|██▍       | 2388/10000 [00:15<00:39, 191.41it/s]

19
19
21
13
41
35
36
17
11
35
18
25
12
12
21
9
37
22
17
15
28
30
34
17
15
14
23
18
32
5
11
41
36
11
83
16
18
15
11
6
25
40
16
35


 24%|██▍       | 2413/10000 [00:15<00:36, 205.28it/s]

31
20
26
24
34
22
22
34
33
26
16
31
13
20
14
17
28
19
31
15
55
10
17
70
35
37
48
32
38
46
42
28
12
19
17
21


 25%|██▍       | 2455/10000 [00:15<00:40, 184.37it/s]

10
28
35
5
35
12
22
24
37
13
30
21
35
13
46
26
41
11
15
36
13
26
33
44
77
20
15
24
23
16
31
33
44
39


 25%|██▍       | 2498/10000 [00:15<00:38, 192.79it/s]

41
15
19
23
27
21
8
19
27
22
15
17
34
11
40
31
30
35
32
12
5
13
21
11
18
45
18
20
32
15
29
22
21
21
26
10
31
22
35
28
25
9
29
26
14
35
15
17
18
17
32
20


 25%|██▌       | 2545/10000 [00:15<00:38, 193.12it/s]

45
23
11
21
21
12
48
23
24
13
24
29
36
23
41
18
43
24
28
34
31
37
44
17
29
27
20
42
32
27
68


 26%|██▌       | 2565/10000 [00:16<00:58, 126.59it/s]

21
68
35
23
14
15
19
29
7
24
8
34
20
58
44
31
20
34
14
16
29
10
20
11
26
28
14
30
23
45
28
18
7
42
14
20
46
16
40


 26%|██▌       | 2605/10000 [00:16<00:50, 147.44it/s]

19
17
29
19
37
38
42
24
20
28
42
31
18
33
59
47
10
32
16
37
26
32
44
35
27
36
45
21
23
11
38
24
29


 26%|██▋       | 2646/10000 [00:16<00:44, 165.64it/s]

30
24
43
17
52
14
23
23
23
25
15
28
24
21
19
12
21
34
28
12
23
13
31
25
40
40
22
56
36
35
25
22
20
27
22
45
34


 27%|██▋       | 2683/10000 [00:16<00:44, 163.09it/s]

46
31
10
27
20
11
18
31
39
37
43
28
32
14
26
36
37
31
35
21
26
41
16
32
24
38
14
28
29
33
25
70
35


 27%|██▋       | 2720/10000 [00:17<00:45, 161.75it/s]

60
9
28
11
20
28
22
33
46
62
14
30
28
15
6
55
14
17
21
26
25
12
37
22
15
15
9
51
19
31
23
16
37
20
26
14
22
33


 28%|██▊       | 2757/10000 [00:17<00:44, 161.92it/s]

21
19
27
13
33
39
27
20
15
13
47
11
15
27
14
28
34
8
44
32
31
23
46
14
71
33
22
18
10
55


 28%|██▊       | 2774/10000 [00:17<00:44, 160.70it/s]

12
32
36
23
14
58
23
31
14
20
19
19
19
16
33
18
24
21
27
28
40
41
29
24
68
45
19
27
14
51


 28%|██▊       | 2806/10000 [00:17<00:50, 142.10it/s]

11
23
58
28
18
33
22
32
63
19
35
36
19
23
18
40
26
29
26
37
14
14
31
16
27
17
19
47
14
22
18
21


 28%|██▊       | 2828/10000 [00:17<00:45, 158.46it/s]


41
24
22
10
24
6
35
19
0


ValueError: need at least one array to stack