In [1]:
import numpy as np
import pandas as pd
import torch
from torch import nn
from transformers import BertTokenizerFast, BertModel
import torch_scatter
import inspect
from transformers import get_cosine_schedule_with_warmup
from torch.cuda import amp
import ast
import time

In [2]:
#!pip install torch-scatter

## Create For basic learnable parameter, can be weighted matrix or basis vector

In [3]:
def get_param(shape):
    param = nn.Parameter(torch.Tensor(*shape))
    nn.init.xavier_normal_(param.data)
    return param
def com_mult(a, b):
    r1, i1 = a[:, 0], a[:, 1]
    r2, i2 = b[:, 0], b[:, 1]
    return torch.stack([r1 * r2 - i1 * i2, r1 * i2 + i1 * r2], dim = -1)

def conj(a):    
    a[:, 1] = -a[:, 1]
    return a
def ccorr(a, b):
    return torch.irfft(com_mult(conj(torch.rfft(a, 1)), torch.rfft(b, 1)), 1, signal_sizes=(a.shape[-1],))

In [4]:
class CompGcnBasis(nn.Module):
    nodes_dim = 0
    head_dim = 0
    tail_dim = 1
    def __init__(self, in_channels, out_channels, num_relations, num_basis_vector,act = torch.tanh,cache = True,dropout = 0.2):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_relations = num_relations
        self.num_basis_vector = num_basis_vector
        self.act = act
        self.device = None
        self.cache = cache
        
        #----------- Creating learnable basis vector , shape is (num_basis, feature_size(in channel))
        self.basis_vector = get_param((num_basis_vector, in_channels))
        # this weight matrix initialize the weight features for each relation(including inverse), shape is (2*num_relations, num_basis)
        self.rel_weight = get_param((num_relations*2, self.num_basis_vector))
        # this learnable weight matrix is for projection, that project each relation to the same dimension of node_dimension
        self.weight_rel = get_param((in_channels,out_channels))
        # add another embedding for loop
        self.loop_rel = get_param((1,in_channels))
        #----------- Creating three updated matrix, as three kind of relations updating, in, out, loop
        # using for updating weight
        self.w_in = get_param((in_channels,out_channels))
        self.w_out = get_param((in_channels,out_channels))
        self.w_loop = get_param((in_channels,out_channels))
        
        # define some helpful parameter
        self.in_norm, self.out_norm = None, None
        self.in_index, self.out_index = None, None
        self.in_type, self.out_type = None, None
        self.loop_index, self.loop_type =None, None
        
        self.drop = nn.Dropout(dropout)
        self.bn = nn.BatchNorm1d(out_channels)
    def relation_transform(self, entity_embedding, relation_embedding,type_):
        '''
        This function given entity embedding and relation embedding, in order return three types of 
        non-parameterized operations, which is subjection, corr, multiplication
        '''
        assert type_ in ["mul","sub","corr"], "not implemented now"
        if type_ == "mul":
            out = entity_embedding*relation_embedding
        elif type_ == "sub":
            out = entity_embedding - relation_embedding
        else:
            out = ccorr(entity_embedding,relation_embedding)
        return out
    
    def normalization(self, edge_index, num_entity):
        '''
        As normal GCN, this function calculate the normalization adj matrix 
        '''
        head, tail = edge_index
        edge_weight = torch.ones_like(head).float()
        degree = torch_scatter.scatter_add(edge_weight,head,dim_size=num_entity,dim = self.nodes_dim)
        degree_inv = degree.pow(-0.5)
        # if inf, in order to prevent nan in scatter function
        degree_inv[degree_inv == float("inf")] = 0
        norm = degree_inv[head] * edge_weight * degree_inv[tail]
        return norm
    def scatter_function(self,type_, src, index, dim_size = None):
        '''
        This function given scatter_ type, which should me max, mean,or sum, given source array, given index array, given dimension size
        '''
        assert type_.lower() in ["sum","mean","max"]
        return torch_scatter.scatter(src, index, dim=0,out=None,dim_size = dim_size, reduce= type_)
    
    def propogating_message(self, method, node_features,edge_index,edge_type, rel_embedding, edge_norm,mode,type_):
        '''
        This function done the basic aggregation
        '''
        assert method in ["sum", "mean", "max"]
        assert mode in ["in","out","loop"]
        size = node_features.shape[0]
        coresponding_weight = getattr(self, 'w_{}'.format(mode))
        #-------------- this index selection: given relation embedding and relation_basic representation, choose the inital basis vector part
        relation_embedding = torch.index_select(rel_embedding,dim = 0, index = edge_type)
        # ------------- using index of tail in edge index to represent head by relation
        node_features = node_features[edge_index[1]]
        out = self.relation_transform(node_features, relation_embedding,type_)
        out = torch.matmul(out,coresponding_weight)
        out = out if edge_norm is None else out * edge_norm.view(-1, 1)
        out = self.scatter_function(method,out,edge_index[0],  size)
        return out    
    def forward(self, nodes_features, edge_index,edge_type):
        '''
        Forward propogate function:
            Given input nodes_features, adj_matrix, relation_matrix
        '''
        with amp.autocast():
            if self.device is None:
                self.device = edge_index.device
            # ----------- First done the basis part, which means represent each relation using a vector space defining previously
            relation_embedding = torch.mm(self.rel_weight,self.basis_vector)
            # ----------- add a self-loop dimension
            relation_embedding = torch.cat([relation_embedding,self.loop_rel],dim = 0)
            num_edges = edge_index.shape[1]//2
            num_nodes = nodes_features.shape[self.nodes_dim]
            if not self.cache or self.in_norm == None:
                #---------------- in represent in_relation, out represent out_relation
                self.in_index, self.out_index = edge_index[:,:num_edges], edge_index[:,num_edges:]
                self.in_type, self.out_type = edge_type[:num_edges], edge_type[num_edges:]
                # --------------- create self-loop part
                self.loop_index = torch.stack([torch.arange(num_nodes), torch.arange(num_nodes)]).to(self.device)
                self.loop_type = torch.full((num_nodes,), relation_embedding.shape[0]-1, dtype = torch.long).to(self.device)
                # -------------- create normalization part
                self.in_norm = self.normalization(self.in_index, num_nodes)
                self.out_norm = self.normalization(self.out_index, num_nodes)
            #print(self.in_norm.isinf().any())
            in_res = self.propogating_message('sum',nodes_features,self.in_index,self.in_type, relation_embedding,self.in_norm,"in","sub")
            loop_res = self.propogating_message('sum',nodes_features,self.loop_index,self.loop_type, relation_embedding,None,"loop","sub")
            out_res = self.propogating_message('sum',nodes_features,self.out_index,self.out_type, relation_embedding,self.out_norm,"out","sub")
            # I don't know why but source code done it
            out = self.drop(in_res)*(1/3) + self.drop(out_res)*(1/3) + loop_res*(1/3)
            # update the relation embedding
            out_2 = torch.matmul(relation_embedding,self.weight_rel)
            return self.act(out),out_2

## Test layer

## check some stats


In [5]:
class CompGcn_non_first_layer(nn.Module):
    nodes_dim = 0
    head_dim = 0
    tail_dim = 1
    def __init__(self, in_channels, out_channels, num_relations,act = torch.tanh,dropout = 0.2):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_relations = num_relations
        self.act = act
        self.device = None
        
        # this learnable weight matrix is for projection, that project each relation to the same dimension of node_dimension
        self.weight_rel = get_param((in_channels,out_channels))
        # add another embedding for loop
        self.loop_rel = get_param((1,in_channels))
        #----------- Creating three updated matrix, as three kind of relations updating, in, out, loop
        # using for updating weight
        self.w_in = get_param((in_channels,out_channels))
        self.w_out = get_param((in_channels,out_channels))
        self.w_loop = get_param((in_channels,out_channels))
        self.drop = nn.Dropout(dropout)
        self.bn = nn.BatchNorm1d(out_channels)
    def relation_transform(self, entity_embedding, relation_embedding,type_):
        '''
        This function given entity embedding and relation embedding, in order return three types of 
        non-parameterized operations, which is subjection, corr, multiplication
        '''
        assert type_ in ["mul","sub","corr"], "not implemented now"
        if type_ == "mul":
            out = entity_embedding*relation_embedding
        elif type_ == "sub":
            out = entity_embedding - relation_embedding
        else:
            out = ccorr(entity_embedding,relation_embedding)
        return out
    
    def normalization(self, edge_index, num_entity):
        '''
        As normal GCN, this function calculate the normalization adj matrix 
        '''
        head, tail = edge_index
        edge_weight = torch.ones_like(head).float()
        degree = torch_scatter.scatter_add(edge_weight,head,dim_size=num_entity,dim = self.nodes_dim)
        degree_inv = degree.pow(-0.5)
        # if inf, in order to prevent nan in scatter function
        degree_inv[degree_inv == float("inf")] = 0
        norm = degree_inv[head] * edge_weight * degree_inv[tail]
        return norm
    def scatter_function(self,type_, src, index, dim_size = None):
        '''
        This function given scatter_ type, which should me max, mean,or sum, given source array, given index array, given dimension size
        '''
        assert type_.lower() in ["sum","mean","max"]
        return torch_scatter.scatter(src, index, dim=0,out=None,dim_size = dim_size, reduce= type_)
    
    def propogating_message(self, method, node_features,edge_index,edge_type, rel_embedding, edge_norm,mode,type_):
        '''
        This function done the basic aggregation
        '''
        assert method in ["sum", "mean", "max"]
        assert mode in ["in","out","loop"]
        size = node_features.shape[0]
        coresponding_weight = getattr(self, 'w_{}'.format(mode))
        #-------------- this index selection: given relation embedding and relation_basic representation, choose the inital basis vector part
        relation_embedding = torch.index_select(rel_embedding,dim = 0, index = edge_type)
        # ------------- using index of tail in edge index to represent head by relation
        node_features = node_features[edge_index[1]]
        out = self.relation_transform(node_features, relation_embedding,type_)
        out = torch.matmul(out,coresponding_weight)
        out = out if edge_norm is None else out * edge_norm.view(-1, 1)
        out = self.scatter_function(method,out,edge_index[0],  size)
        return out    
    def forward(self, nodes_features, edge_index,edge_type,relation_embedding):
        '''
        Forward propogate function:
            Given input nodes_features, adj_matrix, relation_matrix
        '''
        with amp.autocast():
            if self.device is None:
                self.device = edge_index.device
            # ----------- add a self-loop dimension
            relation_embedding = torch.cat([relation_embedding,self.loop_rel],dim = 0)
            num_edges = edge_index.shape[1]//2
            num_nodes = nodes_features.shape[self.nodes_dim]
            #---------------- in represent in_relation, out represent out_relation
            self.in_index, self.out_index = edge_index[:,:num_edges], edge_index[:,num_edges:]
            self.in_type, self.out_type = edge_type[:num_edges], edge_type[num_edges:]
            # --------------- create self-loop part
            self.loop_index = torch.stack([torch.arange(num_nodes), torch.arange(num_nodes)]).to(self.device)
            self.loop_type = torch.full((num_nodes,), relation_embedding.shape[0]-1, dtype = torch.long).to(self.device)
            # -------------- create normalization part
            self.in_norm = self.normalization(self.in_index, num_nodes)
            self.out_norm = self.normalization(self.out_index, num_nodes)
            #print(self.in_norm.isinf().any())
            in_res = self.propogating_message('sum',nodes_features,self.in_index,self.in_type, relation_embedding,self.in_norm,"in","sub")
            loop_res = self.propogating_message('sum',nodes_features,self.loop_index,self.loop_type, relation_embedding,None,"loop","sub")
            out_res = self.propogating_message('sum',nodes_features,self.out_index,self.out_type, relation_embedding,self.out_norm,"out","sub")
            # I don't know why but source code done it
            out = self.drop(in_res)*(1/3) + self.drop(out_res)*(1/3) + loop_res*(1/3)
            # update the relation embedding
            out_2 = torch.matmul(relation_embedding,self.weight_rel)
            return self.act(out),out_2[:-1]# ignoring self loop inserted 

## Test layer

## CompGCN total + score function

In [6]:
class CompGcn_total(nn.Module):
    def __init__(self, channel_ls ,num_relation, num_basis_vector, edge_idx, edge_type,num_layers = 2, basis = False):
        '''
        Notice that in preprocessing, we assume that the node number will not be changed on the graph, only change relation.
        input params:
            1. channel_ls: a channel list containing all conv channel
            2. num_relation, the number of relations_type for each graph(after preprocessing, should be the same for each graph)
            3. num_basis_vector, the first layer basis of first graph.
            4. edge_idx, adj matrix 
            5. edge_type, relation init
            6. basis, whether need basis
        '''
        assert len(channel_ls) == num_layers + 1 , "channel number should be layer numbers + 1 , got length "+str(len(channel_ls))+" with number of layers "+str(num_layers)
        super(CompGcn_total, self).__init__()
        self.edge_idx = edge_idx
        self.edge_type = edge_type
        self.basis = basis
        self.GCN_block = nn.Sequential()
        for i in range(num_layers):
            if basis and i == 0:
                self.GCN_block.add_module("Basis_conv_layer",  CompGcnBasis(in_channels = channel_ls[0], out_channels= channel_ls[1],
                                                                            num_relations=num_relation,
                                                                            num_basis_vector= num_basis_vector)) 
            else:
                self.GCN_block.add_module("Conv_layer"+str(i),CompGcn_non_first_layer(channel_ls[i], channel_ls[i+1], num_relation))
    def forward(self, init_features = None, node_embd = None,rel_embd = None,device = None):
        with amp.autocast():
            for i, blk in enumerate(self.GCN_block):
                if self.basis and i == 0:
                    node_embd, rel_embd = blk(init_features, self.edge_idx.to(device), self.edge_type.to(device))
                else:
                    node_embd, rel_embd = blk(node_embd,self.edge_idx.to(device), self.edge_type.to(device), rel_embd)
            return node_embd, rel_embd

## Test layer

In [7]:
class CompGcn_with_temporal(nn.Module):
    def __init__(self, conv_dim, num_relation, num_entity,node_dim, num_hiddens ,num_basis_vector, edge_idx, edge_type, model,num_class,time_stamp = 2,
                 num_layers = 2,score_func="TransE"):
        '''
        Notice that in preprocessing, we assume that the node number will not be changed on the graph, only change relation.
        input params:
            1. conv_dim, a list of tuple, [(channel 1, channel2, channel 3), (channel 3, channel 4, channel 5),...]
            2. num_layer, number of CompGCN layer, now assume to 2 per graph
            3. num_relation, the number of relations_type for each graph(after preprocessing, should be the same for each graph)
            4. num_entity, number of entity, should be the same for each graph
            5. node_dimension, dimension of nodes
            6. num_basis_vector, the first layer basis of first graph.
            7. edge_idx, a list of edge_idx
            8. edge_type, a list of edge_type
            9. num_class, classification number 
            10. time stamp: How many time steps 
        '''
        assert len(conv_dim) == time_stamp, "time stamp length should be the same as number of convloution dimension list!, got time stamp "+str(time_stamp)+" with conv_dim "+str(len(conv_dim))
        assert len(edge_idx) == len(edge_type) == len(conv_dim) == time_stamp, "Number of KG mismatched with time stamp!"
        super(CompGcn_with_temporal,self).__init__()
        self.model = model
        self.num_relation = num_relation
        self.node_features = get_param(shape= (num_entity,node_dim))
        assert node_dim == conv_dim[0][0]
        self.temporal_blk = nn.Sequential()
        self.drop_bert = nn.Dropout(0.2)
        for i in range(time_stamp):
            if i == 0:
                self.temporal_blk.add_module("Temporal block Basis" , CompGcn_total(conv_dim[i], num_relation, num_basis_vector, 
                                                                                 edge_idx[i], edge_type[i], num_layers,True))
            else:
                self.temporal_blk.add_module("Temporal block" + str(i), CompGcn_total(conv_dim[i], num_relation, num_basis_vector, 
                                                                                 edge_idx[i], edge_type[i], num_layers,False))
        self.ln1 = nn.Linear(768 + conv_dim[-1][-1], num_class)
        self.dropout_node = nn.Dropout(0.2)
        self.dropout_rel = nn.Dropout(0.2)
        self.ln1 = self.func_init(self.ln1)
        self.score = score_func
        # add GRU 
        self.W_xr = nn.Linear(conv_dim[0][-1],num_hiddens)
        self.W_xz = nn.Linear(conv_dim[0][-1],num_hiddens)
        self.W_xh = nn.Linear(conv_dim[0][-1],num_hiddens)
        self.W_hr = nn.Linear(num_hiddens,num_hiddens, bias = True)
        self.W_hz = nn.Linear(num_hiddens,num_hiddens, bias = True)
        self.W_hh = nn.Linear(num_hiddens,num_hiddens, bias = True)
        self.act_update = nn.Sigmoid()
        self.act_hidden = nn.Tanh()
        self.act_reset = nn.Sigmoid()
    def func_init(self,m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
        return m
    def init_state(self, device):
        return torch.zeros(self.num_relation * 2 + 1, conv_dim[1][0], device = device)
    def forward(self, input_ids,segment_ids, attention_mask  ,head_index, tail_index, rel_index,state):
        '''
        Node index and rel index are corresponding information in a batch for bert part, we only care about the node, edge relation in a batch.
        Since the embedding is tail - relation to head, the source will be tail, target will be head
        '''
        with amp.autocast():
            device = self.node_features.device
            bert_out = self.model(input_ids = input_ids, attention_mask = attention_mask, token_type_ids = segment_ids)["pooler_output"]
            bert_out = self.drop_bert(bert_out)
            for i, blk in enumerate(self.temporal_blk):
                if i == 0:
                    node_embd, rel_embd = blk(self.node_features, device = device)
                else:
                    node_embd, rel_embd = blk(node_embd = node_embd, rel_embd = rel_embd,device = device)
                # with GRU
                node_embd, rel_embd = self.dropout_node(node_embd), self.dropout_rel(rel_embd)
                Z = self.act_update(self.W_xz(rel_embd) + self.W_hz(state))
                R = self.act_reset(self.W_xr(rel_embd) + self.W_hr(state))
                H_candidate = self.act_hidden(self.W_xh(rel_embd) + self.W_hh(R * state))
                state = Z * state + (1 - Z) * H_candidate
            # then choose corresponding index out:
            # shape should be (len(index), hidden_out)
            hidden_node_state = node_embd[tail_index,:]
            hidden_rel_state  = rel_embd[rel_index,:]
            hidden_target_state = node_embd[head_index,:]
            head, rel, tail      = (
                                        hidden_node_state, 
                                        hidden_rel_state, 
                                        hidden_target_state
                                   )
            score                = Score_func(head, rel, tail, func_type=self.score)
            score                = score.forward_score()
            score = torch.cat([score, bert_out], axis = 1)
            score = self.ln1(score)
        return score # hidden_node_state, hidden_rel_state, hidden_target_state

In [8]:
class Score_func(nn.Module):
    """
        Func:
            Contain all the score functions we often meet. Now we finished ConvE, TransE, TransH, DisMult
        
        Args:
            sub_emb: the head embedding (subject)
            rel_emb: the relation embedding (relation)
            obj_emb: the tail embedding (object)
            kernel_size: a tuple. Only when the score function is ConvE, we need it to do the 
                        convolutional computation. i.e. kernel_size = (hight, width)
            func_type: a string indicating the score function we wanna use. default to be "TransE"
            conv_drop: a list containing floats, indicating the dropout rate we will use in the ConvE. 
                        If None, set to be all the same as "dropout" value. Default to all be the 
                        tuned parameter in compGCN paper.
            conv_bias: whether to use bias. Default to be True
            gamma: a float - margin hyperparameter. Only when we use TransE as our score function, we 
                    need it. Default to be 40.0, the tuned best parameter in compGCN.
    """
    
    def __init__(self, sub_emb, rel_emb, obj_emb, func_type="transE", 
                 kernel_size = None, conv_drop=(0.2, 0.3, 0.2), 
                 conv_bias=True, gamma=40.0):
        # we can't use self.__class__, because it may cause a recursive problem
        super(Score_func, self).__init__()
        
        self.func_type = func_type.lower()
        self.gamma     = gamma
        self.sub_emb   = sub_emb
        self.rel_emb   = rel_emb
        self.obj_emb   = obj_emb
        
        if self.func_type == "transh":
            self.relation_norm_embedding  = torch.nn.Embedding(num_embeddings=relation_num,
                                                              embedding_dim=self.dimension)
            self.relation_hyper_embedding = torch.nn.Embedding(num_embeddings=relation_num,
                                                               embedding_dim=self.dimension)
            self.entity_embedding         = torch.nn.Embedding(num_embeddings=entity_num,
                                                               embedding_dim=self.dimension)
        
        if self.func_type == "conve":
            assert not kernel_size is None  # to ensure that the kernel size is defined
            
            if not conv_drop:
                self.hidden_drop = [dropout, dropout, dropout]
            else:
                l = len(hidden_drop)
                assert l <= 3  # ensure the length of hidden_drop smaller equal to 3
                if l == 1:
                    self.conv_drop = [conv_drop[0], conv_drop[0], conv_drop[0]]
                elif l == 2:
                    self.conv_drop = [conv_drop[0], conv_drop[0], conv_drop[1]]
                else:
                    self.conv_drop = conv_drop
                
            
            self.kernel_size    = kernel_size
            self.bias           = conv_bias
            
            self.bn0            = torch.nn.BatchNorm2d(1)
            self.bn1            = torch.nn.BatchNorm2d(self.out_channels)
            self.bn2            = torch.nn.BatchNorm1d(self.kernel_size)

            self.hidden_drop    = torch.nn.Dropout(self.conv_drop[0])
            self.hidden_drop2   = torch.nn.Dropout(self.conv_drop[1])
            self.feature_drop   = torch.nn.Dropout(self.conv_drop[2])
            self.m_conv1        = torch.nn.Conv2d(1, out_channels=self.out_channels, 
                                                  kernel_size=(self.kernel_size, self.kernel_size), 
                                                  stride=1, padding=0, bias=self.bias)

            flat_sz_h           = int(2*self.kernel_size[1]) - self.kernel_size + 1
            flat_sz_w           = self.kernel_size[0] - self.kernel_size + 1
            self.flat_sz        = flat_sz_h * flat_sz_w * self.out_channels
            self.fc             = torch.nn.Linear(self.flat_sz, self.kernel_size)
    
    def concat(self, e1_embed, rel_embed):
        e1_embed    = e1_embed. view(-1, 1, self.p.embed_dim)
        rel_embed   = rel_embed.view(-1, 1, self.p.embed_dim)
        stack_inp   = torch.cat([e1_embed, rel_embed], 1)
        stack_inp   = torch.transpose(stack_inp, 2, 1).reshape((-1, 1, 2*self.p.k_w, self.p.k_h))
        return stack_inp
    
    def projected(self, ent, norm):
        norm = F.normalize(norm, p=2, dim=-1)
        return ent - torch.sum(ent * norm, dim = 1, keepdim=True) * norm
    
    def forward_score(self):
        with amp.autocast():
            if   self.func_type == "transe":
                x        = self.sub_emb + self.rel_emb - self.obj_emb
            elif self.func_type == "transh":
                head       = self.entity_embedding(self.sub_emb)
                tail       = self.entity_embedding(self.obj_emb)
                r_norm     = self.relation_norm_embedding(self.rel_emb)
                r_hyper    = self.relation_hyper_embedding(self.rel_emb)
                head_hyper = self.projected(head, r_norm)
                tail_hyper = self.projected(tail, r_norm)
                x          = torch.norm(head_hyper + r_hyper - tail_hyper, p=2, dim=2)
            elif self.func_type == "distmult":
                x        =   torch.mm(self.sub_emb + self.rel_emb, self.obj_emb.transpose(1, 0))
                x        +=  self.bias.expand_as(x)
            elif self.func_type == "conve":
                stk_inp  = self.concat(sub_emb, rel_emb)
                x        = self.bn0(stk_inp)
                x        = self.m_conv1(x)
                x        = self.bn1(x)
                x        = F.relu(x)
                x        = self.feature_drop(x)
                x        = x.view(-1, self.flat_sz)
                x        = self.fc(x)
                x        = self.hidden_drop2(x)
                x        = self.bn2(x)
                x        = F.relu(x)

                x = torch.mm(x, self.obj_emb.transpose(1,0))
                x += self.bias.expand_as(x)
            return x

In [9]:
# test = ('de/Franz_Tobisch',
#  'Prabhas',
#  'Igor_Strelbin',
#  'Quentin_N._Burdick',
#  'Bamir_Topi',
#  'Federal_University_of_Amazonas',
#  'Emmanuel_Baffour',
#  'R.E.M.')

## 实现两种结合方法
1. 不用子图，对整张图补零，然后再和语义信息结合之前选择导出对应的index，这样子开销会很大，不一定能够支持训练。
2. 用子图，对子图补零，这样不会出现计算不了的问题，而且由于CompGCN 其可学习参数为 relation basis，relation embedding， node feature embd

## SubGraph implementation

In [10]:
# class CompGcn_subgraph(nn.Module):
#     def __init__(self, num_entity,node_dim ):
#         '''
#         In a subgraph, we first need to choose the index, take all the index containing in it out:
       
#         '''
#         super().__init__()
#         self.node_features_whole = get_param(shape = (num_entity, node_dim))
#     def select_index(self,input_):
#         '''
#         这个函数需要通过给定的input string，去找出整图中对应的edge_idx 和 edge_type(当然可以在外面做)， 然后根据对应的edge idx 的value 去原始的Node_features 里去找
#         对应的子图的node feature，其维度应该为(num_nodes_sub, node_dim), 对于relation，则不需要改动，仍用整张图即可
#         '''
#         raise NotImplementedError
#     def forward(self, input_):
#         batch_node_feature, batch_edge_idx, batch_edge_type = self.select_index(*input_)
#         return batch_node_feature

In [11]:
# class CompGcn_subgraph_temporal(CompGcn_subgraph):
#     def __init__(self, conv_dim,num_relation,node_dim, num_basis_vector, edge_idx, edge_type):
#         '''
#         1. conv_dim: (in_channel, out_channel)
#         2. num_relation: number of relation for whole graph
#         4. node_dim: dimension for each node
#         5. basis_vector
#         6. a list of edge_idx for sub_graph
#         6. a list of edge_type for sub_graph
#         '''
#         super().__init__(num_entity, node_dim)
#         self.conv1 = CompGcn_total(conv_dim[0], conv_dim[1], num_relation, num_basis_vector, edge_idx[0],edge_type[0],True)
#         self.conv2 = CompGcn_total(conv_dim[0], conv_dim[1], num_relation, num_basis_vector, edge_idx[1],edge_type[1],False)
#         self.conv3 = CompGcn_total(conv_dim[0], conv_dim[1], num_relation, num_basis_vector, edge_idx[2],edge_type[2],False)
#         self.RNN_nodes = nn.RNN(input_size = conv_dim[0], hidden_size = conv_dim[1])
#         self.RNN_relembd = nn.RNN(input_size = conv_dim[0], hidden_size = conv_dim[1])
#         self.dropout_node = nn.Dropout(0.2)
#         self.dropout_rel = nn.Dropout(0.4)
#     def forward2(self, input_string):
#         batch_node_features, head_index, rel_index, tail_index =self.forward(input_string)
#         node_embd1, rel_embd1 = self.conv1(batch_node_features) 
#         node_embd2, rel_embd2 = self.conv2(node_embeding = node_embd1, rel_embeding = rel_embd1)
#         node_embd3, rel_embd3 = self.conv3(node_embeding = node_embd2, rel_embeding = rel_embd2)
#         hidden_node_state = self.dropout_node(self.RNN_nodes(torch.cat([node_embd1[tail_index,:].unsqueeze(0), node_embd2[tail_index,:].unsqueeze(0), node_embd3[tail_index,:].unsqueeze(0)]))[0])
#         hidden_rel_state = self.dropout_rel(self.RNN_relembd(torch.cat([rel_embd1[rel_index,:].unsqueeze(0), rel_embd2[rel_index,:].unsqueeze(0),rel_embd3[rel_index,:].unsqueeze(0)]))[0])
#         hidden_node_state = self.dropout_node(self.RNN_nodes(torch.cat([node_embd1[head_index,:].unsqueeze(0), node_embd2[head_index,:].unsqueeze(0), node_embd3[head_index,:].unsqueeze(0)]))[0])
#         return hidden_node_state[-1,:,:].squeeze(0), hidden_rel_state[-1,:,:].squeeze(0), hidden_target_state[-1,:,:].squeeze(0)

## Dataset and KG-Bert part:

In [12]:
def change_input(tokenizer, text1, text2=None, text3=None, labels = None,max_length=512):
    '''
    This function will change the given input from double to triple
    '''
    #do the basic tokenization without changing to index
    tokens_1 = tokenizer.tokenize(text1)
    if text2 is not None:
        tokens_2 = tokenizer.tokenize(text2)
    if text3 is not None:
        tokens_3 = tokenizer.tokenize(text3)
    #as shown in kg-bert, do the truncation
    while True:
        #do the trunctation 
        total_length = len(tokens_1)+len(tokens_2)+len(tokens_3)
        if total_length<= max_length-4:
            break
        if len(tokens_1)>len(tokens_2) and len(tokens_1)>len(tokens_3):
            tokens_1.pop()
        elif len(tokens_2)>len(tokens_1) and len(tokens_2)>len(tokens_3):
            tokens_2.pop()
        elif len(tokens_3)>len(tokens_2) and len(tokens_3)>len(tokens_1):
            tokens_3.pop()
        else:
            #else pop the token3(tail)
            tokens_3.pop()
    #segment encoding
    final_token = ["[CLS]"]+tokens_1+["[SEP]"]
    #segment for first sentence
    segment_ids = [0]*len(final_token)
    if text2 is not None:
        final_token+=tokens_2+["[SEP]"]
        segment_ids+=[1]*(len(tokens_2)+1)
    if text3 is not None:
        final_token+=tokens_3+["[SEP]"]
        segment_ids+=[0]*(len(tokens_3)+1)
    #change it to the index
    input_ids = tokenizer.convert_tokens_to_ids(final_token)
    #for padding
    padding = [0]*(max_length - len(input_ids))
    #for attention mask
    attention_mask = [1]*len(input_ids)
    input_ids+=padding
    attention_mask+= padding
    segment_ids+=padding
    assert len(input_ids) == max_length
    assert len(attention_mask) == max_length
    assert len(segment_ids) == max_length
    return {"input_ids": input_ids,
            "segment_ids": segment_ids,
            "attention_mask": attention_mask,
            "labels":labels,
    }

In [13]:
class language_Dataset(torch.utils.data.Dataset):
    def __init__(self, df):
        '''
        df is dataframe given previously
        '''
        self.df = df
        self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        '''
        This function will return the index
        '''
        dic = change_input(self.tokenizer,self.df.iloc[idx]["head"], self.df.iloc[idx]["relation"], self.df.iloc[idx]["tail"],self.df.iloc[idx]["labels"])
        return torch.tensor(dic["input_ids"]), torch.tensor(dic["segment_ids"]), torch.tensor(dic["attention_mask"]), torch.tensor(dic["labels"]), self.df.iloc[idx]["index_where"]

In [14]:
def select_index(head, relation,tail,head2idx,rel2idx,edge_idx,edge_type):
    '''
    This function select the index given by correct index
    '''
    batch_head_tail_idx = [[head2idx[h], head2idx[t]] for h,t in zip(head, tail)]
    rel_idx = [rel2idx[i] for i in relation]
    ls = []
    for i, j in zip(batch_head_tail_idx, rel_idx):
        # column index
        #print(np.argwhere(np.isin(edge_idx[0],i[0] )&np.isin(edge_idx[1], i[1])&np.isin(edge_type, j)))
        ls.append(np.argwhere(np.isin(edge_idx[0],i[0] )&np.isin(edge_idx[1], i[1])&np.isin(edge_type, j))[0])          
    idx = np.concatenate(ls)
#     print(idx)
    rel_value = edge_type[idx]
    head_value, tail_value = edge_idx[:,idx]
    assert len(head_value) == len(tail_value) == len(rel_value)
    return head_value.tolist(), tail_value.tolist(), rel_value.tolist()
# head_idx, tail_idx, rel_idx = select_index(head, relation,tail, head2idx, rel2idx,edge_idx_3, edge_type_2)
# hidden_node, hidden_rel, hidden_target = tmp(input_ids, seg_ids, att_mask,head_idx,tail_idx, rel_idx)
# hidden_node.shape, hidden_rel.shape, hidden_target.shape

In [15]:
# def train_with_amp(net, train_set, criterion, optimizer, epochs,batch_size, scheduler, gradient_accumulate_step, max_grad_norm ,device):
#     net.train()
#     # instantiate a scalar object  
#     print("train on " + str(device))
#     enable_amp = True if "cuda" in device.type else False
#     scaler = amp.GradScaler(enabled= enable_amp)
#     net.to(device)
#     global_step = 0
#     train_iter = torch.utils.data.DataLoader(train_set, batch_size = batch_size)
#     for epoch in range(epochs):
#         for idx, value in enumerate(train_iter):
#             input_ids, seg_ids, att_mask, labels, index = value
#             input_ids = input_ids.to(device)
#             att_mask =att_mask.to(device)
#             labels = labels.to(device)
#             seg_ids = seg_ids.to(device)
#             head_values = torch.tensor(index[0]).to(device)
#             tail_values = torch.tensor(index[2]).to(device)
#             rel_values = torch.tensor(index[1]).to(device)
#             # when forward process, use amp
#             with amp.autocast(enabled= enable_amp):
#                 output = net(input_ids, seg_ids, att_mask,head_values,tail_values, rel_values)  
#             loss = criterion(output, labels.view(-1,1).float())
#             # prevent gradient to 0
#             if gradient_accumulate_step > 1:
#                 # 如果显存不足，通过 gradient_accumulate 来解决
#                 loss = loss/gradient_accumulate_step
            
#             # 放大梯度，避免其消失
#             scaler.scale(loss).backward()
#             # do the gradient clip
#             gradient_norm = nn.utils.clip_grad_norm_(net.parameters(),max_grad_norm)
#             if (idx + 1) % gradient_accumulate_step == 0:
#                 # 多少 step 更新一次梯度
#                 # 通过 scaler.step 来unscale 回梯度值， 如果气结果不是infs 和Nans， 调用optimizer.step()来更新权重
#                 # 否则忽略step调用， 保证权重不更新
#                 scaler.step(optimizer)
#                 scaler.update()
#                 optimizer.zero_grad()
#                 global_step += 1
#                 scheduler.step()
#             # 每100次计算 print 出一次loss
#             if idx % 1000 == 0 or idx == len(train_iter) -1:
#                 with torch.no_grad():
#                     print("==============Epochs "+ str(epoch) + " ======================")
#                     print("loss: " + str(loss) + "; grad_norm: " + str(gradient_norm))
#                 torch.save({'epoch': epoch,
#                 'model_state_dict': net.state_dict(),
#                 'optimizer_state_dict': optimizer.state_dict(),
#                 'loss': loss},"./checkpoint.params")
#             print("successfully done one train")

In [18]:
def try_gpu(i=0):
    """Return gpu(i) if exists, otherwise return cpu()."""
    if torch.cuda.device_count() >= i + 1:
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')

# In[ ]:
def train_with_amp(net, train_set, criterion, optimizer, epochs,batch_size, scheduler, gradient_accumulate_step, max_grad_norm , num_gpu):
    net.train()   
    
    # instantiate a scalar object 
    ls          = []
    device_ids  = [try_gpu(i) for i in range(num_gpu)]
    device  = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    print("\ntrain on %s\n"%str(device_ids))
    enable_amp  = True if "cuda" in device_ids[0].type else False
    scaler      = amp.GradScaler(enabled= enable_amp)
    #net         = nn.DataParallel(net, device_ids = device_ids)
    net.to(device)
    train_iter  = torch.utils.data.DataLoader(train_set, batch_size = batch_size)
    for epoch in range(epochs):
        for idx, value in enumerate(train_iter):
            ini_time    = time.time()
            input_ids, seg_ids, att_mask, labels, index = value
            input_ids   = input_ids.to(device_ids[0])
            att_mask    = att_mask.to(device_ids[0])
            labels      = labels.to(device_ids[0])
            seg_ids     = seg_ids.to(device_ids[0])
            head_values = torch.tensor(index[0]).to(device_ids[0])
            tail_values = torch.tensor(index[2]).to(device_ids[0])
            rel_values  = torch.tensor(index[1]).to(device_ids[0])
            # when forward process, use amp
            init_state = net.init_state(device)
            with amp.autocast(enabled= enable_amp):
                output  = net(input_ids, seg_ids, att_mask,head_values,tail_values, rel_values, init_state)  
            loss        = criterion(output, labels.view(-1,1).float())
            # prevent gradient to 0
            if gradient_accumulate_step > 1:
                # 如果显存不足，通过 gradient_accumulate 来解决
                loss    = loss/gradient_accumulate_step
            
            # 放大梯度，避免其消失
            scaler.scale(loss).mean().backward()
            # do the gradient clip
            gradient_norm = nn.utils.clip_grad_norm_(net.parameters(),max_grad_norm)
            if (idx + 1) % gradient_accumulate_step == 0:
                # 多少 step 更新一次梯度
                # 通过 scaler.step 来unscale 回梯度值， 如果气结果不是infs 和Nans， 调用optimizer.step()来更新权重
                # 否则忽略step调用， 保证权重不更新
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()
                print("train 1 times")
            # 每1000次计算 print 出一次loss
            if idx % 1000 == 0 or idx == len(train_iter) -1:
                with torch.no_grad():
                    print("==============Epochs "+ str(epoch) + " ======================")
                    print("loss: " + str(loss) + "; grad_norm: " + str(gradient_norm))
                ls.append(loss.item())
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': net.state_dict(),
                    'param_groups': optimizer.state_dict()["param_groups"],
                    'loss': ls
                },"./checkpoint.params")
            with open("train_log", "a") as f:
                f.write("Epoch %s, Batch %s: %.4f sec\n"%(epoch, idx, time.time() - ini_time))
    

In [None]:
if __name__ == "__main__":
    edge_idx_1 = torch.tensor(np.load("../input/3-graphs-info/edge_idx0.npz")["arr_0"])
    edge_idx_2 = torch.tensor(np.load("../input/3-graphs-info/edge_idx1.npz")["arr_0"])
    edge_idx_3 = torch.tensor(np.load("../input/3-graphs-info/edge_idx2.npz")["arr_0"])
    num_nodes_0 = torch.tensor(np.load("../input/3-graphs-info/graph_0_num_nodes.npz")["arr_0"])
    num_nodes_1 = torch.tensor(np.load("../input/3-graphs-info/graph_1_num_nodes.npz")["arr_0"])
    num_nodes_2 = torch.tensor(np.load("../input/3-graphs-info/graph_2_num_nodes.npz")["arr_0"])
    num_relation_0 = torch.tensor(np.load("../input/3-graphs-info/graph_0_num_edges.npz")["arr_0"])
    num_relation_1 = torch.tensor(np.load("../input/3-graphs-info/graph_1_num_edges.npz")["arr_0"])
    num_relation_2 = torch.tensor(np.load("../input/3-graphs-info/graph_2_num_edges.npz")["arr_0"])
    edge_type_0 = torch.tensor(np.load("../input/3-graphs-info/edge_type0.npz")["arr_0"])
    edge_type_1 = torch.tensor(np.load("../input/3-graphs-info/edge_type1.npz")["arr_0"])
    edge_type_2 = torch.tensor(np.load("../input/3-graphs-info/edge_type2.npz")["arr_0"])
    head2idx = np.load('../input/3-graphs-info/graph_2entity2index.npy', allow_pickle=True).item()
    rel2idx =  np.load('../input/3-graphs-info/graph_2rel2index.npy', allow_pickle=True).item()
    train = pd.read_csv("../input/train-valid-test-dataset/train.csv").drop("Unnamed: 0", axis = 1)
    train["index_where"] = train["index_where"].apply(ast.literal_eval)
    train_set = language_Dataset(train)
    model = BertModel.from_pretrained('bert-base-uncased')
    num_hiddens = 10
    conv_dim, num_layer, node_dim, num_basis, edge_idx, edge_type = [[10, 20, num_hiddens], [10,20, num_hiddens], [10,20, num_hiddens]], 2, 10, 37, [edge_idx_1,edge_idx_2,edge_idx_3],[edge_type_0,edge_type_1,edge_type_2]
    tmp = CompGcn_with_temporal(conv_dim,num_relation_2, num_nodes_2+1, node_dim, num_hiddens,num_basis,edge_idx, edge_type,model,1,3)
    batch_size = 2
    loss = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.AdamW(tmp.parameters(), lr = 2e-4)
    scheduler = get_cosine_schedule_with_warmup(optimizer= optimizer, num_warmup_steps = 0, num_training_steps= len(torch.utils.data.DataLoader(train_set, batch_size = batch_size)), num_cycles = 0.5)
    train_with_amp(tmp, train_set, loss,optimizer,1,2, scheduler,1,1000,1)

In [None]:
# def train(net, training_dataset, testing_dataset = None, lr = 0.01, loss_func=nn.CrossEntropyLoss ,
#           num_epoches = 30, batch_size = 2, finetuning=False, plot=False):
#     """
#         Func:
#             To train the model
        
#         Args:
#             net: the neuronal network
#             training_dataset: a tensor form of dataset
#             testing_dataset: a tensor form of dataset
#             lr: learning rate
#             loss_func: a callable loss function defined in torch.nn (for example, nn.CrossEntropyLoss())
#             num_epoches: the number of epoches
#             batch_size: the size of a batch
#             finetuning: whether we need to initialize the weights. If true, pass a list 
#                         to indicate some specific layers, so that we can keep the weights fixed 
#                         within the specified layers (include the start and ending layer). For
#                         example, [2, 4] means keep 2 to 4 layers fixed: [start_layer, end_layer].
#                         We count layers begin from 0. Default to be False(initialize for every layer).
#                         If it's True, we just use the pre-trained weights, and continue to train the 
#                         model.
#     """
#     # to ensure we have a loss function
#     assert loss_func != None, "Must pass a loss function"
    
#     device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
#     print("Training on ", device)
    
#     # move all the tensor to GPU
#     def init_weights(m):
#         if type(m) == nn.Linear or type(m) == nn.Conv2d:
#             nn.init.xavier_uniform_(m.weight)#初始化
#     if not finetuning and not type(finetuning) == list:
#         net.apply(init_weights)#将其应用在每层网络
#     elif type(finetuning) == list:
#         # ensure the number is positive
#         for i in range(len(finetuning)):
#             if finetuning[i] < 0:
#                 finetuning[i] += len(list(net.children()))
#         if finetuning[0] > finetuning[1]:
#             finetuning[0], finetuning[1] = finetuning[1], finetuning[0]
#         count = 0
#         para_optim = []
#         for k in net.children():
#             count += 1
#             # finetuning layers should be changed properly
#             if count > finetuning[1]:
#                 for param in k.parameters():
#                     para_optim.append(param)
#             elif count < finetuning[0]:
#                 for param in k.parameters():
#                     para_optim.append(param)
#             else:
#                 for param in k.parameters():
#                     param.requires_grad = False
#     net.to(device)
    
#     training_iter = torch.utils.data.DataLoader(training_dataset, batch_size = batch_size, shuffle =True)
#     if testing_dataset is not None:
#         testing_iter = torch.utils.data.DataLoader(testing_dataset, shuffle = False, batch_size = batch_size)
    
#     # config the optimizer and the loss function
#     # ------------------------------------------------   Change the optimizer here -------------------------------------------------- #
#     optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr = lr)
#     # ------------------------------------------------------------------------------------------------------------------------------- #
    
#     loss = loss_func
#     if plot:
#         animator = d2l.Animator(xlabel = "epoch", xlim=[1,num_epoches], legend=["train_loss","train_Acc", "test_acc"])
#     num_batches = len(training_iter)

#     # plot the loss dynamically
#     train_loss = []
#     train_acc = []
#     net.train()    # turn on the training model
    
#     for epoch in range(num_epoches):
#         # save a model every 2 epoches
#         torch.save(model, "./model_v01_epoch(%s).pt"%(epoch+1))
#         metric = d2l.Accumulator(3)
#         # for every step
#         for i,value in enumerate(training_iter):
#             optimizer.zero_grad()
#         # ------------------------------------------------   Change the training input here -------------------------------------------------- #
#             input_ids, seg_ids, att_mask, labels, head, tail = value
#             head_values, tail_values, rel_values = select_index(head, tail, head2idx, edge_idx_3, edge_type_2)
#             input_ids = input_ids.to(device).long()
#             att_mask =att_mask.to(device).long()
#             labels = labels.to(device).long()
#             seg_ids = seg_ids.to(device).long()
#             output = net(input_ids, seg_ids, att_mask,head_values,tail_values, rel_values)
#             print(output.shape)
#         # ------------------------------------------------------------------------------------------------------------------------------------ #
#             l = loss(output.float(), labels[0].long())
#             l.backward()
#             optimizer.step()
#             with torch.no_grad():
#                 metric.add(l * input_ids.shape[0], d2l.accuracy(output,labels), input_ids.shape[0])
#                 train_loss.append(metric[0]/metric[2])
#                 train_acc.append( metric[1]/metric[2])
# #                 train_loss = metric[0]/metric[2]
# #                 train_acc  = metric[1]/metric[2]
# #                 if (i+1) % (num_batches//5) ==0 or i == num_batches-1:
# #                     animator.add(epoch+(i+1)/num_batches, (train_loss,train_acc, None))
#             #print("successfully train for period ",i)
#             if plot and i % 1000 == 0:
#                 with torch.no_grad():
#                     dic = {"epoch":epoch,"iteration":i,"optimizer":optimizer.state_dict(),"net":net.state_dict()}
#                     torch.save(dic,"./model_params")
#                     plt.figure()
#                     x_ = range(len(train_loss))
#                     plt.plot(x_,train_loss,color="yellow")
#                     plt.plot(x_, train_acc,color = "red")
#         if testing_dataset is not None:
#             test_acc = evaluate_accuracy(net, graph_info, testing_iter)
#             if plot:
#                 animator.add(epoch+1,(None,None,test_acc))
#         print("epoch %d, loss: %f"%(epoch, train_loss[-1]))
    
#     return train_loss, train_acc

In [None]:
# from d2l import torch as d2l

In [None]:
# train(net,test_set)

In [None]:
# conv_dim, num_layer, node_dim, num_basis, edge_idx, edge_type = [10, 20], 2, 10, 38, [edge_idx_1,edge_idx_2,edge_idx_3],[edge_type_0,edge_type_1,edge_type_2]
# net = CompGcn_with_temporal(conv_dim,num_layer, num_relation_2+1, num_nodes_2+1, node_dim, num_basis,edge_idx, edge_type, model)
