In [1]:
from libs.basic import *
import torch
import torch_geometric
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.nn import MessagePassing
from torch.nn import Sequential as Seq, Linear, ReLU, Sigmoid
import torch.optim as optim
import joblib
from scipy.optimize import root_scalar

pd.set_option("display.max_columns", 100)
PATH_DATA0 = './data/00.01'
PATH_DATA = './data/00.02'
RANDOM_SEED =0
np.random.seed(RANDOM_SEED)  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


Using device: cuda


# Hyperparameters

In [2]:
CRITERION = nn.BCEWithLogitsLoss()
LR = 0.001
TOLERANCE = 20
LR_TOLERANCE= 5
MAX_EPOCHS = 200
BATCH_SIZE = 4

# Loaders

In [3]:
loader_train = pyg.loader.DataLoader(
    pd.read_pickle(os.path.join(PATH_DATA0, 'graphs','max_prob_10_subsample_0.1','graphs_train.pkl')).tolist(),
    batch_size = BATCH_SIZE,shuffle = True)
loader_val = pyg.loader.DataLoader(
    pd.read_pickle(os.path.join(PATH_DATA0, 'graphs','max_prob_10_subsample_0.1','graphs_val.pkl')).tolist(),batch_size = BATCH_SIZE
    ,shuffle = False)
loader_test = pyg.loader.DataLoader(
    pd.read_pickle(os.path.join(PATH_DATA0, 'graphs','max_prob_10_subsample_0.1','graphs_test.pkl')).tolist(),batch_size = BATCH_SIZE
    ,shuffle = False)

# Model Architecture

In [4]:
class RelationalModel(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, n_layers):
        super(RelationalModel, self).__init__()

        layers = [nn.Linear(input_size, hidden_size), 
                 nn.ReLU()]
        if n_layers>=3:
            for _ in range(n_layers - 2):
                layers.append(nn.Linear(hidden_size, hidden_size))
                layers.append(nn.ReLU())

        layers.append(nn.Linear(hidden_size, output_size))

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)
class ObjectModel(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, n_layers):
        super(ObjectModel, self).__init__()

        layers = [nn.Linear(input_size, hidden_size), 
                 nn.ReLU()]
        if n_layers>=3:
            for _ in range(n_layers - 2):
                layers.append(nn.Linear(hidden_size, hidden_size))
                layers.append(nn.ReLU())

        layers.append(nn.Linear(hidden_size, output_size))

        self.layers = nn.Sequential(*layers)

    def forward(self, C):
        return self.layers(C)
class InteractionNetwork(MessagePassing):
    def __init__(self, hidden_size, n_layers):
        super(InteractionNetwork, self).__init__(aggr='add', 
                                                 flow='source_to_target')
        self.R1 = RelationalModel(10, 4, hidden_size, n_layers)
        self.O = ObjectModel(7, 3, hidden_size, n_layers)
        self.R2 = RelationalModel(10, 1, hidden_size, n_layers)
        self.E: Tensor = Tensor()

    def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor) -> Tensor:

        # propagate_type: (x: Tensor, edge_attr: Tensor)
        x_tilde = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None)

        m2 = torch.cat([x_tilde[edge_index[1]],
                        x_tilde[edge_index[0]],
                        self.E], dim=1)
        return self.R2(m2)

    def message(self, x_i, x_j, edge_attr):
        # x_i --> incoming
        # x_j --> outgoing        
        m1 = torch.cat([x_i, x_j, edge_attr], dim=1)
        self.E = self.R1(m1)
        return self.E

    def update(self, aggr_out, x):
        c = torch.cat([x, aggr_out], dim=1)
        return self.O(c) 

# Functions

In [5]:
def count_parameters(model):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return trainable_params
def evaluate(model, loader):
    model.eval()
    preds, actuals = [],[]
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            preds.append(torch.sigmoid(model(batch.x, batch.edge_index, batch.edge_attr)))
            actuals.append(batch.y)
        preds = torch.cat(preds)
        actuals = torch.cat(actuals)
        acc = ((preds>0.5)==(actuals>0.5)).type(torch.float).mean().item()
        entropy = CRITERION(preds, actuals.float()).item()
    model.train()
    return preds.cpu().numpy(), actuals.cpu().numpy(), acc, entropy
def train_epoch(model, loader_train):
    model.train()
    train_loss = 0.0
    for batch in tqdm(loader_train, leave = False):
        batch = batch.to(device)
        optimizer.zero_grad()
        output = model(batch.x, batch.edge_index, batch.edge_attr)
        loss = CRITERION(output, batch.y.float())
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * batch.num_graphs
    train_loss /= len(loader_train)    
    return train_loss


# Train

In [None]:


for n_layers in tqdm([3,2,4]): 
    for target_params in tqdm([100_000,500_000,1_000_000]): 
        print(n_layers, target_params)
        # Find out the hyperparameteres yielding #params = target_params
        def objective(h):
            return count_parameters(InteractionNetwork(int(h),n_layers)) - target_params
        optimal_h = int(root_scalar(objective, bracket=[1, 3000], method='bisect').root)
        optimal_h= pd.Series({optimal_h:target_params-count_parameters(InteractionNetwork(optimal_h, n_layers)),
                    optimal_h-1:target_params-count_parameters(InteractionNetwork(optimal_h-1,n_layers)),
                    optimal_h+1:target_params-count_parameters(InteractionNetwork(optimal_h+1,n_layers))}).abs().idxmin()
        
        model = InteractionNetwork(optimal_h,n_layers).to(device)
        lr = LR
        optimizer = optim.Adam(model.parameters(), lr=LR)
        best_val_loss = float('inf')
        epochs_no_improve, epochs_no_improve2 = 0,0
        best_model_state = None
        stats = []
        best = None
        # Print header once
        print(f"{'Epoch':>5} | {'Train Loss':>10} | {'Val Loss':>9} | {'Val Acc':>8} | {'Test Acc':>9}")
        print("-" * 50)
        for epoch in trange(MAX_EPOCHS):
            train_loss = train_epoch(model, loader_train)   
            preds_val, actuals_val, acc_val, val_loss = evaluate(model,loader_val)
            preds_test, actuals_test, acc_test, test_loss = evaluate(model,loader_test)
            
            stats.append({'train_loss':train_loss, 'val_loss':val_loss, 'acc_val':acc_val, 'acc_test':acc_test})
            if val_loss < best_val_loss: 
                print(f"{epoch+1:5d} | {train_loss:10.4f} | {val_loss:9.4f} | {acc_val:8.4f} | {acc_test:9.4f} *")
                best_val_loss = val_loss
                epochs_no_improve = 0
                epochs_no_improve2 = 0
                best = {'model_state': {k: v.cpu() for k, v in model.state_dict().items()},
                        'preds_test':preds_test, 'preds_val':preds_val}        
            else:
                print(f"{epoch+1:5d} | {train_loss:10.4f} | {val_loss:9.4f} | {acc_val:8.4f} | {acc_test:9.4f}")
                epochs_no_improve += 1
                epochs_no_improve2 += 1
        
            if epochs_no_improve >= TOLERANCE:
                print(f"Early stopping at epoch {epoch+1}")
                break
            if epochs_no_improve2 >= LR_TOLERANCE:
                lr/=10
                print(f"LR reduction to {lr}")
        joblib.dump(best, os.path.join(PATH_DATA, f"{n_layers}_{target_params}.pkl"))
        
        stats = pd.DataFrame(stats)
        stats[['train_loss','val_loss']].plot(figsize = (15,4))
        plt.show()
        stats[['acc_val','acc_test']].plot(figsize = (15,4))
        plt.show()

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

3 100000
Epoch | Train Loss |  Val Loss |  Val Acc |  Test Acc
--------------------------------------------------


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/142 [00:00<?, ?it/s]

    1 |     1.1808 |    0.4255 |   0.9169 |    0.9200 *


  0%|          | 0/142 [00:00<?, ?it/s]

    2 |     0.6986 |    0.4098 |   0.9399 |    0.9423 *


  0%|          | 0/142 [00:00<?, ?it/s]

    3 |     0.5559 |    0.4035 |   0.9506 |    0.9523 *


  0%|          | 0/142 [00:00<?, ?it/s]

    4 |     0.4930 |    0.4012 |   0.9523 |    0.9540 *


  0%|          | 0/142 [00:00<?, ?it/s]

    5 |     0.4515 |    0.3976 |   0.9590 |    0.9600 *


  0%|          | 0/142 [00:00<?, ?it/s]

    6 |     0.4240 |    0.3997 |   0.9565 |    0.9571


  0%|          | 0/142 [00:00<?, ?it/s]

    7 |     0.3920 |    0.3957 |   0.9623 |    0.9634 *


  0%|          | 0/142 [00:00<?, ?it/s]

    8 |     0.3737 |    0.3949 |   0.9630 |    0.9637 *


  0%|          | 0/142 [00:00<?, ?it/s]

    9 |     0.3559 |    0.3935 |   0.9654 |    0.9662 *


  0%|          | 0/142 [00:00<?, ?it/s]

   10 |     0.3428 |    0.3925 |   0.9668 |    0.9675 *


  0%|          | 0/142 [00:00<?, ?it/s]

   11 |     0.3351 |    0.3933 |   0.9646 |    0.9658


  0%|          | 0/142 [00:00<?, ?it/s]

   12 |     0.3242 |    0.3913 |   0.9687 |    0.9696 *


  0%|          | 0/142 [00:00<?, ?it/s]

   13 |     0.3116 |    0.3901 |   0.9703 |    0.9710 *


  0%|          | 0/142 [00:00<?, ?it/s]

   14 |     0.3085 |    0.3906 |   0.9699 |    0.9708


  0%|          | 0/142 [00:00<?, ?it/s]

   15 |     0.3002 |    0.3910 |   0.9684 |    0.9690


  0%|          | 0/142 [00:00<?, ?it/s]

   16 |     0.2948 |    0.3892 |   0.9716 |    0.9722 *


  0%|          | 0/142 [00:00<?, ?it/s]

   17 |     0.2862 |    0.3904 |   0.9705 |    0.9715


  0%|          | 0/142 [00:00<?, ?it/s]

   18 |     0.2820 |    0.3885 |   0.9728 |    0.9736 *


  0%|          | 0/142 [00:00<?, ?it/s]

   19 |     0.2733 |    0.3886 |   0.9728 |    0.9733


  0%|          | 0/142 [00:00<?, ?it/s]

   20 |     0.2707 |    0.3874 |   0.9743 |    0.9751 *


  0%|          | 0/142 [00:00<?, ?it/s]

   21 |     0.2674 |    0.3880 |   0.9733 |    0.9742


  0%|          | 0/142 [00:00<?, ?it/s]

   22 |     0.2592 |    0.3879 |   0.9750 |    0.9755


  0%|          | 0/142 [00:00<?, ?it/s]

   23 |     0.2596 |    0.3871 |   0.9746 |    0.9750 *


  0%|          | 0/142 [00:00<?, ?it/s]

   24 |     0.2511 |    0.3872 |   0.9745 |    0.9752


  0%|          | 0/142 [00:00<?, ?it/s]

   25 |     0.2537 |    0.3867 |   0.9756 |    0.9764 *


  0%|          | 0/142 [00:00<?, ?it/s]

   26 |     0.2426 |    0.3865 |   0.9756 |    0.9766 *


  0%|          | 0/142 [00:00<?, ?it/s]

   27 |     0.2421 |    0.3856 |   0.9772 |    0.9780 *


  0%|          | 0/142 [00:00<?, ?it/s]

   28 |     0.2404 |    0.3859 |   0.9764 |    0.9768


  0%|          | 0/142 [00:00<?, ?it/s]

   29 |     0.2343 |    0.3858 |   0.9764 |    0.9773


  0%|          | 0/142 [00:00<?, ?it/s]