In [None]:
using Pkg;
#=
Pkg.add("Genie")
Pkg.add("Images")
Pkg.add("Flux")
Pkg.add("MLUtils")
Pkg.add("ArgParse")
Pkg.add("LoggingExtras")
Pkg.add("Dates")
Pkg.add("Printf")
Pkg.add("TensorBoardLogger")
Pkg.add("JLD2");
Pkg.add("BSON");
Pkg.add("Pickle")
Pkg.add("ProgressBars")
Pkg.add("Distributions")
Pkg.add("Revise")
Pkg.add("Zygote")
=#

In [1]:
using Pickle;
using ArgParse;
using Random;
using Logging, LoggingExtras, TensorBoardLogger;
using Dates;
using Printf;
using ProgressBars;
using Flux;
using MLUtils;
using Revise;

In [2]:
cd("/sata/sdb5/julia/pro_kgreasoning")

f_dir = "dataset";
f_model = "FB15k-betae";
f_train_queries = "train-queries.pkl";
f_train_answers = "train-answers.pkl";
f_valid_queries = "valid-queries.pkl";
f_valid_hard_answers = "valid-hard-answers.pkl";
f_valid_easy_answers = "valid-easy-answers.pkl";
f_test_queries = "test-queries.pkl";
f_test_hard_answers = "test-hard-answers.pkl";
f_test_easy_answers = "test-easy-answers.pkl";

query_name_dict = Dict{Tuple, String}(("e",("r",))=> "1p",
                                    ("e", ("r", "r"))=> "2p",
                                    ("e", ("r", "r", "r"))=> "3p",
                                    (("e", ("r",)), ("e", ("r",)))=> "2i",
                                    (("e", ("r",)), ("e", ("r",)), ("e", ("r",)))=> "3i",
                                    ((("e", ("r",)), ("e", ("r",))), ("r",))=> "ip",
                                    (("e", ("r", "r")), ("e", ("r",)))=> "pi",
                                    (("e", ("r",)), ("e", ("r", "n")))=> "2in",
                                    (("e", ("r",)), ("e", ("r",)), ("e", ("r", "n")))=> "3in",
                                    ((("e", ("r",)), ("e", ("r", "n"))), ("r",))=> "inp",
                                    (("e", ("r", "r")), ("e", ("r", "n")))=> "pin",
                                    (("e", ("r", "r", "n")), ("e", ("r",)))=> "pni",
                                    (("e", ("r",)), ("e", ("r",)), ("u",))=> "2u-DNF",
                                    ((("e", ("r",)), ("e", ("r",)), ("u",)), ("r",))=> "up-DNF",
                                    ((("e", ("r", "n")), ("e", ("r", "n"))), ("n",))=> "2u-DM",
                                    ((("e", ("r", "n")), ("e", ("r", "n"))), ("n", "r"))=> "up-DM"
                                );
name_query_dict = Dict{String, Tuple}((y => x) for (x, y) in query_name_dict);
all_tasks = collect(keys(name_query_dict));


In [3]:
include("src/utils.jl")
include("src/dataloader.jl")

using .KGDataset

args = Dict{String, Any}("geo" => "beta", "test_log_steps" => 1000, 
    "tasks" => "1p.2p.3p.2i.3i.ip.pi.2in.3in.inp.pin.pni.2u.up", "batch_size" => 2, 
    "evaluate_union" => "DNF", "nentity" => 0, "nrelation" => 0, "print_on_screen" => true, 
    "cpu" => 1, "valid" => false, "valid_steps" => 15000, "train" => true, 
    "negative_sample_size" => 32, "checkpoint_path" => nothing, "prefix" => nothing, 
    "cuda" => false, "warm_up_steps" => nothing, "hidden_dim" => 800, "beta_mode" => "(1600,2)", 
    "learning_rate" => 0.0001, "box_mode" => "(nothing,0.02)", "data_path" => "dataset/FB15k-betae", 
    "max_steps" => 450001, "save_checkpoint_steps" => 50000, "save_path" => ".", "test" => false, 
    "gamma" => 24.0, "log_steps" => 100, "seed" => 0, "test_batch_size" => 1)

train_queries, train_answers, valid_queries, valid_hard_answers, valid_easy_answers, 
test_queries, test_hard_answers, test_easy_answers = KGDataset.load_data(args, name_query_dict);

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mloading data....


load_data: deleteing structure...:
 ((("e", ("r", "n")), ("e", ("r", "n"))), ("n", "r"))
load_data: deleteing structure...:
 ((("e", ("r", "n")), ("e", ("r", "n"))), ("n",))


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mload data....Done


In [5]:
flatten_queries = flatten_query(train_queries)
t = length(flatten_queries)
println(" flatten_queries length: $(t)")
train_dataset = KGDataset.TrainDataset(flatten_queries, train_answers, t, t, 32)
train_data_loader = MLUtils.DataLoader(train_dataset, batchsize = 4, collate = true, shuffle = false);

idx = 1
for d in train_data_loader
    println(d)
    println("-----------------")
    idx += 1
    if idx > 4
        break
    end
end

 flatten_queries length: 1505405


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset [1] -> ((2674, (218, 383, 347)), ("e", ("r", "r", "r"))) answer: Set(Any[7460, 6039, 1830, 3877, 6746, 3140, 6318, 783, 2979, 3284, 2461, 182, 2996, 9228])
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset tail 9228 subsampling_weight: 18


getobs set mask at 24 int  24 max 64
value: 1830
getobs set mask at 45 int  45 max 64
value: 2461
getobs set mask at 28 int  28 max 64
value: 2996
getobs set mask at 59 int  59 max 64
value: 9228
getobs set mask at 23 int  23 max 64
value: 783
getobs set mask at 2 int  2 max 64
value: 7460
getobs set mask at 37 int  37 max 64
value: 6746
getobs set mask at 56 int  56 max 64
value: 6318
getobs set mask at 25 int  25 max 64
value: 182
getobs set mask at 41 int  41 max 64
value: 9228
getobs set mask at 53 int  53 max 64
value: 2461
getobs set mask at 22 int  22 max 64
value: 3284
getobs set mask at 8 int  8 max 64
value: 3140
getobs set mask at 12 int  12 max 64
value: 783
getobs set mask at 34 int  34 max 64
value: 2979
getobs set mask at 47 int  47 max 64
value: 783
getobs set mask at 22 int  22 max 64
value: 3284
getobs set mask at 30 int  30 max 64
value: 6318
getobs set mask at 26 int  26 max 64
value: 3877
getobs set mask at 55 int  55 max 64
value: 3877
getobs set mask at 24 int  2

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mgetobs one item -------------------------------------------------
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset [2] -> ((253, (1219, 1074, 157)), ("e", ("r", "r", "r"))) answer: Set(Any[5553, 3927, 5520, 2014, 1954, 9625, 3049, 13006, 4865, 1659, 5767, 2775, 7077, 11998, 11729, 7831, 7726, 3295, 4495, 13524])
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset tail 4495 subsampling_weight: 24


getobs set mask at 37 int  37 max 64
value: 5553
getobs set mask at 35 int  35 max 64
value: 5553
getobs set mask at 28 int  28 max 64
value: 1659
getobs set mask at 14 int  14 max 64
value: 7726
getobs set mask at 29 int  29 max 64
value: 3295
getobs set mask at 40 int  40 max 64
value: 11998
getobs set mask at 61 int  61 max 64
value: 7726
getobs set mask at 47 int  47 max 64
value: 3049
getobs set mask at 4 int  4 max 64
value: 7077
getobs set mask at 22 int  22 max 64
value: 4865
getobs set mask at 59 int  59 max 64
value: 13006
getobs set mask at 21 int  21 max 64
value: 3927
getobs set mask at 45 int  45 max 64
value: 11998
getobs set mask at 47 int  47 max 64
value: 2775
getobs set mask at 44 int  44 max 64
value: 9625
getobs set mask at 30 int  30 max 64
value: 13006
getobs set mask at 41 int  41 max 64
value: 7726
getobs set mask at 60 int  60 max 64
value: 5553
getobs set mask at 42 int  42 max 64
value: 13006
getobs set mask at 56 int  56 max 64
value: 5767
getobs set mask a

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mgetobs one item -------------------------------------------------
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset [3] -> ((7206, (65, 1509, 1650)), ("e", ("r", "r", "r"))) answer: Set(Any[6037, 3531, 10444, 8179])
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset tail 6037 subsampling_weight: 8


getobs set mask at 43 int  43 max 64
value: 10444
getobs set mask at 42 int  42 max 64
value: 10444
getobs set mask at 32 int  32 max 64
value: 10444
getobs set mask at 11 int  11 max 64
value: 3531
getobs set mask at 55 int  55 max 64
value: 3531
getobs set mask at 3 int  3 max 64
value: 8179
getobs set mask at 4 int  4 max 64
value: 10444
getobs set mask at 42 int  42 max 64
value: 8179
getobs set mask at 59 int  59 max 64
value: 3531
getobs set mask at 45 int  45 max 64
value: 8179
getobs set mask at 14 int  14 max 64
value: 8179
getobs set mask at 44 int  44 max 64
value: 8179
getobs set mask at 51 int  51 max 64
value: 8179
getobs set mask at 36 int  36 max 64
value: 3531
getobs set mask at 2 int  2 max 64
value: 6037
getobs set mask at 32 int  32 max 64
value: 8179
getobs set mask at 28 int  28 max 64
value: 10444
getobs set mask at 1 int  1 max 64
value: 6037
getobs set mask at 42 int  42 max 64
value: 3531
getobs set mask at 23 int  23 max 64
value: 10444
getobs set mask at 27 

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mgetobs one item -------------------------------------------------
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset [4] -> ((9713, (205, 613, 109)), ("e", ("r", "r", "r"))) answer: Set(Any[697, 1308, 1659, 7941])
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset tail 1308 subsampling_weight: 8


getobs set mask at 57 int  57 max 64
value: 7941
getobs set mask at 37 int  37 max 64
value: 1308
getobs set mask at 61 int  61 max 64
value: 1659
getobs set mask at 25 int  25 max 64
value: 7941
getobs set mask at 20 int  20 max 64
value: 7941
getobs set mask at 20 int  20 max 64
value: 1308
getobs set mask at 31 int  31 max 64
value: 1659
getobs set mask at 45 int  45 max 64
value: 1659
getobs set mask at 35 int  35 max 64
value: 1308
getobs set mask at 12 int  12 max 64
value: 1659
getobs set mask at 40 int  40 max 64
value: 697
getobs set mask at 16 int  16 max 64
value: 7941
getobs set mask at 11 int  11 max 64
value: 7941
getobs set mask at 50 int  50 max 64
value: 1659
getobs set mask at 16 int  16 max 64
value: 1308
getobs set mask at 5 int  5 max 64
value: 1659
getobs set mask at 15 int  15 max 64
value: 1308
getobs set mask at 21 int  21 max 64
value: 7941
getobs set mask at 31 int  31 max 64
value: 1308
getobs set mask at 13 int  13 max 64
value: 1308
getobs set mask at 3 in

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mgetobs one item -------------------------------------------------


([9228.0 4495.0 6037.0 1308.0], [965653 480458 273664 25168; 1178975 1365843 619868 1088278; 1337604 696182 654479 186166; 443539 705501 1098602 1283478; 395299 184725 1305820 1222434; 229330 1142227 1065141 1080549; 816757 976524 1168553 310292; 1088249 90436 967398 949089; 1328058 32492 1123873 890133; 333547 868970 166899 1315969; 500689 103834 692535 771453; 320349 87735 161967 318634; 2167 687086 696853 808506; 745036 1014506 1074453 937721; 1154488 1323650 33505 160412; 74993 886751 863601 1137579; 425462 1176962 1470113 617559; 1414316 100417 781567 195934; 71982 1148382 1406915 305586; 502436 368545 554920 435478; 95728 696587 600681 1178552; 522957 326491 288654 188909; 1001501 850622 775914 971016; 9371 752811 488182 1465378; 981742 317860 773213 1079195; 328994 821568 1036272 586168; 1363535 377234 502640 1294739; 17814 1265809 1323464 1458487; 497610 1026423 35772 697429; 1217278 1092085 1074311 1126622; 125377 317728 1032631 150756; 534495 777030 235538 949983], [0.2357022

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset [5] -> ((3, (1310, 1195, 1315)), ("e", ("r", "r", "r"))) answer: Set(Any[4267, 11764, 631, 677, 13535, 6931, 13925, 2029, 2892, 5247, 2065, 8726, 2683, 6047, 5253, 3532, 5980])
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset tail 4267 subsampling_weight: 21


getobs set mask at 15 int  15 max 64
value: 6931
getobs set mask at 30 int  30 max 64
value: 5253
getobs set mask at 6 int  6 max 64
value: 13535
getobs set mask at 41 int  41 max 64
value: 677
getobs set mask at 22 int  22 max 64
value: 677
getobs set mask at 50 int  50 max 64
value: 8726
getobs set mask at 1 int  1 max 64
value: 631
getobs set mask at 28 int  28 max 64
value: 2029
getobs set mask at 5 int  5 max 64
value: 5247
getobs set mask at 63 int  63 max 64
value: 11764
getobs set mask at 62 int  62 max 64
value: 5980
getobs set mask at 30 int  30 max 64
value: 2065
getobs set mask at 13 int  13 max 64
value: 13535
getobs set mask at 11 int  11 max 64
value: 4267
getobs set mask at 58 int  58 max 64
value: 13925
getobs set mask at 27 int  27 max 64
value: 6047
getobs set mask at 3 int  3 max 64
value: 2683
getobs set mask at 23 int  23 max 64
value: 4267
getobs set mask at 19 int  19 max 64
value: 4267
getobs set mask at 22 int  22 max 64
value: 5980
getobs set mask at 39 int  

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mgetobs one item -------------------------------------------------
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset [6] -> ((2535, (161, 336, 700)), ("e", ("r", "r", "r"))) answer: Set(Any[13133, 8095, 6459, 719, 13501, 13106, 2688, 11269, 2915, 9696, 10882, 10989, 14673, 5674, 4115, 14570, 6806, 8053, 2486, 8340, 9603, 11670, 10162, 4634, 177, 10916, 2550, 3655, 111, 10467, 8552, 13184, 12240, 8663, 8382, 6444, 9439, 14111, 10406, 4464, 2506, 4741, 11475, 13890, 4346, 8902, 3341, 3319, 10525, 2767, 920, 278, 1574, 5115, 10081, 7169, 13767, 10714, 14096, 1139, 11635, 14319, 5698, 247, 3955, 8723, 397, 10143, 13538, 11261, 14043, 12643, 14735, 6581, 2766, 6997, 10127, 5717, 154, 6972, 13692, 10870, 8041, 3116, 10826, 2433, 245, 1138, 13858, 12824, 6420, 6925, 13917, 14730, 11524, 14148, 14593, 1967, 4051, 7126, 12220, 9710, 4334, 6637, 11958, 13265, 758, 11406, 12478, 5645, 2044, 11743, 10083, 14295, 706, 176, 6885, 6460, 13259, 179

getobs set mask at 25 int  25 max 64
value: 14673
getobs set mask at 39 int  39 max 64
value: 14155
getobs set mask at 33 int  33 max 64
value: 3700
getobs set mask at 17 int  17 max 64
value: 3639
getobs set mask at 31 int  31 max 64
value: 111
getobs set mask at 59 int  59 max 64
value: 11448
getobs set mask at 59 int  59 max 64
value: 10162
getobs set mask at 53 int  53 max 64
value: 154
getobs set mask at 8 int  8 max 64
value: 658
getobs set mask at 53 int  53 max 64
value: 8165
getobs set mask at 26 int  26 max 64
value: 13089
getobs set mask at 20 int  20 max 64
value: 8616
getobs set mask at 28 int  28 max 64
value: 10882
getobs set mask at 29 int  29 max 64
value: 6237
getobs set mask at 63 int  63 max 64
value: 13636
getobs set mask at 35 int  35 max 64
value: 2687
getobs set mask at 31 int  31 max 64
value: 12849
getobs set mask at 62 int  62 max 64
value: 1018
getobs set mask at 63 int  63 max 64
value: 12643
getobs set mask at 6 int  6 max 64
value: 1713
getobs set mask at

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mgetobs one item -------------------------------------------------
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset [7] -> ((2828, (1471, 1233, 385)), ("e", ("r", "r", "r"))) answer: Set(Any[2143])
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset tail 2143 subsampling_weight: 5


getobs set mask at 48 int  48 max 64
value: 2143
getobs set mask at 28 int  28 max 64
value: 2143
getobs set mask at 38 int  38 max 64
value: 2143
getobs set mask at 9 int  9 max 64
value: 2143
getobs set mask at 39 int  39 max 64
value: 2143
getobs set mask at 11 int  11 max 64
value: 2143
getobs set mask at 32 int  32 max 64
value: 2143
getobs set mask at 16 int  16 max 64
value: 2143
getobs set mask at 6 int  6 max 64
value: 2143
getobs set mask at 26 int  26 max 64
value: 2143
getobs set mask at 52 int  52 max 64
value: 2143
getobs set mask at 44 int  44 max 64
value: 2143
getobs set mask at 55 int  55 max 64
value: 2143
getobs set mask at 1 int  1 max 64
value: 2143
getobs set mask at 2 int  2 max 64
value: 2143
getobs set mask at 58 int  58 max 64
value: 2143
getobs set mask at 10 int  10 max 64
value: 2143
getobs set mask at 11 int  11 max 64
value: 2143
getobs set mask at 30 int  30 max 64
value: 2143
getobs set mask at 5 int  5 max 64
value: 2143
getobs set mask at 2 int  2 ma

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mgetobs one item -------------------------------------------------
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset [8] -> ((4542, (94, 1166, 330)), ("e", ("r", "r", "r"))) answer: Set(Any[7141, 4259, 9332])
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset tail 9332 subsampling_weight: 7


getobs set mask at 52 int  52 max 64
value: 4259
getobs set mask at 41 int  41 max 64
value: 9332
getobs set mask at 27 int  27 max 64
value: 7141
getobs set mask at 22 int  22 max 64
value: 9332
getobs set mask at 63 int  63 max 64
value: 9332
getobs set mask at 12 int  12 max 64
value: 7141
getobs set mask at 8 int  8 max 64
value: 9332
getobs set mask at 8 int  8 max 64
value: 4259
getobs set mask at 44 int  44 max 64
value: 4259
getobs set mask at 49 int  49 max 64
value: 9332
getobs set mask at 20 int  20 max 64
value: 7141
getobs set mask at 31 int  31 max 64
value: 4259
getobs set mask at 7 int  7 max 64
value: 9332
getobs set mask at 5 int  5 max 64
value: 7141
getobs set mask at 18 int  18 max 64
value: 4259
getobs set mask at 26 int  26 max 64
value: 9332
getobs set mask at 35 int  35 max 64
value: 7141
getobs set mask at 50 int  50 max 64
value: 9332
getobs set mask at 33 int  33 max 64
value: 9332
getobs set mask at 25 int  25 max 64
value: 4259
getobs set mask at 37 int  3

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mgetobs one item -------------------------------------------------
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset [9] -> ((10531, (37, 460, 684)), ("e", ("r", "r", "r"))) answer: Set(Any[12704, 1144, 6459, 10157, 8232, 3424, 983, 9062, 6836, 1882, 11594, 4001, 4115, 12071, 1381, 14017, 12255, 4414, 8758, 12098, 5917, 14401, 10593, 14832, 6505, 14835, 2667, 12, 10916, 9082, 75, 3484, 3655, 111, 715, 41, 11647, 10483, 11803, 12483, 11193, 13661, 13176, 14309, 3572, 4535, 5324, 414, 4185, 3341, 3521, 7383, 2767, 9252, 2670, 7611, 3651, 278, 14185, 7128, 4024, 655, 9277, 13399, 14065, 3303, 247, 8639, 1688, 8723, 397, 2695, 13149, 13431, 8502, 6581, 258, 14817, 4844, 5570, 2766, 3338, 5219, 4713, 14563, 3788, 2114, 13095, 354, 7391, 9698, 10114, 11342, 3116, 3628, 1034, 5874, 482, 416, 6420, 4216, 11253, 3590, 1071, 8110, 5749, 2738, 4488, 5312, 9139, 13871, 7905, 14361, 4080, 4280, 5031, 7642, 6637, 13156, 9276, 3381, 12333, 8467, 1

getobs set mask at 40 int  40 max 64
value: 14156
getobs set mask at 50 int  50 max 64
value: 4488
getobs set mask at 13 int  13 max 64
value: 3942
getobs set mask at 63 int  63 max 64
value: 10483
getobs set mask at 58 int  58 max 64
value: 10593
getobs set mask at 5 int  5 max 64
value: 5942
getobs set mask at 25 int  25 max 64
value: 715
getobs set mask at 12 int  12 max 64
value: 4089
getobs set mask at 48 int  48 max 64
value: 8036
getobs set mask at 7 int  7 max 64
value: 5804


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mgetobs one item -------------------------------------------------
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset [10] -> ((12842, (356, 289, 697)), ("e", ("r", "r", "r"))) answer: Set(Any[4089, 8036, 9701, 2554, 541, 6931, 2214, 5804])
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset tail 2554 subsampling_weight: 12


getobs set mask at 63 int  63 max 64
value: 2214
getobs set mask at 47 int  47 max 64
value: 4089
getobs set mask at 33 int  33 max 64
value: 9701
getobs set mask at 31 int  31 max 64
value: 2554
getobs set mask at 9 int  9 max 64
value: 5804
getobs set mask at 56 int  56 max 64
value: 2214
getobs set mask at 48 int  48 max 64
value: 5804
getobs set mask at 17 int  17 max 64
value: 5804
getobs set mask at 59 int  59 max 64
value: 2554
getobs set mask at 25 int  25 max 64
value: 541
getobs set mask at 41 int  41 max 64
value: 6931
getobs set mask at 14 int  14 max 64
value: 4089
getobs set mask at 36 int  36 max 64
value: 2554
getobs set mask at 47 int  47 max 64
value: 4089
getobs set mask at 58 int  58 max 64
value: 4089
getobs set mask at 49 int  49 max 64
value: 5804
getobs set mask at 43 int  43 max 64
value: 9701
getobs set mask at 47 int  47 max 64
value: 2554
getobs set mask at 58 int  58 max 64
value: 2554
getobs set mask at 44 int  44 max 64
value: 4089
getobs set mask at 19 i

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mgetobs one item -------------------------------------------------
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset [11] -> ((9077, (280, 313, 272)), ("e", ("r", "r", "r"))) answer: Set(Any[2102, 1194, 1251, 4589, 4774, 4937, 2392, 30, 4245, 5306, 1312, 12185, 6832, 1598, 6436, 185, 8096, 5766, 1166, 9483, 877, 11888, 9039, 6914, 333, 8145, 1899, 4293, 4813, 4927, 3200, 3970, 3314, 3129, 6082, 2539, 4294, 850, 7017, 188, 3884, 8011, 8015, 7734, 3793, 2275, 1597, 1049, 1602, 1001, 3440, 1212, 6223, 2053, 1171, 3564, 211, 6215, 1322, 4203, 3336, 9188, 3261, 1281, 4077, 5291, 3128, 3199, 4144, 31, 1728, 3335, 7710, 4059, 653, 2112, 10243, 814, 902, 3563, 2414, 212, 4222, 1791, 2990, 11334, 815, 4669, 2748, 994, 163, 1790, 2121, 7001, 6671, 2404, 4529, 335, 509, 5626, 1709, 6828, 2111, 2413, 11505, 8652, 2675, 6477, 334, 7544, 508, 2054, 2059, 4703, 1723, 741, 10381])
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset tail 9039

getobs set mask at 37 int  37 max 64
value: 741
getobs set mask at 7 int  7 max 64
value: 741
getobs set mask at 63 int  63 max 64
value: 11888
getobs set mask at 46 int  46 max 64
value: 3129
getobs set mask at 17 int  17 max 64
value: 1166
getobs set mask at 46 int  46 max 64
value: 653
getobs set mask at 48 int  48 max 64
value: 5766
getobs set mask at 8 int  8 max 64
value: 334
getobs set mask at 34 int  34 max 64
value: 2102
getobs set mask at 34 int  34 max 64
value: 814
getobs set mask at 10 int  10 max 64
value: 1001
getobs set mask at 52 int  52 max 64
value: 5306
getobs set mask at 18 int  18 max 64
value: 1001
getobs set mask at 12 int  12 max 64
value: 508
getobs set mask at 27 int  27 max 64
value: 3261
getobs set mask at 41 int  41 max 64
value: 2112
getobs set mask at 51 int  51 max 64
value: 163
getobs set mask at 34 int  34 max 64
value: 1049
getobs set mask at 24 int  24 max 64
value: 6671
getobs set mask at 13 int  13 max 64
value: 3563
getobs set mask at 40 int  40 

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mgetobs one item -------------------------------------------------
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset [12] -> ((2965, (270, 508, 1352)), ("e", ("r", "r", "r"))) answer: Set(Any[4185])
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset tail 4185 subsampling_weight: 5


getobs set mask at 14 int  14 max 64
value: 4185
getobs set mask at 21 int  21 max 64
value: 4185
getobs set mask at 35 int  35 max 64
value: 4185
getobs set mask at 41 int  41 max 64
value: 4185
getobs set mask at 30 int  30 max 64
value: 4185
getobs set mask at 52 int  52 max 64
value: 4185
getobs set mask at 48 int  48 max 64
value: 4185
getobs set mask at 56 int  56 max 64
value: 4185
getobs set mask at 27 int  27 max 64
value: 4185
getobs set mask at 25 int  25 max 64
value: 4185
getobs set mask at 16 int  16 max 64
value: 4185
getobs set mask at 1 int  1 max 64
value: 4185
getobs set mask at 17 int  17 max 64
value: 4185
getobs set mask at 46 int  46 max 64
value: 4185
getobs set mask at 64 int  64 max 64
value: 4185
getobs set mask at 55 int  55 max 64
value: 4185
getobs set mask at 42 int  42 max 64
value: 4185
getobs set mask at 49 int  49 max 64
value: 4185
getobs set mask at 23 int  23 max 64
value: 4185
getobs set mask at 62 int  62 max 64
value: 4185
getobs set mask at 46 

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mgetobs one item -------------------------------------------------
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset [13] -> ((8114, (1121, 61, 15)), ("e", ("r", "r", "r"))) answer: Set(Any[2099, 114, 4805, 5126, 10330, 5794, 10963, 3357, 5485, 10453, 9182, 1106, 6522, 6114, 3074, 7475, 8708, 7959, 10112, 7073, 4953, 7612, 8585, 365, 2952, 5348, 3989, 3986, 557, 3035, 970, 2424, 8990, 7414, 1844, 6214, 4505, 3305, 3719, 2710, 9614, 7445, 1340, 2638, 4210, 2859, 12028, 11527, 3777, 802, 1654, 6464, 6253, 3127, 4223, 4205, 6258, 873, 3451, 10397, 66, 6003, 9528, 2733, 2140, 5971, 4836, 5474, 4170, 788, 1779, 6791, 2243, 6757, 4314, 1272, 7202, 5178, 7332, 4859, 11785, 754, 4055, 11492, 7745, 19, 7248, 4074, 1436, 11170, 8920, 5915, 2400, 7334, 445, 6001, 3643, 2116, 10288, 12863, 8329, 10436, 3918, 8018, 209, 4800, 772, 3431, 3968, 649, 3678, 6309, 1185, 9286, 7986, 1055, 252, 1187, 9234, 9208, 4914, 8517, 10201, 13246, 7493, 2135, 80

getobs set mask at 38 int  38 max 64
value: 3989
getobs set mask at 44 int  44 max 64
value: 1844


[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mgetobs one item -------------------------------------------------
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset [14] -> ((7269, (74, 738, 183)), ("e", ("r", "r", "r"))) answer: Set(Any[403, 101])
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset tail 101 subsampling_weight: 6


getobs set mask at 2 int  2 max 64
value: 403
getobs set mask at 52 int  52 max 64
value: 101
getobs set mask at 63 int  63 max 64
value: 101
getobs set mask at 55 int  55 max 64
value: 101
getobs set mask at 14 int  14 max 64
value: 101
getobs set mask at 23 int  23 max 64
value: 101
getobs set mask at 14 int  14 max 64
value: 403
getobs set mask at 11 int  11 max 64
value: 101
getobs set mask at 39 int  39 max 64
value: 101
getobs set mask at 5 int  5 max 64
value: 403
getobs set mask at 19 int  19 max 64
value: 101
getobs set mask at 61 int  61 max 64
value: 101
getobs set mask at 47 int  47 max 64
value: 403
getobs set mask at 32 int  32 max 64
value: 101
getobs set mask at 37 int  37 max 64
value: 101
getobs set mask at 16 int  16 max 64
value: 101
getobs set mask at 1 int  1 max 64
value: 403
getobs set mask at 35 int  35 max 64
value: 101
getobs set mask at 55 int  55 max 64
value: 403
getobs set mask at 13 int  13 max 64
value: 403
getobs set mask at 21 int  21 max 64
value: 10

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mgetobs one item -------------------------------------------------
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset [15] -> ((11432, (1163, 732, 679)), ("e", ("r", "r", "r"))) answer: Set(Any[6534])
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset tail 6534 subsampling_weight: 5


getobs set mask at 20 int  20 max 64
value: 6534
getobs set mask at 58 int  58 max 64
value: 6534
getobs set mask at 25 int  25 max 64
value: 6534
getobs set mask at 62 int  62 max 64
value: 6534
getobs set mask at 4 int  4 max 64
value: 6534
getobs set mask at 44 int  44 max 64
value: 6534
getobs set mask at 37 int  37 max 64
value: 6534
getobs set mask at 55 int  55 max 64
value: 6534
getobs set mask at 1 int  1 max 64
value: 6534
getobs set mask at 38 int  38 max 64
value: 6534
getobs set mask at 32 int  32 max 64
value: 6534
getobs set mask at 29 int  29 max 64
value: 6534
getobs set mask at 31 int  31 max 64
value: 6534
getobs set mask at 37 int  37 max 64
value: 6534
getobs set mask at 56 int  56 max 64
value: 6534
getobs set mask at 38 int  38 max 64
value: 6534
getobs set mask at 48 int  48 max 64
value: 6534
getobs set mask at 49 int  49 max 64
value: 6534
getobs set mask at 54 int  54 max 64
value: 6534
getobs set mask at 62 int  62 max 64
value: 6534
getobs set mask at 7 int

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mgetobs one item -------------------------------------------------
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset [16] -> ((8760, (57, 228, 81)), ("e", ("r", "r", "r"))) answer: Set(Any[9967])
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mTrainDataset tail 9967 subsampling_weight: 5


getobs set mask at 55 int  55 max 64
value: 9967
getobs set mask at 23 int  23 max 64
value: 9967
getobs set mask at 62 int  62 max 64
value: 9967
getobs set mask at 3 int  3 max 64
value: 9967
getobs set mask at 25 int  25 max 64
value: 9967
getobs set mask at 63 int  63 max 64
value: 9967
getobs set mask at 36 int  36 max 64
value: 9967
getobs set mask at 6 int  6 max 64
value: 9967
getobs set mask at 17 int  17 max 64
value: 9967
getobs set mask at 42 int  42 max 64
value: 9967
getobs set mask at 54 int  54 max 64
value: 9967
getobs set mask at 51 int  51 max 64
value: 9967
getobs set mask at 50 int  50 max 64
value: 9967
getobs set mask at 49 int  49 max 64
value: 9967
getobs set mask at 15 int  15 max 64
value: 9967
getobs set mask at 41 int  41 max 64
value: 9967
getobs set mask at 56 int  56 max 64
value: 9967
getobs set mask at 27 int  27 max 64
value: 9967
getobs set mask at 11 int  11 max 64
value: 9967
getobs set mask at 19 int  19 max 64
value: 9967
getobs set mask at 43 in

[36m[1m[ [22m[39m[36m[1mInfo: [22m[39mgetobs one item -------------------------------------------------


In [3]:
function jupy_load_data(args, tasks, all_tasks, query_dict)
    @info "loading data...."

    data_path = args["data_path"];
    train_queries = Pickle.load(open(joinpath(data_path, f_train_queries)));
    train_answers = Pickle.load(open(joinpath(data_path, f_train_answers)));
    valid_queries = Pickle.load(open(joinpath(data_path, f_valid_queries)));
    valid_hard_answers = Pickle.load(open(joinpath(data_path, f_valid_hard_answers)));
    valid_easy_answers = Pickle.load(open(joinpath(data_path, f_valid_easy_answers)));
    test_queries = Pickle.load(open(joinpath(data_path, f_test_queries)));
    test_hard_answers = Pickle.load(open(joinpath(data_path, f_test_hard_answers)));
    test_easy_answers = Pickle.load(open(joinpath(data_path, f_test_easy_answers)));

    # remove tasks not in args.tasks
    for name in all_tasks
        if 'u' in name
            name, evaluate_union = split(name, "-")
        else
            evaluate_union = args["evaluate_union"]
        end
        if !(name in tasks) || evaluate_union != args["evaluate_union"]
            query_structure = query_dict[eval(if !('u' in name) name else join([name, evaluate_union], "-") end)]
            #println("load_data: deleteing structure...:\n $(query_structure)")
            if haskey(train_queries, query_structure)
                delete!(train_queries, query_structure);
            end
            if haskey(valid_queries, query_structure)
                delete!(valid_queries, query_structure);
            end
            if haskey(test_queries, query_structure)
                delete!(test_queries, query_structure)
            end
        end
    end
    @info "load data....Done"
    return train_queries, train_answers, valid_queries, valid_hard_answers, valid_easy_answers, test_queries, test_hard_answers, test_easy_answers
end

jupy_load_data (generic function with 1 method)

In [9]:
"""
    Print the evaluation logs
"""
function log_metrics(mode, step, metrics)
    for metric in metrics
        @info "$mode $metric at step $(step): $(metrics[metric.first])"
    end
end

"""
    Evaluate queries in dataloader
"""
function evaluate(model, tp_answers, fn_answers, args, dataloader, query_name_dict, mode, step, writer)

    average_metrics = Dict{Float}()
    all_metrics = Dict{Float}()

    metrics = model.test_step(model, tp_answers, fn_answers, args, dataloader, query_name_dict)
    num_query_structures = 0
    num_queries = 0
    for query_structure in metrics
        log_metrics(mode * " " * query_name_dict[query_structure], step, metrics[query_structure])

        for metric in metrics[query_structure]
            writer.add_scalar("_".join([mode, query_name_dict[query_structure], metric]), metrics[query_structure][metric], step)
            all_metrics["_".join([query_name_dict[query_structure], metric])] = metrics[query_structure][metric]
            if metric != "num_queries"
                average_metrics[metric] += metrics[query_structure][metric]
            end
        end
        num_queries += metrics[query_structure]["num_queries"]
        num_query_structures += 1
    end

    for metric in average_metrics
        average_metrics[metric] /= num_query_structures
        writer.add_scalar("_".join([mode, "average", metric]), average_metrics[metric], step)
        all_metrics["_".join(["average", metric])] = average_metrics[metric]
    end

    log_metrics("$mode average", step, average_metrics)
    return all_metrics
end

evaluate

In [30]:
function format_time()
    return Dates.format(Dates.now(), "YY.mm.dd")
end

#---Evaluate a tuple string into a tuple.
function eval_tuple(arg_return)
    if typeof(arg_return) <: Tuple
        return arg_return
    end

    if !(arg_return[1] in ("(", "["))
        arg_return = eval(arg_return)
    else
        splitted = split(arg_return[2:length(arg_return)-1], ", ")
        List = []
        for item in splitted
            try
                item = eval(item)
            catch err
                pass
            end
            if item == ""
                continue
            end
            append!(List, item)
        end
        arg_return = tuple(List)
    return arg_return
    end
end

function flatten_query(queries)
    all_queries = []
    for query_structure in keys(queries)
        list_queries = collect(queries[query_structure])
        #ttt = [(query, query_structure) for query in list_queries]
        #println("query_structure key: $(query_structure) queries length: $(length(list_queries)) tuple length: $(length(ttt))");
        append!(all_queries, [(query, query_structure) for query in list_queries])
    end
    return all_queries
end

function flatten(l)
    collect(Iterators.flatten(l))
end

flatten (generic function with 1 method)

In [46]:
module KGDataset

using Main
import MLUtils: DataLoader

export numobs, getobs, TrainDataset, TestDataset, SingleDirectionalOneShotIterator

abstract type Dataset end
struct TrainDataset <: Dataset
    queries::Vector{Any}
    answer::Dict{Any, Any}
    nentity::Int
    nrelation::Int
    negative_sample_size::Int
    count::Dict{Tuple, Int}
end
#=
function collate_fn(data::TrainDataset)
    positive_sample = cat([_[0] for _ in data], dim=0)
    negative_sample = stack([_[1] for _ in data], dim=0)
    subsample_weight = cat([_[2] for _ in data], dim=0)
    query = [_[3] for _ in data]
    query_structure = [_[4] for _ in data]
    return positive_sample, negative_sample, subsample_weight, query, query_structure
end
=#
function count_frequency(queries, answer, start=4)
    count = Dict{Tuple, Int}()
    for (query, _) in queries
        #println("$(query) -- $(length(answer[query]))")
        count[query] = start + length(answer[query])
    end
    return count
end

function TrainDataset(queries, answers, nentity, nrelation, negtaive_sample_size)
    count = count_frequency(queries, answers);
    return TrainDataset(queries, answers, nentity, nrelation, negtaive_sample_size, count)
end

function numobs(data::TrainDataset)
    return length(data.queries)
end

function getobs(data::TrainDataset, idx::Int)

    query = data.queries[idx][1]
    query_structure = data.queries[idx][2]
    @info "TrainDataset [$(idx)] -> $(data.queries[idx]) answer: $(data.answer[query])"
    tail = rand(collect(data.answer[query]))
    subsampling_weight = data.count[query]
    @info "TrainDataset tail $(tail) subsampling_weight: $(subsampling_weight)"
    subsampling_weight = sqrt.(1 ./ [subsampling_weight])
    negative_sample_list = []
    negative_sample_size = 0
    while negative_sample_size < data.negative_sample_size
        negative_sample = rand(1:data.nentity, data.negative_sample_size*2)
        # check whether the items in ar1 belong to ar2, return a vector
        # has the same length with ar1, filled with true or false
        #mask = np.in1d(negative_sample, data.answer[query],
        #               assume_unique=true, invert=true)
 
        avail_index = indexin(negative_sample, collect(data.answer[query]))
        #println("avail_index: $(avail_index)")
        mask = falses(data.negative_sample_size * 2)
        #println("mask: $(mask)")
        map(avail_index) do x
            if x != nothing
                println("getobs: set mask at $x")
                mask[Int(x)] = true
            end 
        end
        reverse!(mask)
        negative_sample = negative_sample[mask]
        append!(negative_sample_list, negative_sample)
        negative_sample_size += length(negative_sample)
    end
    negative_sample = stack(negative_sample_list)[1:data.negative_sample_size]
    negative_sample = negative_sample # original: torch.from_numpy
    positive_sample = convert.(Float64, [tail])

    return positive_sample, negative_sample, subsampling_weight, Main.flatten(query), query_structure
end

struct SingleDirectionalOneShotIterator
    data_loader::DataLoader
end

function Base.iterate(iter::SingleDirectionalOneShotIterator)
    state = 1
    if length(iter.data_loader.data.queries) <= 0
        return nothing
    end

    return ( getobs(iter.data_loader.data, state), state + 1 )
end

function Base.iterate(iter::SingleDirectionalOneShotIterator, state = 1)
    if length(iter.data_loader.data.queries) < state
        return nothing
    end

    return ( getobs(iter.data_loader.data, state), state + 1 )
end


end

Main.KGDataset13

In [None]:
using Revise
using MLUtils;

using .KGDataset13

train_flatten_queries = flatten_query(train_queries);
println("length of train_flatten_queries: $(length(train_flatten_queries))")
println("length of train_answers: $(length(train_answers))")
negative_sample_size = 32
batch_size = 1
nentity=sum([length(q) for q in train_flatten_queries])
nrelation=sum([length(q) for q in train_flatten_queries])
data_set = KGDataset13.TrainDataset(train_flatten_queries, train_answers, nentity, nrelation, negative_sample_size);
#DataLoader(data; [batchsize, buffer, collate, parallel, partial, rng, shuffle])
data_loader = MLUtils.DataLoader(data_set, batchsize = batch_size, collate = true, shuffle=false);
#println(typeof(data_loader.data))
#println("numobs, getobs.....")
train_path_iterator = KGDataset13.SingleDirectionalOneShotIterator(data_loader);

data_index = 1
for item in train_path_iterator
    println("for loop: --$(item)")
    data_index += 1
    if data_index >= 2
        break
    end
end

(item, next)  = iterate(train_path_iterator,data_index)
println(item)

(item, next)  = iterate(train_path_iterator, next)
println(item)

In [None]:
module KGModule

using Zygote: AbstractFFTs
export Identity, normDims, BoxOffsetIntersection, CenterIntersection, BetaIntersection,
    BetaProjection, Regularizer, KGReasoning, train_step

using Flux;
using Zygote;
using Statistics;
using Distributions;
using MLUtils;

include("src/utils.jl")

function Identity(x)
    return x;
end

function normDims(itr, p::Real=2; dim)
    sum(itr .^ p; dims=dim).^(1 / p)
end

struct BoxOffsetIntersection
    dim::Int
    layer1::Flux.Dense
    layer2::Flux.Dense
end

function BoxOffsetIntersection(dim::Int)
    layer1 = Flux.Dense(dim => dim);
    layer2 = Flux.Dense(dim => dim);

    return BoxOffsetIntersection(dim, layer1, layer2);
end

#Function-like Object
function (m::BoxOffsetIntersection)(embeddings)
    @show embeddings
    layer1_act = Flux.relu(m.layer1(embeddings))
    @show layer1_act
    layer1_mean = mean(layer1_act, dims=length(size(layer1_act)))
    @show layer1_mean
    gate = Flux.sigmoid(m.layer2(layer1_mean))
    @show gate
    offset = minimum(embeddings, dims=length(size(layer1_act)))

    return offset .* gate
end

Flux.@functor BoxOffsetIntersection

struct CenterIntersection
    dim::Int
    layer1::Flux.Dense
    layer2::Flux.Dense
end

function CenterIntersection(dim::Int)
    layer1 = Flux.Dense(dim => dim)
    layer2 = Flux.Dense(dim => dim)

    #Flux.Dense is initialized by  xavier defaultly
    return CenterIntersection(dim, layer1, layer2)
end

function (m::CenterIntersection)(embeddings)
    layer1_act = Flux.relu(m.layer1(embeddings)) # ( dim, num_conj)
    attention = Flux.softmax(m.layer2(layer1_act), dims=length(size(layer1_act))) # (dim, num_conj, )
    embedding = sum(attention * embeddings, dims=length(size(layer1_act)))

    return embedding
end

Flux.@functor CenterIntersection

struct BetaIntersection
    dim::Int
    layer1::Flux.Dense
    layer2::Flux.Dense
end

function BetaIntersection(dim::Int)
    layer1 = Flux.Dense(2 * dim, 2 * dim)
    layer2 = Flux.Dense(2 * dim, dim)

    return BetaIntersection(dim, layer1, layer2)
end

function (m::BetaIntersection)(alpha_embeddings, beta_embeddings)
    all_embeddings = cat(length(size(alpha_embeddings)), alpha_embeddings, beta_embeddings)
    layer1_act = Flux.relu(m.layer1(all_embeddings)) # (num_conj, batch_size, 2 * dim)
    attention = Flux.softmax(m.layer2(layer1_act), dims=length(size(alpha_embeddings))) # (num_conj, batch_size, dim)

    alpha_embedding = sum(attention * alpha_embeddings, dims=length(size(alpha_embeddings)))
    beta_embedding = sum(attention * beta_embeddings, dims=length(size(alpha_embeddings)))

    return alpha_embedding, beta_embedding
end

Flux.@functor BetaIntersection

struct Regularizer
    base_add::AbstractFloat
    min_val::AbstractFloat
    max_val::AbstractFloat
end

function (m::Regularizer)(entity_embedding)
    return clamp(entity_embedding + m.base_add, m.min_val, m.max_val)
end

Flux.@functor Regularizer

struct BetaProjection
    entity_dim::Int
    relation_dim::Int
    hidden_dim::Int
    num_layers::Int

    layers::Dict{Symbol, Flux.Dense}
    projection_regularizer
end

function Base.setproperty!(m::BetaProjection, property::Symbol, value)
    getfield(m, :layers)[property] = value
end

function Base.getproperty(m::BetaProjection, property::Symbol, value)
    return getfield(m, :layers)[property]
end

function Base.propertynames(m::BetaProjection, private = false)
    return keys(getproperty(m, :layers))
end

function BetaProjection(entity_dim, relation_dim, hidden_dim, projection_regularizer, num_layers)
    layer1 = Flux.Dense((entity_dim + relation_dim) => hidden_dim) # 1st layer
    layer0 = Flux.Dense(hidden_dim => entity_dim) # final layer

    layers = Dict{Symbol, Flux.Dense}()
    layers[:layer1] = Flux.Dense((entity_dim + relation_dim) => hidden_dim) # 1st layer
    layers[:layer0]  = Flux.Dense(hidden_dim => entity_dim) # final layer
    for nl in range(2, num_layers)
        layers[Symbol("layer$(nl)")] = Flux.Dense(hidden_dim, hidden_dim)
    end

    return BetaProjection(entity_dim, relation_dim, hidden_dim, num_layers,
                          layers, projection_regularizer)

end

function (m::BetaProjection)(e_embedding, r_embedding)
    x = cat(1, e_embedding, r_embedding)
    for nl in range(1, m.num_layers)
        x = Flux.relu(getproperty(m, Symbol("layer$(nl)")(x)))
    end
    x = getproperty(m, :layer0)(x)
    x = m.projection_regularizer(x)

    return x
end

Flux.@functor BetaProjection

struct KGReasoning
    nentity::Int
    nrelation::Int
    hidden_dim::Int
    epsilon::AbstractFloat
    geo::String
    use_cuda::Bool
    batch_entity_range #TODO type and initialize
    query_name_dict::Dict{Tuple, String}
    ############################################
    gamma # nn.Parameter
    embedding_range # nn.Parameter

    entity_dim::Int
    relation_dim::Int

    entity_embedding # nn.Parameter
    cen
    func
    entity_regularizer
    projection_regularizer

    offset_embedding
    center_net
    offset_net
    #hidden_dim
    num_layers
    #center_net
    projection_net
end

Flux.@functor KGReasoning

function KGReasoning(nentity, nrelation, hidden_dim, gamma, geo, test_batch_size=1,
                     box_mode=nothing, beta_mode=nothing, query_name_dict=nothing, use_cuda=false)
    epsilon = 2.0

    batch_entity_range = repeat(convert.(Float32, range(0, nentity - 1)), 1, test_batch_size)

    gamma = Zygote.Params([gamma])
    embedding_range = Zygote.Params([(gamma .+ epsilon) / hidden_dim]);

    entity_dim = hidden_dim
    relation_dim = hidden_dim

    activation, cen, func = repeat([nothing], 3)
    entity_embedding , entity_regularizer, projection_regularizer = repeat([nothing], 3)
    if geo == "box"
        entity_embedding = Zygote.Params(zeros(nentity, entity_dim)) # centor for entities
        activation, cen = box_mode
        cen = cen # hyperparameter that balances the in-box distance and the out-box distance
        if activation == "none"
            func = Identity;
        elseif activation == "relu"
            func = Flux.relu;
        elseif activation == "softplus"
            func = Flux.softplus;
        end
    elseif geo == "vec"
        #entity_embedding = Flux.params(zeros(nentity, entity_dim)) # center for entities
        entity_embedding = Zygote.Params(Flux.glorot_uniform(nentity, entity_dim))
    elseif geo == "beta"
        #entity_embedding = Flux.params(zeros(nentity, self.entity_dim * 2)) # alpha and beta
        entity_embedding = Zygote.Params(Flux.glorot_uniform(nentity, entity_dim * 2))
        entity_regularizer = Regularizer(1, 0.05, 1e9) # make sure the parameters of beta embeddings are positive
        projection_regularizer = Regularizer(1, 0.05, 1e9) # make sure the parameters of beta embeddings after relation projection are positive
    end
    #nn.init.uniform_(
    #    tensor=self.entity_embedding,
    #    ###########################TODO##################################
    #    a = -embedding_range,
    #    b = embedding_range
    #)
    #relation_embedding = Flux.params(zeros(nrelation, relation_dim))
    relation_embedding = Zygote.Params(Flux.glorot_uniform(nrelation, relation_dim))
    #nn.init.uniform_(
    #    tensor=relation_embedding,
    #    a = -embedding_range,
    #    b = embedding_range
    #)

    num_layers, offset_embedding, center_net, offset_net, projection_net = repeat([nothing], 6)
    if geo == "box"
        offset_embedding = Zygote.Params(Flux.glorot_uniform(nrelation, entity_dim))
        #self.offset_embedding = nn.Parameter(torch.zeros(nrelation, self.entity_dim))
        #nn.init.uniform_(
        #    tensor=self.offset_embedding,
        #    a=0.,
        #    b=self.embedding_range.item()
        #)
        center_net = CenterIntersection(entity_dim)
        offset_net = BoxOffsetIntersection(entity_dim)
    elseif geo == "vec"
        center_net = CenterIntersection(entity_dim)
    elseif geo == "beta"
        hidden_dim, num_layers = eval_tuple(beta_mode)

        center_net = BetaIntersection(entity_dim)
        projection_net = BetaProjection(entity_dim * 2,
                                        relation_dim,
                                        hidden_dim,
                                        projection_regularizer,
                                        num_layers)
    end

    return KGReasoning(nentity, nrelation, hidden_dim, epsilon, geo, use_cuda, batch_entity_range,
                       query_name_dict, gamma, embedding_range, entity_dim, relation_dim, entity_embedding,
                       cen, func, entity_regularizer, projection_regularizer, offset_embedding,
                       center_net, offset_net, num_layers, projection_net);
end

function forward(m::KGReasoning, positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict)
    if m.geo == "box"
        return forward_box(m, positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict)
    elseif m.geo == "vec"
        return forward_vec(m, positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict)
    elseif m.geo == "beta"
        return forward_beta(m, positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict)
    end
end

####
# embed a batch of queries with same structure using Query2box
# queries: a flattened batch of queries
####
function embed_query_box(m::KGReasoning, queries, query_structure, idx)
    #@printf (queries)
    #@printf (query_structure)
    all_relation_flag = true
     # whether the current query tree has mfferged to one branch and only need to do relation traversal,
     # e.g., path queries or conjunctive queries after the intersection
    for r in last(query_structure)
        if !(r in ["r", "n"])
            all_relation_flag = false
            break
        end
    end

    if all_relation_flag
        if query_structure[0] == "e"
            embedding = m.entity_embedding[:, queries[:, idx]]
            #embedding = torch.index_select(m.entity_embedding, dim=0, index=queries[:, idx])
            offset_embedding = zeros(size(embedding))
            if m.use_cuda
                offset_embedding = zeros(size(embedding)) .|> gpu
            end
            idx += 1
        else
            embedding, offset_embedding, idx = embed_query_box(m, queries, query_structure[0], idx)
        end

        for i in range(1, length(last(query_structure)))
            if last(query_structure)[i] == "n"
                @assert false "box cannot handle queries with negation"
            else
                r_embedding = m.ralation_embedding[:, queries[:, idx]]
                #r_embedding = torch.index_select(self.relation_embedding, dim=0, index=queries[:, idx])
                r_offset_embedding = offset_bedding[:, queries[:, idx]]
                #r_offset_embedding = torch.index_select(self.offset_embedding, dim=0, index=queries[:, idx])
                embedding += r_embedding
                offset_embedding += m.func(r_offset_embedding)
            end
            idx += 1
        end
    else
        embedding_list = []
        offset_embedding_list = []
        for i in range(1, length(query_structure))
            embedding, offset_embedding, idx = embed_query_box(m, queries, query_structure[i], idx)
            push!(embedding_list, embedding)
            push!(offset_embedding_list, offset_embedding)
        end
        embedding = m.center_net(vcat(embedding_list))
        offset_embedding = m.offset_net(vcat(offset_embedding_list))
    end
    return embedding, offset_embedding, idx
end

#=
Iterative embed a batch of queries with same structure using GQE
queries: a flattened batch of queries
=#
function embed_query_vec(m::KGReasoning, queries, query_structure, idx)

    all_relation_flag = true
    # whether the current query tree has merged to one branch and only need to do relation traversal,
    # e.g., path queries or conjunctive queries after the intersection
    for ele in last(query_structure)
        if !(ele in ["r", "n"])
            all_relation_flag = false
            break
        end
    end
    if all_relation_flag
        if query_structure[1] == "e"
            embedding = m.entity_embedding[:,queries[:, idx]]
            #embedding = torch.index_select(self.entity_embedding, dim=0, index=queries[:, idx])
            idx += 1
        else
            embedding, idx = embed_query_vec(m, queries, query_structure[0], idx)
        end

        for i in range(length(last(query_structure)))
            if last(query_structure)[i] == "n"
                @assert false  "vec cannot handle queries with negation"
            else
                r_embedding = m.relation_embedding[:, queries[:, idx]]
                #r_embedding = torch.index_select(self.relation_embedding, dim=0, index=queries[:, idx])
                embedding += r_embedding
            end
            idx += 1
        end
    else
        embedding_list = []
        for i in range(1, length(query_structure))
            embedding, idx = embed_query_vec(m, queries, query_structure[i], idx)
            push!(dembedding_list, embedding)
        end
        embedding = m.center_net(vcat(embedding_list))
    end
    return embedding, idx
end

#=
Iterative embed a batch of queries with same structure using BetaE
queries: a flattened batch of queries
=#
function embed_query_beta(m::KGReasoning, queries, query_structure, idx)

    all_relation_flag = true
    # whether the current query tree has merged to one branch and only need to do relation traversal,
    # e.g., path queries or conjunctive queries after the intersection
    for ele in last(query_structure)
        if !(ele in ["r", "n"])
            all_relation_flag = false
            break
        end
    end
    if all_relation_flag
        if query_structure[1] == "e"
            embedding = m.entity_regularizer(selectdim(m.entity_embedding, dims=ndims(m.entity_embedding), queries[:, idx]))
            #embedding = m.entity_regularizer(torch.index_select(self.entity_embedding, dim=0, index=queries[:, idx]))
            idx += 1
        else
            alpha_embedding, beta_embedding, idx = m.embed_query_beta(m, queries, query_structure[1], idx)
            embedding = cat(alpha_embedding, beta_embedding, dim=0)
        end
        for i in range(1, length(last(query_structure)))
            if last(query_structure)[i] == "n"
                @assert (queries[:, idx] == -2).all()
                embedding = 1 ./ embedding
            else
                r_embedding = m.relation_embedding(queries[:, idx], :)
                #r_embedding = torch.index_select(self.relation_embedding, dim=0, index=queries[:, idx])
                embedding = m.projection_net(embedding, r_embedding)
            end
            idx += 1
        end
        ###############################TODO####################################
        alpha_embedding, beta_embedding = chunk(embedding, 2, dim=ndims(embedding))
    else
        alpha_embedding_list = []
        beta_embedding_list = []
        for i in range(1, length(query_structure))
            alpha_embedding, beta_embedding, idx = embed_query_beta(m, queries, query_structure[i], idx)
            push!(alpha_embedding_list, alpha_embedding)
            push!(beta_embedding_list, beta_embedding)
        end
        alpha_embedding, beta_embedding = m.center_net(cat(alpha_embedding_list), cat(beta_embedding_list))
    end
    return alpha_embedding, beta_embedding, idx
end

#============================================
    transform 2u queries to two 1p queries
    transform up queries to two 2p queries
============================================#
function transform_union_query(m::KGReasoning, queries, query_structure)

    if m.query_name_dict[query_structure] == "2u-DNF"
        queries = queries[:, 1:(size(queries, 2) - 1)] # remove union -1
    elseif m.query_name_dict[query_structure] == "up-DNF"
        queries = cat(cat(queries[:, 1:2], queries[:, 5:6], dims=1), cat(queries[:, 2:4], queries[:, 5:6], dims=1), dims=1)
    end
    queries = reshape(queries, :, size(queries)[1]*2)
    return queries
end

function transform_union_structure(m::KGReasoning, query_structure)
    if m.query_name_dict[query_structure] == "2u-DNF"
        return ("e", ("r",))
    elseif m.query_name_dict[query_structure] == "up-DNF"
        return ("e", ("r", "r"))
    end
end

function cal_logit_beta(m::KGReasoning, entity_embedding, query_dist)
    ##########################TODO#######################################
    alpha_embedding, beta_embedding = chunk(entity_embedding, 2)
    entity_dist = Distributions.Beta(alpha_embedding, beta_embedding)
    logit = m.gamma - normDims(Distributions.KLDivergence(entity_dist, query_dist), 1)
    return logit
end

function forward_beta(m::KGReasoning, positive_sample, negative_sample, subsampling_weight,
                      batch_queries_dict, batch_idxs_dict)
    all_idxs, all_alpha_embeddings, all_beta_embeddings = [], [], []
    all_union_idxs, all_union_alpha_embeddings, all_union_beta_embeddings = [], [], []
    for query_structure in batch_queries_dict
        if "u" in m.query_name_dict[query_structure] && "DNF" in m.query_name_dict[query_structure]
            alpha_embedding, beta_embedding, _ = \
                embed_query_beta(m, transform_union_query(m, batch_queries_dict[query_structure],
                                                            query_structure),
                                 transform_union_structure(m, query_structure),
                                 0)
            push!(all_union_idxs, batch_idxs_dict[query_structure])
            #all_union_idxs.extend(batch_idxs_dict[query_structure])
            push!(all_union_alpha_embeddings, alpha_embedding)
            push!(all_union_beta_embeddings, beta_embedding)
        else
            alpha_embedding, beta_embedding, _ = embed_query_beta(m, batch_queries_dict[query_structure],
                                                                  query_structure,
                                                                  0)
            push!(all_idxs, batch_idxs_dict[query_structure])
            #all_idxs.extend(batch_idxs_dict[query_structure])
            push!(all_alpha_embeddings, alpha_embedding)
            push!(all_beta_embeddings, beta_embedding)
        end
    end

    if length(all_alpha_embeddings) > 0
        #all_alpha_embeddings = torch.cat(all_alpha_embeddings, dim=0).unsqueeze(1)
        all_alpha_embeddfings = reduce((x, y) -> cat(x, y, dims=ndims(x)), all_alpha_embeddings)
        all_beta_embeddings = reduce(all_beta_embeddings) do x, y
                                         cat(x, y, dims=ndims(x))
                                     end
        all_beta_embeddings= unsqueeze(all_beta_embeddings, dims = ndims(all_beta_embeddings))
        all_dists = Distributions.Beta(all_alpha_embeddings, all_beta_embeddings)
    end

    if len(all_union_alpha_embeddings) > 0
        #all_union_alpha_embeddings = torch.cat(all_union_alpha_embeddings, dim=0).unsqueeze(1)
        #all_union_beta_embeddings = torch.cat(all_union_beta_embeddings, dim=0).unsqueeze(1)
        #all_union_alpha_embeddings = all_union_alpha_embeddings.view(all_union_alpha_embeddings.shape[0]//2, 2, 1, -1)
        #all_union_beta_embeddings = all_union_beta_embeddings.view(all_union_beta_embeddings.shape[0]//2, 2, 1, -1)
        #all_union_dists = torch.distributions.beta.Beta(all_union_alpha_embeddings, all_union_beta_embeddings)
        all_union_alpha_embeddings = reduce(all_union_alpha_embeddings) do x, y
                                             cat(x, y, dim=ndims(x))
                                         end
        #all_union_alpha_embeddings = cat(all_union_alpha_embeddings, dims = ndims(all_union_alpha_embeddings) + 1)
        all_union_alpha_embeddings = unsqueeze(all_union_alpha_embeddings, dims = ndims(all_union_alpha_embeddings))
        all_union_beta_embeddings = reduce(all_union_beta_embeddings) do x, y
                                             cat(x, y, dim=ndims(x))
                                         end
        all_union_beta_embeddings = unsqueeze(all_union_beta_embeddings, dims = ndims(all_union_beta_embeddings))
        #################################################################################################
        #all_union_alpha_embeddings = all_union_alpha_embeddings.view(all_union_alpha_embeddings.shape[0]//2, 2, 1, -1)
        #all_union_beta_embeddings = all_union_beta_embeddings.view(all_union_beta_embeddings.shape[0]//2, 2, 1, -1)
        all_union_alpha_embeddings = reshape(all_union_alpha_embeddings, :, 1, 2,
                                             div(size(all_union_alpha_embeddings, ndims(all_union_alpha_embedding)), 2))
        all_union_beta_embeddings = reshape(all_union_beta_embeddings, :, 1, 2,
                                             div(size(all_union_beta_embeddings, ndims(all_union_beta_embedding)), 2))

        all_union_dists = Distributions.Beta(all_union_alpha_embeddings, all_union_beta_embeddings)
    end

    if typeof(subsampling_weight) != typeof(nothing)
        subsampling_weight = subsampling_weight[all_idxs+all_union_idxs]
    end

    if typeof(positive_sample) != type(None)
        if length(all_alpha_embeddings) > 0
            positive_sample_regular = positive_sample[all_idxs] # positive samples for non-union queries in this batch
            entity_embedding_select = selectdim(m.entity_embedding,
                                                ndims(m.entity_embedding),
                                                positive_sample_regular);
            positive_embedding = m.entity_regularizer(unsqueeze(entity_embedding_select,
                                                                ndims(entity_embedding_select)))
            positive_logit = cal_logit_beta(m, positive_embedding, all_dists)
        else
            positive_logit = [] .|> Flux.get_device()
        end

        if length(all_union_alpha_embeddings) > 0
            positive_sample_union = positive_sample[all_union_idxs] # positive samples for union queries in this batch

            entity_embedding_select = selectdim(m.entity_embedding,
                                                ndims(m.entity_embedding),
                                                positive_sample_union);
            entity_embedding_select_unsqueeze = unsqueeze(entity_embedding_select,
                                                        dims=ndims(entity_embedding_select) - 1);
            positive_embedding = m.entity_regularizer(entity_embedding_select_unsqueeze)
            positive_union_logit = cal_logit_beta(m, positive_embedding, all_union_dists)
            positive_union_logit = max(positive_union_logit, dim=1)[0]
        else
            positive_union_logit = [] .|> Flux.get_device()
        end
        positive_logit = cat(positive_logit, positive_union_logit, dims=ndims(positive_logit))
    else
        positive_logit = nothing
    end

    if typeof(negative_sample) != typeof(nothing)
        if length(all_alpha_embeddings) > 0
            negative_sample_regular = negative_sample[all_idxs]
            batch_size, negative_size = negative_sample_regular.shape
            #negative_embedding = self.entity_regularizer(torch.index_select(self.entity_embedding, dim=0, index=negative_sample_regular.view(-1)).view(batch_size, negative_size, -1))
            negative_embedding = m.entity_regularizer(reshape(reshape(selectdim(m.entity_embedding, ndims(m.entity_embedding), negative_sample_regular), :), :, negative_size, batch_size))
            negative_logit = cal_logit_beta(m, negative_embedding, all_dists)
        else
            ########################## TODO ##################################
            #negative_logit = torch.Tensor([]).to(self.entity_embedding.device)
            negative_logit = [] .|> Flux.get_device()
        end

        if length(all_union_alpha_embeddings) > 0
            negative_sample_union = negative_sample[all_union_idxs]
            batch_size, negative_size = size(negative_sample_union)
            negative_embedding = m.entity_regularizer(reshape(reshape(selectdim(m.entity_embedding, 0, negative_sample_union), :), (:, negative_size, 1, batch_size)))
            negative_union_logit = cal_logit_beta(m, negative_embedding, all_union_dists)
            negative_union_logit = max(negative_union_logit, dim=2)[0]
        else
            ######################### TODO  ###################################
            negative_union_logit = [] .|> Flux.get_device()
        end

        negative_logit = cat(negative_logit, negative_union_logit, dim=ndims(negative_logit))
    else
        negative_logit = nothing
    end

    return positive_logit, negative_logit, subsampling_weight, all_idxs+all_union_idxs
end

function cal_logit_box(m::KGReasoning, entity_embedding, query_center_embedding, query_offset_embedding)
    delta = abs(entity_embedding - query_center_embedding)
    distance_out = Flux.relu(delta - query_offset_embedding)
    distance_in = min(delta, query_offset_embedding)
    logit = m.gamma - normDims(distance_out, 1; dims=0) - m.cen * normDims(distance_in, 1, dims=0)
    return logit
end

function forward_box(m::KGReasoning, positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict)
    all_center_embeddings, all_offset_embeddings, all_idxs = [], [], []
    all_union_center_embeddings, all_union_offset_embeddings, all_union_idxs = [], [], []
    for query_structure in batch_queries_dict
        if "u" in m.query_name_dict[query_structure]
            center_embedding, offset_embedding, _ = \
                embed_query_box(m, m.transform_union_query(batch_queries_dict[query_structure],
                                                                query_structure),
                                transform_union_structure(m, query_structure),
                                0)
            push!(all_union_center_embeddings, center_embedding)
            push!(all_union_offset_embeddings, offset_embedding)
            push!(all_union_idxs, batch_idxs_dict[query_structure])
        else
            center_embedding, offset_embedding, _ = embed_query_box(m, batch_queries_dict[query_structure],
                                                                    query_structure,
                                                                    0)
            push!(all_center_embeddings, center_embedding)
            push!(all_offset_embeddings, offset_embedding)
            push!(all_idxs, batch_idxs_dict[query_structure])
        end
    end

    if length(all_center_embeddings) > 0 && length(all_offset_embeddings) > 0
        all_center_embeddings_cat = reduce(all_center_embeddings) do x, y
                                          cat(x, y, dims=ndims(x))
                                    end
        all_center_embeddings_cat_unsqueeze = unsqueeze(all_center_embeddings_cat,
                                          dims = ndims(all_center_embeddings_cat) - 1)

        all_offset_embeddings_cat = reduce(all_offset_embeddings) do x, y
                                          cat(x, y, dims=ndims(x))
                                    end
        all_offset_embeddings_cat_unsqueeze = unsqueeze(all_offset_embeddings_cat,
                                                        dims = ndims(all_offset_embeddings_cat) - 1)
        #all_offset_embeddings = torch.cat(all_offset_embeddings, dim=0).unsqueeze(1)
    end

    if length(all_union_center_embeddings) > 0 && length(all_union_offset_embeddings) > 0
        #all_union_center_embeddings = torch.cat(all_union_center_embeddings, dim=0).unsqueeze(1)
        #all_union_offset_embeddings = torch.cat(all_union_offset_embeddings, dim=0).unsqueeze(1)
        all_union_center_embeddings_cat = reduce(all_union_center_embeddings) do x, y
                                              cat(x, y, dims=ndims(x))
                                          end
        all_union_center_embeddings_cat_unsqueeze = unsqueeze(all_union_center_embeddings_cat,
                                                              dims = ndims(all_union_center_embeddings_cat) - 1)
        all_union_offset_embeddings_cat = reduce(all_union_offset_embeddings) do x, y
                                              cat(x, y, dims=ndims(x))
                                          end
        all_union_offset_embeddings_cat_unsqueeze = unsqueeze(all_union_offset_embeddings_cat,
                                                              dims = ndims(all_offset_embeddings_cat) - 1)
        #all_union_center_embeddings = all_union_center_embeddings.view(all_union_center_embeddings.shape[0]//2, 2, 1, -1)
        #all_union_offset_embeddings = all_union_offset_embeddings.view(all_union_offset_embeddings.shape[0]//2, 2, 1, -1)
        all_union_center_embeddings = reshape(all_union_center_embeddings_cat_unsqueeze,
                                              :, 1, 2, div(ndims(all_union_center_embeddings_cat_unsqueeze), 2))
        all_union_offset_embeddings = reshape(all_union_offset_embeddings_cat_unsqueeze,
                                              :, 1, 2, div(ndims(all_union_offset_embeddings_cat_unsqueeze), 2))
    end

    if typeof(subsampling_weight) != typeof(nothing)
        subsampling_weight = subsampling_weight[all_idxs+all_union_idxs]
    end

    if typeof(positive_sample) != typeof(nothing)
        if length(all_center_embeddings) > 0
            positive_sample_regular = positive_sample[all_idxs]
            entity_embedding_select = selectdim(m.entity_embedding, ndims(m.entity_embedding), positive_sample_regular)
            positive_embedding = unsqueeze(entity_embedding_select, ndims(entity_embedding_select) - 1)
            positive_logit = cal_logit_box(m, positive_embedding, all_center_embeddings, all_offset_embeddings)
        else
            #positive_logit = torch.Tensor([]).to(self.entity_embedding.device)
            positive_logit = [] .|> Flux.get_device()
        end

        if length(all_union_center_embeddings) > 0
            positive_sample_union = positive_sample[all_union_idxs]
            entity_embedding_select = selectdim(m.entity_embedding, ndims(m.entity_embedding), positive_sample_union)
            entity_embedding_select_unquezze = unsqueeze(entity_embedding_select, ndims(entity_embedding_select) - 1)
            positive_embedding = unsqueeze(entity_embedding_select_unquezze, ndims(entity_embedding_select_unquezze) - 1)
            positive_union_logit = cal_logit_box(m, positive_embedding, all_union_center_embeddings, all_union_offset_embeddings)
            positive_union_logit = max(positive_union_logit, dims=ndims(positive_union_logit) - 1)[1]
        else
            #positive_union_logit = torch.Tensor([]).to(self.entity_embedding.device)
            positive_union_logit = [] .|> Flux.get_device()
        end
        positive_logit = reduce([positive_logit, positive_union_logit]) do x, y
                              cat(x, y, dim=ndims(x))
                         end
    else
        positive_logit = nothing
    end

    if typeof(negative_sample) != typeof(nothing)
        if len(all_center_embeddings) > 0
            negative_sample_regular = negative_sample[all_idxs]
            batch_size, negative_size = size(negative_sample_regular)
            entity_embedding_select = selectdim(m.entity_embedding, ndims(m.entity_embedding), reshape(negative_sample_regular, :))
            negative_embedding = reshape(entity_embedding_select, :, negative_size, batch_size)
            negative_logit = cal_logit_box(m, negative_embedding, all_center_embeddings, all_offset_embeddings)
        else
            #negative_logit = torch.Tensor([]).to(self.entity_embedding.device)
            negative_logit = [] .|> Flux.get_device()
        end

        if length(all_union_center_embeddings) > 0
            negative_sample_union = negative_sample[all_union_idxs]
            batch_size, negative_size = size(negative_sample_union)
            entity_embedding_select = selectdim(m.entity_embedding, ndims(m.entity_embedding), reshape(negative_sample_union, :))
            negative_embedding = reshape(entity_embedding_select, :, negative_size, 1, batch_size)
            negative_union_logit = cal_logit_box(m, negative_embedding, all_union_center_embeddings, all_union_offset_embeddings)
            negative_union_logit = max(negative_union_logit, dims=ndims(negative_union_logit) - 1)[1]
        else
            #negative_union_logit = torch.Tensor([]).to(self.entity_embedding.device)
            negative_union_logit = [] .|> Flux.get_device()
        end
        negative_logit = reduce([negative_logit, negative_union_logit]) do x, y
                              cat(x, y, dim=ndims(x))
                         end
    else
        negative_logit = nothing
    end

    return positive_logit, negative_logit, subsampling_weight, all_idxs+all_union_idxs
end

function cal_logit_vec(m::KGReasoning, entity_embedding, query_embedding)
    distance = entity_embedding - query_embedding
    logit = m.gamma - normDims(distance, 1, dim=2)
    return logit
end

function forward_vec(m::KGReasoning, positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict)
    all_center_embeddings, all_idxs = [], []
    all_union_center_embeddings, all_union_idxs = [], []
    for query_structure in batch_queries_dict
        if "u" in m.query_name_dict[query_structure]
            center_embedding, _ = embed_query_vec(m, transform_union_query(m, batch_queries_dict[query_structure],
                                                                           query_structure),
                                                  transform_union_structure(query_structure), 0)
            push!(all_union_center_embeddings, center_embedding)
            append!(all_union_idxs, batch_idxs_dict[query_structure])
        else
            center_embedding, _ = embed_query_vec(m, batch_queries_dict[query_structure], query_structure, 0)
            push!(all_center_embeddings, center_embedding)
            append!(all_idxs, batch_idxs_dict[query_structure])
        end
    end

    if length(all_center_embeddings) > 0
        all_center_embeddings_cat = reduce(all_center_embeddings) do x, y
                                        cat(x, y, dims = ndims(x))
                                    end
        all_center_embeddings = unsqueeze(all_center_embeddings_cat, ndims(all_center_embeddings_cat) - 1)
    end

    if length(all_union_center_embeddings) > 0
        all_union_center_embeddings_cat = reduce(all_union_center_embeddings) do x, y
                                              cat(x, y, dims = ndims(x))
                                          end
        all_union_center_embeddings_unsqueeze = unsqueeze(all_union_center_embeddings_cat, ndims(all_union_center_embeddings_cat) - 1)
        #all_union_center_embeddings = torch.cat(all_union_center_embeddings, dim=0).unsqueeze(1)
        all_union_center_embeddings = reshape(all_union_center_embeddings_unsqueeze,
                                              :, 1, 2, div(size(all_union_center_embeddings,
                                                                ndims(all_union_center_embeddings)),
                                                           2))
    end

    if typeof(subsampling_weight) != typeof(nothing)
        subsampling_weight = subsampling_weight[all_idxs+all_union_idxs]
    end

    if typeof(positive_sample) != typeof(nothing)
        if length(all_center_embeddings) > 0
            positive_sample_regular = positive_sample[all_idxs]
            positive_embedding_select = selectdim(m.entity_embedding, ndims(m.entity_embedding), positive_sample_regular)
            positive_embedding = unsqueeze(positive_embedding_select, ndims(positive_embedding_select) - 1)
            positive_logit = cal_logit_vec(m, positive_embedding, all_center_embeddings)
        else
            positive_logit = [] .|> Flux.get_device()
        end

        if length(all_union_center_embeddings) > 0
            positive_sample_union = positive_sample[all_union_idxs]
            positive_embedding_select = selectdim(m.entity_embedding, ndims(m.entity_embedding), positive_sample_regular)
            positive_embedding_unsqueeze = unsqueeze(positive_embedding_select, ndims(positive_embedding_select) - 1)
            positive_embedding = unsqueeze(positive_embedding_unsqueeze, ndims(positive_embedding_unsqueeze) - 1)
            positive_union_logit = cal_logit_vec(m, positive_embedding, all_union_center_embeddings)
            positive_union_logit = max(positive_union_logit, dims=ndims(positive_union_logit) - 1)[1]
        else
            positive_union_logit = [] .|> Flux.get_device()
        end
        positive_logit = reduce([positive_logit, positive_union_logit]) do x, y
                             cat(x, y, dims=ndims(x))
                         end
    else
        positive_logit = nothing
    end

    if typeof(negative_sample) != typeof(nothing)
        if length(all_center_embeddings) > 0
            negative_sample_regular = negative_sample[all_idxs]
            batch_size, negative_size = size(negative_sample_regular)
            entity_embedding_select = selectdim(m.entity_embedding, 0, reshape(negative_sample_regular, :))
            negative_embedding = reshape(entity_embedding_select, :, negative_size, batch_size)
            negative_logit = cal_logit_vec(m, negative_embedding, all_center_embeddings)
        else
            negative_logit = [] .|> Flux.get_device()
        end

        if length(all_union_center_embeddings) > 0
            negative_sample_union = negative_sample[all_union_idxs]
            batch_size, negative_size = size(negative_sample_union)
            entity_embedding_select = selectdim(m.entity_embedding, ndims(m.entity_embedding, reshape(negative_sample_union, :)))
            negative_embedding = reshape(entity_embedding_select, :, negtive_size, 1, batch_size)
            negative_union_logit = cal_logit_vec(m, negative_embedding, all_union_center_embeddings)
            negative_union_logit = max(negative_union_logit, dim=ndims(negative_union_logit) -1)[0]
        else
            negative_union_logit = [] .|> Flux.get_device()
        end

        negative_logit = reduce([negative_logit, negative_union_logit]) do x, y
                             cat(x, y, dim=ndims(x))
                         end
    else
        negative_logit = nothing
    end

    return positive_logit, negative_logit, subsampling_weight, all_idxs+all_union_idxs
end
#=================================================================================
function mean_loss(y_bar)
    negative_logsigmoid = Flux.logsigmoid(-y_bar[:negative_logit])
    negative_score = mean.(negative_logsigmoid, dims=ndims(negative_logsigmoid))
    positive_logsigmoid =Flux.logsigmoid(-y_bar[:positive_logit])
    positive_score = squeeze(positive_logsigmoid, dim=ndims(positive_logsigmoid))
    positive_sample_loss = - sum(y_bar[:subsampling_weight] * positive_score)
    negative_sample_loss = - sum(y_bar[:subsampling_weight] * negative_score)
    positive_sample_loss /= sum(y_bar[:subsampling_weight])
    negative_sample_loss /= sum(y_bar[:subsampling_weight])

    loss = (positive_sample_loss + negative_sample_loss)/2
end
===============================================================================#

function loss()
    return 0
end

# @staticmethod
function train_step(model::KGReasoning, opt_state, data, args, step)
    #opti_stat = Flux.setup(model, optimizer)

    opti_stat = Flux.train!(model, data, opt_state) do

        ########################################################################################################
        #model.train() # set model as train mode
        #optimizer.zero_grad() # clear grad, set to zero

        positive_sample, negative_sample, subsampling_weight, batch_queries, query_structures = next(train_iterator)
        print("train_step chunk :$(batch_queries)\n state: $(query_structures)")
        #batch_queries_dict = collections.defaultdict(list)
        #batch_idxs_dict = collections.defaultdict(list)
        batch_queries_dict = Dict{Any, Any}()
        batch_idxs_dict = Dict{Any, Any}()
        for (i, query) in enumerate(batch_queries) # group queries with same structure
            push!(get!(batch_queries_dict, query_structures[i], []), query)
            push!(get!(batch_idxs_dict, query_structures[i], []), i)
        end

        for query_structure in batch_queries_dict
            if args["cuda"]
                batch_queries_dict[query_structure] = Int64.(batch_queries_dict[query_structure]) .|> gpu
            else
                batch_queries_dict[query_structure] = Int64.(batch_queries_dict[query_structure])
            end
        end

        if args["cuda"]
            positive_sample = positive_sample |> gpu
            negative_sample = negative_sample |> gpu
            subsampling_weight = subsampling_weight |> gpu
        end

        opt_grads = Flux.gradient(model) do m
            positive_logit, negative_logit,
            subsampling_weight, _ = model(positive_sample, negative_sample,
                                          subsampling_weight, batch_queries_dict, batch_idxs_dict)
            negative_logsigmoid = Flux.logsigmoid(negative_logit)
            negative_score = mean.(negative_logsigmoid, dims=ndims(negative_logsigmoid))
            positive_logsigmoid =Flux.logsigmoid(positive_logit)
            positive_score = squeeze(positive_logsigmoid, dim=ndims(positive_logsigmoid))
            positive_sample_loss = - sum(subsampling_weight * positive_score)
            negative_sample_loss = - sum(subsampling_weight * negative_score)
            positive_sample_loss /= sum(subsampling_weight)
            negative_sample_loss /= sum(subsampling_weight)

            loss = (positive_sample_loss + negative_sample_loss)/2
        end
    end
    #=========================================================================
    negative_score = F.logsigmoid(-negative_logit).mean(dim=1)
    positive_score = F.logsigmoid(positive_logit).squeeze(dim=1)
    positive_sample_loss = - (subsampling_weight * positive_score).sum()
    negative_sample_loss = - (subsampling_weight * negative_score).sum()
    positive_sample_loss /= subsampling_weight.sum()
    negative_sample_loss /= subsampling_weight.sum()

    loss = (positive_sample_loss + negative_sample_loss)/2
    loss.backward()
    optimizer.step()
    ==========================================================================#
    log = Dict{
        "positive_sample_loss": positive_sample_loss.item(),
        "negative_sample_loss": negative_sample_loss.item(),
        "loss": loss.item(),
    }
    return log
end

#@staticmethod
function test_step(model, easy_answers, hard_answers, args, test_dataloader, query_name_dict, save_result=False, save_str="", save_empty=False)
#    model.eval()

    step = 0
    total_steps = length(test_dataloader)
    #logs = collections.defaultdict(list)
    logs = Dict()

    #with torch.no_grad():
    for (negative_sample, queries, queries_unflatten, query_structures) in tqdm(test_dataloader)
        batch_queries_dict = Dict() #collections.defaultdict(list)
        batch_idxs_dict = Dict() #collections.defaultdict(list)
        for (i, query) in enumerate(queries)
            push!(batch_queries_dict[query_structures[i]], query)
            push!(batch_idxs_dict[query_structures[i]], i)
        end

        for query_structure in batch_queries_dict
            if args["cuda"]
                batch_queries_dict[query_structure] = Int64.(batch_queries_dict[query_structure]) .|> gpu
            else
                batch_queries_dict[query_structure] = Int64.(batch_queries_dict[query_structure])
            end
        end

        if args["cuda"]
            negative_sample = negative_sample .|> gpu
        end

        _, negative_logit, _, idxs = model(None, negative_sample, None, batch_queries_dict, batch_idxs_dict)
        queries_unflatten = [queries_unflatten[i] for i in idxs]
        query_structures = [query_structures[i] for i in idxs]
        argsort = sortperm(negative_logit, dim=ndims(negative_logit)-1, rev=true)
        ranking = Float32.(copy(argsort))
        if length(argsort) == args["test_batch_size"] # if it is the same shape with test_batch_size, we can reuse batch_entity_range without creating a new one
            #ranking = ranking.scatter_(1, argsort, model.batch_entity_range) # achieve the ranking of all entities
            ranking = getindex(model.batch_entity_range, argsort)
        else # otherwise, create a new torch Tensor for batch_entity_range
            if args["cuda"]
                #ranking = ranking.scatter_(1,
                #                           argsort,
                #                           torch.arange(model.nentity).to(torch.float).repeat(argsort.shape[0],
                #                                                                              1).cuda()
                #                           ) # achieve the ranking of all entities
                target = repeat(Float32.(collect(1:model.nentity)), 1, size(argsort, ndims(argsort)))
                ranking = getindex(argsort, target) |> gpu
            else
                #ranking = ranking.scatter_(1,
                #                           argsort,
                #                           torch.arange(model.nentity).to(torch.float).repeat(argsort.shape[0],
                #                                                                              1)
                #                           ) # achieve the ranking of all entities
                target = repeat(Float32.(collect(1:model.nentity)), 1, size(argsort, ndims(argsort)))
                ranking = getindex(argsort, target)
            end
        end

        for (idx, (i, query, query_structure)) in enumerate(zip(argsort[:, ndims(argsort)], queries_unflatten, query_structures))
            hard_answer = hard_answers[query]
            easy_answer = easy_answers[query]
            num_hard = length(hard_answer)
            num_easy = length(easy_answer)
            @assert length(hard_answer.intersection(easy_answer)) == 0
            cur_ranking = ranking[idx, list(easy_answer) + list(hard_answer)]
            cur_ranking, indices = sortperm(cur_ranking)
            masks = indices >= num_easy
            if args["cuda"]
                answer_list = Float32.(collect(1:(num_hard + num_easy))) .|> gpu
            else
                answer_list = Float32.(collect(1:(num_hard + num_easy)))
            end
            cur_ranking = cur_ranking .- answer_list + 1 # filtered setting
            cur_ranking = cur_ranking[masks] # only take indices that belong to the hard answers

            mrr = collect(mean(1 ./ cur_ranking))
            h1 = collect(Float32.(mean((cur_ranking <= 1))))
            h3 = collect(Float32.(mean((cur_ranking <= 3))))
            h10 = collect(Float32.(mean((cur_ranking <= 10))))

            push!(get!(logs, query_structure, Dict()), Dict(
                "MRR"=> mrr,
                "HITS1"=> h1,
                "HITS3"=> h3,
                "HITS10"=> h10,
                "num_hard_answer"=> num_hard
            ))
        end

        if step % args.test_log_steps == 0
            @info("Evaluating the model... ($step/$total_steps)")
        end
        step += 1
    end

    #metrics = collections.defaultdict(lambda: collections.defaultdict(int))
    metrics = Dict()
    for query_structure in logs
        for metric in keys(logs[query_structure][0])
            if metric in ["num_hard_answer"]
                continue
            end
            metrics[:query_structure][:metric] = sum([log[metric] for log in logs[query_structure]])/length(logs[query_structure])
        end
        metrics[:query_structure]["num_queries"] = length(logs[query_structure])
    end

    return metrics
end

end #end module


In [None]:
using Revise

using MLUtils
using Random
using ArgParse
using LoggingExtras, TensorBoardLogger
using Dates
using Flux, JLD2

#println("working directory: {$(pwd())}")

include("src/dataloader.jl")
include("src/model.jl")
include("src/utils.jl")

using .KGDataset
using .KGModel

f_dir = "dataset";
f_model = "FB15k-betae";

query_name_dict = Dict{Tuple, String}(("e",("r",))=> "1p",
                                      ("e", ("r", "r"))=> "2p",
                                      ("e", ("r", "r", "r"))=> "3p",
                                      (("e", ("r",)), ("e", ("r",)))=> "2i",
                                      (("e", ("r",)), ("e", ("r",)), ("e", ("r",)))=> "3i",
                                      ((("e", ("r",)), ("e", ("r",))), ("r",))=> "ip",
                                      (("e", ("r", "r")), ("e", ("r",)))=> "pi",
                                      (("e", ("r",)), ("e", ("r", "n")))=> "2in",
                                      (("e", ("r",)), ("e", ("r",)), ("e", ("r", "n")))=> "3in",
                                      ((("e", ("r",)), ("e", ("r", "n"))), ("r",))=> "inp",
                                      (("e", ("r", "r")), ("e", ("r", "n")))=> "pin",
                                      (("e", ("r", "r", "n")), ("e", ("r",)))=> "pni",
                                      (("e", ("r",)), ("e", ("r",)), ("u",))=> "2u-DNF",
                                      ((("e", ("r",)), ("e", ("r",)), ("u",)), ("r",))=> "up-DNF",
                                      ((("e", ("r", "n")), ("e", ("r", "n"))), ("n",))=> "2u-DM",
                                      ((("e", ("r", "n")), ("e", ("r", "n"))), ("n", "r"))=> "up-DM"
                                      );
name_query_dict = Dict{String, Tuple}((y => x) for (x, y) in query_name_dict);
all_tasks = collect(keys(name_query_dict));

function parse_cmdargs(args::Vector{String})
    s = ArgParseSettings(
        description = "Training and Testing Knowledge Graph Embedding Models",
        usage = "julia --project=[/path/to/project] src/$(@__FILE__) [<args>] [-h | --help]"
    )

    @add_arg_table s begin
        "--cuda"
        action= :store_true
        help="use GPU"
        "--train"
        action= :store_true
        help="do train"
        "--valid"
        action= :store_true
        help="do valid"
        "--test"
        action= :store_true
        help="do test"
        "--data_path"
        arg_type=String
        default= nothing
        help="KG data path"
        "-n", "--negative_sample_size"
        default=128
        arg_type=Int
        help="negative entities sampled per query"
        "-d", "--hidden_dim"
        default=500
        arg_type=Int
        help="embedding dimension"
        "-g", "--gamma"
        default=12.0
        arg_type=Float64
        help="margin in the loss"
        "-b", "--batch_size"
        default=1024
        arg_type=Int
        help="batch size of queries"
        "--test_batch_size"
        default=1
        arg_type=Int
        help="valid/test batch size"
        "--learning_rate"
        default=0.0001
        arg_type=Float64
        "--cpu"
        default=10
        arg_type=Int
        help="used to speed up torch.dataloader"
        "--save_path"
        default="."
        arg_type=String
        help="no need to set manually, will configure automatically"
        "--max_steps"
        default=100000
        arg_type=Int
        help="maximum iterations to train"
        "--warm_up_steps"
        default=nothing
        arg_type=Int
        help="no need to set manually, will configure automatically"
        "--save_checkpoint_steps"
        default=50000
        arg_type=Int
        help="save checkpoints every xx steps"
        "--valid_steps"
        default=10000
        arg_type=Int
        help="evaluate validation queries every xx steps"
        "--log_steps"
        default=100
        arg_type=Int
        help="train log every xx steps"
        "--test_log_steps"
        default=1000
        arg_type=Int
        help="valid/test log every xx steps"
        "--nentity"
        arg_type=Int
        default=0
        help="DO NOT MANUALLY SET"
        "--nrelation"
        arg_type=Int
        default=0
        help="DO NOT MANUALLY SET"
        "--geo"
        default="vec"
        arg_type=String
        help="the reasoning model, vec for GQE, box for Query2box, beta for BetaE"
        "--print_on_screen"
        action= :store_false
        "--tasks"
        default="1p.2p.3p.2i.3i.ip.pi.2in.3in.inp.pin.pni.2u.up"
        arg_type=String
        help="tasks connected by dot, refer to the BetaE paper for detailed meaning and structure of each task"
        "--seed"
        default=0
        arg_type=Int
        help="random seed"
        "--beta_mode"
        default="(1600,2)"
        arg_type=String
        help="(hidden_dim,num_layer) for BetaE relational projection"
        "--box_mode"
        default="(nothing,0.02)"
        arg_type=String
        help="(offset activation,center_reg) for Query2box, center_reg balances the in_box dist and out_box dist"
        "--prefix"
        default=nothing
        arg_type=String
        help="prefix of the log path"
        "--checkpoint_path"
        default=nothing
        arg_type=String
        help="path for loading the checkpoints"
        "--evaluate_union"
        default="DNF"
        arg_type=String
        help="the way to evaluate union queries, transform it to disjunctive normal form (DNF) or use the De Morgan\"s laws (DM)"
    end

    return parse_args(args, s)
end

#="""
Write logs to console and log file
"""=#
function set_logger(args)

    if args["train"] == true
        log_file = joinpath(args["save_path"], "train.log")
    else
        log_file = joinpath(args["save_path"], "test.log")
    end

    log_io = open(log_file, "w");
    datefmt=DateFormat("YY-mm-dd HH:MM:SS");

    timestamp_logger(logger) = TransformerLogger(logger) do log
        merge(log, (; message = "$(Dates.format(now(), datefmt)) $(log.message)"))
    end

    file_logger = timestamp_logger(FileLogger(log_file));
    global_logger(file_logger)

    if args["print_on_screen"]
        time_loger = timestamp_logger(ConsoleLogger(stdout, Logging.Info));

        tl = TeeLogger(file_logger, time_loger);
        global_logger(tl)
    end
end

#="""
Print the evaluation logs
"""=#
function log_metrics(mode, step, metrics)
    for metric in metrics
        @info "$mode $metric at step $(step): $(metrics[metric.first])"
    end
end

#="""
Evaluate queries in dataloader
"""=#
function evaluate(model, tp_answers, fn_answers, args, dataloader, query_name_dict, mode, step, writer)

    average_metrics = Dict{Float}()
    all_metrics = Dict{Float}()

    metrics = model.test_step(model, tp_answers, fn_answers, args, dataloader, query_name_dict)
    num_query_structures = 0
    num_queries = 0
    for query_structure in metrics
        log_metrics(mode * " " * query_name_dict[query_structure], step, metrics[query_structure])

        for metric in metrics[query_structure]
            writer.add_scalar("_".join([mode, query_name_dict[query_structure], metric]), metrics[query_structure][metric], step)
            all_metrics["_".join([query_name_dict[query_structure], metric])] = metrics[query_structure][metric]
            if metric != "num_queries"
                average_metrics[metric] += metrics[query_structure][metric]
            end
        end
        num_queries += metrics[query_structure]["num_queries"]
        num_query_structures += 1
    end

    for metric in average_metrics
        average_metrics[metric] /= num_query_structures
        writer.add_scalar("_".join([mode, "average", metric]), average_metrics[metric], step)
        all_metrics["_".join(["average", metric])] = average_metrics[metric]
    end

    log_metrics("$mode average", step, average_metrics)
    return all_metrics
end

function main(args)
    global train_queries, train_answers, valid_queries, valid_hard_answers, valid_easy_answers, test_queries, test_hard_answers, test_easy_answers

    Random.seed!(args["seed"])
    tasks = split(args["tasks"], ".")
    for task in tasks
        if 'n' in task && args["geo"] in ["box", "vec"]
            @assert false "Q2B and GQE cannot handle queries with negation"
        end
    end
    if args["evaluate_union"] == "DM"
        @assert args["geo"] == "beta" "only BetaE supports modeling union using De Morgan's Laws"
    end

    cur_time = format_time()
    if args["prefix"] == nothing
        prefix = "logs"
    else
        prefix = args["prefix"]
    end

    @info ("overwritting saving path: $(args["save_path"])")
    args["save_path"] = joinpath(prefix, last(split(args["data_path"], "/")), args["tasks"], args["geo"])
    geo = args["geo"]
    if geo in ["box"]
        save_str = "g-$(args["gamma"])-mode-$(args["box_mode"])"
    elseif geo in ["vec"]
        save_str = "g-$(args["gamma"])"
    elseif geo == "beta"
        save_str = "g-$(args["gamma"])-mode-$(args["beta_mode"])"
    end

    if args["checkpoint_path"] != nothing
        args["save_path"] = args["checkpoint_path"]
    else
        args["save_path"] = joinpath(args["save_path"], save_str, cur_time)
    end

    if ! ispath(args["save_path"])
        mkpath(args["save_path"])
    end

    @info ("logging to $(args["save_path"])")
    if ! args["train"] # if not training, then create tensorboard files in some tmp location
        writer = TBLogger("./logs-debug/unused-tb")
    else
        writer = TBLogger(args["save_path"])
    end
    set_logger(args)

    nentity, nrelation = open(joinpath(args["data_path"], "stats.txt")) do f
        entrel = readlines(f)
        nentity = parse(Int, last(split(entrel[1], " ")))
        nrelation = parse(Int, last(split(entrel[2], " ")))

        (nentity, nrelation)
    end

    args["nentity"] = nentity
    args["nrelation"] = nrelation

    @info(repeat("-------------------------------", 2))
    @info("Geo: $(args["geo"])")
    @info("Data Path: $(args["data_path"])")
    @info("#entity: $(nentity)")
    @info("#relation: $(nrelation)")
    @info("#max steps: $(args["max_steps"])")
    @info("Evaluate unoins using: $(args["evaluate_union"])")

    #train_queries, train_answers, valid_queries, valid_hard_answers, valid_easy_answers,
    #test_queries, test_hard_answers, test_easy_answers = KGDataset.load_data(args, name_query_dict)

    local train_path_iterator, train_other_iterator
    if args["train"]
        @info("Train asked...")
        train_path_queries = Dict{Any, Set}()
        train_other_queries = Dict{Any, Set}()
        query_path_list = ["1p", "2p", "3p"]
        for query_structure in keys(train_queries)
            print(query_structure)
            if query_name_dict[query_structure] in query_path_list
                train_path_queries[query_structure] = train_queries[query_structure]
            else
                train_other_queries[query_structure] = train_queries[query_structure]
            end
        end

        train_path_queries = flatten_query(train_path_queries)
        @info "Flatten query length: $(length(train_path_queries)) typeof(query) $(typeof(train_path_queries))"

        train_dataset = KGDataset.TrainDataset(train_path_queries, train_answers, nentity, nrelation, args["negative_sample_size"])
        data_loader = MLUtils.DataLoader(train_dataset, batchsize = args["batch_size"], collate = true, shuffle = false);
        #for x in data_loader
        #    @info "data_loader loop...." * "$(size(x))"
        #end
        train_path_iterator = KGDataset.SingleDirectionalOneShotIterator(data_loader);
        #            num_workers=args.cpu_num,
        #            collate_fn=TrainDataset.collate_fn));

        if length(train_other_queries) > 0
            train_other_queries = flatten_query(train_other_queries)
            train_other_iterator = KGDataset.SingleDirectionalOneShotIterator(
                MLUtils.DataLoader(KGDataset.TrainDataset(train_other_queries,
                                                          train_answers,
                                                          nentity,
                                                          nrelation,
                                                          args["negative_sample_size"]),
                                   batchsize=args["batch_size"],
                                   shuffle=true));
            #                                       num_workers=args.cpu_num,
            #                                       collate_fn=TrainDataset.collate_fn))
        else
            train_other_iterator = nothing
        end
    end

    if args["valid"]
        @info("Validation asked...")

        #for query_structure in keys(valid_queries)
        #    @info query_name_dict[query_structure] * ": " * "$(length(valid_queries[query_structure]))"
        # end
        valid_queries2 = flatten_query(valid_queries)
        valid_dataloader = KGDataset.DataLoader(KGDataset.TestDataset(valid_queries2, nentity, nrelation),
                                                batchsize=args["test_batch_size"]);
        #            num_workers=args.cpu_num,
        #            collate_fn=TestDataset.collate_fn)
    end

    if args["test"]
        @info("Test ...")

        # for query_structure in keys(test_queries)
        #    @info query_name_dict[query_structure] * ": " * "$(length(test_queries[query_structure]))"
        # end
        test_queries = flatten_query(test_queries)
        test_dataloader = KGDataset.DataLoader(
            KGDataset.TestDataset(test_queries, nentity, nrelation),
            batchsize=args["test_batch_size"]);
        #         num_workers=args.cpu_num,
        #         collate_fn=TestDataset.collate_fn)
    end

    model = KGModel.KGReasoning(nentity,
                                 nrelation,
                                 args["hidden_dim"],
                                 args["gamma"],
                                 args["geo"],
                                 args["test_batch_size"],
                                 eval_tuple(args["box_mode"]),
                                 eval_tuple(args["beta_mode"]),
                                 query_name_dict,
                                 args["cuda"] == "Yes")

    @info("Model Parameter Configuration:")
    for (lindex,layer) in enumerate(Flux.params(model)) #.named_parameters()
        #@info("Parameter %s: %s, require_grad = %s" % (name, str(param.size()), str(param.requires_grad)))
        #if param.requires_grad
        #    num_params += np.prod(param.size())
        #end
        num_params = 0
        for (pindex, pa) in enumerate(Flux.params(layer))
            @info("Parameter layer$lindex-$pindex: $(size(pa))")
            num_params += sum(length, Flux.params(layer))
        end
        @info("Parameter Number: $num_params")
    end

    if args["cuda"]
        model = model.cuda()
    end

    local init_step, checkpoint, step, current_learning_rate, warn_up_steps
    if args["train"]
        current_learning_rate = args["learning_rate"]
        opt_state = Flux.setup(Flux.Optimise.Adam(current_learning_rate), model)
        warn_up_steps = floor(args["max_steps"] / 2)
    end

    if args["checkpoint_path"] != nothing
        @info("Loading checkpoint $(args["checkpoint_path"])...")
        checkpoint = Flux.loadmodel!(model, JLD2.load(joinPath(args["checkpoint_path"], "checkpoint"), "model_state"))
        init_step = checkpoint["step"]
        Flux.loadmodel!(model, checkpoint["model_state_dict"])
        #model.load_state_dict(checkpoint["model_state_dict"])

        if args["train"]
            current_learning_rate = checkpoint["current_learning_rate"]
            warn_up_steps = checkpoint["warn_up_steps"]
            #optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        end
        @info("Ramdomly Initializing $(args["geo"]) Model...")
    else
        @info("Ramdomly Initializing $(args["geo"]) Model...")
        init_step = 0
    end

    step = init_step
    if args["geo"] == "box"
        @info("box mode = $(args["box_mode"])")
    elseif args["geo"] == "beta"
        @info("beta mode = $(args["beta_mode"])")
    end
    @info("tasks = $(args["tasks"])")
    @info("init_step = $init_step")
    if args["train"]
        @info("learning_rate = $current_learning_rate")
    end
    @info("batch_size = $(args["batch_size"])")
    @info("hidden_dim = $(args["hidden_dim"])")
    @info("gamma = $(args["gamma"])")


    if args["train"]
        @info("Start Training...")
        training_logs = []
        
        # #Training Loop
        local path_data; path_next=1; 
        local other_data; other_next=1;
        for step in range(init_step, args["max_steps"])
            if step == 2 * floor(args["max_steps"] / 3)
                args["valid_steps"] *= 4
            end

            (path_data, path_next) = iterate(train_path_iterator, path_next)
            println("path_data: $(path_data),\n path_next: $(path_next)")
            log = KGModel.train_step(model, opt_state, path_data, args, step)
            for metric in log
                writer.add_scalar("path_" * metric, log[metric], step)
            end
            
            if train_other_iterator != nothing
                (other_data, other_next) = iterate(train_other_iterator, other_next)
                log = KGModel.train_step(model, opt_state, other_data, args, step)
                for metric in log
                    @info "metric : $(metric)"
                    writer.add_scalar("other_"+metric, log[metric], step)
                end
                log = KGModule.train_step(model, opt_state, path_data, args, step)
            end

            training_logs.append(log)

            if step >= warn_up_steps
                current_learning_rate = current_learning_rate / 5
                @info("Change learning_rate to $(current_learning_rate) at step $(step)")

                opt_state = Flux.setup(Flux.Optimiser.Adam(lr = current_learning_rate),
                                       model)
                warn_up_steps = warn_up_steps * 1.5
            end

            if step % args["save_checkpoint_steps"] == 0
                save_variable_list = (
                    "step": step,
                    "current_learning_rate": current_learning_rate,
                    "warm_up_steps": warn_up_steps
                )
                JLD2.save(model, opt_state, save_variable_list, args)
            end

            if step % args["valid_steps"] == 0 && step > 0
                if args["do_valid"]
                    @info("Evaluating on Valid Dataset...")
                    valid_all_metrics = evaluate(model, valid_easy_answers, valid_hard_answers, args,
                                                 valid_dataloader, query_name_dict, "Valid", step, writer)
                end

                if args["do_test"]
                    @info("Evaluating on Test Dataset...")
                    test_all_metrics = evaluate(model, test_easy_answers, test_hard_answers, args,
                                                test_dataloader, query_name_dict, "Test", step, writer)
                end
            end

            if step % args["log_steps"] == 0
                metrics = Dict()
                for metric in training_logs[0].keys()
                    metrics[metric] = sum([log[metric] for log in training_logs])/len(training_logs)
                end

                log_metrics("Training average", step, metrics)
                training_logs = []
            end

            save_variable_list = (
                "step": step,
                "current_learning_rate": current_learning_rate,
                "warm_up_steps": warn_up_steps
            )
            JLD2.save(model, opt_state, save_variable_list, args)

            try
                print(step)
            catch
                step = 0
            end
        end
        @info("Training finished!!")
    end

    #    if args["test"]
    #        @info("Evaluating on Test Dataset...")
    #        test_all_metrics = evaluate(model, test_easy_answers, test_hard_answers, args, test_dataloader, query_name_dict, "Test", step, writer)
    #    end
    #
end



In [None]:
#if abspath(PROGRAM_FILE) == @__FILE__
    args = Vector{String}(["--train", "--data_path", "dataset/FB15k-betae",
                           "-n", "128", "-b", "1", "-d", "800", "-g", "24","--learning_rate",
                           "0.0001", "--max_steps", "450001",
                           "--cpu", "1", "--geo", "beta", "--valid_steps", "15000"])
    structed_args = parse_cmdargs(args);
    println(structed_args)
    set_logger(structed_args);

    main(structed_args)

#end


In [34]:
GC.gc()

In [11]:
using Revise

include("src/main.jl")

args = Vector{String}(["--train", "--data_path", "dataset/FB15k-betae",
                       "-n", "32", "-b", "2", "-d", "800", "-g", "24","--learning_rate",
                       "0.0001", "--max_steps", "450001",
                       "--cpu", "1", "--geo", "beta", "--valid_steps", "15000"])
structed_args = parse_cmdargs(args);
println(structed_args)
set_logger(structed_args);

main(structed_args)

Dict{String, Any}("geo" => "beta", "test_log_steps" => 1000, "tasks" => "1p.2p.3p.2i.3i.ip.pi.2in.3in.inp.pin.pni.2u.up", "batch_size" => 2, "evaluate_union" => "DNF", "nentity" => 0, "nrelation" => 0, "print_on_screen" => true, "cpu" => 1, "valid" => false, "valid_steps" => 15000, "train" => true, "negative_sample_size" => 32, "checkpoint_path" => nothing, "prefix" => nothing, "cuda" => false, "warm_up_steps" => nothing, "hidden_dim" => 800, "beta_mode" => "(1600,2)", "learning_rate" => 0.0001, "box_mode" => "(nothing,0.02)", "data_path" => "dataset/FB15k-betae", "max_steps" => 450001, "save_checkpoint_steps" => 50000, "save_path" => ".", "test" => false, "gamma" => 24.0, "log_steps" => 100, "seed" => 0, "test_batch_size" => 1)
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m2024-02-15 10:19:28 overwritting saving path: .




[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m2024-02-15 10:19:28 logging to logs/FB15k-betae/1p.2p.3p.2i.3i.ip.pi.2in.3in.inp.pin.pni.2u.up/beta/g-24.0-mode-(1600,2)/2024.02.15
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m2024-02-15 10:19:28 --------------------------------------------------------------
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m2024-02-15 10:19:28 Geo: beta
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m2024-02-15 10:19:28 Data Path: dataset/FB15k-betae
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m2024-02-15 10:19:28 #entity: 14951
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m2024-02-15 10:19:28 #relation: 2690
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m2024-02-15 10:19:28 #max steps: 450001
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m2024-02-15 10:19:28 Evaluate unoins using: DNF
[36m[1m[ [22m[39m[36m[1mInfo: [22m[39m2024-02-15 10:19:28 Train required...
("e", ("r", "r", "r"))(("e", ("r",)), ("e", ("r",)))(("e", ("r",)), ("e", ("r",)), ("e", ("r",)

LoadError: MethodError: no method matching getindex(::Main.KGDataset.TrainDataset, ::Int64)