In [10]:
import torch
from models import SpKBGATModified, SpKBGATConvOnly
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from copy import deepcopy
from torchsummary import summary

import random
import argparse
import os
import sys
import logging
import time
import pickle

from collections import defaultdict
import queue
from config import Config
from utils import load_object

In [5]:
args = Config()
args.load_config()
device = "cuda" if args.cuda else "cpu"

In [8]:
print("Loading corpus")
Corpus_ = load_object(output=args.data_folder, name="corpus")
entity_embeddings = load_object(output=args.data_folder, name="entity_embeddings")
relation_embeddings = load_object(output=args.data_folder, name="relation_embeddings")
node_neighbors_2hop = Corpus_.node_neighbors_2hop

Loading corpus


In [11]:
model_gat = SpKBGATModified(entity_embeddings, relation_embeddings, args.entity_out_dim, args.entity_out_dim,
                            args.drop_GAT, args.alpha, args.nheads_GAT)
model_conv = SpKBGATConvOnly(entity_embeddings, relation_embeddings, args.entity_out_dim, args.entity_out_dim,
                             args.drop_GAT, args.drop_conv, args.alpha, args.alpha_conv,
                             args.nheads_GAT, args.out_channels)

In [13]:
folder = "{output}/{dataset}".format(output=args.output_folder, dataset=args.dataset)
if args.save_gdrive:
    folder = args.drive_folder
model_name = "{folder}/{dataset}_{device}_{name}_{epoch}.pt".format(folder=folder, dataset=args.dataset, device=device, name="conv", epoch=args.epochs_conv - 1)
model_conv.load_state_dict(torch.load(model_name), strict=False)

<All keys matched successfully>

In [14]:
print("Loading GAT encoder")
folder = "{output}/{dataset}".format(output=args.output_folder, dataset=args.dataset)
if args.save_gdrive:
    folder = args.drive_folder
model_name = "{folder}/{dataset}_{device}_{name}_{epoch}.pt".format(folder=folder, dataset=args.dataset, device=device, name="gat", epoch=args.epochs_gat - 1)
model_gat.load_state_dict(torch.load(model_name), strict=False)

Loading GAT encoder


<All keys matched successfully>

In [21]:
gat_w = model_gat.state_dict()
conv_w = model_conv.state_dict()

Model GAT

In [16]:
for para_name in gat_w:
      print("{0} : {1}".format(para_name, gat_w[para_name].shape))

final_entity_embeddings : torch.Size([8, 200])
final_relation_embeddings : torch.Size([9, 200])
entity_embeddings : torch.Size([8, 50])
relation_embeddings : torch.Size([9, 50])
W_entities : torch.Size([50, 200])
sparse_gat_1.W : torch.Size([50, 200])
sparse_gat_1.attention_0.a : torch.Size([100, 150])
sparse_gat_1.attention_0.a_2 : torch.Size([1, 100])
sparse_gat_1.attention_1.a : torch.Size([100, 150])
sparse_gat_1.attention_1.a_2 : torch.Size([1, 100])
sparse_gat_1.out_att.a : torch.Size([200, 600])
sparse_gat_1.out_att.a_2 : torch.Size([1, 200])


In [19]:
model_gat.eval()

SpKBGATModified(
  (sparse_gat_1): SpGAT(
    (dropout_layer): Dropout(p=0.3, inplace=False)
    (attention_0): SpGraphAttentionLayer (50 -> 100)
    (attention_1): SpGraphAttentionLayer (50 -> 100)
    (out_att): SpGraphAttentionLayer (200 -> 200)
  )
)

Model ConvKB

In [22]:
for para_name in conv_w:
      print("{0} : {1}".format(para_name, conv_w[para_name].shape))

final_entity_embeddings : torch.Size([8, 200])
final_relation_embeddings : torch.Size([9, 200])
convKB.conv_layer.weight : torch.Size([500, 1, 1, 3])
convKB.conv_layer.bias : torch.Size([500])
convKB.fc_layer.weight : torch.Size([1, 100000])
convKB.fc_layer.bias : torch.Size([1])


In [23]:
model_conv.eval()

SpKBGATConvOnly(
  (convKB): ConvKB(
    (conv_layer): Conv2d(1, 500, kernel_size=(1, 3), stride=(1, 1))
    (dropout): Dropout(p=0.0, inplace=False)
    (non_linearity): ReLU()
    (fc_layer): Linear(in_features=100000, out_features=1, bias=True)
  )
)