In [None]:
import copy
import pickle
import datetime

import numpy as np
import pandas as pd

import torch
import deepsnap
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as pyg_nn

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

from deepsnap.hetero_gnn import forward_op
from deepsnap.hetero_graph import HeteroGraph
from torch_sparse import SparseTensor, matmul

import matplotlib.pyplot as plt

from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings('ignore')
from collections import defaultdict

In [None]:
base_path = './data/graph/'

In [None]:
def load_pickle(fname):
    return pickle.load(open(base_path + fname, 'rb'))

def save_tensor(data,fname):
    pickle.dump(data, open(fname, 'wb'))

def compare_date(source, target):
    s_y, s_m, s_d = source.split("-")
    t_y, t_m, t_d = target.split("-")
    
    s_date = datetime.date(int(s_y), int(s_m), int(s_d))
    t_date = datetime.date(int(t_y), int(t_m), int(t_d))

    return int(s_date < t_date)

# Load Data

In [None]:
bill = pd.read_csv('./data/final_bill.csv')

In [None]:
bh_tweet = pd.read_csv('./data/proceseed_bh_tweets.tsv', sep='\t', dtype={'id':str})
new_bh_tweet = pd.DataFrame(bh_tweet.groupby('date')['id'].apply(lambda x: ','.join(x)))
new_bh_tweet['id'] = pd.DataFrame(bh_tweet.groupby('date')['id'].apply(lambda x: ' '.join(x)))
new_bh_tweet = new_bh_tweet.reset_index(drop=False)

In [None]:
member = pd.read_csv('./data/final_cgppl.csv')
votes = pd.read_csv('./data/final_votes_21.csv')

# Load Graphs

In [None]:
bill_edges = load_pickle('bill_edges.edge_index')
mem_mem_edges = load_pickle('mem_mem_edges.edge_index')

bh_edges = load_pickle('bh_edges.edge_index')
bh_bill_edges = load_pickle('bh_bill_edges.edge_index')
bh_mem_edges = load_pickle('bh_mem_edges.edge_index')
mem_bill_edges = load_pickle('mem_bill_edges.edge_index')
bill_bh_edges = load_pickle('bill_bh_edges.edge_index')

bill_weights = load_pickle('bill_weights.edge_attr')
mem_mem_weights = load_pickle('mem_mem_weights.edge_attr')
bh_weights = load_pickle('bh_weights.edge_attr')
bh_bill_weights = load_pickle('bh_bill_weights.edge_attr')
bh_mem_weights = load_pickle('bh_mem_weights.edge_attr')
mem_bill_weights = load_pickle('mem_bill_weights.edge_attr')
bill_bh_weights = load_pickle('bill_bh_weights.edge_attr')

bill_emb = load_pickle('bill_emb.x')
mem_emb = load_pickle('mem_feat.x')
bh_tweet_feat = load_pickle('bh_tweet_feat.x')

bill_mem_vote = load_pickle('bill_mem_vote.y')

In [None]:
# Please do not change the following parameters
args = {
    'device': torch.device('cuda:1' if torch.cuda.is_available() else 'cpu'),
    'hidden_size': 64,
    'epochs': 100,
    'weight_decay': 1e-5,
    'lr': 0.003,
    'attn_size': 64,
    'batch_size': 128,
    'output_dim': 2
}

# DateSplit

## BIll mask

In [None]:
target_date = '2021-06-01'
train_mask = [compare_date(source_date, target_date) for source_date in bill['VOTE_DATE'].values]
train_mask = torch.LongTensor(train_mask)
print("# train", train_mask.sum())

In [None]:
target_date = '2021-08-01'
valid_mask = [compare_date(source_date, target_date) for source_date in bill['VOTE_DATE'].values]
valid_mask = torch.LongTensor(valid_mask)
valid_mask = valid_mask - train_mask
print("# valid", valid_mask.sum())

In [None]:
target_date = '2021-12-01'
test_mask = [compare_date(source_date, target_date) for source_date in bill['VOTE_DATE'].values]
test_mask = torch.LongTensor(test_mask)
test_mask = test_mask - valid_mask - train_mask
print("# test", test_mask.sum(), torch.unique(test_mask))

In [None]:
bill_train_mask = train_mask.nonzero().view(-1).long()
bill_valid_mask = valid_mask.nonzero().view(-1).long()
bill_test_mask = test_mask.nonzero().view(-1).long()

In [None]:
save_tensor(bill_train_mask.numpy(), './data/graph/bill_train_mask.y')
save_tensor(bill_valid_mask.numpy(), './data/graph/bill_valid_mask.y')
save_tensor(bill_test_mask.numpy(), './data/graph/bill_test_mask.y')

In [None]:
bill_train_mask.numpy().shape, bill_valid_mask.numpy().shape, bill_test_mask.numpy().shape, 

## tweet mask

In [None]:
target_date = '2021-06-01'
values = new_bh_tweet['date'].values
train_mask = [compare_date(source_date, target_date) for source_date in values]
train_mask = torch.LongTensor(train_mask)
print("# train", train_mask.sum())

In [None]:
target_date = '2021-08-01'
values = new_bh_tweet['date'].values
valid_mask = [compare_date(source_date, target_date) for source_date in values]
valid_mask = torch.LongTensor(valid_mask)
valid_mask = valid_mask - train_mask
print("# valid", valid_mask.sum(), torch.unique(valid_mask))

In [None]:
target_date = '2021-12-01'
values = new_bh_tweet['date'].values
test_mask = [compare_date(source_date, target_date) for source_date in values]
test_mask = torch.LongTensor(test_mask)
test_mask = test_mask - valid_mask - train_mask
print("# test", test_mask.sum(), torch.unique(test_mask))

In [None]:
bh_train_mask = train_mask.nonzero().view(-1).long()
bh_valid_mask = valid_mask.nonzero().view(-1).long()
bh_test_mask = test_mask.nonzero().view(-1).long()

In [None]:
save_tensor(bh_train_mask, './data/graph/train_mask.bh_tweet.x')
save_tensor(bh_valid_mask, './data/graph/valid_mask.bh_tweet.x')
save_tensor(bh_test_mask, './data/graph/test_mask.bh_tweet.x')

## member mask

In [None]:
mem_train_mask = torch.arange(len(member),dtype=torch.long)
mem_valid_mask = torch.arange(len(member),dtype=torch.long)
mem_test_mask = torch.arange(len(member),dtype=torch.long)

# Graph

In [None]:
# Message types
message_type = [
    ("bill", "keyword", "bill"),
    ("president", "keyword", "bill"),
    ("member", "propose", "bill"),
    
    ("member", "party", "member"),
    ("president", "party", "member"),
    
    ("president", "keyword", "president"),
    ("bill", "keyword", "president")
]

# Dictionary of edge indices
edge_index = {}
edge_attr = {}
message_type_index = [
    bill_edges,
    bh_bill_edges,
    mem_bill_edges,
    
    mem_mem_edges,
    bh_mem_edges,
    
    bh_edges,
    bill_bh_edges
]
message_type_attr = [
    bill_weights,
    bh_bill_weights,
    mem_bill_weights,
    
    mem_mem_weights,
    bh_mem_weights,
    
    bh_weights,
    bill_bh_weights
]

for mtyp, minx, mattr in zip(message_type, message_type_index, message_type_attr):
    edge_index[mtyp] = minx.t().long()
    edge_attr[mtyp] = mattr.unsqueeze(1).float()
    print(mtyp, minx.size(), mattr.size())

# Dictionary of node features
node_feature = {}
node_feature["bill"] = bill_emb.float()
node_feature["member"] = mem_emb.float()
node_feature["president"] = bh_tweet_feat.float()

# Dictionary of node labels
node_label = {}
node_label["bill"] = bill_mem_vote.long()

In [None]:
train_idx = {"bill":bill_train_mask, "member":mem_train_mask, "president": bh_train_mask}
val_idx = {"bill":bill_valid_mask, "member":mem_valid_mask, "president": bh_valid_mask}
test_idx = {"bill":bill_test_mask, "member":mem_test_mask, "president": bh_test_mask}

train_idx = {k: val.to(args['device']) for k, val in train_idx.items()}
val_idx = {k: val.to(args['device']) for k, val in val_idx.items()}
test_idx = {k: val.to(args['device']) for k, val in test_idx.items()}

In [None]:
bill_mem_vote[bill_train_mask].sum(), len(bill_mem_vote[bill_train_mask].view(-1))

In [None]:
bill_mem_vote[bill_valid_mask].sum(), len(bill_mem_vote[bill_valid_mask].view(-1))

In [None]:
bill_mem_vote[bill_test_mask].sum(), len(bill_mem_vote[bill_test_mask].view(-1))

In [None]:
# Construct a deepsnap tensor backend HeteroGraph
hetero_graph = HeteroGraph(
    node_feature=node_feature,
    node_label=node_label,
    edge_index=edge_index,
    edge_attr=edge_attr,
    directed=True
)

In [None]:
print(f"KO-VOTE heterogeneous graph: {hetero_graph.num_nodes()} nodes, {hetero_graph.num_edges()} edges")

In [None]:
# Node feature and node label to device
for key in hetero_graph.node_feature:
    hetero_graph.node_feature[key] = hetero_graph.node_feature[key].to(args['device'])
for key in hetero_graph.node_label:
    hetero_graph.node_label[key] = hetero_graph.node_label[key].to(args['device'])

In [None]:
# Edge_index to sparse tensor and to device
for key, mtyp in zip(hetero_graph.edge_index, message_type):
    source_node_type = mtyp[0]
    target_node_type = mtyp[-1]
    print(source_node_type, target_node_type)
    edge_index = hetero_graph.edge_index[key]
    print(edge_index)
    adj = SparseTensor(row=edge_index[0], col=edge_index[1], 
                       sparse_sizes=(
                           hetero_graph.num_nodes(source_node_type), 
                           hetero_graph.num_nodes(target_node_type)))
    hetero_graph.edge_index[key] = adj.t().to(args['device'])

In [None]:
# edge_attr to device
for key in hetero_graph.edge_attr:
    hetero_graph.edge_attr[key] = hetero_graph.edge_attr[key].to(args['device'])
    

In [None]:
for mtyp in message_type:
    print(mtyp)
#     print(hetero_graph.edge_attr[mtyp], hetero_graph.edge_attr[mtyp].dtype)
    print(hetero_graph.edge_index[mtyp])

# Model

In [None]:
def _init_weights(module):
    if isinstance(module,nn.Linear):
        nn.init.xavier_uniform_(module.weight.data)

In [None]:
import copy
import torch
import deepsnap
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as pyg_nn

from sklearn.metrics import f1_score
from deepsnap.hetero_gnn import forward_op
from deepsnap.hetero_graph import HeteroGraph
from torch_sparse import SparseTensor, matmul


class HeteroGNNConv(pyg_nn.MessagePassing):
    def __init__(self, in_channels_src, in_channels_dst, out_channels):
        super(HeteroGNNConv, self).__init__(aggr="mean")

        self.in_channels_src = in_channels_src
        self.in_channels_dst = in_channels_dst
        self.out_channels = out_channels

        self.lin_dst = nn.Linear(self.in_channels_dst, self.out_channels)
        self.lin_src = nn.Linear(self.in_channels_src, self.out_channels)
        self.lin_update = nn.Linear(self.out_channels * 2, self.out_channels)
        self.apply(_init_weights)

    def forward(
        self,
        node_feature_src,
        node_feature_dst,
        edge_index,
        size=None,
        res_n_id=None,
    ):
        return self.propagate(edge_index, size=size,
                              node_feature_src=node_feature_src,
                              node_feature_dst=node_feature_dst,
                              res_n_id=res_n_id)

    def message_and_aggregate(self, edge_index, node_feature_src):
        return matmul(edge_index, node_feature_src, reduce='mean')

    def update(self, aggr_out, node_feature_dst, res_n_id):
        node_feature_dst = self.lin_dst(node_feature_dst)
        aggr_out = self.lin_src(aggr_out)
        concat_out = torch.cat((node_feature_dst, aggr_out), dim=-1)
        aggr_out = self.lin_update(concat_out)
        return aggr_out


class HeteroGNNWrapperConv(deepsnap.hetero_gnn.HeteroConv):
    def __init__(self, convs, args, hetero_graph, aggr):
        super(HeteroGNNWrapperConv, self).__init__(convs, None)
        self.aggr = aggr

        # Map the index and message type
        self.mapping = {}

        # A numpy array that stores the final attention probability
        
        self.alpha = {}
        self.attn_proj = {}

        if self.aggr == "attn":           
            self.attn_proj = nn.ModuleDict()
            for node_type in hetero_graph.node_types:
                self.attn_proj[node_type] = nn.Sequential(
                    nn.Linear(args['hidden_size'], args['attn_size']),
                    nn.Tanh(),
                    # q_semantic_attention
                    nn.Linear(args['attn_size'], 1, bias=False),
                )
                self.alpha[node_type] = None
        self.apply(_init_weights)
           
    def reset_parameters(self):
        super(HeteroConvWrapper, self).reset_parameters()
        if self.aggr == "attn":
            for layer in self.attn_proj.children():
                layer.reset_parameters()

    def forward(self, node_features, edge_indices):
        # message_type 별로 conv.
        message_type_emb = {}
        for message_key, message_type in edge_indices.items():
            src_type, edge_type, dst_type = message_key
            node_feature_src = node_features[src_type]
            node_feature_dst = node_features[dst_type]
            edge_index = edge_indices[message_key]
            message_type_emb[message_key] = (
                self.convs[message_key](
                    node_feature_src,
                    node_feature_dst,
                    edge_index,
                )
            )
            
        # 
        node_emb = {dst: [] for _, _, dst in message_type_emb.keys()} 
        # example of elements of message_type_emb.keys()('bill', 'keyword', 'bill')
        
        mapping = {}
        for (src, edge_type, dst), item in message_type_emb.items():
            mapping[len(node_emb[dst])] = (src, edge_type, dst)
            node_emb[dst].append(item)
        
        self.mapping = mapping
        for node_type, embs in node_emb.items():
            if len(embs) == 1:
                node_emb[node_type] = embs[0]
            else:
                node_emb[node_type] = self.aggregate(embs, node_type)
        return node_emb

    def aggregate(self, xs, node_type):
        if self.aggr == "mean":
            out = torch.mean(torch.stack(xs), dim=0)
            return out

        elif self.aggr == "attn":
            x = self.attn_proj[node_type](torch.stack(xs, dim=0))
            x = torch.mean(x, dim=1)

            self.alpha[node_type] = torch.softmax(x, dim=0)
            self.alpha[node_type] = self.alpha[node_type].detach()

            # apply the attention and update the h
            out = torch.stack(xs, dim=0)
            out = self.alpha[node_type].unsqueeze(-1) * out

            out = torch.sum(out, dim=0)
            return out


def generate_convs(hetero_graph, conv, hidden_size, first_layer=False):
    convs = {}
    for m in hetero_graph.message_types:  # get all message types
        if first_layer:  # in_channel_size = node_feature_size
            num_node_feat_src = hetero_graph.num_node_features(m[0])
            num_node_feat_dst = hetero_graph.num_node_features(m[-1])
            convs[m] = conv(num_node_feat_src, num_node_feat_dst,
                            hidden_size)
        else:  # in_channel_size = hidden_size
            convs[m] = conv(hidden_size, hidden_size, hidden_size)
    return convs


class HeteroGNN(torch.nn.Module):
    def __init__(self, hetero_graph, args, aggr="mean"):
        super(HeteroGNN, self).__init__()
        self.aggr = aggr
        self.hidden_size = args['hidden_size']
        num_labels = 2

        convs1 = generate_convs(
            hetero_graph, HeteroGNNConv, self.hidden_size, first_layer=True)
        convs2 = generate_convs(
            hetero_graph, HeteroGNNConv, self.hidden_size, first_layer=False)
        
        self.convs1 = HeteroGNNWrapperConv(convs1, args, hetero_graph, self.aggr)
        self.convs2 = HeteroGNNWrapperConv(convs2, args, hetero_graph, self.aggr)

        self.bns1 = nn.ModuleDict()
        self.bns2 = nn.ModuleDict()
        self.relus1 = nn.ModuleDict()
        self.relus2 = nn.ModuleDict()

        for node_type in hetero_graph.node_types:
            self.bns1[node_type] = torch.nn.BatchNorm1d(
                self.hidden_size, eps=1.0)
            self.bns2[node_type] = torch.nn.BatchNorm1d(
                self.hidden_size, eps=1.0)
            self.relus1[node_type] = nn.LeakyReLU()
            self.relus2[node_type] = nn.LeakyReLU()

        self.clf = nn.Linear(self.hidden_size, num_labels)

        self.apply(_init_weights)
        
        

    def forward(self, node_feature, edge_index):
        x = node_feature

        x = self.convs1(x, edge_index)
        x = forward_op(x, self.bns1)
        x = forward_op(x, self.relus1)

        x = self.convs2(x, edge_index)
        x = forward_op(x, self.bns2)
        x = forward_op(x, self.relus2)

        leg_emb = x['bill'] # (1105, 64)
        mem_emb = x['member'] # (286, 64)
        
        leg_emb = leg_emb.unsqueeze(-1).permute(1, 0, 2) # torch.Size([64, 1105, 1])
        mem_emb = mem_emb.unsqueeze(-1).permute(1, 2, 0) #  torch.Size([64, 1, 287])

        
        leg_mem = torch.bmm(leg_emb, mem_emb)  # torch.Size([64, 1105, 287])
        leg_mem = leg_mem.permute(1, 2, 0)  # torch.Size([1105, 287, 64])
        
        vote = self.clf(leg_mem) # torch.Size([1105, 287, 2])
        
        return vote, mem_emb

# Training!

In [None]:
def loss_fn(preds, y, indices, node_type):
    loss = 0
    
    loss_func = F.cross_entropy

    idx = indices[node_type]
    preds = preds[idx]
    true = y[node_type][idx]
    loss += loss_func(preds, true)
    
    return loss

In [None]:
def train(model, optimizer, hetero_graph, train_idx):
    model.train()
    optimizer.zero_grad()
    preds, _ = model(hetero_graph.node_feature, hetero_graph.edge_index)
    preds = torch.softmax(preds, dim=2)
    preds = preds.permute(0, 2, 1)
    loss = loss_fn(preds, hetero_graph.node_label, train_idx, node_type='bill')

    loss.backward()
    optimizer.step()
    return loss.item()

In [None]:
def test(model, graph, index, node_type='bill'):
    model.eval()
    idx = index[node_type]

    preds, embed = model(graph.node_feature, graph.edge_index)
    preds = preds[idx]
    preds = torch.softmax(preds, dim=2)
    preds = preds.permute(0, 2, 1)
    
    target = graph.node_label[node_type][idx]
    
    loss = F.cross_entropy(preds, target)
    
    label_np = target.cpu().numpy()
    pred_np = torch.argmax(preds, dim=1).cpu().numpy()
    
    return loss.item(), pred_np, label_np, embed

In [None]:
def eval(y_true, y_pred, index, node_type='bill'):
    accs = []
    pres = []
    recs = []
    f1s = []
    num_samples = len(index[node_type])
    for i in range(287):
        pred = y_pred[:, i]
        true = y_true[:, i]

        accs.append(accuracy_score(y_pred=pred, y_true=true))
        pres.append(precision_score(y_pred=pred, y_true=true, average='weighted'))
        recs.append(recall_score(y_pred=pred, y_true=true, average='weighted'))
        f1s.append(f1_score(y_pred=pred, y_true=true, average='weighted'))

    return np.array(accs).mean(), np.array(pres).mean(), np.array(recs).mean(), np.array(f1s).mean()

# Mean

In [None]:
model = HeteroGNN(hetero_graph, args, aggr="mean").to(args['device'])
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])

best_model = None
best_val = 0

history = defaultdict(list)
pbar = tqdm(range(args['epochs']))
for epoch in range(args['epochs']):
    loss = train(model, optimizer, hetero_graph, train_idx)
    _, y_pred, y_true, _ = test(model, hetero_graph, train_idx)
    history["loss"].append(loss)
    
    train_acc, train_pre, train_rec, train_f1 = eval(y_pred=y_pred, y_true=y_true, index=train_idx)
    history["f1_score"].append(train_f1)
    
    val_loss, y_pred, y_true, _ = test(model, hetero_graph, val_idx)
    history["val_loss"].append(val_loss)
    
    val_acc, val_pre, val_rec, val_f1 = eval(y_pred=y_pred, y_true=y_true, index=val_idx)
    history["val_f1_score"].append(val_f1)
    
    _, y_pred, y_true, _ = test(model, hetero_graph, test_idx)
    test_acc, test_pre, test_rec, test_f1 = eval(y_pred=y_pred, y_true=y_true, index=test_idx)
    
    if val_acc > best_val:
        best_val = val_acc
        best_model = copy.deepcopy(model)
    
    pbar.set_description(
        f"Epoch {epoch + 1}: loss {round(loss, 5)}, " +
        f"train acc {train_acc:.4f}%, " +
        f"train macro {train_f1:.4f}%, " +
        f"valid acc {val_acc:.4f}%, " +
        f"valid macro {val_f1:.4f}%, " +
        f"test acc {test_acc:.4f}%, " +
        f"test macro {test_f1:.4f}%"
    )
    pbar.update()
pbar.close()

_, y_pred, y_true, embed = test(best_model, hetero_graph, test_idx)
print(eval(y_pred=y_pred, y_true=y_true, index=test_idx))

In [None]:
plt.plot(history["loss"])
plt.plot(history["val_loss"])
plt.title("Loss")
plt.legend(["train", "val"])
plt.show()

In [None]:
plt.plot(history["f1_score"])
plt.plot(history["val_f1_score"])
plt.title("Macro F1 score")
plt.legend(["train", "val"])
plt.show()

# Training! - Att

In [None]:
model = HeteroGNN(hetero_graph, args, aggr="attn").to(args['device'])
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])

best_model = None
best_val = 0

history = defaultdict(list)
pbar = tqdm(range(args['epochs']))
for epoch in range(args['epochs']):
    loss = train(model, optimizer, hetero_graph, train_idx)
    _, y_pred, y_true, _ = test(model, hetero_graph, train_idx)
    history["loss"].append(loss)
    
    train_acc, train_pre, train_rec, train_f1 = eval(y_pred=y_pred, y_true=y_true, index=train_idx)
    history["f1_score"].append(train_f1)
    
    val_loss, y_pred, y_true, _ = test(model, hetero_graph, val_idx)
    history["val_loss"].append(val_loss)
    
    val_acc, val_pre, val_rec, val_f1 = eval(y_pred=y_pred, y_true=y_true, index=val_idx)
    history["val_f1_score"].append(val_f1)
    
    _, y_pred, y_true, _ = test(model, hetero_graph, test_idx)
    test_acc, test_pre, test_rec, test_f1 = eval(y_pred=y_pred, y_true=y_true, index=test_idx)
    
    if val_acc > best_val:
        best_val = val_acc
        best_model = copy.deepcopy(model)
    
    pbar.set_description(
        f"Epoch {epoch + 1}: loss {round(loss, 5)}, " +
        f"train acc {train_acc:.4f}%, " +
        f"train macro {train_f1:.4f}%, " +
        f"valid acc {val_acc:.4f}%, " +
        f"valid macro {val_f1:.4f}%, " +
        f"test acc {test_acc:.4f}%, " +
        f"test macro {test_f1:.4f}%"
    )
    pbar.update()
pbar.close()

_, y_pred, y_true, embed = test(best_model, hetero_graph, test_idx)
print(eval(y_pred=y_pred, y_true=y_true, index=test_idx))

In [None]:
plt.plot(history["loss"])
plt.plot(history["val_loss"])
plt.title("Loss")
plt.legend(["train", "val"])
plt.show()

In [None]:
plt.plot(history["f1_score"])
plt.plot(history["val_f1_score"])
plt.title("Macro F1 score")
plt.legend(["train", "val"])
plt.show()

# All Approval

In [None]:
all_yes = np.ones_like(y_pred)
print(eval(y_pred=all_yes, y_true=y_true, index=test_idx))

# Random

In [None]:
random_yes = np.random.randint(2, size=y_pred.shape[0] * y_pred.shape[1]).reshape(y_pred.shape[0], y_pred.shape[1])
print(eval(y_pred=random_yes, y_true=y_true, index=test_idx))