In this packages, notice that we have some specific attributes:
1. x: [num_head, num_node_features], which is node embeddings
2. edge_index, the adjacency matrix 
3. edge_attr, the edge embedding, if have
4. y, target
5. pos, the position matrix with shape [num_nodes, num_dimensions]

In [4]:
from transformers import AutoModel
from transformers import AutoTokenizer
import numpy as np
import torch
from torch_geometric.nn import MessagePassing
from torch import nn
from torch_geometric.utils import softmax
from torch_scatter import scatter
import json
import os
from tqdm import tqdm
import pickle
from torch.autograd import Variable
import math
import time

In [5]:
def make_one_hot(labels, C):
    labels = labels.unsqueeze(1)
    one_hot = torch.FloatTensor(labels.size(0), C).zero_().to(labels.device)
    target = one_hot.scatter_(1, labels.data, 1)
    target = Variable(target)
    return target
def freeze_net(module):
    for p in module.parameters():
        p.requires_grad = False

## implement QAGNN basic layer

In [6]:
class TextEncoder(nn.Module):
    '''
    This class is the encoder of text, represent as f_enc in paper
    '''
    
    def __init__(self, model_name, path = None):
        super().__init__()
        if path is None:
            # download model from website
            self.model = AutoModel.from_pretrained(model_name,output_hidden_states=True)
        else:
            self.model =AutoModel.from_pretrained(path)
        self.output_size = self.model.config.hidden_size
    def forward(self, *inputs, layers_id = -1):
        '''
        only support for transformer-based model
        '''
        input_ids, att_mask, token_types_ids, output_mask = inputs
#         print(input_ids)
#         print(att_mask)
#         print(token_types_ids)
        output = self.model(input_ids, attention_mask = att_mask, token_type_ids = token_types_ids)
        all_hidden = output[-1]
        final_hidden = all_hidden[layers_id]
        out_vec = self.model.pooler(final_hidden)
        return out_vec, all_hidden

In [7]:
class Customized_Embd(nn.Module):
    '''
    This class return the embedding of each concepts
    '''
    def __init__(self, num_concept, concept_in_dim, concept_out_dim, 
                 pretrained_concept_emb=None,use_contextualized = False, scale = 1.0, freeze_net_ = True,
                init_range = 0.02):
        super().__init__()
        self.use_contextualized = use_contextualized
        self.scale = scale
        if not use_contextualized:
            # get embedding 
            self.embd = nn.Embedding(num_concept, concept_in_dim)
            if pretrained_concept_emb is not None:
                print(pretrained_concept_emb.shape)
                self.embd.weight.data.copy_(pretrained_concept_emb)
            else:
                self.embd.weight.data.normal_(mean = 0.0, std = init_range)
            if freeze_net_:
                freeze_net(self.embd)
        
        if concept_in_dim != concept_out_dim:
            print("create projection")
            self.ln1 = nn.Linear(concept_in_dim, concept_out_dim)
            # use gelu activation function
            self.act = nn.GELU()
            
    def forward(self, index, contextualized_emb=None):
        '''
        index shape : (batch_size, a)
        contextualized_embd shape : (batch_size, b, embd_size)
        '''
        if contextualized_emb is not None:
            assert index.shape[0] == contextualized_emb.shape[0]
            if hasattr(self, 'ln1'):
                contextualized_emb = self.act(self.ln1(contextualized_emb * self.scale))
#                 print("doing projection")
#                 print(contextualized_emb.shape)
            else:
                contextualized_emb *= self.scale
            embd_dim = contextualized_emb.shape[-1]
            #print(embd_dim)
            return contextualized_emb.gather(1, index.unsqueeze(-1).expand(-1,-1, embd_dim))
        else:
            if hasattr(self, 'ln1'):
                return self.act(self.ln1(self.embd(index) * self.scale))
            else:
                return self.embd(index) * self.scale

In [8]:
class QAGNN_message_passing(nn.Module):
    def __init__(self, args, k, node_type, edge_type,input_size, hidden_size, output_size,
                dropout_rate =0.1):
        '''
        params:
            1. args, extract args
            2. k, num_layer
            3. node_type, num_node_type
            4. edge_type, num_edge_type
            5. input_size, initial embedding
            6. hidden_size, hidden embedding
            7. output_size, output_embedding
            8. dropout rate, rate of dropout
        '''
        super().__init__()
        # since need to use sequential
        assert input_size == output_size
        self.args = args
        self.node_type = node_type
        self.edge_type = edge_type
        assert input_size == hidden_size
        self.hidden_size = hidden_size
        self.embd_node_type = nn.Linear(self.node_type, hidden_size//2)
        # define basis function
#         print("input_size", input_size)
#         print("hidden_size", hidden_size)
#         print("output_size", output_size)
        self.basis_f = 'sin'
        self.embd_score = nn.Linear(hidden_size//2, hidden_size//2)
        # create edge_encoder
        # in paper, it's represent edge info
        self.edge_encoder = nn.Sequential(nn.Linear(edge_type +1 + node_type *2, hidden_size),nn.BatchNorm1d(hidden_size), nn.ReLU(), torch.nn.Linear(hidden_size, hidden_size))
        self.k = k
        # k layer per each
        self.gnn_layers = nn.ModuleList([GATConvE(args=args, embd_dim= hidden_size,num_node_type= node_type, num_edges= edge_type, edge_encoder= self.edge_encoder) for _ in range(k)])
        self.vh = nn.Linear(input_size, output_size)
        self.vx = nn.Linear(hidden_size, output_size)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(dropout_rate)
        self.dropout_rate = dropout_rate
    def message_passing_helpler(self, _X, edge_index, edge_type, _node_type, _node_feature_extra):
        for i in range(self.k):
            _X = self.gnn_layers[i](_X, edge_index, edge_type, _node_type, _node_feature_extra)
            _X = self.act(_X)
            _X = torch.nn.functional.dropout(_X, self.dropout_rate, training = self.training)
        return _X
    
    def forward(self, H,A, node_type, node_score, cache_output =False):
        '''
        H : node features from previous layer, shape like (batch_size, n_node,dim)
        A : (edge_index, edge_type), tuple
        node_type: long tensor of shape (batch_size, n_node):
            in this form, which is:
                0 == question, 1 == answer, 2 == others, 3 == context
        node_score, tensor of shape (batch_size, num_node ,1)
        '''
        # embed type
        bs, num_node = node_type.size()
        T = make_one_hot(node_type.view(-1).contiguous(),self.node_type).view(bs, num_node, self.node_type)
        # get node type embedding, linearly transform
        node_type_embd = self.act(self.embd_node_type(T)) # shape [bs, num_node, dim/2]
        
        # embedding score
        # shape is [1,1, dim/2]
        # create a sin form weight
        js = torch.arange(self.hidden_size//2).unsqueeze(0).unsqueeze(0).float().to(node_type.device)
        js = torch.pow(1.1, js)
        # [bs, num_node, dim/2]
        B = torch.sin(js * node_score)
        # passing linearly transform layer with GELU activation
        node_score_embd = self.act(self.embd_score(B))
#         print("node_score_embd shape",node_score_embd.shape)
#         print("node_type_embd shape", node_type_embd.shape)
        # get adjaency matrix and it's relation 
        edge_index, edge_type = A
        tmp = H.view(-1, H.size(2)).contiguous()# [bs * num_node, dim]
        _node_type = node_type.view(-1).contiguous() #[bs*num_node, ]
        
        # create node_feature_extra, which concat node_type_embedding and node_score_embd
        # shape [bs*num_node, dim]
        node_feature_extra = torch.cat([node_type_embd, node_score_embd],dim = 2).view(_node_type.size(0), -1).contiguous()
#         print("feature_extra shape", node_feature_extra.shape)
        X = self.message_passing_helpler(tmp, edge_index, edge_type, _node_type, node_feature_extra)
        X = X.view(node_type.size(0),node_type.size(1), -1)# [bs, num_node, dim]
        output = self.dropout(self.act(self.vh(H) + self.vx(X)))
        return output
# implement GATConvE
# by using message passing

class GATConvE(MessagePassing):
    '''
    args:
        1. embd_dim, the dimension of GNN hidden states
        2. num_node_type: number of node types, which is 4
        3. num_edges, number of edges
    '''
    def __init__(self, args, embd_dim, num_node_type, num_edges, edge_encoder,num_heads = 4, aggr = "add"):
        super().__init__(aggr=aggr)
        self.args = args
        assert embd_dim %2 == 0
        self.embd_dim = embd_dim
        self.num_node_type = num_node_type
        self.num_edges = num_edges
        self.edge_encoder = edge_encoder
        
        # calculate attention part
        self.num_heads = num_heads
        assert embd_dim % num_heads == 0
        # it has repeated in early part
        self.dim_per_head = embd_dim // num_heads
        # define linear transformation for key, value, query
        # query: q = (hidden, node_feature, score_feature)
        # key: k = (hidden, node_features, relation_features, score)
        self.ln_key = nn.Linear(3 * embd_dim, num_heads * self.dim_per_head)
        self.ln_msg = nn.Linear(3 * embd_dim, num_heads * self.dim_per_head)
        self.ln_query = nn.Linear(2 * embd_dim, num_heads * self.dim_per_head)
        
        self._alpha =None
        self.mlp = nn.Sequential(nn.Linear(embd_dim, embd_dim),nn.BatchNorm1d(embd_dim), nn.ReLU(),
                                nn.Linear(embd_dim, embd_dim))
    def forward(self, x, edge_index, edge_type, node_type, node_feature_extra):
        '''
        x: dimension of nodes [num_nodes, embd_dim]
        edge_index: [2,num_Edges]
        edge_type: [1,num_edges]
        node_type: [1,num_nodes]
        node_feature_extra [N, dim]
        '''
#         print("x ", x.shape)
        # here use the one-hot vector to represent each edge
        # unlike compgcn use basis vector to represent, here use one-hot
        edge_vec = make_one_hot(edge_type, self.num_edges + 1)# [num_edges, 39]
        # create a shape likes [num_node, num_edges]
        self_edge_vec = torch.zeros(x.shape[0], self.num_edges + 1).to(edge_vec.device)
        # for last edge vector
        self_edge_vec[:, self.num_edges] = 1
        # get head, tail embedding
        
        #------------------------------ what is node_type here? source or target ? or dataset own node type?
        # ----------------------------- need to be checked
        head_type = node_type[edge_index[0]]# head is source
        tail_type = node_type[edge_index[1]]# tail is target
        head_vec = make_one_hot(head_type, self.num_node_type) #[E, 4]
        tail_vec = make_one_hot(tail_type, self.num_node_type)# [E, 4]
        headtail_vec = torch.cat([head_vec, tail_vec], dim=1)# [E, 8]
        self_head_vec = make_one_hot(node_type, self.num_node_type)#[N, 4]
        self_head_tail_vec = torch.cat([self_head_vec, self_head_vec], dim = 1)# [N,8]
        
        edge_vec = torch.cat([edge_vec, self_edge_vec], dim = 0)# [E+N, ?]
        headtail_vec = torch.cat([headtail_vec, self_head_tail_vec], dim = 0)# [E+N,?]
        edge_embeddings = self.edge_encoder(torch.cat([edge_vec, headtail_vec], dim=1)) #[E+N, emb_dim]
        
        # add self loop
        loop_idx = torch.arange(0, x.shape[0], dtype= torch.long, device=edge_index.device)
        loop_idx = loop_idx.unsqueeze(0).repeat(2,1)
#         print("loop index", loop_idx.shape)
#         print("edge_embedding", edge_embeddings.shape)
        edge_index = torch.cat([edge_index, loop_idx], dim = 1)
        x = torch.cat([x, node_feature_extra], dim = 1)
#         print("x", x.shape)
        x = (x,x)
        aggr_out = self.propagate(edge_index, x = x, edge_attr = edge_embeddings)# [N, embd_dim]
        output = self.mlp(aggr_out)
        return output
    
    def message(self, edge_index, x_i, x_j, edge_attr):
        '''
        This message function process the message
        '''
        # print ("edge_attr.size()", edge_attr.size()) #[E, emb_dim]
        # print ("x_j.size()", x_j.size()) #[E, emb_dim]
        # print ("x_i.size()", x_i.size()) #[E, emb_dim]
        assert len(edge_attr.size()) == 2
        assert edge_attr.size(1) == self.embd_dim
        assert x_i.size(1) == x_j.size(1) == 2*self.embd_dim
        assert x_i.size(0) == x_j.size(0) == edge_attr.size(0) == edge_index.size(1)
        # for key, query, value
        # (E, heads, dim)
#         print(x_i.shape, edge_attr.shape)
        key = self.ln_key(torch.cat([x_i,edge_attr], dim = 1)).view(-1, self.num_heads, self.dim_per_head)
        msg = self.ln_msg(torch.cat([x_j, edge_attr], dim = 1)).view(-1, self.num_heads, self.dim_per_head)
        query = self.ln_query(x_j).view(-1, self.num_heads, self.dim_per_head)
        
        # calculate softmax
        query = query / math.sqrt(self.dim_per_head)
        scores = (query * key).sum(2)# shape is [E, num_heads]
        # select source index
        src_node_index = edge_index[0]
        # group by source node index, and calculate score
        alpha = softmax(scores, src_node_index)
       # print("attention score", alpha.shape)
        num_edges = edge_index.shape[1]
        num_nodes = int(src_node_index.max()) + 1
       # print("num edges", src_node_index.shape)
        #print("num edges", num_edges)
        # 我不知道为什么这里根据 out degree调整了 attention score
        # adjust by outgoing degree of src
        # this operation get the out degree from source node
        ones = torch.full((num_edges,), 1.0, dtype=torch.float).to(edge_index.device)
        # shape (E,)
        src_node_Edge_count = scatter(ones, src_node_index, dim=0, dim_size= num_nodes,
                                     reduce="sum")[src_node_index]
       #print("out degree shape", src_node_Edge_count.shape)
        alpha = alpha * src_node_Edge_count.unsqueeze(1)# [E, num_heads]
        
        out = msg * alpha.view(-1, self.num_heads, 1)
        return out.view(-1, self.num_heads * self.dim_per_head)#[E, embd_dim]

In [9]:
class MatrixVectorScaledDotProductAttention(nn.Module):

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, q, k, v, mask=None):
        """
        q: tensor of shape (n*b, d_k)
        k: tensor of shape (n*b, l, d_k)
        v: tensor of shape (n*b, l, d_v)
        returns: tensor of shape (n*b, d_v), tensor of shape(n*b, l)
        """
        attn = (q.unsqueeze(1) * k).sum(2)  # (n*b, l)
        attn = attn / self.temperature
        if mask is not None:
            attn = attn.masked_fill(mask, -np.inf)
        attn = self.softmax(attn)
        attn = self.dropout(attn)
        output = (attn.unsqueeze(2) * v).sum(1)
        return output, attn
class MultiheadAttPoolLayer(nn.Module):

    def __init__(self, n_head, d_q_original, d_k_original, dropout=0.1):
        super().__init__()
        assert d_k_original % n_head == 0  # make sure the outpute dimension equals to d_k_origin
        self.n_head = n_head
        self.d_k = d_k_original // n_head
        self.d_v = d_k_original // n_head

        self.w_qs = nn.Linear(d_q_original, n_head * self.d_k)
        self.w_ks = nn.Linear(d_k_original, n_head * self.d_k)
        self.w_vs = nn.Linear(d_k_original, n_head * self.d_v)

        nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_q_original + self.d_k)))
        nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_k_original + self.d_k)))
        nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_k_original + self.d_v)))

        self.attention = MatrixVectorScaledDotProductAttention(temperature=np.power(self.d_k, 0.5))
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, mask=None):
        """
        q: tensor of shape (b, d_q_original)
        k: tensor of shape (b, l, d_k_original)
        mask: tensor of shape (b, l) (optional, default None)
        returns: tensor of shape (b, n*d_v)
        """
        n_head, d_k, d_v = self.n_head, self.d_k, self.d_v

        bs, _ = q.size()
        bs, len_k, _ = k.size()

        qs = self.w_qs(q).view(bs, n_head, d_k)  # (b, n, dk)
        ks = self.w_ks(k).view(bs, len_k, n_head, d_k)  # (b, l, n, dk)
        vs = self.w_vs(k).view(bs, len_k, n_head, d_v)  # (b, l, n, dv)

        qs = qs.permute(1, 0, 2).contiguous().view(n_head * bs, d_k)
        ks = ks.permute(2, 0, 1, 3).contiguous().view(n_head * bs, len_k, d_k)
        vs = vs.permute(2, 0, 1, 3).contiguous().view(n_head * bs, len_k, d_v)

        if mask is not None:
            mask = mask.repeat(n_head, 1)
        output, attn = self.attention(qs, ks, vs, mask=mask)

        output = output.view(n_head, bs, d_v)
        output = output.permute(1, 0, 2).contiguous().view(bs, n_head * d_v)  # (b, n*dv)
        output = self.dropout(output)
        return output, attn

In [10]:
class MLP(nn.Module):
    """
    Multi-layer perceptron
    Parameters
    ----------
    num_layers: number of hidden layers
    """
    activation_classes = {'gelu': nn.GELU, 'relu': nn.ReLU, 'tanh': nn.Tanh}

    def __init__(self, input_size, hidden_size, output_size, num_layers, dropout, batch_norm=False,
                 init_last_layer_bias_to_zero=False, layer_norm=False, activation='gelu'):
        super().__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.batch_norm = batch_norm
        self.layer_norm = layer_norm

        assert not (self.batch_norm and self.layer_norm)

        self.layers = nn.Sequential()
        for i in range(self.num_layers + 1):
            n_in = self.input_size if i == 0 else self.hidden_size
            n_out = self.hidden_size if i < self.num_layers else self.output_size
            self.layers.add_module(f'{i}-Linear', nn.Linear(n_in, n_out))
            if i < self.num_layers:
                self.layers.add_module(f'{i}-Dropout', nn.Dropout(self.dropout))
                if self.batch_norm:
                    self.layers.add_module(f'{i}-BatchNorm1d', nn.BatchNorm1d(self.hidden_size))
                if self.layer_norm:
                    self.layers.add_module(f'{i}-LayerNorm', nn.LayerNorm(self.hidden_size))
                self.layers.add_module(f'{i}-{activation}', self.activation_classes[activation.lower()]())
        if init_last_layer_bias_to_zero:
            self.layers[-1].bias.data.fill_(0)

    def forward(self, input_):
        return self.layers(input_)

In [11]:
class QAGNN(nn.Module):
    '''
    Layer of GNN
    '''
    def __init__(self, args, k, n_node_type, n_edge_type, sent_dim, n_concept, concept_dim, concept_in_dim, 
                n_attention_head, fc_dim, num_fc_layer, drop_embd, drop_gnn,drop_fc,
                 pretrained_concept_emb=None,freeze_ent_emb=True,init_range=0.02):
        '''
        params:
            1. args, extra args
            2. k, num_layers
            3. n_node_type, num of node type
            4. n_edge_type, num of edge type
            5. sent_dim, sentence dimension
            6. n_concept, num of concept
            7. concept_dim, concept out dimension
            8. concepte_in_dim, concept in dimension
            9. n_attention_head, num of attention head for multi-head-attention
            10. fc_dim, linear transform dimension
            11. num_fc_layer, linear transform layer number
            12. drop_embd, dropout rate of embd
            13. drop_gnn, dropout rate of gnn
            14. drop_fc, dropout rate of fc
        '''
        super().__init__()
        print("freeze_ent_emb", freeze_ent_emb)
        self.init_range = init_range
#         print("concept_in_dim", concept_in_dim)
#         print("concept_out_dim", concept_dim)
        self.concept_embd = Customized_Embd(num_concept=n_concept, concept_in_dim= concept_in_dim, concept_out_dim = concept_dim,
                                           use_contextualized=False, pretrained_concept_emb = pretrained_concept_emb,
                                            freeze_net_=freeze_ent_emb)
        # project sentence dimension to concept dimension
        self.ln1 = nn.Linear(sent_dim, concept_dim)
        self.act = nn.GELU()
        self.concept_dim = concept_dim
        
        #define gnn layer
        self.gnn = QAGNN_message_passing(args, k = k, node_type= n_node_type, edge_type= n_edge_type,
                                        input_size=concept_dim, hidden_size= concept_dim, output_size= concept_dim, 
                                         dropout_rate= drop_gnn)
        # pooler for sentence 
        self.pooler = MultiheadAttPoolLayer(n_attention_head, sent_dim, concept_dim)
        # MLP
        self.mlp = MLP(concept_dim + sent_dim + concept_dim, fc_dim, 1, num_fc_layer, drop_fc,
                      layer_norm= True)
        self.dropout_ebd = nn.Dropout(drop_embd)
        self.dropout_fc = nn.Dropout(drop_fc)
        if init_range > 0 :
            self.apply(self.init_weights)
    
    def init_weights(self,module):
        if isinstance(module, (nn.Embedding, nn.Linear)):
            module.weight.data.normal_(mean=0.0, std=self.init_range)
            if hasattr(module, 'bias') and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
    
    def forward(self, sent_vecs, concept_ids, node_type_ids, node_scores, adj_length,
               adj, embd_data = None, cache_output = False):
        '''
        params:
            1. sent_vecs: sentence embedding (batch_size, dim_sent)
            2. concepts_ids: (batch_size, num_nodes)
            3. adj: (edge_index, edge_type)
            4. adj_length:(batch_size ,)
            5. node_type_ids: (batch_size, num_nodes)
            6. node_socres: (batch_size, num_nodes, 1)
        return:
            (batch_size, 1)
        '''
        # project dimnsion of sentence to node_dim
        # shape (batch_size, 1, node_dim)
        # which is contextualize part
        sent2node_vecs = self.act(self.ln1(sent_vecs)).unsqueeze(1)
        # get concepts ids, remove the context part, get the embedding 
        # (bs, num_node-1, dim_node)
        gnn_input1 = self.concept_embd(concept_ids[:,1:]-1, embd_data).to(node_type_ids.device)
        # concat the context embd z and others embd
        gnn_input = self.dropout_ebd(torch.cat([sent2node_vecs, gnn_input1], dim = 1))   
        
        
        #------------------- normalize the node score ---------------------------
        # create mask
        # 0 means masked out [bs, num_node]
        mask_ = (torch.arange(node_scores.size(1), device = node_scores.device) < adj_length.unsqueeze(1)).float()
        node_scores = -node_scores
        # 避免梯度爆炸, 此前已排序
        node_scores = node_scores - node_scores[:,0:1,:]
        # [bs, num_node]
        node_scores = node_scores.squeeze(2)
        # get the mask scores, means only calculate scores for each batch
        node_scores *= mask_
        mean_norm = (torch.abs(node_scores)).sum(dim = 1)/ adj_length
        # [bs, n_node]
        node_scores = node_scores / (mean_norm.unsqueeze(1) + 1e-05)
        node_scores = node_scores.unsqueeze(2)#[bs, n_node, 1]
        
        gnn_output = self.gnn(gnn_input, adj, node_type_ids, node_scores)
        # extract updated Z vectors in paper
        z_vecs = gnn_output[:,0]
        mask = torch.arange(node_type_ids.size(1), device=node_type_ids.device) >= adj_length.unsqueeze(1) #1 means masked out
        # mask the non-this-batch node and context node
        mask = mask | (node_type_ids == 3)
        mask[mask.all(1),0] = 0
        sent_vecs_for_pooler = sent_vecs
        # finall use sentence as query, gnn_output as key-value pair with mask
        graph_vecs, pool_attn = self.pooler(sent_vecs_for_pooler, gnn_output, mask)
        
        # finally concat vector 
        concat = self.dropout_fc(torch.cat((graph_vecs, sent_vecs, z_vecs), dim = 1))
        # pass mlp to update
        logits = self.mlp(concat)
        return logits, pool_attn

In [12]:
class LM_QAGNN(nn.Module):
    def __init__(self, args, model_name, k, num_node_type, num_edge_type, 
                num_concept, concept_dim, concept_in_dim, n_attention_head, fc_dim,
                num_fc_layer, drop_embd, drop_gnn, drop_fc, pretrained_concept_emb=None,freeze_ent_emb = True,
                init_range = 0.0):
        super().__init__()
        self.encoder = TextEncoder(model_name)
        self.decoder = QAGNN(args, k, num_node_type, num_edge_type, self.encoder.output_size,
                            num_concept,concept_dim, concept_in_dim, n_attention_head, fc_dim,
                            num_fc_layer, drop_embd, drop_gnn, drop_fc, pretrained_concept_emb,
                             freeze_ent_emb, init_range)
        
    def forward(self, *inputs, layer_id=-1, cache_output=False, detail=False):
        """
        sent_vecs: (batch_size, num_choice, d_sent)    -> (batch_size * num_choice, d_sent)
        concept_ids: (batch_size, num_choice, n_node)  -> (batch_size * num_choice, n_node)
        node_type_ids: (batch_size, num_choice, n_node) -> (batch_size * num_choice, n_node)
        adj_lengths: (batch_size, num_choice)          -> (batch_size * num_choice, )
        adj -> edge_index, edge_type
            edge_index: list of (batch_size, num_choice) -> list of (batch_size * num_choice, ); each entry is torch.tensor(2, E(variable))
                                                         -> (2, total E)
            edge_type:  list of (batch_size, num_choice) -> list of (batch_size * num_choice, ); each entry is torch.tensor(E(variable), )
                                                         -> (total E, )
        returns: (batch_size, 1)
        """
        bs, nc = inputs[0].size(0), inputs[0].size(1)

        #Here, merge the batch dimension and the num_choice dimension
        edge_index_orig, edge_type_orig = inputs[-2:]
        _inputs = [x.view(x.size(0) * x.size(1), *x.size()[2:]) for x in inputs[:-6]] + [x.view(x.size(0) * x.size(1), *x.size()[2:]) for x in inputs[-6:-2]] + [sum(x,[]) for x in inputs[-2:]]

        *lm_inputs, concept_ids, node_type_ids, node_scores, adj_lengths, edge_index, edge_type = _inputs
#         print(lm_inputs[0].shape)
#         print(lm_inputs[1].shape)
#         print(lm_inputs[2].shape)
#         print(lm_inputs[3].shape)
#         print("concept ids shape ", concept_ids.shape)
#         print("node_type_ids ", node_type_ids.shape)
#         print("adj_lenght ", adj_lengths.shape)
        edge_index, edge_type = self.batch_graph(edge_index, edge_type, concept_ids.size(1))
        adj = (edge_index.to(node_type_ids.device), edge_type.to(node_type_ids.device)) #edge_index: [2, total_E]   edge_type: [total_E, ]

        sent_vecs, all_hidden_states = self.encoder(*lm_inputs, layers_id=layer_id)
        logits, attn = self.decoder(sent_vecs.to(node_type_ids.device),
                                    concept_ids,
                                    node_type_ids, node_scores, adj_lengths, adj,
                                    embd_data=None)
        logits = logits.view(bs, nc)
        
        return logits, attn
    def batch_graph(self, edge_index_init, edge_type_init, n_nodes):
        #edge_index_init: list of (n_examples, ). each entry is torch.tensor(2, E)
        #edge_type_init:  list of (n_examples, ). each entry is torch.tensor(E, )
        n_examples = len(edge_index_init)
        edge_index = [edge_index_init[_i_] + _i_ * n_nodes for _i_ in range(n_examples)]
        edge_index = torch.cat(edge_index, dim=1) #[2, total_E]
        edge_type = torch.cat(edge_type_init, dim=0) #[total_E, ]
        return edge_index, edge_type

## create dataloader
Create class has data info

In [13]:
class InputExample(object):
    '''
    EXAMPLE class
    '''
    def __init__(self, example_id, question, contexts, endings, label=None):
        self.example_id = example_id
        self.question = question
        self.contexts = contexts
        self.endings = endings
        self.label = label

class InputFeatures(object):
    def __init__(self, example_id, choices_features, label):
            self.example_id = example_id
            self.choices_features = [
                {
                    'input_ids': input_ids,
                    'input_mask': input_mask,
                    'segment_ids': segment_ids,
                    'output_mask': output_mask,
                }
                for _, input_ids, input_mask, segment_ids, output_mask in choices_features
            ]
            self.label = label

def read_examples(input_files):
    with open(input_files, "r", encoding="utf-8") as f:
        examples = []
        for line in f.readlines():
            json_dic = json.loads(line)
            #print(json_dic)
            label = ord(json_dic["answerKey"]) - ord("A") if 'answerKey' in json_dic else 0
            contexts = json_dic["question"]["stem"]
            if "para" in json_dic:
                contexts = json_dic["para"] + " " + contexts
            if "fact1" in json_dic:
                contexts = json_dic["fact1"] + " " + contexts
            examples.append(
            InputExample(
                    example_id=json_dic["id"],
                    contexts=[contexts] * len(json_dic["question"]["choices"]),
                    question="",
                    endings=[ending["text"] for ending in json_dic["question"]["choices"]],
                    label=label
                ))
        return examples

def _truncate_seq_pair(tokens_a, tokens_b, max_length):
        """Truncates a sequence pair in place to the maximum length."""

        # This is a simple heuristic which will always truncate the longer sequence
        # one token at a time. This makes more sense than truncating an equal percent
        # of tokens from each, since if one sequence is very short then each token
        # that's truncated likely contains more information than a longer sequence.
        while True:
            total_length = len(tokens_a) + len(tokens_b)
            if total_length <= max_length:
                break
            if len(tokens_a) > len(tokens_b):
                tokens_a.pop()
            else:
                tokens_b.pop()
    
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer,
                               cls_token='[CLS]',
                                     cls_token_segment_id=0,
                                     sep_token='[SEP]',
                                     sequence_a_segment_id=0,
                                     sequence_b_segment_id=1,
                                     sep_token_extra=False,
                                     pad_token_segment_id=0,
                                     pad_token=0,
                                     mask_padding_with_zero=True):
        label_map = {label: i for i, label in enumerate(label_list)}
        print(label_map)
        features = []
        for ex_index, example in enumerate(tqdm(examples)):
            choices_features = []
            for ending_idx, (context, ending) in enumerate(zip(example.contexts, example.endings)):
                tokens_a = tokenizer.tokenize(context)
                tokens_b = tokenizer.tokenize(example.question + " " + ending)

                special_tokens_count = 4 if sep_token_extra else 3
                _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - special_tokens_count)
                tokens = tokens_a + [sep_token]
                if sep_token_extra:
                    # roberta uses an extra separator b/w pairs of sentences
                    tokens += [sep_token]

                segment_ids = [sequence_a_segment_id] * len(tokens)

                if tokens_b:
                    tokens += tokens_b + [sep_token]
                    segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1)
                tokens = [cls_token] + tokens
                segment_ids = [cls_token_segment_id] + segment_ids
                # take input ids
                input_ids = tokenizer.convert_tokens_to_ids(tokens)
                input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
                special_token_id = tokenizer.convert_tokens_to_ids([cls_token, sep_token])
                output_mask = [1 if id in special_token_id else 0 for id in input_ids]  # 1 for mask
                
                # padding for 0 to sequence length
                padding_length = max_seq_length - len(input_ids)
                input_ids = input_ids + ([pad_token] * padding_length)
                input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
                output_mask = output_mask + ([1] * padding_length)
                segment_ids = segment_ids + ([pad_token_segment_id] * padding_length)
                
                assert len(input_ids) == max_seq_length
                assert len(output_mask) == max_seq_length
                assert len(input_mask) == max_seq_length
                assert len(segment_ids) == max_seq_length
                choices_features.append((tokens, input_ids, input_mask, segment_ids, output_mask))
            label = label_map[example.label]
            features.append(InputFeatures(example_id=example.example_id, choices_features=choices_features, label=label))

        return features
    
def select_field(features, field):
        return [[choice[field] for choice in feature.choices_features] for feature in features]

def convert_features_to_tensors(features):
    all_input_ids = torch.tensor(select_field(features, 'input_ids'), dtype=torch.long)
    all_input_mask = torch.tensor(select_field(features, 'input_mask'), dtype=torch.long)
    all_segment_ids = torch.tensor(select_field(features, 'segment_ids'), dtype=torch.long)
    all_output_mask = torch.tensor(select_field(features, 'output_mask'), dtype=torch.bool)
    all_label = torch.tensor([f.label for f in features], dtype=torch.long)
    print(all_input_ids.shape)
    print(all_input_mask.shape)
    print(all_segment_ids.shape)
    print(all_output_mask.shape)
    print(all_label.shape)
    return all_input_ids, all_input_mask, all_segment_ids, all_output_mask, all_label

In [14]:
def load_tensor(statement_jsonl_path, model_name = "roberta-large", max_seq_lenth = 128):
    examples = read_examples(statement_jsonl_path)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    features = convert_examples_to_features(examples, list(range(len(examples[0].endings))),
                                           max_seq_lenth, tokenizer=tokenizer,cls_token=tokenizer.cls_token,
                                                sep_token=tokenizer.sep_token,
                                                sep_token_extra=True,
                                            sequence_b_segment_id=0,
                                           )
    # convert feature to tensor
    example_ids = [f.example_id for f in features]
    *data_tensors, all_label = convert_features_to_tensors(features)
    return (example_ids, all_label, *data_tensors)

In [15]:
def load_sparse_adj_data_with_contextnode(adj_pk_path, max_node_num, num_choice, args):
    cache_path = adj_pk_path +'.loaded_cache'
    use_cache = True

    if use_cache and not os.path.exists(cache_path):
        use_cache = False

    if use_cache:
        with open(cache_path, 'rb') as f:
            adj_lengths_ori, concept_ids, node_type_ids, node_scores, adj_lengths, edge_index, edge_type, half_n_rel = pickle.load(f)
    ori_adj_mean  = adj_lengths_ori.float().mean().item()
    ori_adj_sigma = np.sqrt(((adj_lengths_ori.float() - ori_adj_mean)**2).mean().item())
    print('| ori_adj_len: mu {:.2f} sigma {:.2f} | adj_len: {:.2f} |'.format(ori_adj_mean, ori_adj_sigma, adj_lengths.float().mean().item()) +
          ' prune_rate： {:.2f} |'.format((adj_lengths_ori > adj_lengths).float().mean().item()) +
          ' qc_num: {:.2f} | ac_num: {:.2f} |'.format((node_type_ids == 0).float().sum(1).mean().item(),
                                                      (node_type_ids == 1).float().sum(1).mean().item()))

    edge_index = list(map(list, zip(*(iter(edge_index),) * num_choice))) #list of size (n_questions, n_choices), where each entry is tensor[2, E] #this operation corresponds to .view(n_questions, n_choices)
    edge_type = list(map(list, zip(*(iter(edge_type),) * num_choice))) #list of size (n_questions, n_choices), where each entry is tensor[E, ]

    concept_ids, node_type_ids, node_scores, adj_lengths = [x.view(-1, num_choice, *x.size()[1:]) for x in (concept_ids, node_type_ids, node_scores, adj_lengths)]
    #concept_ids: (n_questions, num_choice, max_node_num)
    #node_type_ids: (n_questions, num_choice, max_node_num)
    #node_scores: (n_questions, num_choice, max_node_num)
    #adj_lengths: (n_questions,　num_choice)
    return concept_ids, node_type_ids, node_scores, adj_lengths, (edge_index, edge_type) #, half_n_rel * 2 + 1


In [16]:
# *train_decoder,adj = load_sparse_adj_data_with_contextnode("./data/csqa/graph/train.graph.adj.pk",200,5,None)

In [17]:
# input_file = "./data/csqa/statement/train.statement.jsonl"
# ids,labels,*encoder_data = load_tensor(input_file)

In [18]:
class LM_QAGNN_DataLoader(object):

    def __init__(self, args, train_statement_path, train_adj_path,
                 dev_statement_path, dev_adj_path,
                 test_statement_path, test_adj_path,
                 batch_size, eval_batch_size, device, model_name, max_node_num=200, max_seq_length=128,
                 is_inhouse=False, inhouse_train_qids_path=None,
                 subsample=1.0, use_cache=True):
        '''
        params:
            train_statement_path: train pat
            train_adj_path: train (edge index, edge type)
            dev and test are the same as train
            batch_size: given size of batch
            device: gpu or not
            model_name: name of model
            
        '''
        super().__init__()
        self.args = args
        self.batch_size = batch_size
        self.eval_batch_size = eval_batch_size
        self.device0, self.device1 = device
#         self.device = device
        self.is_inhouse = is_inhouse
        
        self.model_name = model_name
        print ('train_statement_path', train_statement_path)
        self.train_qids, self.train_labels, *self.train_encoder_data = load_tensor(train_statement_path,
                                                                                  model_name,max_seq_length)
        self.dev_qids, self.dev_labels, *self.dev_encoder_data = load_tensor(dev_statement_path,
                                                                                  model_name,max_seq_length)
        # input shape is [num_question, num_choices, dim_size]
        self.num_choice = self.train_encoder_data[0].shape[1]
        
        # then load the adjacency matrix
        *self.train_decoder_data, self.train_adj_data = load_sparse_adj_data_with_contextnode(train_adj_path,max_node_num,self.num_choice,args)
        *self.dev_decoder_data, self.dev_adj_data = load_sparse_adj_data_with_contextnode(dev_adj_path, max_node_num, self.num_choice, args)
        assert all(len(self.train_qids) == len(self.train_adj_data[0]) == x.size(0) for x in [self.train_labels] + self.train_encoder_data + self.train_decoder_data)
        assert all(len(self.dev_qids) == len(self.dev_adj_data[0]) == x.size(0) for x in [self.dev_labels] + self.dev_encoder_data + self.dev_decoder_data)
        
        # whether use in house test
        if self.is_inhouse:
            with open(inhouse_train_qids_path, 'r') as fin:
                inhouse_qids = set(line.strip() for line in fin)
            self.inhouse_train_indexes = torch.tensor([i for i, qid in enumerate(self.train_qids) if qid in inhouse_qids])
            self.inhouse_test_indexes = torch.tensor([i for i, qid in enumerate(self.train_qids) if qid not in inhouse_qids])
        
        
        # sub sample:
        assert 0. < subsample <= 1.
        if subsample < 1.:
            n_train = int(self.train_size() * subsample)
            assert n_train > 0
            if self.is_inhouse:
                self.inhouse_train_indexes = self.inhouse_train_indexes[:n_train]
            else:
                self.train_qids = self.train_qids[:n_train]
                self.train_labels = self.train_labels[:n_train]
                self.train_encoder_data = [x[:n_train] for x in self.train_encoder_data]
                self.train_decoder_data = [x[:n_train] for x in self.train_decoder_data]
                self.train_adj_data = self.train_adj_data[:n_train]
                assert all(len(self.train_qids) == len(self.train_adj_data[0]) == x.size(0) for x in [self.train_labels] + self.train_encoder_data + self.train_decoder_data)
            assert self.train_size() == n_train
    
    
    def train_size(self):
        return self.inhouse_train_indexes.size(0) if self.is_inhouse else len(self.train_qids)

    def dev_size(self):
        return len(self.dev_qids)
    
    def train(self):
        if self.is_inhouse:
            n_train = self.inhouse_train_indexes.size(0)
            train_indexes = self.inhouse_train_indexes[torch.randperm(n_train)]
        else:
            train_indexes = torch.randperm(len(self.train_qids))
        return MultiGPUSparseAdjDataBatchGenerator(self.args, 'train', self.device0, self.device1, self.batch_size, train_indexes, self.train_qids, self.train_labels, tensors0=self.train_encoder_data, tensors1=self.train_decoder_data, adj_data=self.train_adj_data)
    def dev(self):
        return MultiGPUSparseAdjDataBatchGenerator(self.args, 'eval', self.device0, self.device1, self.batch_size, torch.arange(len(self.dev_qids)), self.dev_qids, self.dev_labels, tensors0=self.dev_encoder_data, tensors1=self.dev_decoder_data, adj_data=self.dev_adj_data)

In [19]:
class MultiGPUSparseAdjDataBatchGenerator(object):
    '''
    This class include all the info that will be needed when training
    '''
    def __init__(self, args, mode, device0, device1, batch_size, indexes, qids, labels,
                 tensors0=[], lists0=[], tensors1=[], lists1=[], adj_data=None):
        super().__init__()
        self.args = args
        self.mode = mode
        self.device0 = device0
        self.device1 = device1
        self.batch_size = batch_size
        self.indexes = indexes
        self.qids = qids
        self.labels = labels
        self.tensors0 = tensors0
        self.lists0 = lists0
        self.tensors1 = tensors1
        self.lists1 = lists1
        self.adj_data = adj_data

    def __len__(self):
        return (self.indexes.size(0) - 1)// self.batch_size +1

    def __iter__(self):
        '''
        This function just like the training iterator
        '''
        bs = self.batch_size
        n = self.indexes.size(0)
        remain = n % bs
        if remain >0:
            extra = np.random.choice(self.indexes[:-remain], size = (bs - remain), replace = False)
            self.indexes = torch.cat([self.indexes, torch.tensor(extra)])
            n = self.indexes.size(0)
        for a in range(0,n,bs):
            b = min(n, a+bs)
            batch_indexes = self.indexes[a:b]
            batch_qids = [self.qids[idx] for idx in batch_indexes]
            batch_labels = self._to_device(self.labels[batch_indexes], self.device1)
            batch_tensors0 = [self._to_device(x[batch_indexes], self.device0) for x in self.tensors0]
            batch_tensors1 = [self._to_device(x[batch_indexes], self.device1) for x in self.tensors1]
            batch_lists0 = [self._to_device([x[i] for i in batch_indexes], self.device0) for x in self.lists0]
            batch_lists1 = [self._to_device([x[i] for i in batch_indexes], self.device1) for x in self.lists1]
            # print all LM data:
#             print("batch qids", batch_qids, len(batch_qids))
#             print("batch_labels",batch_labels, len(batch_labels))
#             print("batch_tensors0", batch_tensors0, len(batch_tensors0))
#             print("batch_tensor1", batch_tensors1, len(batch_tensors1))
#             print("batch_list0", batch_lists0, len(batch_lists0))
#             print("batch_list1", batch_lists1, len(batch_lists1))
            edge_index_all, edge_type_all = self.adj_data
            #edge_index_all: nested list of shape (n_samples, num_choice), where each entry is tensor[2, E]
            #edge_type_all:  nested list of shape (n_samples, num_choice), where each entry is tensor[E, ]
            edge_index = self._to_device([edge_index_all[i] for i in batch_indexes], self.device1)
            edge_type  = self._to_device([edge_type_all[i] for i in batch_indexes], self.device1)
#             print(edge_index, len(edge_index))
#             print(edge_type, len(edge_type))
            yield tuple([batch_qids, batch_labels, *batch_tensors0, *batch_lists0, *batch_tensors1, *batch_lists1, edge_index, edge_type])
    def _to_device(self, obj, device):
        if isinstance(obj, (tuple, list)):
            return [self._to_device(item, device) for item in obj]
        else:
            return obj.to(device)

In [20]:
class Config(object):
    def __init__(self, mode, num_relation, train_adj_path, dev_adj_path,train_statement_path, dev_statement_path,
                 num_attention_head,
                k, gnn_dim, fc_dim, num_fc_layer,freeze_,max_node_num, init_range, dropout_embd,
                dropout_gnn, dropout_fc,model_name):
        # for data path
        self.mode = mode
        self.num_relation = num_relation
        self.train_adj_path = train_adj_path
        self.dev_adj_path = dev_adj_path
        self.train_statement_path = train_statement_path
        self.dev_statement_path = dev_statement_path
        
        # for model
        # language model
        self.model_name = model_name
        self.gnn_dim = gnn_dim
        self.num_attention_head = num_attention_head
        self.fc_dim = fc_dim
        self.num_fc_layer = num_fc_layer
        self.k = k
        self.max_node_num = max_node_num
        self.freeze_ = freeze_
        self.init_range = init_range
        
        # for regularization
        self.dropout_fc = dropout_fc
        self.dropout_embd = dropout_embd
        self.dropout_gnn = dropout_gnn
        
        self.max_seq_length = 128
        self.encoder_lr = 1e-5
        self.decoder_lr = 1e-3
        self.batch_size = 4
        self.load_model_path = True
        self.load_model_path_ = "./data/checkpoint.pt"
        self.weight_decay = 0.1
        self.n_epochs = 15
        self.cuda = True

In [None]:
config = Config(mode = 'train',num_relation=38, train_adj_path= "./data/needed_data/train.graph.adj.pk", 
               dev_adj_path="./data/needed_data/dev.graph.adj.pk", train_statement_path="./data/needed_data/train.statement.jsonl",
               dev_statement_path="./data/needed_data/dev.statement.jsonl", num_attention_head=2,
               k=5, gnn_dim=100, fc_dim=200, num_fc_layer=0, freeze_=True, max_node_num=200, init_range=0.02,
               dropout_embd=0.2, dropout_fc=0.2, dropout_gnn = 0.2,model_name="./data/needed_data/pretrained_data/")

In [None]:
dataset = LM_QAGNN_DataLoader(config, config.train_statement_path, config.train_adj_path,
                             config.dev_statement_path, config.dev_adj_path,None,None,
                             batch_size=config.batch_size, eval_batch_size= config.batch_size,
                             model_name= config.model_name,max_node_num=config.max_node_num,max_seq_length=config.max_seq_length,
                             device = (torch.device("cpu"),torch.device("cpu")))

In [None]:
# # load pretrained data
# cp_embd = [np.load(path) for path in ["./data/tzw.ent.npy"]]
# cp_embd= torch.tensor(np.concatenate(cp_embd, 1), dtype=torch.float)
# concept_num, concept_dim = cp_embd.shape[0], cp_embd.shape[1]
# print('| num_concepts: {} |'.format(concept_num))

In [None]:
# model = LM_QAGNN(args=config, model_name= config.model_name, k = config.k, num_node_type= 4,
#                 num_edge_type=config.num_relation, num_concept=concept_num, concept_dim=config.gnn_dim,
#                 concept_in_dim=1024, n_attention_head=config.num_attention_head, fc_dim=config.fc_dim,
#                 num_fc_layer=config.num_fc_layer,pretrained_concept_emb=cp_embd, drop_fc= config.dropout_fc,
#                 drop_gnn = config.dropout_gnn, drop_embd = config.dropout_embd,freeze_ent_emb=False,
#                 init_range=config.init_range)

In [None]:
# for qids, labels, *input_data in dataset.train():
#     for a in range(0, config.batch_size, 2):
#         print("a:", a)
#         b = min(a + 2, config.batch_size)
#         logits, _ = model(*[x[a:b] for x in input_data])
#         print(logits)
#         print(labels[a:b])
#         break
#     break

In [None]:
def train(args):
    # load concept embedding
    model_path = "./checkpoint.pt"
    cp_embd = [np.load(path) for path in ["./data/needed_data/tzw.ent.npy"]]
    cp_embd= torch.tensor(np.concatenate(cp_embd, 1), dtype=torch.float)
    concept_num, concept_dim = cp_embd.shape[0], cp_embd.shape[1]
    print('| num_concepts: {} |'.format(concept_num))
    if torch.cuda.device_count() >= 2 and args.cuda:
            device0 = torch.device("cuda:0")
            device1 = torch.device("cuda:1")
    elif torch.cuda.device_count() == 1 and args.cuda:
        device0 = torch.device("cuda:0")
        device1 = torch.device("cuda:0")
    else:
        device0 = torch.device("cpu")
        device1 = torch.device("cpu")
    print("train on device " + str(device0) + " " + str(device1))
    device0 = device1 = torch.device("cpu")
    dataset = LM_QAGNN_DataLoader(config, config.train_statement_path, config.train_adj_path,
                             config.dev_statement_path, config.dev_adj_path,None,None,
                             batch_size=config.batch_size, eval_batch_size= config.batch_size,
                             model_name= config.model_name,max_node_num=config.max_node_num,max_seq_length=config.max_seq_length,
                             device = (device0,device1))
    print(dataset.train_size()//args.batch_size)
    model = LM_QAGNN(args=config, model_name= config.model_name, k = config.k, num_node_type= 4,
                num_edge_type=config.num_relation, num_concept=concept_num, concept_dim=config.gnn_dim,
                concept_in_dim=1024, n_attention_head=config.num_attention_head, fc_dim=config.fc_dim,
                num_fc_layer=config.num_fc_layer,pretrained_concept_emb=cp_embd, drop_fc= config.dropout_fc,
                drop_gnn = config.dropout_gnn, drop_embd = config.dropout_embd,freeze_ent_emb=False,
                init_range=config.init_range)
    if args.load_model_path:
        print (f'loading and initializing model from {args.load_model_path}')
        model_state_dict, old_args,loss_ls = torch.load(args.load_model_path_, map_location=torch.device('cpu'))
        model.load_state_dict(model_state_dict)
    model.encoder.to(device0)
    model.decoder.to(device1)
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    grouped_parameters = [
            {'params': [p for n, p in model.encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.encoder_lr},
            {'params': [p for n, p in model.encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.encoder_lr},
            {'params': [p for n, p in model.decoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.decoder_lr},
            {'params': [p for n, p in model.decoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.decoder_lr},
        ]
    #set optimizer
    optimizer = torch.optim.Adam(grouped_parameters)
    # set warm up step
    from transformers import get_linear_schedule_with_warmup
    max_steps = int(args.n_epochs * (dataset.train_size() / args.batch_size))
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=15, num_training_steps=max_steps)
    print('parameters:')
    for name, param in model.decoder.named_parameters():
        if param.requires_grad:
            print('\t{:45}\ttrainable\t{}\tdevice:{}'.format(name, param.size(), param.device))
        else:
            print('\t{:45}\tfixed\t{}\tdevice:{}'.format(name, param.size(), param.device))
    num_params = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad)
    print('\ttotal:', num_params)
    # define loss
    loss = nn.CrossEntropyLoss(reduction = "mean")
    # start training
    freeze_net(model.encoder)
    max_norm = 1000
    ls = []
    # for validation
    model.eval()
    n_samples, n_correct = 0, 0
    for qids, labels, *input_data in tqdm(dataset.dev()):
        start_time = time.time()
        logits, _ = model(*[x for x in input_data])
        l = loss(logits, labels)
        n_correct += (logits.argmax(1) == labels).sum().item()
        n_samples += labels.size(0)
    print("validation accuracy:", n_correct/n_samples)
    #------------------------------------------------ for training `------------------------------------------
    for epoch_id in range(args.n_epochs):
        model.train()
        count = 0
        for qids, labels, *input_data in dataset.train():
            start_time = time.time()
            optimizer.zero_grad()
#             for a in range(0, config.batch_size, 2):
#                 print("a:", a)
#                 b = min(a + 2, config.batch_size)
#                 logits, _ = model(*[x[a:b] for x in input_data])
#                 print(logits.shape)
#                 # calculat loss
#                 l = loss(logits, labels[a:b])
#                 # backward
#                 l.backward()
#                 print("done one train")
            #[print(x, x.shape) for x in input_data]
            logits, _ = model(*[x for x in input_data])
#             print(logits.shape)
            # calculat loss
            l = loss(logits, labels)
            # backward
            l.backward()
#             print("done one train")
            nn.utils.clip_grad_norm_(model.parameters(), max_norm)
            optimizer.step()
            with torch.no_grad():
                print("done train epochs ", epoch_id)
                if count % 30 == 0 or count == dataset.train_size() - 1:
                    ls.append(l.item())
                    with open("train_log","a") as f:
                        f.write("Epoch %s, Loss %s, Batch %s: %.4f sec\n"%(epoch_id, ls ,count, time.time() - start_time))
                count +=1
                #save model per each epochs
            torch.save([model.state_dict(), args, ls], f"{model_path}")
            print("save to ", model_path) 

In [None]:
train(config)

In [21]:
dic = torch.load("./checkpoint.pt")[0]

In [22]:
dic

OrderedDict([('encoder.model.embeddings.position_ids',
              tensor([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
                        14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
                        28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
                        42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
                        56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
                        70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
                        84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
                        98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
                       112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
                       126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
               