In [1]:
import networkx as nx
import pandas as pd
import numpy as np
import os
import json
import random
from typing import List, Union, Dict
from tqdm import tqdm

In [46]:
%load_ext blackcellmagic

In [None]:
class Example:
    """
    class for holding the data for our example.
    """
    def __init__(self, path: str):
        self.genres = pd.read_csv(os.path.join(path, "band_genres.csv"))
        self.labels = pd.read_csv(os.path.join(path, "band_labels.csv"))
        self.members = pd.read_csv(os.path.join(path, "band_members.csv"))

    def graph_from_data(self) -> nx.Graph:
        """
        Build a graph from our sparql queried wikidata music dataset
        """
        g = nx.DiGraph()
        for band, band_name, member, member_name in self.members.values:
            g.add_node(band, name=band_name)
            g.add_node(member, name=member_name)
            g.add_edge(band, member, rel="has_member")
            g.add_edge(member, band, rel="member_of")

        for band, band_name, label in self.labels.values:
            g.add_node(band, name=band_name)
            g.add_node(label)
            g.add_edge(label, band, rel="signed")
            g.add_edge(band, label, rel="signed_by")

        for band, band_name, genre, genre_name in self.genres.values:
            g.add_node(band, name=band_name)
            g.add_node(genre, name=genre_name)
            g.add_edge(band, genre, rel="has_genre")
            g.add_edge(genre, band, rel="example_of")

        return g


class Learner:
    """
    Primary class for learning our graph embeddings.
    """
    def __init__(self, graph, rel_name="rel", min_relations=1, max_relations=6):
        self.g: nx.Graph = graph
        self.node_ids: Dict[str, int] = {"": 0}
        self.node_id_inv: Dict[int, str] = {}
        self.rel_ids: Dict[str, int] = {"": 0}
        self.rel_id_inv: Dict[int, str] = {}
        self.node_id_list = []
        self.rel_name = rel_name
        self.min_relations = min_relations
        self.max_relations = max_relations

    def fit_graph(self):
        """
        Learn the mapping between nodes and their integer ids,
        as well as relationship types and their ids
        """
        for n1, n2, edge in self.g.edges(data=True):
            rel = edge[self.rel_name]
            if n1 not in self.node_ids:
                self.node_ids[n1] = len(self.node_ids)
            if n2 not in self.node_ids:
                self.node_ids[n2] = len(self.node_ids)
            if rel not in self.rel_ids:
                self.rel_ids[rel] = len(self.rel_ids)
        self.node_id_list = [i for i in self.node_ids.keys() if i != ""]
        self.node_id_inv = {v: k for k, v in self.node_ids.items()}
        self.rel_id_inv = {v: k for k, v in self.rel_ids.items()}

    def node_to_features(self, node_id):
        """
        Given a node on the graph, featurize its relations and
        return its identity
        """
        node = self.g[node_id]
        node_ind = self.node_ids[node_id]
        n_relations = np.random.randint(self.min_relations, self.max_relations + 1)
        keys = list(node.keys())
        random.shuffle(keys)
        keys = keys[:n_relations]
        key_ids = np.array([self.node_ids[n] for n in keys])
        rel_ids = np.array([self.rel_ids[node[n2][self.rel_name]] for n2 in keys])
        return (key_ids, rel_ids), node_ind

    def get_batch(self, batch_size):
        x_zero = []
        x_one = []
        y = []
        for n in range(0, batch_size):
            node_id = random.choice(self.node_id_list)
            xsub, ysub = self.node_to_features(node_id)
            x_zero.append(xsub[0])
            x_one.append(xsub[1])
            y.append(ysub)
        mx = np.max([len(i) for i in x_zero])
        return (
            (
                krs.preprocessing.sequence.pad_sequences(x_zero, mx),
                krs.preprocessing.sequence.pad_sequences(x_one, mx),
            ),
            np.array(y),
        )

In [4]:
# load our example data
ex = Example('data/')

# turn it into a networkx graph
music = ex.graph_from_data()

# fit the Learner on our dataset
l = Learner(music)
l.fit_graph()

In [5]:
"""
Here we define our model. All of this will move to its own class with wrappers
for training and running queries.
"""
import tensorflow as tf
import tensorflow.keras as krs

EMBEDDING_DIM = 128

with tf.device("/GPU:0"):
    # embedding layers
    node_embedding_in = krs.layers.Embedding(len(l.node_ids) + 1, EMBEDDING_DIM)
    rel_embedding = krs.layers.Embedding(len(l.rel_ids) + 1, EMBEDDING_DIM)
    # vars for loss
    output_weights = tf.Variable(tf.random.normal([len(l.node_ids) + 1, EMBEDDING_DIM]))
    output_biases = tf.Variable(tf.zeros([len(l.node_ids) + 1]))


def get_embedding(x):
    with tf.device("/GPU:0"):
        # Lookup the corresponding embedding vectors for each sample in X.
        drop1 = krs.layers.Dropout(0.25)
        drop2 = krs.layers.Dropout(0.25)
        x_node_embed = drop1(node_embedding_in(x[0]))
        x_rel_embed = drop2(rel_embedding(x[1]))
        # multiply the relationship embeddings by the node embeddings
        mul = tf.math.multiply(x_node_embed,x_rel_embed)
        # max pool
        out = tf.math.reduce_max(mul, axis=1)
        return out


def get_loss(x_inp, y):
    with tf.device("/GPU:0"):
        # Compute the average NCE loss for the batch.
        y_true = tf.cast(tf.reshape(y, [-1, 1]), tf.int64)
        y = tf.cast(y, tf.int64)
        samp = tf.random.uniform_candidate_sampler(y_true, 1, 128, True, len(l.node_ids) + 1)
        loss = tf.reduce_mean(
            tf.nn.sampled_softmax_loss(
                weights=output_weights,
                biases=output_biases,
                labels=y_true,
                inputs=x_inp,
                num_sampled=128,
                num_classes=len(l.node_ids) + 1,
                num_true=1,
                sampled_values = samp,
            )
        )
        return loss


def run_optimization(x, y, optimizer):
    y = tf.convert_to_tensor(y)
    with tf.device("/GPU:0"):
        # Wrap computation inside a GradientTape for automatic differentiation.
        with tf.GradientTape() as g:
            emb = get_embedding(x)
            loss = get_loss(emb, y)

        # Compute gradients.
        to_diff = node_embedding_in.weights + rel_embedding.weights + [output_weights, output_biases]
        gradients = g.gradient(loss, to_diff)

        # Update W and b following gradients.
        optimizer.apply_gradients(zip(gradients, to_diff))
    return loss


def evaluate(x_embed):
    with tf.device("/GPU:0"):
        # Compute the cosine similarity between input data embedding and every embedding vectors
        x_embed = tf.cast(x_embed, tf.float32)
        x_embed_norm = x_embed / tf.sqrt(tf.reduce_sum(tf.square(x_embed)))
        embedding_norm = output_weights / tf.sqrt(
            tf.reduce_sum(tf.square(output_weights), 1, keepdims=True), tf.float32
        )
        cosine_sim_op = tf.matmul(x_embed_norm, embedding_norm, transpose_b=True)
        return cosine_sim_op


In [None]:
"""
Training loop.

Sample sets of relationships from sampled nodes and minimize their sampled_softmax_loss on
all the nodes in the graph.

Occasionally evaluate the most-similar output node vectors to our predictions
"""
PRINT_EVERY = 1000
EVAL_EVERY = 10000
INCREASE_BS_EVERY = 2000
LEARNING_RATE = 0.5
BATCH_SIZE = 4
MAX_BATCH_SIZE = 256

# increase batch size linearly
def increase_bs(bs, mx):
    return np.min([bs + 4, mx])

optimizer = tf.optimizers.SGD(LEARNING_RATE)

# monitor loss
train_losses = []
bs = BATCH_SIZE
for ep in tqdm(range(0, 1000000)):
    train_losses.append(run_optimization(*(l.get_batch(bs)), optimizer))
    
    if ep % PRINT_EVERY == 0:
        print(f"training loss at epoch {ep}: {np.mean(train_losses)}")
        train_losses = train_losses[-20:]
        
    if (ep % INCREASE_BS_EVERY == 0) and ep != 0:
        increase_bs(bs, MAX_BATCH_SIZE)

    # run an evaluation loop; pick random nodes+relations and find the most similar output nodes
    if (ep % EVAL_EVERY == 0) and ep != 0:
        batch_x, batch_y = l.get_batch(bs)
        sim = evaluate(get_embedding(batch_x)).numpy()
        for i in range(3):
            top_k = 3  # number of nearest neighbors.
            nearest = (-sim[i, :]).argsort()[:top_k]
            print(f"for node {l.node_id_inv[int(batch_y[i])]}- {music.nodes[l.node_id_inv[int(batch_y[i])]]} closest are nodes:")

            for k in range(top_k):
                print(music.nodes[l.node_id_inv[nearest[k]]])

In [80]:
# ask the model a question
def get_query(node_names, query_nodes):
    query_nodes = [
        l.node_ids[
            [
                i
                for i in music.nodes
                if "name" in music.nodes[i] and music.nodes[i]["name"] == z
            ][0]
        ]
        for z in node_names
    ]

    query_rels = [l.rel_ids[z] for z in relationships]
    return query_nodes, query_rels

# "what band has the relationship 'has_member' with both Tom Petty and Roy Orbison"
node_names = ["Tom Petty", "Roy Orbison"]
relationships = ["has_member", "has_member"]

query_nodes, query_rels = get_query(node_names, relationships)
batch_x = [np.array(query_nodes).reshape(-1, 1),
            np.array(query_rels).reshape(-1, 1)]

# get most similar output vectors
sim = evaluate(get_embedding(batch_x)).numpy()
top_k = 1
nearest = (-sim[0, :]).argsort()[:top_k]
nearest_scores = [sim[0, j] for j in nearest]

for k in range(top_k):
    print(music.nodes[l.node_id_inv[nearest[k]]], nearest_scores[k])

{'name': 'Traveling Wilburys'} 0.6650992


In [78]:
# what about genres?
node_names = ["classic rock"]
relationships = ["has_genre"]

query_nodes, query_rels = get_query(node_names, relationships)

batch_x = [np.array(query_nodes).reshape(-1, 1),
            np.array(query_rels).reshape(-1, 1)]

sim = evaluate(get_embedding(batch_x)).numpy()
top_k = 1  # number of nearest neighbors.
nearest = (-sim[0, :]).argsort()[:top_k]
nearest_scores = [sim[0, j] for j in nearest]

for k in range(top_k):
    print(music.nodes[l.node_id_inv[nearest[k]]], nearest_scores[k])

{'name': 'Allman Brothers Band'} 0.6447925
