In [1]:
import os
import pickle
import torch
import torch_geometric
from tqdm import tqdm
from operator import itemgetter
from sklearn.utils import compute_class_weight
from torch_geometric.loader import DataLoader
from predict import HGT
from sklearn.model_selection import KFold
from finetuning import *
import warnings
warnings.filterwarnings("ignore")

In [None]:
CONFIG = {
    'num_epochs': 20,
    'batch_size': 8,
    'learning_rate': 0.001,
    'num_training_graphs': 50,
    'lr_scheduler': False,
    'hidden_channels': 64
}

In [2]:
def load_graph_by_path(file_name, dataset, setup):
    path = "../static_graphs/" + dataset + "/" + setup + "/" + file_name + ".pickle"
    with open(path, 'rb') as handle:
        return pickle.load(handle)['graph']

def load_graph_ids(dataset):
    file_path = '../' + dataset + '_train_test.pickle'
    with open(file_path, 'rb') as handle:
        data = pickle.load(handle)['graph_ids']
    return data

def load_graphs_by_setup(setup, num_graphs=-1, dataset='politifact'):
    graph_ids = load_graph_ids(dataset=dataset)
    if num_graphs != -1:
        graph_ids = graph_ids[:num_graphs]
    all_graphs = []
    for graph_id in tqdm(graph_ids):
        graph = load_graph_by_path(file_name=graph_id, dataset=dataset, setup=setup)
        graph['tweet'].x = graph['tweet'].x[:, :768]
        graph['user'].x = graph['user'].x[:, :768]
        all_graphs.append(graph)
    return all_graphs

In [3]:
def initialize_model(metadata, models_folder, pretrained_name, use_pretrained=False, hidden_channels=64):
    model_encoder = HGT(hidden_channels=hidden_channels, metadata=metadata)

    if use_pretrained:
        model_path = models_folder + '/' + pretrained_name + '.pth'
        state = torch.load(model_path)
        model_encoder.load_state_dict(state)

    class PretrainedModel(torch.nn.Module):
        def __init__(self, encoder, hidden_channels=hidden_channels, out_channels=2):
            super(PretrainedModel, self).__init__()
            # torch.manual_seed(42)

            self.encoder = encoder
            self.decoder = torch.nn.Linear(hidden_channels * 3, out_channels)

        def forward(self, x_dict, edge_index_dict, batch_dict):
            x_dict = self.encoder(x_dict, edge_index_dict, batch_dict)
            x_dict = {key: global_mean_pool(x, batch_dict[key]) for key, x in x_dict.items()}
            x = torch.cat([x_dict['article'], x_dict['tweet'], x_dict['user']], dim=1)
            x = F.dropout(x, p=0.2, training=self.training)
            x = self.decoder(x)
            return x

    model = PretrainedModel(model_encoder)
    
    return model

In [22]:
# all_graphs = load_graphs_by_setup('all_data', -1, dataset='politifact')
all_graphs = load_graphs_by_setup('all_data', -1, dataset='gossipcop')

kf = KFold(n_splits=5)
kf.get_n_splits(all_graphs)

train_splits = []
test_splits = []

for train_index, test_index in kf.split(all_graphs):
    print(30*"*")
    X_train, X_test = itemgetter(*train_index)(all_graphs), itemgetter(*test_index)(all_graphs)
    print("Num train:", len(X_train), "Num test:", len(X_test))
    train_splits.append(X_train)
    test_splits.append(X_test)

acc_all = []
p_all = []
r_all = []
f1_all = []

for idx, val in enumerate(train_splits):
    if CONFIG['num_training_graphs'] is None:
        X_train = val
    else:
        X_train = val[:CONFIG['num_training_graphs']]

    y_tensors = []
    for graph in X_train:
        y_tensors.append(graph['article'].y)

    class_weights = torch.tensor(compute_class_weight(class_weight='balanced', classes=np.asarray([0, 1]),
                                                      y=torch.cat(y_tensors).cpu().detach().numpy()),
                                 dtype=torch.float32)
    class_weights.to(DEVICE)

    train_loader = DataLoader(X_train, batch_size=CONFIG['batch_size'], shuffle=True)
    test_loader = DataLoader(X_test, batch_size=CONFIG['batch_size'], shuffle=False)

    model = initialize_model(metadata=all_graphs[0].metadata(),
                             use_pretrained=True,
                             models_folder='models_separate_pureGNN', #'models_full_gos',
                             pretrained_name='pretrained_nodes_1',
                             hidden_channels=CONFIG['hidden_channels'])

    optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['learning_rate'])
    criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
    criterion.to(DEVICE)
    model, acc, precision, recall, f1 = train_eval_model(model=model, train_loader=train_loader, test_loader=test_loader,
                                                         loss_fct=criterion, optimizer=optimizer, num_epochs=CONFIG['num_epochs'],
                                                         verbose=1, use_lr_scheduler=CONFIG['lr_scheduler'])
    acc_all.append(acc)
    p_all.append(precision)
    r_all.append(recall)
    f1_all.append(f1)

print("ACC", acc_all, sum(acc_all) / len(acc_all))
print("P", p_all, sum(p_all) / len(p_all))
print("R", r_all, sum(r_all) / len(r_all))
print("F1", f1_all, sum(f1_all) / len(f1_all))

100%|██████████| 12214/12214 [00:39<00:00, 312.48it/s]


******************************
Num train: 9771 Num test: 2443
******************************
Num train: 9771 Num test: 2443
******************************
Num train: 9771 Num test: 2443
******************************
Num train: 9771 Num test: 2443
******************************
Num train: 9772 Num test: 2442
Epoch: 001, Loss: 0.6949, Train Acc: 0.7600, Train F1: 0.4318, Test Acc: 0.8195, Test F1: 0.4504
Epoch: 002, Loss: 0.6944, Train Acc: 0.7600, Train F1: 0.4318, Test Acc: 0.8195, Test F1: 0.4504
Epoch: 003, Loss: 0.6981, Train Acc: 0.7600, Train F1: 0.4318, Test Acc: 0.8195, Test F1: 0.4504
Epoch: 004, Loss: 0.6613, Train Acc: 0.7600, Train F1: 0.4318, Test Acc: 0.8207, Test F1: 0.4661
Epoch: 005, Loss: 0.7008, Train Acc: 0.7600, Train F1: 0.4318, Test Acc: 0.8195, Test F1: 0.4526
Epoch: 006, Loss: 0.6853, Train Acc: 0.8200, Train F1: 0.7717, Test Acc: 0.8133, Test F1: 0.7339
Epoch: 007, Loss: 0.6898, Train Acc: 0.8000, Train F1: 0.7520, Test Acc: 0.8060, Test F1: 0.7397
Epoch: 008,