In [25]:
import torch

from models import SpKBGATModified, SpKBGATConvOnly
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from copy import deepcopy

import flwr as fl
from flwr.common import Context
from flwr.common.typing import NDArrays, Scalar

from preprocess import (
    read_entity_from_id,
    read_relation_from_id,
    init_embeddings,
    build_data,
)
from create_batch import Corpus
from utils import save_model

import random
from datetime import datetime
import argparse
import os
import sys
import logging
import time
import pickle
from typing import Callable, Dict, Tuple, List
from collections import OrderedDict
from logging import ERROR

data = "FB15k-237"


def parse_args():
    args = argparse.ArgumentParser()
    # network arguments
    args.add_argument(
        "-data",
        "--data",
        default=f"./data/{data}/",
        help="data directory",
    )
    args.add_argument(
        "-e_g", "--epochs_gat", type=int, default=45, help="Number of epochs"
    )
    args.add_argument(
        "-e_c", "--epochs_conv", type=int, default=3, help="Number of epochs"
    )
    args.add_argument(
        "-w_gat",
        "--weight_decay_gat",
        type=float,
        default=0.00001,
        help="L2 reglarization for gat",
    )
    args.add_argument(
        "-w_conv",
        "--weight_decay_conv",
        type=float,
        default=0.000001,
        help="L2 reglarization for conv",
    )
    args.add_argument(
        "-pre_emb",
        "--pretrained_emb",
        type=bool,
        default=True,
        help="Use pretrained embeddings",
    )
    args.add_argument(
        "-emb_size",
        "--embedding_size",
        type=int,
        default=50,
        help="Size of embeddings (if pretrained not used)",
    )
    args.add_argument("-l", "--lr", type=float, default=1e-3)
    args.add_argument("-g2hop", "--get_2hop", type=bool, default=True)
    args.add_argument("-u2hop", "--use_2hop", type=bool, default=True)
    args.add_argument("-p2hop", "--partial_2hop", type=bool, default=True)
    args.add_argument(
        "-outfolder",
        "--output_folder",
        default=f"../models/KBGAT/{data}/{(datetime.now()).strftime("%Y%m%d%H%M%S")}",
        help="Folder name to save the models.",
    )

    # arguments for GAT
    args.add_argument(
        "-b_gat", "--batch_size_gat", type=int, default=1360, help="Batch size for GAT"
    )
    args.add_argument(
        "-neg_s_gat",
        "--valid_invalid_ratio_gat",
        type=int,
        default=2,
        help="Ratio of valid to invalid triples for GAT training",
    )
    args.add_argument(
        "-drop_GAT",
        "--drop_GAT",
        type=float,
        default=0.3,
        help="Dropout probability for SpGAT layer",
    )
    args.add_argument(
        "-alpha",
        "--alpha",
        type=float,
        default=0.2,
        help="LeakyRelu alphs for SpGAT layer",
    )
    args.add_argument(
        "-out_dim",
        "--entity_out_dim",
        type=int,
        nargs="+",
        default=[100, 200],
        help="Entity output embedding dimensions",
    )
    args.add_argument(
        "-h_gat",
        "--nheads_GAT",
        type=int,
        nargs="+",
        default=[2, 2],
        help="Multihead attention SpGAT",
    )
    args.add_argument(
        "-margin", "--margin", type=float, default=1, help="Margin used in hinge loss"
    )

    # arguments for convolution network
    args.add_argument(
        "-b_conv",
        "--batch_size_conv",
        type=int,
        default=16,
        help="Batch size for conv",
    )
    args.add_argument(
        "-alpha_conv",
        "--alpha_conv",
        type=float,
        default=0.2,
        help="LeakyRelu alphas for conv layer",
    )
    args.add_argument(
        "-neg_s_conv",
        "--valid_invalid_ratio_conv",
        type=int,
        default=40,
        help="Ratio of valid to invalid triples for convolution training",
    )
    args.add_argument(
        "-o",
        "--out_channels",
        type=int,
        default=50,
        help="Number of output channels in conv layer",
    )
    args.add_argument(
        "-drop_conv",
        "--drop_conv",
        type=float,
        default=0.3,
        help="Dropout probability for convolution layer",
    )

    # fed args
    args.add_argument(
        "--num_rounds",
        type=float,
        default=50,
        help="Dropout probability for convolution layer",
    )
    args.add_argument(
        "--sample_clients",
        type=float,
        default=10,
        help="Dropout probability for convolution layer",
    )
    args.add_argument(
        "--num_clients",
        type=float,
        default=20,
        help="Dropout probability for convolution layer",
    )

    args.add_argument("--f")

    args = args.parse_args()
    return args


args = parse_args()


def load_data(args):
    (
        train_data,
        validation_data,
        test_data,
        entity2id,
        relation2id,
        headTailSelector,
        unique_entities_train,
    ) = build_data(args.data, is_unweigted=False, directed=True)

    if args.pretrained_emb:
        entity_embeddings, relation_embeddings = init_embeddings(
            os.path.join(args.data, "entity2vec.txt"),
            os.path.join(args.data, "relation2vec.txt"),
        )
        print("Initialised relations and entities from TransE")

    else:
        entity_embeddings = np.random.randn(len(entity2id), args.embedding_size)
        relation_embeddings = np.random.randn(len(relation2id), args.embedding_size)
        print("Initialised relations and entities randomly")

    corpus = Corpus(
        args,
        train_data,
        validation_data,
        test_data,
        entity2id,
        relation2id,
        headTailSelector,
        args.batch_size_gat,
        args.valid_invalid_ratio_gat,
        unique_entities_train,
        args.get_2hop,
    )

    return (
        corpus,
        torch.FloatTensor(entity_embeddings),
        torch.FloatTensor(relation_embeddings),
    )


Corpus_, entity_embeddings, relation_embeddings = load_data(args)


if args.get_2hop:
    file = args.data + "/2hop.pickle"
    with open(file, "wb") as handle:
        pickle.dump(
            Corpus_.node_neighbors_2hop, handle, protocol=pickle.HIGHEST_PROTOCOL
        )


if args.use_2hop:
    print("Opening node_neighbors pickle object")
    file = args.data + "/2hop.pickle"
    with open(file, "rb") as handle:
        node_neighbors_2hop = pickle.load(handle)

entity_embeddings_copied = deepcopy(entity_embeddings)
relation_embeddings_copied = deepcopy(relation_embeddings)

print(
    "Initial entity dimensions {} , relation dimensions {}".format(
        entity_embeddings.size(), relation_embeddings.size()
    )
)
# %%

CUDA = torch.cuda.is_available()

number of unique_entities -> 14505
number of unique_entities -> 9809
number of unique_entities -> 10348
Initialised relations and entities from TransE
Graph created
length of graph keys is  13781
time taken  440.8909635543823
length of neighbors dict is  13222
Total triples count 310116, training triples 272115, validation_triples 17535, test_triples 20466


  np.float32
  


Opening node_neighbors pickle object


KeyboardInterrupt: 

In [None]:
def evaluate_conv(args, unique_entities):
    model_conv = SpKBGATConvOnly(
        entity_embeddings,
        relation_embeddings,
        args.entity_out_dim,
        args.entity_out_dim,
        args.drop_GAT,
        args.drop_conv,
        args.alpha,
        args.alpha_conv,
        args.nheads_GAT,
        args.out_channels,
    )
    model_conv.load_state_dict(
        torch.load(
            "{0}conv/trained_{1}.pth".format(
                "./checkpoints/fb/out/", args.epochs_conv - 1
            )
        ),
        strict=False,
    )

    model_conv.cuda()
    model_conv.eval()

    with torch.no_grad():
        Corpus_.get_validation_pred(args, model_conv, unique_entities)


evaluate_conv(args, Corpus_.unique_entities_train)

Sampled indices
test set length  20466
0
sample -  1 1
1
sample -  45 71
2
sample -  2 1
3
sample -  26 572
4
sample -  7 254
5
sample -  1938 2490
6
sample -  15 590
7
sample -  1 47
8
sample -  384 7
9
sample -  60 209
10
sample -  8 2
11
sample -  1029 518
12
sample -  29 10
13
sample -  45 1
14
sample -  2998 562
15
sample -  777 28
16
sample -  21 1
17
sample -  11 390
18
sample -  1 41
19
sample -  1 44
20
sample -  31 9
21
sample -  4 2
22
sample -  25 43
23
sample -  1 2
24
sample -  2 3
25
sample -  100 2
26
sample -  1078 1414
27
sample -  182 21
28
sample -  1997 145
29
sample -  224 6035
30
sample -  17 9
31
sample -  394 8
32
sample -  2 2
33
sample -  4 11
34
sample -  2 36
35
sample -  243 17
36
sample -  7 43
37
sample -  70 1359
38
sample -  43 47
39
sample -  48 1
40
sample -  4 2
41
sample -  214 122
42
sample -  14 1
43
sample -  10 967
44
sample -  22 650
45
sample -  24 1
46
sample -  114 5
47
sample -  15 2
48
sample -  385 41
49
sample -  6 106
50
sample -  2 12

KeyboardInterrupt: 