In [38]:
import torch
# from torch_geometric.nn import TGNMemory, TransformerConv
torch.set_printoptions(threshold=10_000)
torch.cuda.is_available()

True

In [39]:
import torch
print(torch.__version__)

In [40]:
from torch_geometric.nn import  TransformerConv
from torch_geometric.utils.convert import from_scipy_sparse_matrix

In [41]:
printflag = True
printflag1 = True
torch.cuda.device_count()

1

In [42]:
torch.cuda.current_device()

0

In [43]:
torch.cuda.get_device_name(0)

'NVIDIA GeForce GTX 1660 Ti with Max-Q Design'

In [44]:
##Layer

import math
import torch
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
import torch.nn.functional as F
from torch import nn


class GraphConvolution(Module):
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, inputs, adj, global_W = None):
        if len(adj._values()) == 0:
            zeros = torch.zeros(adj.shape[0], self.out_features, device=inputs.device, dtype=self.weight.dtype)
            return zeros

        support = torch.spmm(inputs, self.weight)
        if global_W is not None:
            support = torch.spmm(support, global_W)
        output = torch.spmm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'


class SelfAttention(Module):
    """docstring for SelfAttention"""
    def __init__(self, in_features, idx, hidden_dim):
        super(SelfAttention, self).__init__()
        self.idx = idx
        self.linear = torch.nn.Linear(in_features, hidden_dim)
        # self.leakyrelu = nn.LeakyReLU(0.2)
        self.leakyrelu = F.leaky_relu
        self.a = Parameter(torch.FloatTensor(2 * hidden_dim, 1))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.a.size(1))
        self.a.data.uniform_(-stdv, stdv)

    def forward(self, inputs):
        # inputs size:  node_num * 3 * in_features
        x = self.linear(inputs).transpose(0, 1)
        self.n = x.size()[0]
        x = torch.cat([x, torch.stack([x[self.idx]] * self.n, dim=0)], dim=2)
        U = torch.matmul(x, self.a).transpose(0, 1)
        U = self.leakyrelu(U)
        weights = F.softmax(U, dim=1)
        outputs = torch.matmul(weights.transpose(1, 2), inputs).squeeze(1) * 3
        return outputs, weights
    

    
class TransforerBlock(torch.nn.Module):
    def __init__(self, in_features_list, out_features,  bias=True, gamma = 0.1):
        super().__init__()
        self.ntype = len(in_features_list)
        self.in_features_list = in_features_list
        self.out_features = out_features
        self.weights: nn.ParameterList = nn.ParameterList()
        
        for i in range(self.ntype):
            cache = Parameter(torch.FloatTensor(in_features_list[i], out_features))
            nn.init.xavier_normal_(cache.data, gain=1.414)
            self.weights.append( cache )
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
            stdv = 1. / math.sqrt(out_features)
            self.bias.data.uniform_(-stdv, stdv)
        else:
            self.register_parameter('bias', None)

        self.att_list: nn.ModuleList = nn.ModuleList()
        for i in range(self.ntype):
            self.att_list.append(Transformer_InfLevel(out_features, out_features))


        #self.conv = TransformerConv(in_channels, out_channels // 2, heads=2,
        #                            dropout=0.1, edge_dim=edge_dim)

    def forward(self, inputs_list, adj_list, global_W = None):
        
        print("inputs_list: ", inputs_list[0].size(), "Length ", len(inputs_list))
        print("adj List: ",len(adj_list) )
        print("self.ntype: ", self.ntype)
        print("self.in_features_list: ",  self.in_features_list)
        print("self.out_features: ", self.out_features)
        print("self.weights: ", self.weights) 
        
        
        h = []
        for i in range(self.ntype):
            h.append( torch.spmm(inputs_list[i], self.weights[i]) )
        if global_W is not None:
            for i in range(self.ntype):
                h[i] = (torch.spmm(h[i], global_W))
        outputs = []
        for t1 in range(self.ntype):
            x_t1 = []
            for t2 in range(self.ntype):
                if len(adj_list[t1][t2]._values()) == 0:
                    zeros = torch.zeros(adj_list[t1][t2].shape[0], self.out_features, device=self.bias.device, dtype=self.weights[0].dtype)
                    x_t1.append( zeros )
                    continue
                if self.bias is not None:
                    x_t1.append( self.att_list[t1](h[t1], h[t2], adj_list[t1][t2]) + self.bias )
                else:
                    print("h_size: ", h.size(), adj.size())
                    x_t1.append( self.att_list[t1](h[t1], h[t2], adj_list[t1][t2]) )
            outputs.append(x_t1)
            
        return outputs

class Transformer_InfLevel(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Transformer_InfLevel, self).__init__()

        self.a1 = nn.Parameter(torch.zeros(size=(out_channels, 1)))
        self.a2 = nn.Parameter(torch.zeros(size=(out_channels, 1)))
        nn.init.xavier_normal_(self.a1.data, gain=1.414)
        nn.init.xavier_normal_(self.a2.data, gain=1.414)

        self.leakyrelu = nn.LeakyReLU(0.2, )
        self.conv = TransformerConv(in_channels, out_channels)#.cuda()
        # self.conv.cuda()


    def forward(self, input1, input2, adj):

        h = input1
        g = input2
        N = h.size()[0]
        M = g.size()[0]
        
        print("h_size: ", h.size())
        print("g_size: ", g.size())
        print("adj: ", adj.size())
        

        e1 = torch.matmul(h, self.a1).repeat(1, 32)
        # e2 = torch.matmul(g, self.a2).repeat(1, N).t()
        
        
        
        #e1 = torch.matmul(h, self.a1)
        #e2 = torch.matmul(g, self.a2)
        
        print("e1_size: ", e1.size())
        #print("e2_size: ", e2.size())
        
        #e = e1 + e2    
        #e = torch.matmul(e, g)
        e = self.leakyrelu(e1)
        a = adj.to_dense().nonzero().t().contiguous().long()
        
        
        
        #a = a.cuda()
        #e = e.cuda()
        print(a.is_cuda, e.is_cuda)     
        print("e_size", e.size(), "a_size", a.size())
        
        h_prime = self.conv(e, a)

            

        # print("e1_size: ", e1.size())
        # print("e2_size: ", e2.size())
        #print("e_size: ", e.size())
        #print("self.a1_size" , self.a1.size())
        #print("self.a2_size", self.a2.size())
        #print("adj: ", adj.size())     
        print("h_prime: ", h_prime.size())
        #print("self.in_channels ", self.in_channels)
        #print("self.out_channels ", self.out_channels)
            
        return h_prime


class GraphAttentionConvolution(Module):
    def __init__(self, in_features_list, out_features, bias=True, gamma = 0.1):
        super(GraphAttentionConvolution, self).__init__()
        self.ntype = len(in_features_list)
        self.in_features_list = in_features_list
        self.out_features = out_features
        self.weights: nn.ParameterList = nn.ParameterList()
        
        
        for i in range(self.ntype):
            cache = Parameter(torch.FloatTensor(in_features_list[i], out_features))
            nn.init.xavier_normal_(cache.data, gain=1.414)
            self.weights.append( cache )
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
            stdv = 1. / math.sqrt(out_features)
            self.bias.data.uniform_(-stdv, stdv)
        else:
            self.register_parameter('bias', None)
        
        self.att_list: nn.ModuleList = nn.ModuleList()
        for i in range(self.ntype):
            self.att_list.append( Attention_InfLevel(out_features, gamma) )


    def forward(self, inputs_list, adj_list, global_W = None):
        global printflag1
               
        # print("inputs_list: ", inputs_list[0].size(), "Length ", len(inputs_list))
        # print("adj List: ",len(adj_list) )
        # print("self.ntype: ", self.ntype)
        # print("self.in_features_list: ",  self.in_features_list)
        # print("self.out_features: ", self.out_features)
        # print("self.weights: ", self.weights) 
        
        h = []
        for i in range(self.ntype):
            h.append( torch.spmm(inputs_list[i], self.weights[i]) )
        if global_W is not None:
            for i in range(self.ntype):
                h[i] = ( torch.spmm(h[i], global_W) )
        outputs = []
        for t1 in range(self.ntype):
            x_t1 = []
            for t2 in range(self.ntype):
                if len(adj_list[t1][t2]._values()) == 0:
                    zeros = torch.zeros(adj_list[t1][t2].shape[0], self.out_features, device=self.bias.device, dtype=self.weights[0].dtype)
                    x_t1.append( zeros )
                    continue
                if self.bias is not None:
                    x_t1.append( self.att_list[t1](h[t1], h[t2], adj_list[t1][t2]) + self.bias )
                else:
                    x_t1.append( self.att_list[t1](h[t1], h[t2], adj_list[t1][t2]) )
       
            outputs.append(x_t1)
        return outputs

class Attention_InfLevel(nn.Module):
    def __init__(self, dim_features, gamma):
        super(Attention_InfLevel, self).__init__()

        self.dim_features = dim_features
        self.a1 = nn.Parameter(torch.zeros(size=(dim_features, 1)))
        self.a2 = nn.Parameter(torch.zeros(size=(dim_features, 1)))
        nn.init.xavier_normal_(self.a1.data, gain=1.414)
        nn.init.xavier_normal_(self.a2.data, gain=1.414)        

        self.leakyrelu = nn.LeakyReLU(0.2, )
        self.gamma = gamma

    
    def forward(self, input1, input2, adj):
        global printflag
        # adj = adj.coalesce()
        h = input1
        g = input2
        N = h.size()[0]
        M = g.size()[0]
        
       

        e1 = torch.matmul(h, self.a1).repeat(1, M)
        e2 = torch.matmul(g, self.a2).repeat(1, N).t()
        e = e1 + e2  
        e = self.leakyrelu(e)
        a = adj.to_dense()
        zero_vec = -9e15 * torch.ones_like(e)
        attention = torch.where(a > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)
        attention = torch.mul(attention, a.sum(1).repeat(M, 1).t())
        attention = torch.add(attention * self.gamma, a * (1 - self.gamma))
        del zero_vec

        h_prime = torch.matmul(attention, g)

        #print("h_size: ", h.size())
        #print("g_size: ", g.size())
        #print("e1_size:", e1.size())
        #print("e2_size: ", e2.size())
        #print("e_size: ", e.size())
        #print("self.a1_size" , self.a1.size())
        #print("self.a2_size", self.a2.size())
        #print("adj: ", adj.size())     
        #print("h_prime: ", h_prime.size())
        #print("self.dim_features ", self.dim_features)
    
        
        return h_prime



from transformers import BertModel

class Bert_Model(nn.Module):
    def __init__(self, hidden_dimension, embedding_dimension):
        super(Bert_Model, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.out = nn.Linear(hidden_dimension, embedding_dimension)
    def forward(self, input):
        _, output = self.bert(**input)
        out = self.out(output)
        return out

    
class Transformer_Encoder(nn.Module):
    def __init__(self, hidden_dimension, embedding_dimension):
        super(Transformer_Encoder, self).__init__()
        
        encoder_layers = nn.TransformerEncoderLayer(embedding_dimension, 2)
        # print("hidden_dimension: ", hidden_dimension, ", embedding_dimension: ", embedding_dimension, ", encoder_layers: ", encoder_layers)
        self.encoder = nn.TransformerEncoder(encoder_layers, 4)
        
    def forward(self, input, x):
        out = self.encoder(input)
        # print("out: ", out.size())
        out = out.mean(1)
        return out
    
    




class LstmEncoder(Module):
    def __init__(self, hidden_dimension, embedding_dimension):
        super(LstmEncoder, self).__init__()
        self.hidden_dim = hidden_dimension
        self.lstm = nn.LSTM(embedding_dimension, hidden_dimension, batch_first=True)

    def forward(self, embeds, seq_lens):
        print("embeds: ", embeds.size(), "seq_lens", seq_lens)
        _, idx_sort = torch.sort(seq_lens, dim=0, descending=True)
        _, idx_unsort = torch.sort(idx_sort, dim=0)
        lens = list(seq_lens[idx_sort])
        selected_dim = 0
        x = embeds.index_select(selected_dim, idx_sort)
        rnn_input = nn.utils.rnn.pack_padded_sequence(x, lens, batch_first=True)
        rnn_output, (ht, ct) = self.lstm(rnn_input)
        ht = ht[-1].index_select(selected_dim, idx_unsort)
        print(rnn_input, ht.size())
        return ht  # bs * hidden_dim

class AttentionPooling(nn.Module):
    def __init__(self, params):
        super(AttentionPooling, self).__init__()
        self.params = params
        hidden_dimension = self.params.node_emb_dim // 2
        self.w = nn.Linear(self.params.node_emb_dim, hidden_dimension)
        self.a = nn.Linear(hidden_dimension, 1)
        self.leakyrelu = nn.LeakyReLU(0.2)

    def forward(self, X, dim=0, keepdim=True):
        '''
        :param X:           A tensor with shape:  D * H
        :return:            A tensor with shape:  1 * H (dim = 0)
        '''
        a = self.w(X)
        a = self.leakyrelu(a)
        a = self.a(a)         # D * 1
        a = torch.softmax(a, dim=dim)
        return torch.matmul(a.t(), X)



In [45]:
## model

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
#from models.layer import *
from torch.nn.parameter import Parameter
# from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from functools import reduce
import pickle as pkl
from torch.nn.modules.module import Module

class HGAT(nn.Module):
    def __init__(self, params):
        super(HGAT, self).__init__()
        self.para_init()
        self.attention = True
        self.lower_attention = True
        
        
        # self.nonlinear = nn.LeakyReLU(0.2)
        self.nonlinear = F.relu_
        nfeat_list = [params.hidden_dim] * ({0: 1, 1: 2, 2: 2, 3: 3}[params.node_type])
        
        self.ntype = len(nfeat_list)
        nhid = params.node_emb_dim
        # print("nfeat_list:", nfeat_list, " nhid: " , nhid)   

        self.gc2: nn.ModuleList = nn.ModuleList()
        if not self.lower_attention:
            # print("Inside If Not lower_attention------------1")
            self.gc1: nn.ModuleList = nn.ModuleList()
            for t in range(self.ntype):
                self.gc1.append( GraphConvolution(nfeat_list[t], nhid, bias=False) )
                self.bias1 = Parameter( torch.FloatTensor(nhid) )
                stdv = 1. / math.sqrt(nhid)
                self.bias1.data.uniform_(-stdv, stdv)
        else:
            # print("Executing GraphAttentionConvolution------------1", nfeat_list, nhid)
            self.gc1 = GraphAttentionConvolution(nfeat_list, nhid, gamma=0.1)
            # self.gc1 = TransforerBlock(nfeat_list, nhid, gamma=0.1)  
            
        if self.attention:
            
            self.at1: nn.ModuleList = nn.ModuleList()
            for t in range(self.ntype):
                # print("Executing GraphAttention------------1", " t: " , t, " nhid: " , nhid )
                self.at1.append( SelfAttention(nhid, t, nhid // 2) )
        self.dropout = nn.Dropout(params.dropout)

    def para_init(self):
        print("Para Init Invoked")
        self.attention = False
        self.lower_attention = False

    def forward(self, x_list, adj_list, adj_all = None):
        # print("ntype: ", self.ntype)
        
        x0 = x_list

        if not self.lower_attention:
            # print("Inside If Not lower_attention------------2")
            x1 = [None for _ in range(self.ntype)]

            for t1 in range(self.ntype):
                x_t1 = []
                for t2 in range(self.ntype):
                    idx = t2
                    print("t1: ", t1, " t2: ", idx)
                    
                    x_t1.append( self.gc1[idx](x0[t2], adj_list[t1][t2]) + self.bias1 )
                if self.attention:
                    x_t1, weights = self.at1[t1]( torch.stack(x_t1, dim=1) )
                else:
                    x_t1 = reduce(torch.add, x_t1)
                x_t1 = self.dropout(self.nonlinear(x_t1))
                x1[t1] = x_t1
        else:
            # print("Executing GraphAttentionConvolution------------2")
            x1 = [None for _ in range(self.ntype)]
            x1_in = self.gc1(x0, adj_list)
            for t1 in range(len(x1_in)):
                x_t1 = x1_in[t1]
                if self.attention:
                    # print("Executing GraphAttention------------2")
                    x_t1, weights = self.at1[t1]( torch.stack(x_t1, dim=1) )
                else:
                    x_t1 = reduce(torch.add, x_t1)
                x_t1 = self.dropout(self.nonlinear(x_t1))
                x1[t1] = x_t1

        return x1


    def inference(self, x_list, adj_list, adj_all = None):
        return self.forward(x_list, adj_list, adj_all)


class TextEncoder(Module):
    def __init__(self, params):
        super(TextEncoder, self).__init__()
        
        if(params.encoder == 0):
            self.lstm = LstmEncoder(params.hidden_dim, params.emb_dim)
        if(params.encoder == 1):
            self.lstm = Transformer_Encoder(params.hidden_dim, params.emb_dim)

    def forward(self, embeds, seq_lens):
        return self.lstm(embeds, seq_lens)

class EntityEncoder(Module):
    def __init__(self, params):
        super(EntityEncoder, self).__init__()
        
        if(params.encoder == 0):
            self.lstm = LstmEncoder(params.hidden_dim, params.emb_dim)
        if(params.encoder == 1):
            self.lstm = Transformer_Encoder(params.hidden_dim, params.emb_dim)

        self.gating = GatingMechanism(params)

    def forward(self, embeds, seq_lens, Y):
        X = self.lstm(embeds, seq_lens)
        return self.gating(X, Y)

class Pooling(nn.Module):
    def __init__(self, params):
        super(Pooling, self).__init__()
        self.mode = params.pooling
        self.params = params
        if self.mode == 'max':
            self.pooling = torch.max
        elif self.mode == 'sum':
            self.pooling = torch.sum
        elif self.mode == 'mean':
            self.pooling = torch.mean
        elif self.mode == 'att':
            self.pooling = AttentionPooling(self.params)
        else:
            raise Exception("Unknown pooling mode: {}. (Supported: max, sum, mean, att)".format(self.mode))

    def forward(self, X, sentPerDoc):
        '''
        :param X:           A tensor with shape:  (D1 + D2 + ... + Dn) * H
        :param sentPerDoc:  A tensor with values: [D1, D2, ..., Dn]
        :return:            A tensor with shape:  n * H
        '''
        # weight = [torch.ones((1, i.item()), device=sentPerDoc.device) for i in sentPerDoc]
        # weight = block_diag([m.to_sparse() for m in weight]).to_dense()
        sentPerDoc = sentPerDoc.cpu().numpy().tolist()
        sents = [X[sum(sentPerDoc[: i]): sum(sentPerDoc[: i+1])] for i in range(len(sentPerDoc))]
        output = []
        for s in sents:
            if s.shape[0] == 0:
                output.append(torch.zeros((1, s.shape[1]), device=s.device, dtype=X.dtype))
            else:
                cache = self.pooling(s, dim=0, keepdim=True)
                output.append(cache[0] if isinstance(cache, tuple) else cache)
        output = torch.cat(output, dim=0)
        return output

class ConcatTransform(nn.Module):
    def __init__(self, params):
        super(ConcatTransform, self).__init__()
        self.params = params
        self.preW = nn.Linear(self.params.hidden_dim, self.params.node_emb_dim)
        self.postW = nn.Linear(self.params.node_emb_dim * 2, self.params.node_emb_dim)
        self.dropout = nn.Dropout(self.params.dropout, )

    def forward(self, X: torch.FloatTensor, Y: torch.FloatTensor):
        '''
        :param X:   shape: (N, node_emb_dim)
        :param Y:   shape: (N, hidden_dim)
        :return:    shape: (N, node_emb_dim)
        '''
        Y = self.preW(self.dropout(Y))                # (N, node_emb_dim)
        concatVector = torch.cat([X, Y], dim=1)            # (N, 2 * node_emb_dim)
        concatVector = self.postW(self.dropout(concatVector))
        return concatVector   # (N, node_emb_dim)

class MatchingTransform(nn.Module):
    def __init__(self, params):
        super(MatchingTransform, self).__init__()
        self.params = params
        self.SIMPLE = True
        self.preW = nn.Linear(self.params.hidden_dim, self.params.node_emb_dim)
        self.postW = nn.Linear(self.params.node_emb_dim * (2 if self.SIMPLE else 4), self.params.node_emb_dim)
        # self.nonlinear = nn.LeakyReLU(0.2)
        self.dropout = nn.Dropout(self.params.dropout, )

    def forward(self, X: torch.FloatTensor, Y: torch.FloatTensor):
        '''
        :param X:   shape: (N, node_emb_dim)
        :param Y:   shape: (N, hidden_dim)
        :return:    shape: (N, node_emb_dim)
        '''
        Y = self.preW(self.dropout(Y))                # (N, node_emb_dim)
        # if self.SIMPLE:     matchingVector = torch.cat([X - Y, X.mul(Y)], dim=1)            # (N, 2 * node_emb_dim)
        if self.SIMPLE:     matchingVector = torch.cat([X - Y, X.mul(Y)], dim=1)            # (N, 2 * node_emb_dim)
        else:               matchingVector = torch.cat([X, Y, X - Y, X.mul(Y)], dim=1)      # (N, 4 * node_emb_dim)
        matchingVector = self.postW(self.dropout(matchingVector))
        return matchingVector   # (N, node_emb_dim)

class GatingMechanism(nn.Module):
    def __init__(self, params):
        super(GatingMechanism, self).__init__()
        self.params = params
        with open(self.params.entity_tran, 'rb') as f:
            transE_embedding = pkl.load(f)
        self.enti_tran = nn.Embedding.from_pretrained(torch.from_numpy(transE_embedding).float())
        entity_num = transE_embedding.shape[0]


        self.gate_theta = Parameter(torch.empty(entity_num, self.params.hidden_dim))
        nn.init.xavier_uniform_(self.gate_theta)

        # self.dropout = nn.Dropout(self.params.dropout)

    def forward(self, X: torch.FloatTensor, Y: torch.LongTensor):

        gate = torch.sigmoid(self.gate_theta[Y])
        Y = self.enti_tran(Y)
        output = torch.mul(gate, X) + torch.mul(-gate + 1, Y)
        return output


## Commented below by sumeet

# if __name__ == '__main__':
#     from main import parse_arguments
#     GatingMechanism(parse_arguments())
    

In [46]:
## Classifier

#!/user/bin/env python
# -*- coding: utf-8 -*-
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from functools import reduce
# from models.model import HGAT, TextEncoder, EntityEncoder, Pooling, MatchingTransform, GatingMechanism
import pickle as pkl
from torch.nn.modules.module import Module

class Classifier(nn.Module):
    def __init__(self, params, vocab_size, pte=None):
        super(Classifier, self).__init__()
        self.params = params
        self.vocab_size = vocab_size
        self.pte = False if pte is None else True

        self.text_encoder = TextEncoder(params)
        self.enti_encoder = EntityEncoder(params)
        # numOfEntity = 100000
        # self.enti_encoder = nn.Embedding(numOfEntity, params.hidden_dim)
        # nn.init.xavier_uniform_(self.enti_encoder.weight)
        self.topi_encoder = nn.Embedding(100, 100)
        self.topi_encoder.from_pretrained(torch.eye(100))
        self.match_encoder = MatchingTransform(params)
        # self.match_encoder = ConcatTransform(params)   # 参数试验用的
        self.word_embeddings = nn.Embedding(vocab_size, params.emb_dim)
        if pte is None:
            nn.init.xavier_uniform_(self.word_embeddings.weight)
        else:
            self.word_embeddings.weight.data.copy_(torch.from_numpy(pte))
        # KB Field

        # with open(self.params.entity_tran, 'rb') as f:
        #     transE_embedding = pkl.load(f)
        # self.enti_tran = nn.Embedding.from_pretrained(torch.from_numpy(transE_embedding))

        self.model = HGAT(params)
        self.pooling = Pooling(params)
        self.classifier_sen = nn.Linear(params.node_emb_dim, params.ntags)
        self.classifier_ent = nn.Linear(params.node_emb_dim, params.ntags)

        self.dropout = nn.Dropout(params.dropout, )

        # entity_num = transE_embedding.shape[0]
        # self.gating = GatingMechanism(params) # 这个要放在最后面，尽量少影响随机初始化

    # def forward(self, x_list, adj_list, sentPerDoc, entPerDoc=None):
    def forward(self, documents, ent_desc, doc_lens, ent_lens, adj_lists, feature_lists, sentPerDoc, entiPerDoc=None):
        x_list = []
        embeds_docu = self.word_embeddings(documents)   # sents * max_seq_len * emb
        d = self.text_encoder(embeds_docu, doc_lens)    # sents * max_seq_len * hidden
        d = self.dropout(F.relu_(d))                     # Relu activation and dropout
        x_list.append(d)
        if self.params.node_type == 3 or self.params.node_type == 2:
            embeds_enti = self.word_embeddings(ent_desc)    # sents * max_seq_len * emb
            e = self.enti_encoder(embeds_enti, ent_lens, feature_lists[1])    # sents * max_seq_len * hidden
            e = self.dropout(F.relu_(e))                     # Relu activation and dropout
            x_list.append(e)
        if self.params.node_type == 3 or self.params.node_type == 1:
            t = self.topi_encoder(feature_lists[-1])         # tops * hidden
            x_list.append(t)

        X = self.model(x_list, adj_lists)

        X_s = self.pooling(X[0], sentPerDoc)   # 选择句子的部分
        output = self.classifier_sen(X_s)

        if entiPerDoc is not None:
            # E_trans = self.enti_tran(feature_lists[1])
            E_GCN = X[1]
            # E_KB = self.gating(x_list[1], feature_lists[1])
            E_KB = x_list[1]
            X_e = self.match_encoder(E_GCN, E_KB)  # 选择实体的部分
            X_e = self.pooling(X_e, entiPerDoc)
            X_e = self.classifier_ent(X_e)
            output += X_e
        output = F.softmax(output, dim=1)       # 单分类
        # output = torch.sigmoid(output)        # 多分类
        return output


if __name__ == '__main__':
    pass

In [47]:
## Evaluator

import torch, json
# from models import Classifier
from tqdm import tqdm
import numpy as np
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
# from util import Utils

class Evaluator:
    def __init__(self, params, utils, data_loader):
        self.params = params
        self.utils = utils
        self.data_loader = data_loader

    def get_sentences_from_indices(self, docs):
        actual_sentences = []
        for doc, sent_lens in docs:
            sentences = []
            for i, sent in enumerate(doc):
                sentences.append(' '.join([self.data_loader.i2w[int(wid)] for wid in sent[:sent_lens[i]]]))
            actual_sentences.append(sentences)
        return actual_sentences

    def _evaluate_aux(self, model, data_loader):
        hits = 0
        total = 0
        all_actual = None
        all_predicted = None
        for inputs in tqdm(data_loader):
            with torch.no_grad():
                try:
                    documents, ent_desc, doc_lens, ent_lens, y_batch, adj_lists, feature_lists, sentPerDoc, entiPerDoc = \
                        [self.utils.to_gpu(i, self.params.cuda and torch.cuda.is_available()) for i in inputs]
                    total += sentPerDoc.shape[0]
                    logits = model(documents, ent_desc, doc_lens, ent_lens, adj_lists, feature_lists, sentPerDoc, entiPerDoc)
                    predicted = torch.argmax(logits, dim=1)
                    hits += torch.sum(predicted == y_batch).item()
                    all_predicted = predicted.cpu().data.numpy() if all_predicted is None \
                        else np.concatenate((all_predicted, predicted.cpu().data.numpy()))
                    labels = y_batch.cpu().numpy()
                    all_actual = labels if all_actual is None else np.concatenate((all_actual, labels))
                except RuntimeError as e:
                    if 'out of memory' in str(e).lower():
                        # outOfMemory += 1
                        continue
                    else:
                        print(e)
                        exit()
                except Exception as e:
                    print(e)
                    exit()
        accuracy = hits / total
        return accuracy, all_actual, all_predicted

    def evaluate(self):
        print(json.dumps(vars(self.params), indent=2))

        model: torch.nn.Module = Classifier(self.params, vocab_size=len(self.data_loader.w2i), pte=None)
        if self.utils.HALF:
            model.half()
        if torch.cuda.is_available():
            model = model.cuda()
        # Load the model weights
        
        current_model_dict = model.state_dict()
        loaded_state_dict = torch.load("ckpt/" + params.model_file, map_location=lambda storage, loc: storage)
        new_state_dict={k:v if v.size()==current_model_dict[k].size()  else  current_model_dict[k] for k,v in zip(current_model_dict.keys(), loaded_state_dict.values())}
        model.load_state_dict(new_state_dict, strict=False)


        #model.load_state_dict(torch.load("ckpt/" + self.params.model_file, map_location=lambda storage, loc: storage)) 

        model.eval()

        # This dataset is only available for the binary classifier
        if self.params.ntags == 2:
            accuracy, all_actual, all_predicted = self._evaluate_aux(model, self.data_loader.test_data_loader)
            prec_mac, recall_mac, f1_mac, _ = precision_recall_fscore_support(all_actual, all_predicted, average='macro')
            prec_mic, recall_mic, f1_mic, _ = precision_recall_fscore_support(all_actual, all_predicted, average='micro')
            print("Accuracy on the OOD test set 1: {:.4f}".format(accuracy))
            print("Precision on the OOD test set 1 macro / micro: {:.4f}, {:.4f}".format(prec_mac, prec_mic))
            print("Recall on the OOD test set 1 macro / micro: {:.4f}, {:.4f}".format(recall_mac, recall_mic))
            print("F1 on the OOD test set 1 macro / micro: {:.4f}, {:.4f}".format(f1_mac, f1_mic))
            print("Latex: {:5.2f} & {:5.2f} & {:5.2f} & {:5.2f}".format(accuracy*100, prec_mac*100, recall_mac*100, f1_mac*100))
            print("----------------------------------------------------------------------")

        accuracy, all_actual, all_predicted = self._evaluate_aux(model, self.data_loader.test_data_loader_2)
        prec_mac, recall_mac, f1_mac, _ = precision_recall_fscore_support(all_actual, all_predicted, average='macro')
        prec_mic, recall_mic, f1_mic, _ = precision_recall_fscore_support(all_actual, all_predicted, average='micro')
        print("Accuracy on the OOD test set 2: {:.4f}".format(accuracy))
        print("Precision on the OOD test set 2 macro / micro: {:.4f}, {:.4f}".format(prec_mac, prec_mic))
        print("Recall on the OOD test set 2 macro / micro: {:.4f}, {:.4f}".format(recall_mac, recall_mic))
        print("F1 on the OOD test set 2 macro / micro: {:.4f}, {:.4f}".format(f1_mac, f1_mic))
        print("Latex: {:5.2f} & {:5.2f} & {:5.2f} & {:5.2f}".format(accuracy * 100, prec_mac * 100, recall_mac * 100, f1_mac * 100))

        #ascii if self.params.ntags == 4:
        #     results = confusion_matrix(all_actual, all_predicted)
        #     df_cm = pd.DataFrame(results, index=[i for i in ["Satire", "Hoax", "Propaganda", "Trusted"]],
        #                          columns=[i for i in ["Satire", "Hoax", "Propaganda", "Trusted"]])
        #     sns_plot = sn.heatmap(df_cm, annot=True, fmt='g')
        #     plt.yticks(rotation=45)
        #     sns_plot.get_figure().savefig('plots/cm.png')
        # 
        # print("----------------------------------------------------------------------")
        # accuracy, all_actual, all_predicted = self._evaluate_aux(model, self.data_loader.dev_data_loader)
        # prec_mac, recall_mac, f1_mac, _ = precision_recall_fscore_support(all_actual, all_predicted, average='macro')
        # prec_mic, recall_mic, f1_mic, _ = precision_recall_fscore_support(all_actual, all_predicted, average='micro')
        # print("Accuracy on the dev set: {:.4f}".format(accuracy))
        # print("Precision on the dev set macro / micro: {:.4f}, {:.4f}".format(prec_mac, prec_mic))
        # print("Recall on the dev macro / micro: {:.4f}, {:.4f}".format(recall_mac, recall_mic))
        # print("F1 on the dev macro / micro: {:.4f}, {:.4f}".format(f1_mac, f1_mic))
        # print("Latex: {:5.2f} & {:5.2f} & {:5.2f} & {:5.2f}".format(accuracy * 100, prec_mac * 100, recall_mac * 100, f1_mac * 100))




In [48]:
## Util

from timeit import default_timer as timer
import numpy as np
from tqdm import tqdm
# from models import Classifier
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib

matplotlib.use('Agg')
import matplotlib.pyplot as plt

class Utils:
    def __init__(self, params, dl):
        self.params = params
        self.data_loader = dl
        self.HALF = params.HALF

    @staticmethod
    def to_half(arr):
        if arr is None:
            return arr
        if isinstance(arr, list) or isinstance(arr, tuple):
            return [Utils.to_half(a) for a in arr]
        elif isinstance(arr, torch.FloatTensor) or isinstance(arr, torch.sparse.FloatTensor):
            return arr.half()
        else:
            return arr

    def to_gpu(self, arr, cuda):
        if self.params.HALF:
            arr = Utils.to_half(arr)

        if not cuda or arr is None:
            return arr
        if isinstance(arr, list) or isinstance(arr, tuple):
            return [self.to_gpu(a, cuda) for a in arr]
        else:
            try:
                return arr.cuda()
            except:
                return arr
        # elif isinstance(arr[0], int) or isinstance(arr[0], float):
        #     return arr
        # else:
        #     raise TypeError("Unknown type of input of 'utils.py/Utils/to_gpu': {}.".format(type(arr)))

    def get_dev_loss_and_acc(self, model, loss_fn):
        losses = []; hits = 0; total = 0; outOfMemoryCnt = 0
        model.eval()
        for inputs in tqdm(self.data_loader.dev_data_loader):
            # torch.cuda.empty_cache()
            with torch.no_grad():
                try:
                    documents, ent_desc, doc_lens, ent_lens, y_batch, adj_lists, feature_lists, sentPerDoc, entiPerDoc = \
                        [self.to_gpu(i, self.params.cuda and torch.cuda.is_available()) for i in inputs]
                    logits = model(documents, ent_desc, doc_lens, ent_lens, adj_lists, feature_lists, sentPerDoc, entiPerDoc)
                    loss = loss_fn(logits, y_batch)
                    hits += torch.sum(torch.argmax(logits, dim=1) == y_batch).item()
                    total += sentPerDoc.shape[0]
                    losses.append(loss.item())
                except RuntimeError as e:
                    if 'out of memory' in str(e).lower():
                        # outOfMemory += 1
                        continue
                    else:
                        print(e)
                        exit()
                except Exception as e:
                    print(e)
                    exit()
        if outOfMemoryCnt > 0:
            print("outOfMemoryCnt when validating: ", outOfMemoryCnt)
        return np.asscalar(np.mean(losses)), hits / total

    def train(self, save_plots_as, pretrained_emb=None):
        model: nn.Module = Classifier(self.params, vocab_size=len(self.data_loader.w2i), pte=pretrained_emb)
        if self.params.HALF:
            model.half()
        loss_fn = torch.nn.CrossEntropyLoss()
        if self.params.cuda:
            model = model.cuda()
        # optimizer = optim.Adam(model.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay)
        optimizer = optim.Adam(model.parameters(), lr=self.params.lr, weight_decay=self.params.weight_decay, eps=1e-4)

        # Variables for plotting
        train_losses, dev_losses, train_accs, dev_accs = [], [], [], []
        s_t = timer()
        prev_best = 0
        patience = 0
        outOfMemory = 0
        # Start the training loop
        for epoch in range(1, self.params.max_epochs + 1):
            model.train()
            train_loss, hits, total = 0, 0, 0
            # train_data_loader = self.data_loader.train_data_loader \
            #                     if not self.params.DEBUG else self.data_loader.test2_data_loader
            for inputs in tqdm(self.data_loader.train_data_loader):
                # torch.cuda.empty_cache()
                try:
                    documents, ent_desc, doc_lens, ent_lens, y_batch, adj_lists, feature_lists, sentPerDoc, entiPerDoc = \
                                                        [self.to_gpu(i, self.params.cuda and torch.cuda.is_available()) for i in inputs]
                    total += sentPerDoc.shape[0]
                    logits = model(documents, ent_desc, doc_lens, ent_lens, adj_lists, feature_lists, sentPerDoc, entiPerDoc)
                    if torch.isnan(logits).any():
                        print('stop here')
                        # model(documents, ent_desc, doc_lens, ent_lens, adj_lists, feature_lists, sentPerDoc, entiPerDoc)
                    loss = loss_fn(logits, y_batch)
                    # Book keeping
                    train_loss += loss.item()
                    hits += torch.sum(torch.argmax(logits, dim=1) == y_batch).item()
                    # Back-prop
                    optimizer.zero_grad()  # Reset the gradients
                    loss.backward()  # Back propagate the gradients
                    optimizer.step()  # Update the network

                except RuntimeError as e:
                    if 'out of memory' in str(e).lower():
                        outOfMemory += 1
                        continue
                    else:
                        print(e)
                        exit()
                except Exception as e:
                    print(e)
                    exit()
            print("Times of out of memory: ", outOfMemory)
            # Compute loss and acc for dev set
            dev_loss, dev_acc = self.get_dev_loss_and_acc(model, loss_fn)
            train_loss = train_loss / len(self.data_loader.train_data_loader)
            train_losses.append(train_loss)
            dev_losses.append(dev_loss)
            train_accs.append(hits / total)
            dev_accs.append(dev_acc)
            tqdm.write("Epoch: {}, Train loss: {:.4f}, Train acc: {:.4f}, Dev loss: {:.4f}, Dev acc: {:.4f}".format(
                        epoch, train_loss, hits / total, dev_loss, dev_acc))
            if dev_acc < prev_best:
                patience += 1
                if patience == 3:
                    # Learning rate annealing
                    optim_state = optimizer.state_dict()
                    optim_state['param_groups'][0]['lr'] = optim_state['param_groups'][0]['lr'] / 2
                    optimizer.load_state_dict(optim_state)
                    tqdm.write('Dev accuracy did not increase, reducing the learning rate by 2!!!')
                    patience = 0
            else:
                prev_best = dev_acc
                # Save the model
                torch.save(model.state_dict(), "ckpt/model_{}.t7".format(save_plots_as))

        # Acc vs time plot
        fig = plt.figure()
        plt.plot(range(1, self.params.max_epochs + 1), train_accs, color='b', label='train')
        plt.plot(range(1, self.params.max_epochs + 1), dev_accs, color='r', label='dev')
        plt.ylabel('accuracy')
        plt.xlabel('epochs')
        plt.legend()
        plt.xticks(np.arange(1, self.params.max_epochs + 1, step=4))
        fig.savefig('result/' + '{}_accuracy.png'.format(save_plots_as))

        return timer() - s_t


In [51]:
## Trainer and print_log
import sys
import os

class Trainer:
    def __init__(self, params, utils):
        self.params = params
        self.utils = utils
        self.log_time = {}

    def train(self):
        print('-----------{}-------------'.format(self.params.config))
        training_time = self.utils.train(save_plots_as=self.params.config)
        self.log_time[self.params.config] = training_time
        print('-----------------------------------------')


class Logger(object):
    def __init__(self, filename="Default.log"):
        self.terminal = sys.stdout
        if os.path.exists(filename):
            os.remove(filename)
        self.log = open(filename, "a")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)

    def flush(self):
        pass

    def change_file(self, filename="Default.log"):
        self.log.close()
        self.log = open(filename, "a")


if __name__ == '__main__':
    sys.stdout = Logger("yourlogfilename2.txt")
    print('content.')        
        
        
class Trainer:
    def __init__(self, params, utils):
        self.params = params
        self.utils = utils
        self.log_time = {}

    def train(self):
        print('-----------{}-------------'.format(self.params.config))
        training_time = self.utils.train(save_plots_as=self.params.config)
        self.log_time[self.params.config] = training_time
        print('----------------------------------------')

In [52]:
##Data Loader

import os
import csv
import pickle as pkl
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import torch.utils.data
from sklearn.model_selection import train_test_split
from nltk import sent_tokenize, word_tokenize
from multiprocessing import Pool as ProcessPool

ASYMMETRIC = True
DEBUG_NUM = 400
W2I = None


def sentence_tokenize(doc):
    # return doc.split('.')
    return sent_tokenize(doc)

def read_and_unpkl(file):
    with open(file, 'rb') as f:
        return pkl.load(f)

def parseLine(args):
    idx, tag, doc = args
    global W2I
    # sentences = doc.split('.')
    sentences = sentence_tokenize(doc)
    sentences_idx = []
    for sentence in sentences:
        sentence = sentence.lower().strip().split(" ")
        curr_sentence_idx = [W2I[x] for x in sentence]
        sentences_idx.append(curr_sentence_idx if len(curr_sentence_idx) > 0 else [W2I['<unk>']])
    return int(tag), sentences_idx

class DataLoader:
    def __init__(self, params):
        self.params = params
        self.ntags = params.ntags

        train_pkl_path = '{}/train/'.format(params.adjs)
        test_pkl_path = '{}/test/'.format(params.adjs)
        dev_pkl_path = '{}/dev/'.format(params.adjs)
        print('Loading adj: ', train_pkl_path[: -6])

        w2i_pkl_path = params.root + 'w2i.pkl'
        if params.mode == 0:
            w2i = freezable_defaultdict(lambda: len(w2i))
            UNK = w2i["<unk>"]

            self.train, self.adj_train, self.fea_train = self.read_dataset(params.train, w2i, train_pkl_path)
            print("Average train document length: {}".format(np.mean([len(x[0]) for x in self.train])))
            print("Maximum train document length: {}".format(max([len(x[0]) for x in self.train])))

            self.train, self.dev, self.adj_train, self.adj_dev, self.fea_train, self.fea_dev = \
                train_test_split(self.train, self.adj_train, self.fea_train, test_size=0.2, random_state=42)
        else:
            w2i = freezable_defaultdict(lambda: len(w2i))
            UNK = w2i["<unk>"]

            self.train, self.adj_train, self.fea_train = self.read_dataset(params.train, w2i, train_pkl_path)
            print("Average train document length: {}".format(np.mean([len(x[0]) for x in self.train])))
            print("Maximum train document length: {}".format(max([len(x[0]) for x in self.train])))

        w2i = freezable_defaultdict(lambda: UNK, w2i)
        w2i.freeze()
        self.w2i = w2i
        self.i2w = dict(map(reversed, self.w2i.items()))
        self.nwords = len(w2i)


        with open(params.entity_desc, 'rb') as f:
            corpus = pkl.load(f)
        self.entity_description = []
        for row in corpus:
            self.entity_description.append([w2i[x] for x in row.lower().split(" ")])

        if params.mode == 0:
            dataset_train = DataSet(self.train, self.adj_train, self.fea_train, self.params, self.entity_description)
            self.train_data_loader = torch.utils.data.DataLoader(dataset_train,
                                    batch_size=params.batch_size, collate_fn=dataset_train.collate, shuffle=True)
            dataset_dev = DataSet(self.dev, self.adj_dev, self.fea_dev, self.params, self.entity_description)
            self.dev_data_loader = torch.utils.data.DataLoader(dataset_dev,
                                    batch_size=params.batch_size, collate_fn=dataset_dev.collate,   shuffle=False)


        self.test, self.adj_test, self.fea_test = self.read_dataset(params.test, w2i, test_pkl_path)
        self.test_2, self.adj_test_2, self.fea_test_2 = self.read_dataset(params.dev, w2i, dev_pkl_path)

        dataset_test = DataSet(self.test, self.adj_test, self.fea_test, self.params, self.entity_description)
        self.test_data_loader = torch.utils.data.DataLoader(dataset_test,
                                batch_size=params.batch_size, collate_fn=dataset_test.collate,  shuffle=False)
        dataset_test_2 = DataSet(self.test_2, self.adj_test_2, self.fea_test_2, self.params, self.entity_description)
        self.test_data_loader_2 = torch.utils.data.DataLoader(dataset_test_2,
                                batch_size=params.batch_size, collate_fn=dataset_test_2.collate,shuffle=False)


    def load_adj_and_other(self, path):
        print("Loading {}".format(path))
        if path[-1] == '/':
            files = sorted([path + f for f in os.listdir(path) if judge_data(f)],
                                key=lambda x: int(x.split('/')[-1].split('.')[0]))  # 用idx.pkl中的idx排序
            files = files[: DEBUG_NUM] if self.params.DEBUG else files
            data = [read_and_unpkl(file) for file in tqdm(files)]
        else:
            with open(path, 'rb') as f:
                data = pkl.load(f)
        print("Preprocessing {}".format(path))
        res, device = [], 'cuda' if self.params.cuda else 'cpu'
        for piece in tqdm(data):
            d_idx = piece['idx']
            adj_list = [build_spr_coo(a) for a in piece['adj_list']]
            feature_list = [piece['s2i'], piece['e2i'], piece['t2i']]
            res.append([adj_list, feature_list])
        return res

    def read_dataset(self, filename, w2i, adj_file):
        adj = self.load_adj_and_other(adj_file)
        if 'csv' in filename:
            return self.read_dataset_sentence_wise(filename, w2i, adj)
        if 'xlsx' in filename:
            return self.read_testset_sentence_wise(filename, w2i, adj)

    def read_dataset_sentence_wise(self, filename, w2i, adj):
        data, new_adj, new_fea, removed_idx = [], [], [], []
        global W2I
        W2I = w2i
        # count = 0
        adj, fea = zip(*adj)
        with open(filename, "r") as f:
            readCSV = csv.reader(f, delimiter=',')
            csv.field_size_limit(100000000)
            sents = []
            for idx, (tag, doc) in tqdm(enumerate(readCSV)):
                if self.params.DEBUG and idx >= DEBUG_NUM:
                    break
                sents.append([idx, tag, doc])

            sentences_idx_list = []
            p = ProcessPool(10)
            with tqdm(total=len(sents)) as pbar:
                for out in p.imap(parseLine, sents):
                    sentences_idx_list.append(out)
                    pbar.update(1)
            p.close()
            p.join()

            print(len(sentences_idx_list))
            allowed_tags = [1, 4] if self.ntags == 2 else [1, 2, 3, 4]
            for idx, (tag, sentences_idx) in enumerate(sentences_idx_list):
                if tag in allowed_tags:
                    if self.ntags == 2:
                        tag = tag - 1 if tag == 1 else tag - 3   # Adjust the tag to {0: Satire, 1: Trusted}
                    else:
                        tag -= 1                                 # {0: Satire, 1: Hoax, 2: Propaganda, 3: Trusted}
                    if len(sentences_idx) > 1:
                        data.append((sentences_idx[:self.params.max_sents_in_a_doc], tag))
                        new_adj.append(adj[idx])
                        new_fea.append(fea[idx])
                    else:
                        removed_idx.append(idx)
        print('removed_idx of {}: {}'.format(filename, len(removed_idx)))
        print(len(data), len(new_adj))
        return data, new_adj, new_fea

    def read_dataset_sentence_wise(self, filename, w2i, adj):
        data, new_adj, new_fea = [], [], []
        # count = 0
        adj, fea = zip(*adj)
        with open(filename, "r") as f:
            readCSV = csv.reader(f, delimiter=',')
            csv.field_size_limit(100000000)
            removed_idx = []
            for idx, (tag, doc) in tqdm(enumerate(readCSV)):
                if self.params.DEBUG and idx >= DEBUG_NUM:
                    break
                # sentences = doc.split('.')
                sentences = sentence_tokenize(doc)
                tag = int(tag)
                allowed_tags = [1, 4] if self.ntags == 2 else [1, 2, 3, 4]
                if tag in allowed_tags:
                    if self.ntags == 2:
                        tag = tag - 1 if tag == 1 else tag - 3   # Adjust the tag to {0: Satire, 1: Trusted}
                    else:
                        tag -= 1                                 # {0: Satire, 1: Hoax, 2: Propaganda, 3: Trusted}
                    sentences_idx = []
                    for sentence in sentences:
                        sentence = sentence.lower().strip().split(" ")
                        curr_sentence_idx = [w2i[x] for x in sentence]
                        sentences_idx.append(curr_sentence_idx if len(curr_sentence_idx) > 0 else [w2i['<unk>']])

                    if len(sentences_idx) > 1 and len(sentences_idx) < 1000:
                        data.append((sentences_idx[:self.params.max_sents_in_a_doc], tag))
                        new_adj.append(adj[idx])
                        new_fea.append(fea[idx])
                    else:
                        removed_idx.append(idx)
        print('removed_idx of {}: {}'.format(filename, len(removed_idx)))
        return data, new_adj, new_fea

    def read_testset_sentence_wise(self, filename, w2i, adj):
        df = pd.read_excel(filename)
        data, new_adj, new_fea = [], [], []
        count = 0
        adj, fea = zip(*adj)
        removed_idx = []
        for idx, row in tqdm(enumerate(df.values)):
            if self.params.DEBUG and idx >= DEBUG_NUM:
                break
            # sentences = row[2].split('.')
            sentences = sentence_tokenize(row[2])
            tag = int(row[0])
            # Tag id is reversed in this dataset
            tag = tag + 1 if tag == 0 else tag - 1
            sentences_idx = []
            for sentence in sentences:
                sentence = sentence.lower().replace("\n", " ").strip().split(" ")
                curr_sentence_idx = [w2i[x] for x in sentence]
                sentences_idx.append(curr_sentence_idx if len(curr_sentence_idx) > 0 else [w2i['<unk>']])
            if len(sentences_idx) > 1:
                data.append((sentences_idx, tag))
                new_adj.append(adj[count])
                new_fea.append(fea[count])
            else:
                removed_idx.append(idx)
            count += 1

        print('removed_idx of {}: {}'.format(filename, removed_idx))
        return data, new_adj, new_fea

def judge_data(fileName):
    key = fileName.split('.')[0]
    try:
        x = int(key)
        return True
    except:
        return False

def build_spr_coo(spr, device='cpu'):
    # {'indices': spr.indices(), 'value': spr.values(), 'size': spr.size()}
    if not isinstance(spr, dict):
        raise TypeError("Not recognized type of sparse matrix source: {}".format(type(spr)))
    tensor = torch.sparse.FloatTensor(spr['indices'], spr['value'], spr['size']).coalesce()
    return tensor if device == 'cpu' else tensor.to(device)

class DataSet(torch.utils.data.TensorDataset):
    def __init__(self, data, adj, fea, params, entity_description):
        super(DataSet, self).__init__()
        self.params = params
        # data is a list of tuples (sent, label)
        self.sents = [x[0] for x in data]
        self.labels = [x[1] for x in data]
        self.adjs = adj
        self.feas = fea
        self.entity_description = entity_description
        self.num_of_samples = len(self.sents)
        for i, a in enumerate(self.adjs):
            assert a[0].shape[0] == len(self.sents[i]),\
                "dim of adj does not match the num of sent, where the idx is {}".format(i)
            assert a[4].shape[0] == len(self.feas[i][1]), \
                "dim of adj does not match the num of entity, where the idx is {}".format(i)
            assert a[7].shape[0] == len(self.feas[i][2]), \
                "dim of adj does not match the num of topic, where the idx is {}".format(i)

    def __len__(self):
        return self.num_of_samples

    def __getitem__(self, idx):
        return self.sents[idx], len(self.sents[idx]), self.labels[idx], self.adjs[idx], self.feas[idx]

    def collate(self, batch):
        sents, doc_lens_o, labels, adjs, feas = zip(*batch)
        # concatenate & padding
        doc_lens, curr_sents = [], []
        for doc in sents:
            doc_lens += [min(self.params.max_sent_len, len(x)) for x in doc]
            curr_sents += doc
        padded_sents = np.zeros((len(curr_sents), max(doc_lens)))
        for i, sen in enumerate(curr_sents):
            padded_sents[i, :len(sen)] = sen[:doc_lens[i]]
        documents = torch.from_numpy(padded_sents).long()

        new_feas, new_adjs = [], []
        fea_doc, fea_ent, fea_top = zip(*feas)
        for f in [fea_doc, fea_ent, fea_top]:
            fea = torch.from_numpy(np.array(sum([list(i.values()) for i in f], [])))
            new_feas.append(fea.long())
        for a in zip(*adjs):
            new_adjs.append(block_diag(a).float())

        labels = torch.from_numpy(np.array(labels)).long()
        sentPerDoc = torch.from_numpy(np.array([len(fea[0]) for fea in feas])).int()
        entiPerDoc = torch.from_numpy(np.array([len(fea[1]) for fea in feas])).int()
        topiPerDoc = torch.from_numpy(np.array([len(fea[2]) for fea in feas])).int()

        # concatenate & padding
        ent_lens, curr_sents = [], []
        for doc in fea_ent:
            doc = [self.entity_description[doc[idx]] for idx in range(len(doc))]
            ent_lens += [min(self.params.max_sent_len, len(x)) for x in doc]
            curr_sents += doc
        padded_sents = np.zeros((len(curr_sents), max(ent_lens)))
        for i, sen in enumerate(curr_sents):
            padded_sents[i, :len(sen)] = sen[:ent_lens[i]]
        ent_desc = torch.from_numpy(padded_sents).long()

        doc_lens = torch.from_numpy(np.array(doc_lens)).int()
        ent_lens = torch.from_numpy(np.array(ent_lens)).int()

        if self.params.node_type == 3:
            new_adjs = [new_adjs[0:3], new_adjs[3:6], new_adjs[6:9]]
            new_adjs[0][1].zero_()    # (√)text -> entity   (X)entity -> text
        elif self.params.node_type == 2:    # Document&Entiy
            new_adjs = [new_adjs[0:2], new_adjs[3:5]]
            new_feas = new_feas[0: 2]
            new_adjs[0][1].zero_()
        elif self.params.node_type == 1:    # Document&Topic
            new_adjs = [[new_adjs[0], new_adjs[2]], [new_adjs[6], new_adjs[8]]]
            new_feas = [new_feas[0], new_feas[2]]
            ent_desc, ent_lens, entiPerDoc = None, None, None
        elif self.params.node_type == 0:
            new_adjs = [[new_adjs[0]]]
            new_feas = [new_feas[0]]
            ent_desc, ent_lens, entiPerDoc = None, None, None
        else:
            raise Exception("Unknown node_type.")
        return documents, ent_desc, doc_lens, ent_lens, labels, new_adjs, new_feas, sentPerDoc, entiPerDoc


def block_diag(mat_list: list or tuple):
    shape_list = [m.shape for m in mat_list]
    bias = torch.LongTensor([0, 0])
    indices, values = [], []
    for m in mat_list:
        indices.append(m.indices() + bias.unsqueeze(1))
        values.append(m.values())
        bias += torch.LongTensor(list(m.shape))
    indices = torch.cat(indices, dim=1)
    values = torch.cat(values, dim=0)
    res = torch.sparse.FloatTensor(indices, values, size=torch.Size(bias))
    return res

class freezable_defaultdict(dict):
    def __init__(self, default_factory, *args, **kwargs):
        self.frozen = False
        self.default_factory = default_factory
        super(freezable_defaultdict, self).__init__(*args, **kwargs)

    def __missing__(self, key):
        if self.frozen:
            return self.default_factory()
        else:
            self[key] = value = self.default_factory()
            return value

    def freeze(self):
        self.frozen = True


In [54]:
## main

import os, sys, json, torch
import argparse, datetime, time
import random, numpy as np
# from util import Utils
# from data_loader import DataLoader
#from trainer import Trainer
#from evaluator import Evaluator
from timeit import default_timer as timer
# from print_log import Logger
from tqdm import tqdm

'''
node_type:
    '3 represents three types: Document&Entity&Topic; \n'
    '2 represents two types: Document&Entiy; \n'
    '1 represents two types: Document&Topic; \n'
    '0 represents only one type: Document. '
'''
# CUDA_VISIBLE_DEVICES_DICT = {0: '4',    1: '3',     2: '4',     3: '5'}
# MEMORY_DICT =               {0: 4000,   1: 9500,    2: 7600,    3: 8000}


def parse_arguments():
    parser = argparse.ArgumentParser(description='Argument parser for Fake News Detection')
    data_root_path = 'data/fakeNews/'
    parser.add_argument("--root", type=str, default=data_root_path)
    
    parser.add_argument("--train", type=str, default=data_root_path + 'fulltrain.csv')
    parser.add_argument("--dev", type=str, default=data_root_path + 'balancedtest.csv')
    parser.add_argument("--test", type=str, default=data_root_path + 'test.xlsx',
                        help='Out of domain test set')
    parser.add_argument("--pte", type=str, default='', help='Pre-trained embeds')
    parser.add_argument("--entity_desc", type=str, help='entity description path.',
                        default=data_root_path + 'entityDescCorpus.pkl')
    parser.add_argument("--entity_tran", type=str, help='entity transE embedding path.',
                        default=data_root_path + 'entity_feature_transE.pkl')
    parser.add_argument("--adjs", type=str, default=data_root_path + 'adjs/')
    # Hyper-parameters
    parser.add_argument("--emb_dim", type=int, default=100)
    parser.add_argument("--hidden_dim", type=int, default=100)
    parser.add_argument("--node_emb_dim", type=int, default=32)
    parser.add_argument("--max_epochs", type=int, default=5)
    parser.add_argument("--max_sent_len", type=int, default=50)
    parser.add_argument("--max_sents_in_a_doc", type=int, default=10000)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--dropout", type=float, default=0.5)
    parser.add_argument("--ntags", type=int, default=4)         # 4 or 2
    parser.add_argument("--weight_decay", type=float, default=1e-6)
    parser.add_argument("--pooling", type=str, default='max',
                        help='Pooling type: "max", "mean", "sum", "att". ')

    # parser.add_argument("--config", type=str, default='config_default',
    #                     help='Name for saving plots')
    parser.add_argument("--model_file", type=str, default='model_CompareNet_Max_DET_1112_1338.t7',
                        help='For evaluating a saved model')
    parser.add_argument("--plot", type=int, default=0, help='set to plot attn')
    parser.add_argument("--mode", type=int, default=0, help='0: train&test, 1:test')
    # parser.add_argument("--cuda", type=bool, default=True, help='use gpu to speed up or not')
    parser.add_argument("--cuda", type=bool, default=True, help='use gpu to speed up or not')
    parser.add_argument("--device", type=int, default=0, help='GPU ID. ')
    parser.add_argument("--HALF", type=bool, default=True, help='Use half tensor to save memory')

    parser.add_argument("--DEBUG", action='store_true', default=False, help='')
    parser.add_argument("--node_type", type=int, default=3,
                        help='3 represents three types: Document&Entity&Topic; \n'
                             '2 represents two types: Document&Entiy; \n'
                             '1 represents two types: Document&Topic; \n'
                             '0 represents only one type: Document. ')
    parser.add_argument('-r', "--repeat", type=int, default=1, help='')
    parser.add_argument('-s', "--seed", type=list, default=[5], help='')
    parser.add_argument('-s', "--encoder", type=int, default=[5], help='0: LSTM encoder for text,'
                                                                         '1: Transformer encoder for text')

    for dir in ["models/", "ckpt/", "plots/", "result/", "log/"]:
        if not os.path.exists(dir):   os.makedirs(dir)
    args = parser.parse_args(args=[])

    TIMENOW = (datetime.datetime.utcnow() + datetime.timedelta(hours=8)).strftime("%m%d_%H%M")
    NODETYPE = {0: "D", 1: "DT", 2: "DE", 3: "DET"}[args.node_type]
    if args.mode == 0:
        MODELNAME = 'CompareNet_{}_{}_{}'.format(args.pooling.capitalize(), NODETYPE, TIMENOW)
        args.model_file = 'model_{}.t7'.format(MODELNAME)
        args.config = MODELNAME
        sys.stdout = Logger("./log/{}_{}.log".format(MODELNAME, TIMENOW))
    else:
        MODELNAME = args.model_file.split(".")[0]
        args.config = MODELNAME
        sys.stdout = Logger("./log/{}_{}.log".format(MODELNAME, TIMENOW))

    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.device)
    args.cuda = args.cuda and torch.cuda.is_available()
    args.repeat = len(args.seed) if isinstance(args.seed, list) else args.repeat
    print("TimeStamp: {}\n".format(TIMENOW), json.dumps(vars(args), indent=2))
    return args






In [55]:
def main(params = None):
    global dl, test
    if params is None:
        params = parse_arguments()
        
    SEED = params.seed
    t0 = time.time()
    s_t = timer()
    #dl = DataLoader(params)

    u = Utils(params, dl)
    timeDelta = int(time.time()-t0)
    print("PreCost:", datetime.timedelta(seconds=timeDelta))
    for repeat in range(params.repeat):
        print("\n\n\n{0} Repeat: {1} {0}".format('-'*27, repeat))
        set_seed( SEED[repeat] if isinstance(SEED, list) else SEED )
        print("\n\n\n{0}  Seed: {1}  {0}".format('-'*27, SEED[repeat]))
        if params.mode == 0:
            # Start training
            trainer = Trainer(params, u)
            trainer.log_time['data_loading'] = timer() - s_t
            trainer.train()
            print(trainer.log_time)
            print("Total time taken (in seconds): {}".format(timer() - s_t))

            evaluator = Evaluator(params, u, dl)
            evaluator.evaluate()
        elif params.mode == 1:
            # Evaluate on the test set
            evaluator = Evaluator(params, u, dl)
            evaluator.evaluate()
        else:
            raise NotImplementedError("Unknown mode: {}".format(params.mode))
            

def set_seed(seed=9699):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


if __name__ == '__main__':

    params = parse_arguments()
    set_seed(0)
    main(params)

100%|██████████████████████████████████████████████████████████████████████████████| 1130/1130 [24:18<00:00,  1.29s/it]
100%|████████████████████████████████████████████████████████████████████████████████| 283/283 [02:42<00:00,  1.74it/s]
  return np.asscalar(np.mean(losses)), hits / total
100%|██████████████████████████████████████████████████████████████████████████████| 1130/1130 [26:51<00:00,  1.43s/it]
100%|████████████████████████████████████████████████████████████████████████████████| 283/283 [03:39<00:00,  1.29it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 1130/1130 [30:53<00:00,  1.64s/it]
100%|████████████████████████████████████████████████████████████████████████████████| 283/283 [03:00<00:00,  1.57it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 1130/1130 [25:54<00:00,  1.38s/it]
100%|████████████████████████████████████████████████████████████████████████████████| 283/283 [02:57<00:00,