In [276]:
import pandas as pd
import ast
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.data import Data, DataLoader
from torch_geometric.utils import dense_to_sparse
from torch_geometric.nn import GCNConv
import networkx as nx
from ast import literal_eval

In [329]:
# Load data
signal_df = pd.read_csv('Dijet_bb_pt10_15_dw.csv')
background_df = pd.read_csv('Dijet_qq_pt10_15_dw.csv')

In [330]:
# Distinguish signal and background
signal_df['IsB']=1
background_df['IsB']=0

In [331]:
# Separate Jet 0 and Jet 1 data
sig_jet0 = signal_df[signal_df.columns[signal_df.columns.str.contains("Jet0|IsB")]]
back_jet0 = background_df[background_df.columns[background_df.columns.str.contains("Jet0|IsB")]]

sig_jet1 = signal_df[signal_df.columns[signal_df.columns.str.contains("Jet1|IsB")]]
back_jet1 = background_df[background_df.columns[background_df.columns.str.contains("Jet1|IsB")]]

# Combine signal and background
train_df = pd.concat([sig_jet0, back_jet0])
test_df = pd.concat([sig_jet1, back_jet1])

In [332]:
train_df

Unnamed: 0,Jet0_ENDVERTEX_X,Jet0_ENDVERTEX_Y,Jet0_ENDVERTEX_Z,Jet0_ENDVERTEX_XERR,Jet0_ENDVERTEX_YERR,Jet0_ENDVERTEX_ZERR,Jet0_ENDVERTEX_CHI2,Jet0_ENDVERTEX_NDOF,Jet0_OWNPV_X,Jet0_OWNPV_Y,...,Jet0_Hlt1Phys_Dec,Jet0_Hlt1Phys_TIS,Jet0_Hlt1Phys_TOS,Jet0_Hlt2Global_Dec,Jet0_Hlt2Global_TIS,Jet0_Hlt2Global_TOS,Jet0_Hlt2Phys_Dec,Jet0_Hlt2Phys_TIS,Jet0_Hlt2Phys_TOS,IsB
0,0.8049,-0.1442,9.6248,0.0167,0.0164,0.1036,11.296007,29,0.805862,-0.145482,...,True,True,True,True,True,True,True,True,True,1
1,0.8049,-0.1442,9.6248,0.0167,0.0164,0.1036,11.296007,29,0.805862,-0.145482,...,True,True,True,True,True,True,True,True,True,1
2,0.8049,-0.1442,9.6248,0.0167,0.0164,0.1036,11.296007,29,0.805862,-0.145482,...,True,True,True,True,True,True,True,True,True,1
3,0.8049,-0.1442,9.6248,0.0167,0.0164,0.1036,11.296007,29,0.866242,-0.195181,...,True,True,True,True,True,True,True,True,True,1
4,0.8049,-0.1442,9.6248,0.0167,0.0164,0.1036,11.296007,29,0.866242,-0.195181,...,True,True,True,True,True,True,True,True,True,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
99995,0.8785,-0.2188,-49.0015,0.0079,0.0079,0.0377,43.331200,115,0.878667,-0.219578,...,True,True,True,True,True,True,True,True,True,0
99996,0.8205,-0.2224,-1.1421,0.0301,0.0256,0.2178,3.674259,11,0.758976,-0.186395,...,False,False,False,True,False,False,True,False,False,0
99997,0.8411,-0.1376,-18.3968,0.0143,0.0141,0.0655,12.950173,27,0.848029,-0.184660,...,True,True,True,True,True,True,True,True,True,0
99998,0.8411,-0.1376,-18.3968,0.0143,0.0141,0.0655,12.950173,27,0.847748,-0.185888,...,True,True,True,True,True,True,True,True,False,0


In [452]:
train_df[train_df.columns[train_df.columns.str.contains("Daughters")]]

Unnamed: 0,Jet0_nDaughters,Jet0_Daughters_E,Jet0_Daughters_pT,Jet0_Daughters_ID,Jet0_Daughters_pX,Jet0_Daughters_pY,Jet0_Daughters_pZ,Jet0_Daughters_Eta,Jet0_Daughters_Phi,Jet0_Daughters_Q,...,Jet0_Daughters_trackX,Jet0_Daughters_trackY,Jet0_Daughters_trackZ,Jet0_Daughters_trackVX,Jet0_Daughters_trackVY,Jet0_Daughters_trackVZ,Jet0_Daughters_CaloNeutralEcal,Jet0_Daughters_CaloNeutralHcal2Ecal,Jet0_Daughters_CaloNeutralE49,Jet0_Daughters_CaloNeutralPrs
0,11,"[6641.2001953125, 13513.37109375, 25441.533203...","[226.46542358398438, 235.54550170898438, 526.2...","[22.0, 211.0, 211.0, 22.0, -211.0, -11.0, -211...","[77.6133041381836, 25.076343536376953, 82.6589...","[212.75045776367188, 234.20687866210938, 519.7...","[6637.33740234375, 13510.5966796875, 25435.707...","[4.071311950683594, 4.74254846572876, 4.571417...","[1.2209899425506592, 1.4641335010528564, 1.413...","[0.0, 1.0, 1.0, 0.0, -1.0, 1.0, -1.0, 0.0, 1.0...",...,"[-1000.0, 0.8226000070571899, 1.01010000705719...","[-1000.0, -0.07240000367164612, -0.14920000731...","[-1000.0, 11.181599617004395, -44.063899993896...","[-1000.0, 25.076343536376953, 82.6589279174804...","[-1000.0, 234.20687866210938, 519.701354980468...","[-1000.0, 13510.5966796875, 25435.70703125, -1...","[6934.40380859375, -1000.0, -1000.0, 13473.778...","[1.0405641794204712, -1000.0, -1000.0, 1.81851...","[0.9214292764663696, -1000.0, -1000.0, 0.86976...","[15.359217643737793, -1000.0, -1000.0, 105.046..."
1,11,"[6641.2001953125, 13513.37109375, 25441.533203...","[226.46542358398438, 235.54550170898438, 526.2...","[22.0, 211.0, 211.0, 22.0, -211.0, -11.0, -211...","[77.6133041381836, 25.076343536376953, 82.6589...","[212.75045776367188, 234.20687866210938, 519.7...","[6637.33740234375, 13510.5966796875, 25435.707...","[4.071311950683594, 4.74254846572876, 4.571417...","[1.2209899425506592, 1.4641335010528564, 1.413...","[0.0, 1.0, 1.0, 0.0, -1.0, 1.0, -1.0, 0.0, 1.0...",...,"[-1000.0, 0.8226000070571899, 1.01010000705719...","[-1000.0, -0.07240000367164612, -0.14920000731...","[-1000.0, 11.181599617004395, -44.063899993896...","[-1000.0, 25.076343536376953, 82.6589279174804...","[-1000.0, 234.20687866210938, 519.701354980468...","[-1000.0, 13510.5966796875, 25435.70703125, -1...","[6934.40380859375, -1000.0, -1000.0, 13473.778...","[1.0405641794204712, -1000.0, -1000.0, 1.81851...","[0.9214292764663696, -1000.0, -1000.0, 0.86976...","[15.359217643737793, -1000.0, -1000.0, 105.046..."
2,11,"[6641.2001953125, 13513.37109375, 25441.533203...","[226.46542358398438, 235.54550170898438, 526.2...","[22.0, 211.0, 211.0, 22.0, -211.0, -11.0, -211...","[77.6133041381836, 25.076343536376953, 82.6589...","[212.75045776367188, 234.20687866210938, 519.7...","[6637.33740234375, 13510.5966796875, 25435.707...","[4.071311950683594, 4.74254846572876, 4.571417...","[1.2209899425506592, 1.4641335010528564, 1.413...","[0.0, 1.0, 1.0, 0.0, -1.0, 1.0, -1.0, 0.0, 1.0...",...,"[-1000.0, 0.8226000070571899, 1.01010000705719...","[-1000.0, -0.07240000367164612, -0.14920000731...","[-1000.0, 11.181599617004395, -44.063899993896...","[-1000.0, 25.076343536376953, 82.6589279174804...","[-1000.0, 234.20687866210938, 519.701354980468...","[-1000.0, 13510.5966796875, 25435.70703125, -1...","[6934.40380859375, -1000.0, -1000.0, 13473.778...","[1.0405641794204712, -1000.0, -1000.0, 1.81851...","[0.9214292764663696, -1000.0, -1000.0, 0.86976...","[15.359217643737793, -1000.0, -1000.0, 105.046..."
3,18,"[15737.859375, 8034.22216796875, 13624.5898437...","[723.5347290039062, 376.4151916503906, 631.710...","[211.0, 211.0, 22.0, -211.0, 11.0, -11.0, -211...","[464.20758056640625, 174.55453491210938, 143.1...","[-554.9899291992188, -333.4952697753906, -615....","[15720.5986328125, 8024.18603515625, 13609.937...","[3.7722549438476562, 3.7532196044921875, 3.763...","[-0.8742361068725586, -1.0885971784591675, -1....","[1.0, 1.0, 0.0, -1.0, -1.0, 1.0, -1.0, -1.0, -...",...,"[0.0348999984562397, 0.5769000053405762, -1000...","[0.028599999845027924, 0.1696999967098236, -10...","[-37.759700775146484, -59.09410095214844, -100...","[464.20758056640625, 174.55453491210938, -1000...","[-554.9899291992188, -333.4952697753906, -1000...","[15720.5986328125, 8024.18603515625, -1000.0, ...","[-1000.0, -1000.0, 13334.91015625, -1000.0, -1...","[-1000.0, -1000.0, 0.08267438411712646, -1000....","[-1000.0, -1000.0, 0.9412699341773987, -1000.0...","[-1000.0, -1000.0, 36.75241470336914, -1000.0,..."
4,18,"[15737.859375, 8034.22216796875, 13624.5898437...","[723.5347290039062, 376.4151916503906, 631.710...","[211.0, 211.0, 22.0, -211.0, 11.0, -11.0, -211...","[464.20758056640625, 174.55453491210938, 143.1...","[-554.9899291992188, -333.4952697753906, -615....","[15720.5986328125, 8024.18603515625, 13609.937...","[3.7722549438476562, 3.7532196044921875, 3.763...","[-0.8742361068725586, -1.0885971784591675, -1....","[1.0, 1.0, 0.0, -1.0, -1.0, 1.0, -1.0, -1.0, -...",...,"[0.0348999984562397, 0.5769000053405762, -1000...","[0.028599999845027924, 0.1696999967098236, -10...","[-37.759700775146484, -59.09410095214844, -100...","[464.20758056640625, 174.55453491210938, -1000...","[-554.9899291992188, -333.4952697753906, -1000...","[15720.5986328125, 8024.18603515625, -1000.0, ...","[-1000.0, -1000.0, 13334.91015625, -1000.0, -1...","[-1000.0, -1000.0, 0.08267438411712646, -1000....","[-1000.0, -1000.0, 0.9412699341773987, -1000.0...","[-1000.0, -1000.0, 36.75241470336914, -1000.0,..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
99995,14,"[562.1334228515625, 252.59033203125, 177.37800...","[161.11512756347656, 73.49575805664062, 55.020...","[-22.0, -22.0, 22.0, 211.0, 211.0, 22.0, -11.0...","[-158.51673889160156, -71.37385559082031, -51....","[-28.81884765625, -17.53278350830078, -20.2006...","[538.5498046875, 241.66143798828125, 168.62904...","[1.9215672016143799, 1.9058173894882202, 1.838...","[-2.9617536067962646, -2.900714874267578, -2.7...","[0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, -1.0, 1.0,...",...,"[-1000.0, -1000.0, -1000.0, 0.4255999922752380...","[-1000.0, -1000.0, -1000.0, -0.229100003838539...","[-1000.0, -1000.0, -1000.0, -43.99909973144531...","[-1000.0, -1000.0, -1000.0, -107.1642837524414...","[-1000.0, -1000.0, -1000.0, -36.28998947143555...","[-1000.0, -1000.0, -1000.0, 1016.1304321289062...","[-1000.0, -1000.0, 177.3780059814453, -1000.0,...","[-1000.0, -1000.0, 1.0579288005828857, -1000.0...","[-1000.0, -1000.0, 1.096325397491455, -1000.0,...","[-1000.0, -1000.0, 1.3713587522506714, -1000.0..."
99996,24,"[883.2999877929688, 1688.60791015625, 3160.659...","[248.16073608398438, 321.23675537109375, 425.2...","[22.0, -211.0, 22.0, 211.0, -211.0, 310.0, -22...","[173.06446838378906, 193.11441040039062, 399.8...","[177.8551025390625, 256.709716796875, 144.7807...","[847.7235107421875, 1651.8848876953125, 3131.9...","[1.9423915147781372, 2.3399641513824463, 2.694...","[0.7990489602088928, 0.9258455038070679, 0.347...","[0.0, -1.0, 0.0, 1.0, -1.0, 0.0, 0.0, 0.0, -1....",...,"[-1000.0, 0.48339998722076416, -1000.0, 0.4483...","[-1000.0, -0.3709999918937683, -1000.0, -0.482...","[-1000.0, 0.250900000333786, -1000.0, -0.90630...","[-1000.0, 193.11441040039062, -1000.0, 1177.78...","[-1000.0, 256.709716796875, -1000.0, 1061.9003...","[-1000.0, 1651.8848876953125, -1000.0, 10423.3...","[746.2540283203125, -1000.0, 3245.3056640625, ...","[0.0, -1000.0, 0.0, -1000.0, -1000.0, -1000.0,...","[0.8760475516319275, -1000.0, 0.97900903224945...","[10.970870018005371, -1000.0, 0.0, -1000.0, -1..."
99997,24,"[2899.510009765625, 2247.7099609375, 5763.1201...","[620.09521484375, 347.87799072265625, 893.6690...","[22.0, 22.0, -211.0, 22.0, 22.0, 22.0, 22.0, 2...","[618.8517456054688, 344.859375, 887.1937866210...","[39.25105667114258, -45.728546142578125, -107....","[2832.426513671875, 2220.626220703125, 5691.69...","[2.223935604095459, 2.552919626235962, 2.55068...","[0.06334078311920166, -0.13183148205280304, -0...","[0.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",...,"[-1000.0, -1000.0, -0.0210999995470047, -1000....","[-1000.0, -1000.0, -0.09319999814033508, -1000...","[-1000.0, -1000.0, -35.232398986816406, -1000....","[-1000.0, -1000.0, 887.1937866210938, -1000.0,...","[-1000.0, -1000.0, -107.384765625, -1000.0, -1...","[-1000.0, -1000.0, 5691.6982421875, -1000.0, -...","[-1000.0, 2088.239990234375, -1000.0, 794.9787...","[-1000.0, 0.0, -1000.0, 0.0, 0.0, 0.0, 0.0, 3....","[-1000.0, 0.9859955310821533, -1000.0, 0.88109...","[-1000.0, 22.490283966064453, -1000.0, 22.4902..."
99998,7,"[17476.078125, 14764.8095703125, 172354.03125,...","[417.58673095703125, 457.4699401855469, 3613.6...","[-211.0, -211.0, 211.0, -22.0, 22.0, -211.0, -...","[-11.1947660446167, -105.61746215820312, -2159...","[417.4366455078125, 445.1108703613281, 2897.67...","[17470.529296875, 14757.060546875, 172316.0781...","[4.42706823348999, 4.1671528816223145, 4.55786...","[1.5976077318191528, 1.8037711381912231, 2.211...","[-1.0, -1.0, 1.0, 0.0, 0.0, -1.0, -1.0]",...,"[0.7827000021934509, 0.8055999875068665, 0.434...","[0.01209999993443489, 0.13120000064373016, 0.3...","[-22.350900650024414, -19.80270004272461, 3.52...","[-11.1947660446167, -105.61746215820312, -2159...","[417.4366455078125, 445.1108703613281, 2897.67...","[17470.529296875, 14757.060546875, 172316.0781...","[-1000.0, -1000.0, -1000.0, -1000.0, 14106.536...","[-1000.0, -1000.0, -1000.0, -1000.0, 10.835189...","[-1000.0, -1000.0, -1000.0, -1000.0, 0.8827957...","[-1000.0, -1000.0, -1000.0, -1000.0, 2.1941740..."


In [333]:
for x in train_df['Jet0_Daughters_ID'][0]:
    print(x)

[22.0, 211.0, 211.0, 22.0, -211.0, -11.0, -211.0, -22.0, 211.0, 211.0, -211.0]
[211.0, 22.0, 22.0, 310.0, 321.0, 22.0, 22.0, 321.0, 321.0, -22.0, 22.0, 211.0, 22.0, -211.0]


In [350]:
def convert(df):
    for col in df.columns:
        for i in df.index:
            cell_value = df.at[i, col]
            if isinstance(cell_value, str) and cell_value.startswith('['):
                cell_value = cell_value[1:]
                cell_value = cell_value.split(']')[0]
                list = [float(num) for num in cell_value.strip('[]').split(',')]
                df.at[i, col] = list
    return df

In [426]:
mylist = [[1,2,3,4,5]]*len(train_df)
train_df['list'] = mylist

In [427]:
train_df['list']

0        [1, 2, 3, 4, 5]
1        [1, 2, 3, 4, 5]
2        [1, 2, 3, 4, 5]
3        [1, 2, 3, 4, 5]
4        [1, 2, 3, 4, 5]
              ...       
99995    [1, 2, 3, 4, 5]
99996    [1, 2, 3, 4, 5]
99997    [1, 2, 3, 4, 5]
99998    [1, 2, 3, 4, 5]
99999    [1, 2, 3, 4, 5]
Name: list, Length: 200000, dtype: object

In [444]:
def process_cell(cell_value):
    if isinstance(cell_value, str) and cell_value.startswith('['):
        cell_value = cell_value[1:]
        cell_value = cell_value.strip('[]').split(',')
    return cell_value

def rewrite_data(df, output_file):
    df = df.apply(process_cell)
    df.to_csv(output_file, index=False)

In [445]:
rewrite_data(train_df, 'train_data.csv')
df = pd.read_csv('train_data.csv')

In [447]:
df.Jet0_Daughters_ID[0]

'[22.0, 211.0, 211.0, 22.0, -211.0, -11.0, -211.0, -22.0, 211.0, 211.0, -211.0]'

In [450]:
df.Jet0_Daughters_ID[0][0]

'['

In [454]:
df.Jet0_Daughters_ID[0].strip('[]').split(',')

['22.0',
 ' 211.0',
 ' 211.0',
 ' 22.0',
 ' -211.0',
 ' -11.0',
 ' -211.0',
 ' -22.0',
 ' 211.0',
 ' 211.0',
 ' -211.0']

In [455]:
df.Jet0_Daughters_ID[0].strip('[]').split(',')[0]

'22.0'

In [None]:
df2 = rewrite_data(test_df, 'train_data.csv')

In [431]:
df2 = pd.read_csv('test_data.csv')

In [457]:
df2.Jet1_Daughters_ID[0]

'[211.0, 211.0, 22.0, -211.0, 11.0, -11.0, -211.0, -211.0, -211.0, -211.0, 211.0, 321.0, 3122.0, 211.0, 211.0, 22.0, 22.0, -211.0]'

In [None]:
def load_data(df, nodes, features_cols):
    data_list=[]
    
    for i, row in df.iterrows():
        num_nodes = row[nodes]
        features = [row[col] for col in features_cols]
        num_features = len(features)
        data.x = torch.Tensor([num_nodes, num_features])
        adj = torch.ones((num_nodes, num_nodes))
        edge_index = dense_to_sparse(adj)[0]
        y = torch.ones(num_nodes, dtype=torch.long)
        train_mask = torch.ones(num_nodes, dtype=torch.bool)
        test_mask = torch.ones(num_nodes, dtype=torch.bool)
        data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, test_mask=test_mask, num_features=num_features, num_nodes=num_nodes)
        data_list.append(data)
    return data_list

In [88]:
def load_data(df, features_cols):
    data_list = []
    
    for i, row in df.iterrows():
        all_features = []

        for col in features_cols:
            feature_values_str = row[col]
            feature_values = eval(feature_values_str) 
            
            features = [float(value) for value in feature_values]
            all_features.extend(features) 
            
        x = torch.tensor(all_features, dtype=torch.float32)
        features = torch.tensor(all_features, dtype=torch.int64)
        
        num_features = len(features_cols)
        num_nodes = len(all_features) // num_features
        
        adj = torch.ones((num_nodes, num_nodes))
        edge_index = dense_to_sparse(adj)[0]
        
        y = torch.ones(num_nodes, dtype=torch.long)
        train_mask = torch.ones(num_nodes, dtype=torch.bool)
        test_mask = torch.ones(num_nodes, dtype=torch.bool)
        
        data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, test_mask=test_mask, num_features=num_features, num_nodes=num_nodes)
        data.num_features = num_features
        data.num_nodes = num_nodes
        data_list.append(data)
    
    return data_list

In [None]:
train_data = load_data(df, ['Jet0_Daughters_Eta', 'Jet0_Daughters_Phi', 'Jet0_Daughters_pT'])
test_data = load_data(df2, ['Jet1_Daughters_Eta', 'Jet1_Daughters_Phi', 'Jet1_Daughters_pT'])

In [119]:
def graph(df, features_cols):
    G = nx.Graph()

    for i, row in df.iterrows():
        nodes = []
        for col in features_cols:
            nodes.extend(ast.literal_eval(row[col]))
        nodes = list(set(nodes))
        for node in nodes:
            G.add_node(node)
    
    node_features_list = []

    for node in G.nodes():
        node_features = []
        for col in features_cols:
            feature_value = df.loc[df[node_col].apply(lambda x: node in ast.literal_eval(x)), col].values[0]
            node_features.append(feature_value)
        node_features_list.append(np.array(node_features).astype(np.float32))

    node_features_np = np.array(node_features_list)
    node_features_tensor = torch.tensor(node_features_np)

    for i, node in enumerate(G.nodes()): 
        G.nodes[node]['x'] = node_features_tensor[i]
    
    data = from_networkx(G)

    data.y = torch.ones(data.num_nodes, dtype=torch.long)
    data.train_mask = torch.ones(data.num_nodes, dtype=torch.bool)
    data.test_mask = torch.ones(data.num_nodes, dtype=torch.bool)

    return data

In [120]:
train_data = graph(train_df, ['Jet0_Daughters_Eta', 'Jet0_Daughters_Phi', 'Jet0_Daughters_pT'])
test_data = graph(test_df, ['Jet1_Daughters_Eta', 'Jet1_Daughters_Phi', 'Jet1_Daughters_pT'])

KeyboardInterrupt: 

In [58]:
train_data

[Data(x=[0], edge_index=[2, 1], y=1, train_mask=True, test_mask=True, num_features=0, num_nodes=1),
 Data(x=[0], edge_index=[2, 1], y=1, train_mask=True, test_mask=True, num_features=0, num_nodes=1),
 Data(x=[0], edge_index=[2, 1], y=1, train_mask=True, test_mask=True, num_features=0, num_nodes=1),
 Data(x=[0], edge_index=[2, 1], y=1, train_mask=True, test_mask=True, num_features=0, num_nodes=1),
 Data(x=[0], edge_index=[2, 1], y=1, train_mask=True, test_mask=True, num_features=0, num_nodes=1),
 Data(x=[0], edge_index=[2, 1], y=1, train_mask=True, test_mask=True, num_features=0, num_nodes=1),
 Data(x=[0], edge_index=[2, 1], y=1, train_mask=True, test_mask=True, num_features=0, num_nodes=1),
 Data(x=[0], edge_index=[2, 1], y=1, train_mask=True, test_mask=True, num_features=0, num_nodes=1),
 Data(x=[0], edge_index=[2, 1], y=1, train_mask=True, test_mask=True, num_features=0, num_nodes=1),
 Data(x=[0], edge_index=[2, 1], y=1, train_mask=True, test_mask=True, num_features=0, num_nodes=1),


In [136]:
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=True)



In [None]:
# Define GNN model
class GNN(torch.nn.Module):
    def __init__(self):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(3, 64)
        self.conv2 = GCNConv(64, 2)
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

model = GNN()

In [None]:
# Training parameters
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Training loop
def train(model, data, optimizer, criterion):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

for epoch in range(1):
    loss = train(model, train_data, optimizer, criterion)
    print(f'Epoch {epoch+1}, Loss: {loss}')

# Evaluation
def test(model, data):
    model.eval()
    _, pred = model(data).max(dim=1)
    correct = pred[data.test_mask].eq(data.y[data.test_mask]).sum().item()
    acc = correct / data.test_mask.sum().item()
    return acc

accuracy = test(model, test_data)
print(f'Accuracy: {accuracy}')

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got -2)