<a id='table_of_contents'></a>

0. [Import Libraries](#imports)
1. [Custom Dataset](#dataset)
2. [Generate Random Inputs](#inputs)
3. [Hyperparameters](#hparams)
4. [Define the Tripartite Graph Model](#model)<br>
5. [Train the Model](#training)<br>

# 0. Import Necessary Libraries <a id='imports'></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T
import numpy as np
import random
import time

# 1. Custom Dataset <a id='dataset'></a>

In [None]:
class TripartiteData(Data):
    def __init__(self, vc_edge=None, vo_edge=None, co_edge=None, x_v=None, x_c=None, x_o=None, y=None, c_ind=None):
        super().__init__()
        self.vc_edge = vc_edge
        self.vo_edge = vo_edge
        self.co_edge = co_edge
        self.x_v = x_v
        self.x_c = x_c
        self.x_o = x_o
        self.y = y
        self.c_ind = c_ind
        
    def __inc__(self, key, value, *args, **kwargs):   # incremental count between two consecutive graph attributes
        if key == 'vc_edge':
            return torch.tensor([[self.x_c.size(0)], [self.x_v.size(0)]])
        if key == 'vo_edge':
            return torch.tensor([[self.x_o.size(0)], [self.x_v.size(0)]])
        if key == 'co_edge':
            return torch.tensor([[self.x_o.size(0)], [self.x_c.size(0)]])
        else:
            return super().__inc__(key, value, *args, **kwargs)
        
    def __cat_dim__(self, key, value, *args, **kwargs): # defines in which dimension graph tensors will be concatenated together
        if key == 'vc_edge':
            return 1
        if key == 'vo_edge':
            return 1
        if key == 'co_edge':
            return 1
        if key == 'c_ind':
            return 1
        else:
            return super().__cat_dim__(key, value, *args, **kwargs)

In [None]:
def normalize(value):  # row-normalize feature matrix
    value = value - value.min()
    value.div_(value.sum(dim=-1, keepdim=True).clamp_(min=1.))
    return value

# 2. Generate Random Inputs <a id='inputs'></a>

In [None]:
# number of features for variable nodes, constraint nodes and objective node, respectively
feat_var = np.random.randint(1, 50) 
feat_cons = np.random.randint(1, 30)
feat_obj = 1

num_obj = 1 # there is only one objective node for each MILP

### Graphs with different node numbers are generated and added to the data list.

In [None]:
data_list = []
max_var_nodes = 100
num_graphs = 500

for i in range(num_graphs):
    num_var = np.random.randint(2, max_var_nodes)
    num_cons = np.random.randint(1, num_var) # Number of contraints has to be less than number of variables.
    num_edges = np.random.randint(1, num_var*num_cons + 1) # between variable nodes and constraint nodes
    
    # EDGE INDICES (FOR EFFICIENT MEMORY USAGE) [source, destination/target]
    vc_edge = torch.vstack((torch.randint(0, num_cons, (1, num_edges)), torch.randint(0, num_var, (1, num_edges))))
    vc_edge = torch.unique(vc_edge, dim=1)  # remove overlapping edges

    # all variables are connected to the objective node
    vo_edge = torch.vstack((torch.zeros(1, num_var, dtype=torch.long), torch.tensor([i for i in range(num_var)]))) 
    
    # all constraints are connected to the objective node
    co_edge = torch.vstack((torch.zeros(1, num_cons, dtype=torch.long), torch.tensor([i for i in range(num_cons)])))
    
    # CREATE RANDOM INPUTS
    # Feature matrices for variables, constraint, and the objective
    x_v = torch.rand(num_var, feat_var, dtype=torch.float)
    x_c = torch.rand(num_cons, feat_cons, dtype=torch.float)
    x_o = torch.rand(num_obj, feat_obj, dtype=torch.float)
    
    # Create Binary Labels
    y = torch.randint(0, 2, (num_var,), dtype=torch.float)
    cons_indices = i * torch.ones(1, num_cons, dtype=torch.long) # It will be used for copy h_o for each h_c, then sum up
    
    data = TripartiteData(vc_edge, vo_edge, co_edge, normalize(x_v), normalize(x_c), x_o, y, cons_indices)
    data.num_nodes = num_var  # num_var + num_cons + num_obj
    data_list.append(data)

In [None]:
torch.manual_seed(12345)
random.shuffle(data_list)

# Split the graphs into train, validation and test sets
train_ratio = 0.6
val_ratio = 0.2
test_ratio = 0.2

train_size = int(len(data_list)*train_ratio)
val_size = int(len(data_list)*val_ratio)

train_dataset = data_list[:train_size]
val_dataset = data_list[train_size:train_size+val_size]
test_dataset = data_list[train_size+val_size:]

print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of validation graphs: {len(val_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

# 3. Hyperparameters <a id='hparams'></a>

In [None]:
h = 64   # number of hidden layers
T = 2    # number of transitions
num_epochs = 100
batch_size = 32

In [None]:
# Split the dataset into mini-batches
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 4. Define the Tripartite Graph Model <a id='model'></a>
* By Using Edge List 

In [None]:
class GNN(nn.Module):
    def __init__(self, h, T, num_feat):
        super(GNN, self).__init__()
        
        torch.manual_seed(3)
        self.hid = h
        self.weights = {}
        for i in ["vo", "oc", "vc", "co", "ov", "cv"]:
            W = nn.init.kaiming_normal_(torch.empty(h, 2*h), nonlinearity='relu')
            self.weights["W_" + i] = torch.nn.parameter.Parameter(data=W, requires_grad=True)

        self.att_weight = torch.nn.parameter.Parameter(data=nn.init.xavier_uniform_(torch.empty(2*h, 1)), requires_grad=True)
        self.fc1 = nn.Linear(2*h, 32)
        self.fc2 = nn.Linear(32, 1)
        
        var_feat, cons_feat, obj_feat = num_feat
        self.emb_var = nn.Linear(var_feat, h)
        self.emb_cons = nn.Linear(cons_feat, h)
        self.emb_obj = nn.Linear(obj_feat, h)
        self.relu = nn.ReLU()
        
    def forward(self, x, edge_indices, batch, c_ind):
        x_v, x_c, x_o = x
        vc_edge, vo_edge, co_edge = edge_indices
        num_cons = x_c.shape[0]
        num_var = x_v.shape[0]
        h_v = self.emb_var(x_v)
        h_c = self.emb_cons(x_c)
        h_o = self.emb_obj(x_o)
        h_v_init = torch.clone(h_v)
        _, counts = torch.unique_consecutive(c_ind.reshape(-1), return_counts=True)
        
        for t in range(T):
            alpha_vo = self.attentions(vo_edge, h_v, h_o)
            h_o = torch.matmul(torch.cat((h_o, torch.matmul(alpha_vo, h_v)), dim=1), self.weights["W_vo"].t())
            h_o = self.relu(h_o)
            
            alpha_vc = self.attentions(vc_edge, h_v, h_c)
            # h_oc = self.relu(torch.matmul(torch.cat((h_o.tile((num_cons, 1)), h_c), dim=1), self.weights["W_oc"].t()))
            h_oc = self.relu(torch.matmul(torch.cat((torch.repeat_interleave(h_o, counts, dim=0), h_c), dim=1), self.weights["W_oc"].t()))
            h_c = self.relu(torch.matmul(torch.cat((h_oc, torch.matmul(alpha_vc, h_v)), dim=1), self.weights["W_vc"].t()))
            
            alpha_co = self.attentions(co_edge, h_c, h_o)
            h_o = self.relu(torch.matmul(torch.cat((h_o, torch.matmul(alpha_co, h_c)), dim=1), self.weights["W_co"].t()))
            
            alpha_cv = self.attentions((vc_edge[1], vc_edge[0]), h_c, h_v) 
            h_ov = self.relu(torch.matmul(torch.cat((torch.repeat_interleave(h_o, batch.bincount(), dim=0), h_v), dim=1), self.weights["W_ov"].t()))
            h_v = self.relu(torch.matmul(torch.cat((h_ov, torch.matmul(alpha_cv, h_c)), dim=1), self.weights["W_cv"].t()))
                
        logits = self.fc2(self.relu(self.fc1(torch.cat((h_v_init, h_v), dim=1))))
        # z_v = torch.sigmoid(logits) # probabilities
        return logits # z_v
        
    def attentions(self, edge_index, h_src, h_dest):
        src_index, dest_index = edge_index
        unnormalized_coeff = torch.matmul(torch.cat((h_src[dest_index], h_dest[src_index]), dim=1), self.att_weight)
        unnormalized_coeff = torch.sigmoid(unnormalized_coeff)
        e_matr = torch.zeros(h_dest.shape[0], h_src.shape[0]) # (variables, constraints)
        e_matr[src_index, dest_index] = unnormalized_coeff.reshape(-1,)
        e_matr = torch.where(e_matr == 0, -9e20, e_matr)
        alpha = F.softmax(e_matr, dim=1)
        return alpha

In [None]:
# for name, param in model.state_dict().items():
    # print(name)
    # print(param.shape)

In [None]:
# model.weights

# 5. Train the Model <a id='training'></a>

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Initialize the model
model = GNN(h, T, num_feat=(feat_var, feat_cons, feat_obj)) # Initialize the model
model.to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()

def train():
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
        data = data.to(device)
        out = model((data.x_v, data.x_c, data.x_o), edge_indices=(data.vc_edge, data.vo_edge, data.co_edge), batch=data.batch, c_ind=data.c_ind)
        loss = criterion(out, data.y.view(out.shape))  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.

def test(loader):
    model.eval()

    correct = 0          # Accuracy is not a good metric here, it is just used to set up the structure of model training.
    for data in loader:  # Iterate in batches over the training/test dataset.
        out = model((data.x_v, data.x_c, data.x_o), edge_indices=(data.vc_edge, data.vo_edge, data.co_edge), batch=data.batch, c_ind=data.c_ind)
        probs = torch.sigmoid(out)
        pred = torch.where(probs > 0.5, 1, 0)  # Use the class with highest probability.
        correct += int((out == data.y).sum())  # Check against ground-truth labels.
    return correct / len(loader.dataset)  # Derive ratio of correct predictions.


for epoch in range(1, num_epochs):
    train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')