## Load data

In [None]:
import sys
 
sys.path.insert(0, "../../")

In [None]:
from dataAnalysis.DataAnalysis import DataAnalysis
import pandas as pd

data = pd.read_csv(r"../../extdata/sbcdata.csv", header=0)
data_analysis = DataAnalysis(data, None)

In [None]:
import torch

y_train = torch.tensor(data_analysis.get_y_train(), dtype=torch.long)
X_train = torch.tensor(data_analysis.get_X_train(), dtype=torch.float)

y_test = torch.tensor(data_analysis.get_y_test(), dtype=torch.long)
X_test = torch.tensor(data_analysis.get_X_test(), dtype=torch.float)

y_gw_test = torch.tensor(data_analysis.get_y_gw(), dtype=torch.long)
X_gw_test = torch.tensor(data_analysis.get_X_gw(), dtype=torch.float)

## Normalize and Concatenate data

In [None]:
def normalize(tensor):
    mean = torch.mean(tensor, dim = 0)
    std = torch.std(tensor, dim = 0)
    mean_diff = tensor - mean
    return mean_diff / std

X_train = normalize(X_train)
X_test = normalize(X_test)
X_gw_test = normalize(X_gw_test)

In [None]:
y_all = torch.concat((y_train, y_test, y_gw_test))
X_all = torch.concat((X_train, X_test, X_gw_test))

## Train/Validation/Test splits

In [None]:
def true_indices_like(tensor):
    return torch.ones((tensor.shape[0])).type(torch.bool)

def false_indices_like(tensor):
    return torch.zeros((tensor.shape[0])).type(torch.bool)

def split(train_features):
    tensor = true_indices_like(train_features)
    max_index = round(tensor.shape[0] * .8)
    train = torch.zeros(tensor.shape[0])
    train[:max_index] = 1
    
    val = torch.zeros(tensor.shape[0])
    val[max_index:] = 1
    return{
        "train": train.type(torch.bool),
        "val":val.type(torch.bool)
    }
train_data = split(X_train)

train_mask = torch.concat((train_data["train"], false_indices_like(X_test), false_indices_like(X_gw_test)))
val_mask = torch.concat((train_data["val"], false_indices_like(X_test), false_indices_like(X_gw_test)))
test_l_mask = torch.concat((false_indices_like(X_train), true_indices_like(X_test), false_indices_like(X_gw_test)))
test_gw_mask = torch.concat((false_indices_like(X_train), false_indices_like(X_test), true_indices_like(X_gw_test)))

## Construct edges and define graph

In [None]:
from torch_geometric.nn import knn_graph
from torch_geometric.data import Data

edge_index = knn_graph(X_all[:, :7],k = 4, loop = True, num_workers = -1)
graph = Data(x= X_all,  edge_index = edge_index, y = y_all)

## Define model

## Model in paper GraphSAGE

In [None]:
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
import torch
from dataAnalysis.Constants import FEATURES

class GraphNeuralNetwork(torch.nn.Module):
    
    def __init__(self, hidden_dim = 128, out_channels = 1):
        super(GraphNeuralNetwork, self).__init__()
        
        input_dim = len(FEATURES)          
        self.conv1 = SAGEConv(input_dim, hidden_dim, normalize=True, project= True, aggr = "mean", root_weight = True, dropout=0.0)
        self.conv_end = SAGEConv(hidden_dim, out_channels, aggr = "mean", root_weight = True)


    def forward(self, graph):
        x, edge_index = graph.x, graph.edge_index
        x = torch.relu(self.conv1(x, edge_index))
        x = self.conv_end(x, edge_index)
        return x
            
    def predict_proba(self, graph, mask):
        with torch.inference_mode():
            self.eval()
            logits = self.forward(graph)
            scores = torch.sigmoid(torch.squeeze(logits[mask]))
            scores = torch.unsqueeze(scores, 0)
            proba_predict = torch.concat((1- scores, scores), dim = 0)
            return torch.transpose(proba_predict, 0, 1)
            
    def predict(self, graph, mask):
        return torch.round(self.predict_proba(graph, mask)[:, 1])

## Shift data to device

In [None]:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

graph = graph.to(device)
WEIGHT = torch.tensor([664])
WEIGHT = WEIGHT.to(device)

print("Data shifted to the device " + str(device))

## Model-Wrapper class

In [None]:
import torch 

class ModelWrapper():
    def __init__(self, graph):
        self.LEARNING_RATE = 3e-4
        self.MAX_EPOCHS = 40000

        self.model = GraphNeuralNetwork(hidden_dim = 128, out_channels=1) 
        self.model = self.model.to(device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.LEARNING_RATE,betas=(0.9, 0.999), eps=1e-08)
        self.graph = graph
        
        self.last_loss = 0
        self.increased_loss = 0
        self.BREAKING_THRESHOLD = 10    
        self.val_loss = []
        self.train_loss = []
    
    def validate(self):
        with torch.inference_mode():
            self.model.eval()
            out = self.model(self.graph)
            loss = F.binary_cross_entropy_with_logits(torch.squeeze(out[val_mask]), self.graph.y[val_mask].type(torch.float32),
                                                      pos_weight=WEIGHT)
            self.val_loss.append(loss.item())
            if loss.item() > self.last_loss:
                self.increased_loss += 1
            else:
                self.increased_loss = 0
            self.last_loss = loss.item()

    def train(self):
        for epoch in range(self.MAX_EPOCHS):
            self.model.train()
            self.optimizer.zero_grad()
            out = self.model(self.graph)
            loss = F.binary_cross_entropy_with_logits(torch.squeeze(out[train_mask]), self.graph.y[train_mask].type(torch.float32),
                                                      pos_weight=WEIGHT)
            self.train_loss.append(loss.item())
            loss.backward()
            self.optimizer.step()
            self.validate() 

            if self.increased_loss >= self.BREAKING_THRESHOLD:
#                 print(f"Breaked at {str(epoch)}")
                break
            
    def get_model(self):
        return self.model    
    
    def plot_loss(self):
        plt.plot(self.epochs, self.train_loss, 'g', label='Training loss')
        plt.plot(self.epochs, self.val_loss, 'y', label='Validation loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.show()

In [None]:
import time 
model_wrapper = ModelWrapper(graph)
start = time.time()
model_wrapper.train()
print(time.time()-start)
model = model_wrapper.get_model()

## Shift data and model back to CPU for evaluation

In [None]:
graph = graph.cpu()
model = model.cpu()

## Evaluation

In [None]:
from dataAnalysis.Metrics import Evaluation

evaluation = Evaluation(y_test.cpu(), y_gw_test.cpu(), X_test.cpu(), X_gw_test.cpu())
evaluation.set_test_args([graph, test_l_mask])
evaluation.set_gw_args([graph, test_gw_mask])

In [None]:
evaluation.plot_confusion_matrix(model)
evaluation.get_df_metrics(model)

## Error evaluation

In [None]:
import time

number_of_iter = 100-13
dataframes =[]
gnn_models = []
times = []
for i in range(number_of_iter):
#     print(i)
    graph = graph.to(device)
    start = time.time()
    model_wrapper = ModelWrapper(graph)
    model_wrapper.train()
    times.append(time.time()-start)
    print(time.time()- start)
    model = model_wrapper.get_model()
    model = model.cpu()
    graph = graph.cpu()
    df = evaluation.get_df_metrics(model)
    print(df)
    dataframes.append(df)
    gnn_models.append(model)

In [None]:
for t in times:
    print(t)

In [None]:
for df in dataframes:
    print(df)