In [None]:
import os
#sys.path.append('./align_gnn_toolkit')
os.chdir('./align_gnn_toolkit')
import argparse
import warnings
from engine import EngineFactory
from data_set.data_set_factory import DataSetFactory
from utils import config_const
print(f'Working directory: {os.getcwd()}')

EXPERIMENT_TEMPLATE = 'align_gnn_toolkit/experiments_repository/template_default.yaml'
EXPERIMENT_CONFIG = 'align_gnn_toolkit/experiments_repository/sick_all_spd_hetero.yaml'

parser = argparse.ArgumentParser()
parser.add_argument('-temp', required=False, help='Path to the config file',nargs='?', const='1', type=str, default=EXPERIMENT_TEMPLATE)
parser.add_argument('-conf', required=False, help='Nbr of experiment from config',nargs='?', const='1', type=str, default=EXPERIMENT_CONFIG)
#args = parser.parse_args()
args, unknown = parser.parse_known_args()
config_utils = EngineFactory().getConfigurationUtils(args)
config_utils.setValue(config_const.CONF_SEC_DATASET, config_const.CONF_DATASET_NAME, "sick")
engine = EngineFactory().getEngineType(config_utils)
params=engine.getProcessingParameters()
data_holder = DataSetFactory.get_data_holder(params)  

In [None]:
from torch_geometric.data import HeteroData
from torch_geometric.utils import to_undirected
import os
import torch
from graph_builder import GraphBuilderFactory
graph_builder = GraphBuilderFactory.getBuilder(params=params)
graph_builder.initialize(data_holder.train_data_set) 
builders = []

for builder in graph_builder.builders:   
    meta = {} 
    meta["name"] = builder.builder_name
    meta["node_feats_nbr"] = builder.getNodeFeatsNbr()
    meta["edge_feats_nbr"] = builder.getEdgeFeatsNbr()
    meta["offset_edge_feats"] = builder.offset_edge_feats
    meta["offset_node_feats"] = builder.offset_node_feats
    meta["total_edge_feature_number"] = builder.total_edge_feature_number
    meta["total_node_feature_number"] = builder.total_node_feature_number
    builders.append(meta)
    

    
def to_hetero_graph(data, graph, builders_meta, prefix):
    def add_prefix(name, prefix):
        return prefix+"_"+name
    for index, builder in enumerate(builders_meta):
        data["y"].y = graph.y
        data[add_prefix("node_labels", prefix)].node_labels = graph.node_labels
        if builder["node_feats_nbr"] >0:
            data[add_prefix(builder["name"], prefix)].x = graph.x[:, torch.arange(builder["offset_node_feats"], builder["offset_node_feats"]+builder["node_feats_nbr"])]
        if index > 0 and builder["edge_feats_nbr"]==0:
            new_trg_index = data[add_prefix(builder["name"], prefix)].x.gt(0).nonzero(as_tuple=True)[0]
            new_src_index = torch.arange(0, new_trg_index.shape[0])
            new_edge_index = torch.stack([new_src_index, new_trg_index], dim=0).long()
            new_edge_attr = torch.ones(new_trg_index.shape[0], 1)
            data[add_prefix(builders_meta[0]["name"], prefix),  add_prefix(builder["name"], prefix)].edge_attr = new_edge_attr 
            data[add_prefix(builders_meta[0]["name"], prefix),  add_prefix(builder["name"], prefix)].edge_index = new_edge_index 
        if builder["edge_feats_nbr"] >0:
            data[add_prefix(builder["name"], prefix)].edge_attr_tmp = graph.edge_attr[:, torch.arange(builder["offset_edge_feats"], builder["offset_edge_feats"]+builder["edge_feats_nbr"])]
            mask  = data[add_prefix(builder["name"], prefix)].edge_attr_tmp.gt(0).nonzero(as_tuple=True)[0]
            masked_attr =  data[add_prefix(builder["name"], prefix)].edge_attr_tmp[mask]    
            masked_index =  graph.edge_index[:,mask]
            if builder["node_feats_nbr"] >0:         
                data[add_prefix(builders_meta[0]["name"], prefix),  add_prefix(builder["name"], prefix)].edge_attr = masked_attr 
                data[add_prefix(builders_meta[0]["name"], prefix),  add_prefix(builder["name"], prefix)].edge_index = masked_index 
            else:
                data[add_prefix(builders_meta[0]["name"], prefix),  add_prefix(builder["name"], prefix), add_prefix(builders_meta[0]["name"], prefix)].edge_attr = masked_attr 
                data[add_prefix(builders_meta[0]["name"], prefix),  add_prefix(builder["name"], prefix), add_prefix(builders_meta[0]["name"], prefix)].edge_index = masked_index                 
                
            data[add_prefix(builder["name"], prefix)].pop("edge_attr_tmp")
            if len(data[add_prefix(builder["name"], prefix)].keys()) == 0:
                del data[add_prefix(builder["name"], prefix)]    
    return data

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.decoder.weight', 'lm_head.dense.weight', 'lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
from data_set.impl.data_set_sick_hetero import SickHeteroDataset
import torch
from tqdm import tqdm
import torch_geometric.transforms as T

def convert_to_hetero(data_set):
    sick_hetero = SickHeteroDataset(root=data_set.root.replace("sick", "sick_hetero"), transform=data_set.transform, pre_transform=data_set.pre_transform, pre_filter=data_set.pre_filter, type=data_set.type, params=data_set.params, graph_builder=data_set.graph_builder, data_set_processor=data_set.data_set_processor)
    for index in tqdm(range(len(data_set))):
        src = data_set[index].get_source()
        trg = data_set[index].get_target()
        data = HeteroData()    
        hg = to_hetero_graph(data, src,builders, "src")
        hg = to_hetero_graph(hg, trg, builders, "trg")        
        hg = T.ToUndirected()(hg)
        torch.save(hg, os.path.join(sick_hetero.processed_dir, f'data_{sick_hetero.type}_{index}.pt'))

convert_to_hetero(data_holder.test_data_set)
convert_to_hetero(data_holder.train_data_set)


Processing...
Done!
100%|██████████| 4906/4906 [00:16<00:00, 300.55it/s]
Processing...
Done!
100%|██████████| 4439/4439 [00:15<00:00, 291.64it/s]


In [5]:
convert_to_hetero(data_holder.validation_data_set)

Processing...
Done!
100%|██████████| 495/495 [00:01<00:00, 275.79it/s]
