# Notebook for training and generating rules

In [None]:
import os
import random
import sys
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch import optim
%matplotlib inline

os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":16:8"

from rule_miner import RuleMiner
from framework import RTFramework
from framework import RTLoss
from dataloader import RTDataLoader

# Paths
here = "."
data_dir = os.path.join(here, "../datasets/")

dataset = "family"

dataset_dir = os.path.join(data_dir, dataset)
facts_file = os.path.join(dataset_dir, "facts.txt")
train_file = os.path.join(dataset_dir, "train.txt")
valid_file = os.path.join(dataset_dir, "valid.txt")
test_file = os.path.join(dataset_dir, "test.txt")
entities_file = os.path.join(dataset_dir, "entities.txt")
relations_file = os.path.join(dataset_dir, "relations.txt")
"""Saved paths"""
experiment_dir = os.path.join(here, "../saved", dataset)
# Model checkpoint for continuing training.
checkpoint_dir = os.path.join(experiment_dir, "checkpoint/")
# Directory to save trained model.
model_save_dir = os.path.join(experiment_dir, "model/")
# Options file.
option_file = os.path.join(experiment_dir, "option.txt")
# Model prediction file.
prediction_file = os.path.join(experiment_dir, "prediction.txt")
if not os.path.exists(experiment_dir):
    os.makedirs(checkpoint_dir)
    os.makedirs(model_save_dir)
"""Other configurations"""
device = "cuda"

In [None]:
"""Hypterparameters"""
rank = 3
num_steps = 2
top_k = 10
num_rnn_layers = 1
entity_embedding_dim = 128
query_embedding_dim = 128
query_rnn_hidden_size = 128
entity_rnn_hidden_size = 128
seed = 210224
batch_size = 128
threshold = 1e-20
train_epochs = 20
num_sample_batches = 0
lr = 0.001
query_include_reverse = True

# Specify random seed.
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

In [None]:
dataloader = RTDataLoader(
    relations_file, entities_file,
    facts_file, train_file,
    valid_file, test_file,
    query_include_reverse
)
dataloader.id2rel[dataloader.num_operators] = "self"
dataloader.rel2id["self"] = dataloader.num_operators

## Train

In [None]:
# Define model and training framework.
miner = RuleMiner(
    rank,
    num_steps,
    dataloader.num_entities,
    dataloader.num_operators,
    dataloader.entity_degrees,
    query_include_reverse,
    entity_embedding_dim,
    query_embedding_dim,
    num_rnn_layers,
    query_rnn_hidden_size,
    entity_rnn_hidden_size
).to(device)

optimizer = optim.Adam(miner.parameters(), lr=lr)
scheduler = None
loss_fn = RTLoss(threshold).to(device)
framework = RTFramework(
    miner, optimizer, dataloader,
    loss_fn, device, None, checkpoint_dir
)

In [None]:
framework.train(top_k, batch_size, 0, train_epochs)

## Evaluation

In [None]:
ckpt_file = os.path.join(checkpoint_dir, "checkpoint.pth.tar")
checkpoint = torch.load(ckpt_file)
miner.load_state_dict(checkpoint['model'])
framework.eval("test", batch_size, top_k)

## Generate rules

In [None]:
for qq, hh, tt, trips in dataloader.one_epoch("test", 10, shuffle=True):
    break
qq = torch.from_numpy(qq).to(device)
hh = torch.from_numpy(hh).to(device)
tt = torch.from_numpy(tt).to(device)
logits = miner(qq, hh, trips)
print([dataloader.id2rel[rel.item()] for rel in qq])

In [None]:
relation = 6
print(miner.query_attn_ops_list[0][:, :, relation, -1].size())
print(miner.query_attn_ops_list[0][:, :, relation, -1])
print(dataloader.id2rel)

In [None]:
import itertools

attn_combs = [list(range(dataloader.num_operators+1)) for _ in range(num_steps)]
attn_combs = itertools.product(*attn_combs)
path_rank = []
for comb in attn_combs:
    cur_path = [[], 0.]
    step2rel = list(zip(range(len(comb)), comb))
    for r in range(rank):
        attention_operators = miner.query_attn_ops_list[r][:, :, relation, -1]
        tmp_score = 1.
        for step, rel in step2rel:
            if r == 0:
                cur_path[0].append(dataloader.id2rel[rel])
            tmp_score *= attention_operators[step, rel].item()
        cur_path[1] += tmp_score
    path_rank.append(cur_path)
path_rank.sort(key=lambda x: x[1], reverse=True)
for item in path_rank:
    item[1] /= path_rank[0][1]

In [None]:
path_rank[:10]