In [1]:
import os
import sys
import glob
import gzip
import pickle
import pathlib
import argparse
import numpy as np

In [2]:
import torch
import torch.nn.functional as F
import torch_geometric
from torch_geometric.nn import GCNConv
from torch_geometric.utils.convert import from_networkx

In [3]:
NODE_FEATURE_NUM = 2

class GraphConvolution(torch.nn.Module):
    def __init__(self):
        super(GraphConvolution, self).__init__()
        self.project = torch.nn.Linear(NODE_FEATURE_NUM, 32)
        self.conv1 = GCNConv(32, 32)
        self.conv2 = GCNConv(32, 32)
        self.conv3 = GCNConv(32, 32)
        self.conv4 = GCNConv(32, 32)
        self.conv5 = GCNConv(32, 32)
        self.conv6 = GCNConv(32, 32)
        self.seq = torch.nn.Sequential(
            torch.nn.Linear(32 * 2, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 1),
        )

    def forward(self, data):
        x, edge_index, edge_weight = data.x.float(), data.edge_index, data.weight.float()

        x = self.project(x)
        x = F.relu(x)
        x = self.conv1(x, edge_index, edge_weight)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index, edge_weight)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv3(x, edge_index, edge_weight)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv4(x, edge_index, edge_weight)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv5(x, edge_index, edge_weight)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv6(x, edge_index, edge_weight)

        # edge regression
        batch_size = len(data.edges)
        total_edges = sum(len(data.edges[u]) for u in range(batch_size))
        output = torch.zeros(total_edges)

        for i in range(batch_size):
            _edges = data.edges[i]  # edge set of the i-th graph
            _nodes = data.nodes[i]  # node set of the i-th graph
            num_edge = len(_edges)
            num_node = len(_nodes)

            # the total number of nodes before i-th graph
            nodes_before = sum(len(data.nodes[u]) for u in range(i))
            edges_before = sum(len(data.edges[u]) for u in range(i))

            # get the node feature x of i-th graph
            _x = x[nodes_before: nodes_before + num_node]

            score = torch.zeros(num_edge)
            for j in range(num_edge):
                v1 = _edges[j][0]
                v2 = _edges[j][1]
                v1_index = _nodes.index(v1)  # get index of node
                v2_index = _nodes.index(v2)
                score[j] = self.seq(torch.cat((_x[v1_index, :], _x[v2_index, :])))
            score = F.softmax(score, dim=0)
            output[edges_before: edges_before + num_edge] = score

        return output

    def eval(self, data):
        torch.set_grad_enabled(False)
        x, edge_index, edge_weight = data.x.float(), data.edge_index, data.weight.float()

        x = self.project(x)
        x = self.conv1(x, edge_index, edge_weight)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index, edge_weight)

        # edge regression
        batch_size = len(data.edges)
        total_edges = sum(len(data.edges[u]) for u in range(batch_size))

        _edges = data.edges
        _nodes = data.nodes
        num_edge = len(_edges)
        num_node = len(_nodes)

        # get the node feature x of i-th graph
        _x = x
        score = torch.zeros(num_edge)
        for j in range(num_edge):
            v1 = _edges[j][0]
            v2 = _edges[j][1]
            v1_index = _nodes.index(v1)  # get index of node
            v2_index = _nodes.index(v2)
            score[j] = self.seq(torch.cat((_x[v1_index, :], _x[v2_index, :])))
        score = F.softmax(score, dim=0)

        return score

In [4]:
# change networkx graph into GCN input format
def networkx2torch(graph):
    data = from_networkx(graph)
    data.edges = [edge for edge in graph.edges]
    data.nodes = [node for node in graph.nodes]
    data.x = data.pos

    return data

def Guided_Search(policy, graph):
    '''
    :param policy: GCN model
    :param graph: networkx graph
    :return: a match
    '''
    edges = [edge for edge in graph.edges]
    selection = np.zeros(len(edges))
    done = False

    temp_graph = graph.copy()
    while not done:
        # choose the edge with the highest score
        GCN_input = networkx2torch(temp_graph)
        temp_edges = GCN_input.edges  # list
        score = policy(GCN_input).detach().numpy()  # tensor to numpy array
        edge_chosen = temp_edges[np.argmax(score)]
        selection[edges.index(edge_chosen)] = 1

        # delete edges and nodes
        for node_chosen in edge_chosen:
            # Removes the node_chosen and all adjacent edges
            temp_graph.remove_node(node_chosen)

        # check whether done
        if len(temp_graph.edges) == 0:
            done = True

    return selection

In [5]:
# get imitation learning target
def get_target(MWM, G):
    edges = [edge for edge in G.edges]
    selection = np.zeros(len(edges))
    for i in range(len(edges)):
        if edges[i] not in MWM:
            selection[i] = 0
        else:
            selection[i] = 1
    return selection

In [6]:
class GraphDataset(torch_geometric.data.Dataset):
    """
    Dataset class implementing the basic methods to read samples from a file.

    Parameters
    ----------
    sample_files : list
        List containing the path to the sample files.
    """

    def __init__(self, sample_files):
        super().__init__(root=None, transform=None, pre_transform=None)
        self.sample_files = sample_files

    def len(self):
        return len(self.sample_files)

    def get(self, index):
        """
        Reads and returns sample at position <index> of the dataset.

        """
        with gzip.open(self.sample_files[index], 'rb') as f:
            sample = pickle.load(f)

        graph = sample['graph']
        MWM = sample['MWM']

        # change networkx graph into GCN input
        data = networkx2torch(graph)

        # add label
        mwm = get_target(MWM, graph)  # get binary selection matrix
        data.mwm = mwm
        return data

In [7]:
def process(policy, data_loader, device, optimizer=None):
    mean_loss = 0
    n_samples_processed = 0

    with torch.set_grad_enabled(optimizer is not None):
        for batch in data_loader:
            batch = batch.to(device)
            output = policy(batch)

            # get target
            batch_size = len(batch.edges)
            total_edges = sum(len(batch.edges[u]) for u in range(batch_size))
            target = torch.zeros(total_edges)

            for i in range(batch_size):
                num_edge = len(batch.edges[i])  # the total number of nodes before i-th graph
                edges_before = sum(len(batch.edges[u]) for u in range(i))
                target[edges_before: edges_before + num_edge] = torch.from_numpy(batch.mwm[i])

            # calculate cross entropy
            target = target.to(device)
            cross_entropy_loss = F.binary_cross_entropy(output, target)

            # if an optimizer is provided, update parameters
            if optimizer is not None:
                optimizer.zero_grad()
                cross_entropy_loss.backward()
                optimizer.step()

            mean_loss += cross_entropy_loss.item() * data_loader.batch_size
            n_samples_processed += data_loader.batch_size

    mean_loss /= n_samples_processed
    return mean_loss


## Toy Example

In [18]:
DIR = '/home/zwt/MWM/bc'
train_files_path = os.path.join(DIR, 'samples/train/sample_*.pkl')
valid_files_path = os.path.join(DIR, 'samples/valid/sample_*.pkl')
trained_model_dir = os.path.join(DIR, 'trained_models')

train_files = glob.glob(train_files_path)[:100]
valid_files = glob.glob(valid_files_path)[:20]
device = f"cuda:0"
EPOCH_SAMPLE_NUM = 100
batch_size = 5
valid_batch_size = 5

In [19]:
valid_data = GraphDataset(valid_files)
valid_loader = torch_geometric.data.DataLoader(valid_data, valid_batch_size, shuffle=False)

policy = GraphConvolution().to(device)
optimizer = torch.optim.Adam(policy.parameters(), lr=0.0002)

In [21]:
def node2edge(data):
    """
    :param data: torch_geometric.data.Data
        data.x: node feature
        data.edges: edge index set
        data.nodes: node index set
    :return:
        output: torch [total_edges, 2 * NODE_FEATURE_NUM]
            the edge features
    """
    # change node feature into edge feature
    batch_size = len(data.edges)
    total_edges = sum(len(data.edges[u]) for u in range(batch_size))
    output = torch.zeros(total_edges, 2 * NODE_FEATURE_NUM)

    for i in range(batch_size):
        _edges = data.edges[i]  # edge set of the i-th graph
        _nodes = data.nodes[i]  # node set of the i-th graph
        num_edge = len(_edges)
        num_node = len(_nodes)

        # the total number of nodes before i-th graph
        nodes_before = sum(len(data.nodes[u]) for u in range(i))
        edges_before = sum(len(data.edges[u]) for u in range(i))

        # get the node feature x of i-th graph
        _x = data.x[nodes_before: nodes_before + num_node]

        for j in range(num_edge):
            v1 = _edges[j][0]
            v2 = _edges[j][1]
            v1_index = _nodes.index(v1)  # get index of node
            v2_index = _nodes.index(v2)
            output[edges_before + j, :] = torch.cat((_x[v1_index, :], _x[v2_index, :]))

    return output

In [22]:
rng = np.random.RandomState(0)
epoch_train_files = rng.choice(train_files, int(np.floor(EPOCH_SAMPLE_NUM / batch_size)) * batch_size,
                               replace=True)
train_data = GraphDataset(epoch_train_files)
train_loader = torch_geometric.data.DataLoader(train_data, batch_size, shuffle=True)

for batch in train_loader:
    batch = batch.to(device)
    output = node2edge(batch)
    print(output.shape)


torch.Size([328, 4])
torch.Size([322, 4])
torch.Size([323, 4])
torch.Size([340, 4])
torch.Size([306, 4])
torch.Size([343, 4])
torch.Size([305, 4])
torch.Size([323, 4])
torch.Size([337, 4])
torch.Size([309, 4])
torch.Size([318, 4])
torch.Size([305, 4])
torch.Size([337, 4])
torch.Size([352, 4])
torch.Size([337, 4])
torch.Size([314, 4])
torch.Size([333, 4])
torch.Size([301, 4])
torch.Size([309, 4])
torch.Size([340, 4])


In [None]:
# rng = np.random.RandomState(0)

# for epoch in range(10000):
#     # train
#     epoch_train_files = rng.choice(train_files, int(np.floor(EPOCH_SAMPLE_NUM / batch_size)) * batch_size,
#                                    replace=True)
#     train_data = GraphDataset(epoch_train_files)
#     train_loader = torch_geometric.data.DataLoader(train_data, batch_size, shuffle=True)
#     train_loss = process(policy, train_loader, device, optimizer)

#     # validate
#     valid_loss = process(policy, valid_loader, device, None)
    
#     if epoch%100==0:
#         print(f'Epoch: {epoch}, Train Loss: {train_loss:0.4f}, Valid Loss: {valid_loss:0.4f}. ')