In [1]:
import os
import pickle
import torch
import torch_geometric
from tqdm import tqdm
from separate_pretraining import *
from models import *
from torch_geometric.loader import DataLoader
import torch.optim as optim
from predict import *

In [None]:
CONFIG = {
    'dataset': 'politifact', # 'politifact' or 'gossipcop'
    'node_level_objective': None, # None or 'node_masking' or 'context_prediction'
    'graph_level_objective': None, # None or 'rt'
    'pretrained_model': None, # None or file name from previously trained model
    'save_model': True
}

In [2]:
def load_graph_by_path(file_name, dataset, setup):
    path = "../static_graphs/" + dataset + "/" + setup + "/" + file_name + ".pickle"
    with open(path, 'rb') as handle:
        return pickle.load(handle)['graph']

def load_graph_ids(dataset):
    file_path = '../' + dataset + '_train_test.pickle'
    with open(file_path, 'rb') as handle:
        data = pickle.load(handle)['graph_ids']
    return data

def load_graphs_by_setup(setup, num_graphs=-1, dataset='politifact'):
    graph_ids = load_graph_ids(dataset=dataset)
    if num_graphs != -1:
        graph_ids = graph_ids[:num_graphs]
    all_graphs = []
    for graph_id in tqdm(graph_ids):
        graph = load_graph_by_path(file_name=graph_id, dataset=dataset, setup=setup)
        graph['tweet'].x = graph['tweet'].x[:, :768]
        graph['user'].x = graph['user'].x[:, :768]
        all_graphs.append(graph)
    return all_graphs

In [3]:
all_graphs = load_graphs_by_setup('all_data', -1, CONFIG['dataset'])
print(all_graphs[0])

100%|██████████| 12214/12214 [00:42<00:00, 285.00it/s]

HeteroData(
  [1marticle[0m={
    x=[1, 768],
    y=[1]
  },
  [1mtweet[0m={ x=[150, 768] },
  [1muser[0m={ x=[25, 768] },
  [1m(tweet, cites, article)[0m={ edge_index=[2, 25] },
  [1m(user, posts, tweet)[0m={ edge_index=[2, 150] },
  [1m(tweet, retweets, tweet)[0m={ edge_index=[2, 0] },
  [1m(article, rev_cites, tweet)[0m={ edge_index=[2, 25] },
  [1m(tweet, rev_posts, user)[0m={ edge_index=[2, 150] },
  [1m(tweet, rev_retweets, tweet)[0m={ edge_index=[2, 0] }
)





In [6]:
if CONFIG['node_level_objective'] == 'context_prediction':
    data_loader = DataLoader(all_graphs, batch_size=128, shuffle=True)
    criterion = nn.BCEWithLogitsLoss()

    model_graph = HGT(hidden_channels=64, metadata=all_graphs[0].metadata(), out_cat=False, num_layers=2)
    model_context = HGT(hidden_channels=64, metadata=get_context_graph(all_graphs[0]).metadata(), out_cat=True, num_layers=2)

    optimizer_graph = optim.Adam(model_graph.parameters(), lr=0.001)
    optimizer_context = optim.Adam(model_context.parameters(), lr=0.001)

    loss, acc, model_graph = pretrain_context_prediction(model_graph=model_graph,
                                                         model_context=model_context,
                                                         data_loader=data_loader,
                                                         optimizer_graph=optimizer_graph,
                                                         optimizer_context=optimizer_context,
                                                         criterion=criterion,
                                                         epochs=50)

epoch: 001, loss: 1.3514, accuracy: 0.5031
epoch: 002, loss: 1.2893, accuracy: 0.4903
epoch: 003, loss: 1.2573, accuracy: 0.4695
epoch: 004, loss: 1.2500, accuracy: 0.4466
epoch: 005, loss: 1.2509, accuracy: 0.4369
epoch: 006, loss: 1.2461, accuracy: 0.4422
epoch: 007, loss: 1.2453, accuracy: 0.4466
epoch: 008, loss: 1.2413, accuracy: 0.4440
epoch: 009, loss: 1.2344, accuracy: 0.4456
epoch: 010, loss: 1.2166, accuracy: 0.4424
epoch: 011, loss: 1.2197, accuracy: 0.4406
epoch: 012, loss: 1.2083, accuracy: 0.4429
epoch: 013, loss: 1.2080, accuracy: 0.4426
epoch: 014, loss: 1.2067, accuracy: 0.4417
epoch: 015, loss: 1.2114, accuracy: 0.4429
epoch: 016, loss: 1.2032, accuracy: 0.4425
epoch: 017, loss: 1.2054, accuracy: 0.4432
epoch: 018, loss: 1.1999, accuracy: 0.4435
epoch: 019, loss: 1.2032, accuracy: 0.4424
epoch: 020, loss: 1.2057, accuracy: 0.4427
epoch: 021, loss: 1.1982, accuracy: 0.4422
epoch: 022, loss: 1.2020, accuracy: 0.4423
epoch: 023, loss: 1.2014, accuracy: 0.4431
epoch: 024,

In [7]:
if CONFIG['save_model'] == True:
    torch.save(model_graph.state_dict(), "pretrained_context.pth")

In [4]:
if CONFIG['node_level_objective'] == 'node_masking':
    encoder_model = HGT(hidden_channels=64, metadata=all_graphs[0].metadata())
    encoder_optimizer = optim.Adam(encoder_model.parameters(), lr=0.001)

    all_graphs = [mask_nodes(g, {'tweet': 1.0}) for g in all_graphs] # {'tweet': 1.0, 'user': 1.0}
    data_loader = DataLoader(all_graphs, batch_size=128, shuffle=True)

    encoder_model = pretrain_node_reconstruction(encoder_model=encoder_model,
                                                 data_loader=data_loader,
                                                 encoder_optimizer=encoder_optimizer,
                                                 epochs=50,
                                                 node_type='tweet',
                                                 decoder_dim=64)

epoch: 001, loss: 0.1593
epoch: 002, loss: 0.1005
epoch: 003, loss: 0.0956
epoch: 004, loss: 0.0938
epoch: 005, loss: 0.0926
epoch: 006, loss: 0.0918
epoch: 007, loss: 0.0912
epoch: 008, loss: 0.0907
epoch: 009, loss: 0.0902
epoch: 010, loss: 0.0898
epoch: 011, loss: 0.0892
epoch: 012, loss: 0.0890
epoch: 013, loss: 0.0887
epoch: 014, loss: 0.0886
epoch: 015, loss: 0.0882
epoch: 016, loss: 0.0879
epoch: 017, loss: 0.0879
epoch: 018, loss: 0.0877
epoch: 019, loss: 0.0875
epoch: 020, loss: 0.0871
epoch: 021, loss: 0.0870
epoch: 022, loss: 0.0869
epoch: 023, loss: 0.0865
epoch: 024, loss: 0.0864
epoch: 025, loss: 0.0863
epoch: 026, loss: 0.0862
epoch: 027, loss: 0.0860
epoch: 028, loss: 0.0858
epoch: 029, loss: 0.0856
epoch: 030, loss: 0.0855
epoch: 031, loss: 0.0853
epoch: 032, loss: 0.0852
epoch: 033, loss: 0.0851
epoch: 034, loss: 0.0849
epoch: 035, loss: 0.0848
epoch: 036, loss: 0.0847
epoch: 037, loss: 0.0846
epoch: 038, loss: 0.0844
epoch: 039, loss: 0.0842
epoch: 040, loss: 0.0843


In [5]:
if CONFIG['save_model'] == True:
    torch.save(encoder_model.state_dict(), "pretrained_nodes.pth")

In [4]:
if CONFIG['graph_level_objective'] == 'rt':
    encoder_model = HGT(hidden_channels=64, metadata=all_graphs[0].metadata())
    
    if CONFIG['pretrained_model'] is not None:
        encoder_model.load_state_dict(torch.load(CONFIG['pretrained_model'] + '.pth'))

    encoder_optimizer = optim.Adam(encoder_model.parameters(), lr=0.001)

    all_graphs = [add_rt_count(g) for g in all_graphs]

    data_loader = DataLoader(all_graphs, batch_size=128, shuffle=True)

    encoder_model = pretrain_graph_level(encoder_model=encoder_model,
                                         data_loader=data_loader,
                                         encoder_optimizer=encoder_optimizer,
                                         epochs=50,
                                         decoder_dim=64)

epoch: 1 loss: 26.466709931691486
epoch: 2 loss: 22.63671286900838
epoch: 3 loss: 20.879405811429024
epoch: 4 loss: 20.182606890797615
epoch: 5 loss: 19.853272005915642
epoch: 6 loss: 19.709533790747326
epoch: 7 loss: 18.914485638340313
epoch: 8 loss: 18.803261399269104
epoch: 9 loss: 18.404570351044338
epoch: 10 loss: 17.67037084698677
epoch: 11 loss: 17.381763632098835
epoch: 12 loss: 17.06446015338103
epoch: 13 loss: 16.750875413417816
epoch: 14 loss: 16.778749987483025
epoch: 15 loss: 15.921978150804838
epoch: 16 loss: 16.170435150464375
epoch: 17 loss: 15.933606204887232
epoch: 18 loss: 15.491712644696236
epoch: 19 loss: 15.30548266073068
epoch: 20 loss: 14.79008573293686
epoch: 21 loss: 15.545309091607729
epoch: 22 loss: 15.23928415775299
epoch: 23 loss: 15.016981825232506
epoch: 24 loss: 14.400368998448053
epoch: 25 loss: 14.315298850337664
epoch: 26 loss: 14.129013513525328
epoch: 27 loss: 14.453448424736658
epoch: 28 loss: 13.927411948641142
epoch: 29 loss: 13.77347075442473
e

In [5]:
if CONFIG['save_model'] == True:
    if CONFIG['pretrained_model'] is not None:
        torch.save(encoder_model.state_dict(), CONFIG['pretrained_model'] + "_rt.pth")
    else:
        torch.save(encoder_model.state_dict(), "rt.pth")