In [16]:
import os
import math
from sklearn.model_selection import train_test_split
from sklearn.neighbors import kneighbors_graph
import pandas as pd
import jieba
import re
from tqdm import tqdm
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import style
sns.set(style='white')
style.use("fivethirtyeight")

os.environ["CUDA_VISIBLE_DEVICES"] = '3' 

In [2]:
# pytorch
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import TensorDataset,DataLoader 
from torch.optim import Adam,SGD,RMSprop
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch.nn.functional as F
from torch.nn.init import xavier_uniform_, kaiming_normal_
import time
import gc

In [3]:
from sklearn import metrics
from utils import * 
import pickle as pkl
datasets = ['mr','ohsumed','R8','R52','weibo_yiqing']
d = 'mr'
if d not in datasets:
    print("error dataset")
else:
    train_adj, train_feature, train_y, val_adj, val_feature, val_y, test_adj, test_feature, test_y = load_data(d)

In [4]:
print('loading training set')
train_adj, train_mask = preprocess_adj(train_adj)
train_feature = preprocess_features(train_feature)
print('loading validation set')
val_adj, val_mask = preprocess_adj(val_adj)
val_feature = preprocess_features(val_feature)
print('loading test set')
test_adj, test_mask = preprocess_adj(test_adj)
test_feature = preprocess_features(test_feature)

  6%|▌         | 361/6398 [00:00<00:01, 3604.70it/s]

loading training set


100%|██████████| 6398/6398 [00:01<00:00, 3940.55it/s]
100%|██████████| 6398/6398 [00:01<00:00, 4264.19it/s]
100%|██████████| 710/710 [00:00<00:00, 4083.72it/s]
  0%|          | 0/710 [00:00<?, ?it/s]

loading validation set


100%|██████████| 710/710 [00:00<00:00, 4680.75it/s]
 10%|█         | 356/3554 [00:00<00:00, 3211.89it/s]

loading test set


100%|██████████| 3554/3554 [00:01<00:00, 3391.38it/s]
100%|██████████| 3554/3554 [00:00<00:00, 4825.65it/s]


In [5]:
train_y = np.where(train_y)[1]
val_y = np.where(val_y)[1]
test_y = np.where(test_y)[1]

In [6]:
train_adj.shape,train_feature.shape,train_y.shape

((6398, 44, 44), (6398, 44, 300), (6398,))

In [7]:
train_adj = torch.Tensor(train_adj)
train_feature = torch.Tensor(train_feature)
train_y = torch.LongTensor(train_y)

val_adj = torch.Tensor(val_adj)
val_feature = torch.Tensor(val_feature)
val_y = torch.LongTensor(val_y)

test_adj = torch.Tensor(test_adj)
test_feature = torch.Tensor(test_feature)
test_y = torch.LongTensor(test_y)

train_mask = torch.Tensor(train_mask)
val_mask =torch.Tensor(val_mask)
test_mask =torch.Tensor(test_mask)

In [8]:
# split mini-batch
def getBatch(i, bs, A, X, Y,mask):
    return A[i*bs:(i+1)*bs],X[i*bs:(i+1)*bs],Y[i*bs:(i+1)*bs],mask[i*bs:(i+1)*bs]

In [9]:
# parameters
lr = 0.01
batch_size =32
epochs = 200
weight_decay = 1e-6

num_class = 2
train_samples = train_y.shape[0]
test_samples = test_y.shape[0]
val_samples = val_y.shape[0]

In [20]:
# GFM-GC model 
from layers import GraphConvolution
class GFMGC(nn.Module):
    def __init__(self, num_class, input_dim):
        super(GFMGC,self).__init__()
        
        self.num_class = num_class
        self.input_dim = input_dim
        
        self.fc1 = nn.Linear(input_dim,num_class)
    
    def cal_gfm(self,x,adj):# x-[bs,seq,emb_size]  adj:[bs,seq,seq]
        # get_edges 
        bs = x.shape[0]
        emb = x.shape[2]
        pairs = torch.nonzero(adj) # [  [b,  source,target]...]
        res = torch.zeros((bs,emb)).cuda()
        for pair in pairs:
            res[pair[0]] += adj[pair[0]][pair[1]][pair[2]] * x[pair[0]][pair[1]] * x[pair[0]][pair[2]]
        return res
                                                
                                                                                      
    def reset_parameters(self):
        pass
        
    def forward(self,x,adj,mask):# x:[bs,seq,emb_size]  adj:[bs,seq,seq]
        
        adj = adj * mask
        gfm = self.cal_gfm(x,adj)
        logit = self.fc1(gfm)
        return logit

In [21]:
model = GFMGC(num_class = num_class, input_dim = 300).cuda()

optimizer = Adam(model.parameters(),lr = lr,weight_decay = weight_decay)
lossfunc = nn.CrossEntropyLoss()

In [1]:
def train():
    for epoch in range(epochs):
        model.train() 
        print('epoch {}'.format(epoch + 1))
        train_loss = []
        train_acc = 0.
        for i in tqdm(range(train_samples // batch_size + 1)):
            adj_batch,feature_batch, y_batch, mask_batch = getBatch(i, batch_size, train_adj, train_feature, train_y,train_mask)
            optimizer.zero_grad()
            
            logits = model(feature_batch.cuda(),adj_batch.cuda(),mask_batch.cuda())
            loss = lossfunc(logits, y_batch.cuda())
            
            train_loss.append(loss.item())
            
            pred = torch.max(logits,1)[1]
            
            train_correct = (pred.cpu() == y_batch).sum()
            
            train_acc += train_correct
            
            loss.backward()
            optimizer.step()
        print(train_acc)
        print('train_loss = {:0.4f}, train_acc = {:0.4f}'.format(np.mean(train_loss), train_acc / train_samples))
        
        model.eval()
        
        val_loss = []
        val_acc = 0.
        with torch.no_grad():
            for i in tqdm(range(test_samples // batch_size + 1)):
                adj_batch,feature_batch, y_batch, mask_batch = getBatch(i, batch_size, test_adj, test_feature, test_y,test_mask)
                logits = model(feature_batch.cuda(),adj_batch.cuda(),mask_batch.cuda())
                loss = lossfunc(logits, y_batch.cuda())
                val_loss.append(loss.item())

                pred = torch.max(logits,1)[1]

                val_correct = (pred.cpu() == y_batch).sum()

                val_acc += val_correct
            # best_acc
            if best_acc < val_acc / test_samples:
                best_acc = val_acc / test_samples
            print('test_loss = {:0.4f},  test_acc = {:0.4f}, best_acc = {:0.4f}'.format(np.mean(val_loss), val_acc / test_samples,\
                                                                                     best_acc))

In [23]:
train()


  0%|          | 0/200 [00:00<?, ?it/s][A

epoch 1



  0%|          | 1/200 [00:00<02:55,  1.13it/s][A
  1%|          | 2/200 [00:01<02:56,  1.12it/s][A
  2%|▏         | 3/200 [00:02<02:55,  1.13it/s][A
  2%|▏         | 4/200 [00:03<02:52,  1.14it/s][A
  2%|▎         | 5/200 [00:04<02:45,  1.18it/s][A
  3%|▎         | 6/200 [00:05<02:47,  1.16it/s][A
  4%|▎         | 7/200 [00:06<02:56,  1.09it/s][A
  4%|▍         | 8/200 [00:07<03:09,  1.01it/s][A
  4%|▍         | 9/200 [00:08<03:09,  1.01it/s][A
  5%|▌         | 10/200 [00:09<03:14,  1.02s/it][A
  6%|▌         | 11/200 [00:10<03:12,  1.02s/it][A
  6%|▌         | 12/200 [00:11<03:10,  1.01s/it][A
  6%|▋         | 13/200 [00:12<03:07,  1.00s/it][A
  7%|▋         | 14/200 [00:13<03:06,  1.00s/it][A
  8%|▊         | 15/200 [00:14<02:56,  1.05it/s][A
  8%|▊         | 16/200 [00:15<02:57,  1.04it/s][A
  8%|▊         | 17/200 [00:16<02:53,  1.06it/s][A
  9%|▉         | 18/200 [00:17<02:52,  1.06it/s][A
 10%|▉         | 19/200 [00:18<02:50,  1.06it/s][A
 10%|█         | 20/

 78%|███████▊  | 157/200 [02:23<00:41,  1.04it/s][A
 79%|███████▉  | 158/200 [02:24<00:40,  1.03it/s][A
 80%|███████▉  | 159/200 [02:25<00:38,  1.07it/s][A
 80%|████████  | 160/200 [02:26<00:38,  1.05it/s][A
 80%|████████  | 161/200 [02:27<00:36,  1.08it/s][A
 81%|████████  | 162/200 [02:28<00:35,  1.06it/s][A
 82%|████████▏ | 163/200 [02:29<00:36,  1.02it/s][A
 82%|████████▏ | 164/200 [02:30<00:35,  1.01it/s][A
 82%|████████▎ | 165/200 [02:31<00:33,  1.06it/s][A
 83%|████████▎ | 166/200 [02:32<00:31,  1.09it/s][A
 84%|████████▎ | 167/200 [02:32<00:28,  1.15it/s][A
 84%|████████▍ | 168/200 [02:33<00:27,  1.17it/s][A
 84%|████████▍ | 169/200 [02:34<00:27,  1.13it/s][A
 85%|████████▌ | 170/200 [02:35<00:25,  1.16it/s][A
 86%|████████▌ | 171/200 [02:36<00:25,  1.14it/s][A
 86%|████████▌ | 172/200 [02:37<00:24,  1.13it/s][A
 86%|████████▋ | 173/200 [02:38<00:24,  1.10it/s][A
 87%|████████▋ | 174/200 [02:39<00:23,  1.09it/s][A
 88%|████████▊ | 175/200 [02:40<00:24,  1.03it

train_loss = 0.3977, train_acc = 0.0000



  1%|          | 1/112 [00:00<01:47,  1.03it/s][A
  2%|▏         | 2/112 [00:01<01:47,  1.02it/s][A
  3%|▎         | 3/112 [00:03<01:48,  1.00it/s][A
  4%|▎         | 4/112 [00:03<01:41,  1.06it/s][A
  4%|▍         | 5/112 [00:04<01:43,  1.03it/s][A
  5%|▌         | 6/112 [00:05<01:38,  1.07it/s][A
  6%|▋         | 7/112 [00:06<01:33,  1.12it/s][A
  7%|▋         | 8/112 [00:07<01:34,  1.11it/s][A
  8%|▊         | 9/112 [00:08<01:30,  1.14it/s][A
  9%|▉         | 10/112 [00:09<01:33,  1.10it/s][A
 10%|▉         | 11/112 [00:10<01:32,  1.09it/s][A
 11%|█         | 12/112 [00:10<01:28,  1.13it/s][A
 12%|█▏        | 13/112 [00:12<01:31,  1.08it/s][A
 12%|█▎        | 14/112 [00:12<01:22,  1.19it/s][A
 13%|█▎        | 15/112 [00:13<01:23,  1.16it/s][A
 14%|█▍        | 16/112 [00:14<01:27,  1.10it/s][A
 15%|█▌        | 17/112 [00:15<01:27,  1.08it/s][A
 16%|█▌        | 18/112 [00:16<01:25,  1.10it/s][A
 17%|█▋        | 19/112 [00:17<01:21,  1.14it/s][A
 18%|█▊        | 20/

KeyboardInterrupt: 