In [2]:
from models.gat_v4 import GATv4
from models.readout import Readout
from torch_geometric.nn import GAT, GCN, global_mean_pool
from config_utils import CONFIG_FILE, Config, read_config_from_file

In [18]:
fc_dim_choices = [[128, 256, 256, 64], [64, 128, 128, 32]]
fc_dropout_choices = [0.1]
fc_act_choices = ['tanh', 'elu']

gat_v4_hidden_channels= [[8, 16], [64, 128]]
gat_v4_heads= [[2, 2], [4, 4]]

gat_num_layers= [2, 4, 6]  # only for GAT and GCN
gat_hidden_channels= [8, 32, 64, 128]
gat_heads= [2, 4]

gcn_num_layers= [2, 4, 6]  # only for GAT and GCN
gcn_hidden_channels= [8, 32, 128]
[512, 128, 64, 32]
mlp_channel_lists= [[128, 256, 256, 64], [64, 128, 128, 32],[416, 256, 128, 32]] #[512, 256, 128, 32]]
dropout_choices= [0.1]


config = read_config_from_file(CONFIG_FILE)


In [19]:
config_model = Config.parse_obj(getattr(config, "gat-v4"))
fc_input_dim = config.num_nodes * len(config.which_layer)
total_params_list_combined = []
for hidden_channels in gat_v4_hidden_channels:
    for heads in gat_v4_heads:
        gat_v4_model = GATv4(
            in_channels=1, 
            hidden_channels=hidden_channels,
            out_channels=1,
            heads=heads, 
            dropout=config.dropout,
            act=config.act,
            which_layer=config.which_layer,
            use_layer_norm=config_model.use_layer_norm,
            num_nodes=config.num_nodes,
            weight_initializer=config_model.weight_initializer,
        )
        for fc_dim in fc_dim_choices:
            for fc_dropout in fc_dropout_choices:
                for fc_act in fc_act_choices:
                    readout_model = Readout(
                        feature_output_dim=config.num_nodes,
                        which_layer=config.which_layer,
                        fc_dim=fc_dim,
                        fc_dropout=fc_dropout,
                        fc_act=fc_act,
                        out_channels=1,
                        fc_input_dim=fc_input_dim,
                        use_feature_encoder=True,
                    )
                    total_params_combined = sum(p.numel() for p in gat_v4_model.parameters()) + sum(p.numel() for p in readout_model.parameters())
                    total_params_list_combined.append(total_params_combined)

print("Max number of parameters: ", max(total_params_list_combined))
print("Min number of parameters: ", min(total_params_list_combined))

Max number of parameters:  3187403
Min number of parameters:  1440347


In [20]:
config_model = Config.parse_obj(getattr(config, "gat"))
fc_input_dim = (config.num_nodes * 2) - 1
total_params_list_combined = []
for num_layers in gat_num_layers:
    for hidden_channels in gat_hidden_channels:
        for heads in gat_heads:
            gat_model = GAT(
                in_channels=1,
                num_layers=num_layers,
                hidden_channels=hidden_channels,
                out_channels=1,
                heads=heads,
                dropout=config.dropout,
                act=config.act,
            )
            for fc_dim in fc_dim_choices:
                for fc_dropout in fc_dropout_choices:
                    for fc_act in fc_act_choices:
                        readout_model = Readout(
                            feature_output_dim=config.num_nodes // 3,
                            which_layer=config.which_layer,
                            fc_dim=fc_dim,
                            fc_dropout=fc_dropout,
                            fc_act=fc_act,
                            out_channels=1,
                            fc_input_dim=fc_input_dim,
                        )
                        total_params_combined = sum(p.numel() for p in gat_model.parameters()) + sum(p.numel() for p in readout_model.parameters())
                        total_params_list_combined.append(total_params_combined)

print("Max number of parameters: ", max(total_params_list_combined))
print("Min number of parameters: ", min(total_params_list_combined))


Max number of parameters:  1116160
Min number of parameters:  496620


In [21]:
config_model = Config.parse_obj(getattr(config, "gcn"))
fc_input_dim = (config.num_nodes * 2) - 1
total_params_list_combined = []
for num_layers in gcn_num_layers:
    for hidden_channels in gcn_hidden_channels:
        gcn_model = GCN(
            in_channels=1,
            num_layers=num_layers,
            hidden_channels=hidden_channels,
            out_channels=1,
            dropout=config.dropout,
            act=config.act,
        )
        for fc_dim in fc_dim_choices:
            for fc_dropout in fc_dropout_choices:
                for fc_act in fc_act_choices:
                    readout_model = Readout(
                        feature_output_dim=config.num_nodes // 3,
                        which_layer=config.which_layer,
                        fc_dim=fc_dim,
                        fc_dropout=fc_dropout,
                        fc_act=fc_act,
                        out_channels=1,
                        fc_input_dim=fc_input_dim,
                    )
                    total_params_combined = sum(p.numel() for p in gcn_model.parameters()) + sum(p.numel() for p in readout_model.parameters())
                    total_params_list_combined.append(total_params_combined)

print("Max number of parameters: ", max(total_params_list_combined))
print("Min number of parameters: ", min(total_params_list_combined))

Max number of parameters:  1114488
Min number of parameters:  496592


In [22]:
dropout = config.dropout
fc_input_dim = (config.num_nodes * 2) -1
total_params_list_combined = []
for fc_dim in mlp_channel_lists:
    for fc_dropout in dropout_choices:
        for fc_act in fc_act_choices:
            readout_model = Readout(
                feature_output_dim=config.num_nodes//3,
                which_layer=config.which_layer,
                fc_dim=fc_dim,
                fc_dropout=fc_dropout,
                fc_act=fc_act,
                out_channels=1,
                fc_input_dim=fc_input_dim,
            )
            total_params_combined = sum(p.numel() for p in readout_model.parameters())
            total_params_list_combined.append(total_params_combined)
    

print("Max number of parameters: ", max(total_params_list_combined))
print("Min number of parameters: ", min(total_params_list_combined))

Max number of parameters:  3169719
Min number of parameters:  496567
