In [1]:
from models.gat_v4 import GATv4
from torch_geometric.nn import GAT, GCN, global_mean_pool
from torch_geometric.nn.models import MLP
from config_utils import CONFIG_FILE, Config, read_config_from_file

In [10]:
gat_v4_hidden_channels= [[8, 16], [32, 64], [64, 128]]
gat_v4_heads= [[2, 3], [2, 2], [4, 4]]
gat_v4_fc_dim= [[64, 128, 128, 32], [128, 256, 256, 64], [256, 512, 512, 128]]

gat_num_layers= [2, 4, 6, 12, 18]  # only for GAT and GCN
gat_hidden_channels= [8, 32, 128, 256, 512]
gat_heads= [1, 2, 4, 8]

gcn_num_layers= [2, 3, 4, 6, 12, 18]  # only for GAT and GCN
gcn_hidden_channels= [8, 32, 128, 256, 512]
mlp_channel_lists = [[7258, 1], [7258, 1028, 1], [7258, 128, 64, 1], [7258, 1028, 128, 1], [7258, 1028, 256, 64, 1], [7258, 1028, 512, 128, 1], [7258, 1028, 256, 128, 64, 1]]


config = read_config_from_file(CONFIG_FILE)


In [3]:
config_model = Config.parse_obj(getattr(config, "gat-v4"))
total_params_list = []
for hidden_channels in gat_v4_hidden_channels:
    for heads in gat_v4_heads:
        for fc_dim in gat_v4_fc_dim:
                model = GATv4(
                in_channels=1, 
                hidden_channels=hidden_channels,
                out_channels=1,
                heads=heads, 
                dropout=config.dropout,
                act=config.act,
                which_layer=config_model.which_layer,
                use_layer_norm=config_model.use_layer_norm,
                fc_dim=config_model.fc_dim,
                fc_dropout=config_model.fc_dropout,
                fc_act=config_model.fc_act,
                num_nodes=config.num_nodes,
                weight_initializer=config_model.weight_initializer,
                use_master_nodes=config.use_master_nodes,
                master_nodes=config.master_nodes,
            )
                total_params = sum(p.numel() for p in model.parameters())
                total_params_list.append(total_params)
print("Max number of parameters: ", max(total_params_list))
print("Min number of parameters: ", min(total_params_list))

Max number of parameters:  3155207
Min number of parameters:  2890007


In [9]:
config_model = Config.parse_obj(getattr(config, "gat"))
total_params_list = []
for num_layers in gat_num_layers:
    for hidden_channels in gat_hidden_channels:
        for heads in gat_heads:
            model = GAT(
                in_channels=1,
                num_layers=num_layers,
                hidden_channels=hidden_channels,
                out_channels=1,
                heads=heads,
                dropout=config.dropout,
                act=config.act,
            )
            total_params = sum(p.numel() for p in model.parameters())
            total_params_list.append(total_params)
print("Max number of parameters: ", max(total_params_list))
print("Min number of parameters: ", min(total_params_list))


Max number of parameters:  4225041
Min number of parameters:  43


In [5]:
config_model = Config.parse_obj(getattr(config, "gcn"))
total_params_list = []
for layer in gcn_num_layers:
    for hidden_channels in gcn_hidden_channels:
        model = GCN(
            in_channels=1,
            num_layers=layer,
            hidden_channels=hidden_channels,
            out_channels=1,
            dropout=config.dropout,
            act=config.act,
        )
        total_params = sum(p.numel() for p in model.parameters())
        total_params_list.append(total_params)
print("Max number of parameters: ", max(total_params_list))
print("Min number of parameters: ", min(total_params_list))

Max number of parameters:  4204033
Min number of parameters:  25


In [6]:
dropout = config.dropout
config_model = Config.parse_obj(getattr(config, "mlp"))
total_params_list = []
for channel_list in mlp_channel_lists:
    dropout = [dropout] * (len(channel_list) - 1)
    model = MLP(
        channel_list=channel_list,
        dropout=dropout,
        act=config.act,
        norm=config_model.norm,
        plain_last=config_model.plain_last
    )
    total_params = sum(p.numel() for p in model.parameters())
    total_params_list.append(total_params)
print("Max number of parameters: ", max(total_params_list))
print("Min number of parameters: ", min(total_params_list))

Max number of parameters:  8058229
Min number of parameters:  7259
