In [1]:
# author: John PouguÃ©-Biyong

In [None]:
import os

from pathlib import Path
import multiprocessing as mp
from joblib import Parallel, delayed
from collections import defaultdict
import pkg_resources
import gc
import time
from copy import deepcopy

import networkx as nx
import json
import pandas as pd
import numpy as np
import scipy as sp
import random as rd
import matplotlib.pyplot as plt

from imblearn.under_sampling import RandomUnderSampler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, auc, roc_curve, roc_auc_score

import torch as t
import torch.nn as nn
from torch import LongTensor as LT
from torch import FloatTensor as FT
from tqdm.auto import tqdm
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
device = t.device("cuda:6" if t.cuda.is_available() else "cpu")
cuda = t.cuda.is_available()

import utils

In [None]:
def parallel_generate_walks(d_graph: dict, global_walk_length: int, num_walks: int, cpu_num: int,
                            sampling_strategy: dict = None, num_walks_key: str = None, 
                            walk_length_key: str = None, neighbors_key: str = None, 
                            probabilities_key: str = None, first_travel_key: str = None,
                            quiet: bool = True) -> list:
    """
    Generates the random walks which will be used as the skip-gram input.
    Returns list of walks. Each walk is a list of nodes.
    """

    walks = list()

    if not quiet:
        pbar = tqdm(total=num_walks, desc='Generating walks (CPU: {})'.format(cpu_num))

    for n_walk in range(num_walks):

        # Update progress bar
        if not quiet:
            pbar.update(1)

        # Shuffle the nodes
        shuffled_nodes = [key for key in d_graph if len(d_graph[key][neighbors_key]) > 0]
        rd.shuffle(shuffled_nodes)

        # Start a random walk from every node
        for source in shuffled_nodes:
          
            # Skip nodes with specific num_walks
            if source in sampling_strategy and \
                    num_walks_key in sampling_strategy[source] and \
                    sampling_strategy[source][num_walks_key] <= n_walk:
                continue

            # Start walk
            walk = [source]

            # Calculate walk length
            if source in sampling_strategy:
                walk_length = sampling_strategy[source].get(walk_length_key, global_walk_length)
            else:
                walk_length = global_walk_length

            # Perform walk
            while len(walk) < walk_length:
                walk_options = d_graph[walk[-1]].get(neighbors_key, None)

                # Skip dead end nodes
                if not walk_options:
                    break

                if len(walk) == 1:  # For the first step
                    probabilities = d_graph[walk[-1]][first_travel_key]
                    walk_to = np.random.choice(walk_options, size=1, p=probabilities)[0]
                else:
                    probabilities = d_graph[walk[-1]][probabilities_key][walk[-2]]
                    walk_to = np.random.choice(walk_options, size=1, p=probabilities)[0]

                walk.append(walk_to)

            walk = list(map(str, walk))  # Convert all to strings

            walks.append(walk)

    if not quiet:
        pbar.close()
    return walks


class WalksGenerator:
    FIRST_TRAVEL_KEY = 'first_travel_key'
    PROBABILITIES_KEY = 'probabilities'
    NEIGHBORS_KEY = 'neighbors'
    WEIGHT_KEY = 'weight'
    NUM_WALKS_KEY = 'num_walks'
    WALK_LENGTH_KEY = 'walk_length'
    P_KEY = 'p'
    Q_KEY = 'q'

    def __init__(self, graph: nx.Graph, dimensions: int = 64, 
                 walk_length: int = 80, num_walks: int = 10, p: float = 1,
                 q: float = 1, weight_key: str = 'weight', workers: int = 1, 
                 sampling_strategy: dict = None,
                 quiet: bool = False, temp_folder: str = None, seed: int = None):
        """
        Initiates the Node2Vec object, precomputes walking probabilities and generates the walks.
        :param graph: Input graph
        :param dimensions: Embedding dimensions (default: 64)
        :param walk_length: Number of nodes in each walk (default: 80)
        :param num_walks: Number of walks per node (default: 10)
        :param p: Return hyper parameter (default: 1)
        :param q: Inout parameter (default: 1)
        :param weight_key: On weighted graphs, this is the key for the weight attribute (default: 'weight')
        :param workers: Number of workers for parallel execution (default: 1)
        :param sampling_strategy: Node specific sampling strategies, supports setting node specific 'q', 'p', 'num_walks' and 'walk_length'.
        :param seed: Seed for the random number generator.
        Use these keys exactly. If not set, will use the global ones which were passed on the object initialization
        :param temp_folder: Path to folder with enough space to hold the memory map of self.d_graph (for big graphs); to be passed joblib.Parallel.temp_folder
        """

        self.graph = graph
        self.dimensions = dimensions
        self.walk_length = walk_length
        self.num_walks = num_walks
        self.p = p
        self.q = q
        self.weight_key = weight_key
        self.workers = workers
        self.quiet = quiet
        self.d_graph = defaultdict(dict)

        if sampling_strategy is None:
            self.sampling_strategy = {}
        else:
            self.sampling_strategy = sampling_strategy

        self.temp_folder, self.require = None, None
        if temp_folder:
            if not os.path.isdir(temp_folder):
                raise NotADirectoryError("temp_folder does not exist or is not a directory. ({})" \
                                         .format(temp_folder))

            self.temp_folder = temp_folder
            self.require = "sharedmem"

        if seed is not None:
            rd.seed(seed)
            np.random.seed(seed)

        self._precompute_probabilities()

    def _precompute_probabilities(self):
        """
        Precomputes transition probabilities for each node.
        """

        d_graph = self.d_graph

        nodes_generator = self.graph.nodes() if self.quiet \
            else tqdm(self.graph.nodes(), desc='Computing transition probabilities')

        for source in nodes_generator:

            # Init probabilities dict for first travel
            if self.PROBABILITIES_KEY not in d_graph[source]:
                d_graph[source][self.PROBABILITIES_KEY] = dict()

            for current_node in self.graph.neighbors(source):

                # Init probabilities dict
                if self.PROBABILITIES_KEY not in d_graph[current_node]:
                    d_graph[current_node][self.PROBABILITIES_KEY] = dict()

                unnormalized_weights = list()
                d_neighbors = list()

                # Calculate unnormalized weights
                for destination in self.graph.neighbors(current_node):

                    p = self.sampling_strategy[current_node].get(self.P_KEY,
                                                                 self.p) \
                    if current_node in self.sampling_strategy else self.p
                    q = self.sampling_strategy[current_node].get(self.Q_KEY,
                                                                 self.q) \
                    if current_node in self.sampling_strategy else self.q

                    if destination == source:  # Backwards probability
                        ss_weight = self.graph[current_node][destination].get(self.weight_key, 1) * 1 / p
                    elif destination in self.graph[source]:  # If the neighbor is connected to the source
                        ss_weight = self.graph[current_node][destination].get(self.weight_key, 1)
                    else:
                        ss_weight = self.graph[current_node][destination].get(self.weight_key, 1) * 1 / q

                    # Assign the unnormalized sampling strategy weight, normalize during random walk
                    unnormalized_weights.append(ss_weight)
                    d_neighbors.append(destination)

                # Normalize
                unnormalized_weights = np.array(unnormalized_weights)
                d_graph[current_node][self.PROBABILITIES_KEY][
                    source] = unnormalized_weights / unnormalized_weights.sum()

            # Calculate first_travel weights for source
            first_travel_weights = []

            for destination in self.graph.neighbors(source):
                first_travel_weights.append(self.graph[source][destination].get(self.weight_key, 1))

            first_travel_weights = np.array(first_travel_weights)
            d_graph[source][self.FIRST_TRAVEL_KEY] = first_travel_weights / first_travel_weights.sum()

            # Save neighbors
            d_graph[source][self.NEIGHBORS_KEY] = list(self.graph.neighbors(source))

    def generate_walks(self, workers) -> list:
        """
        Generates the random walks which will be used as the skip-gram input.
        Returns list of walks. Each walk is a list of nodes.
        """

        flatten = lambda l: [item for sublist in l for item in sublist]

        # Split num_walks for each worker
        num_walks_lists = np.array_split(range(self.num_walks), workers)

        walk_results = Parallel(n_jobs=workers, temp_folder=None, require=None)(
            delayed(parallel_generate_walks)(self.d_graph,
                                             self.walk_length,
                                             len(num_walks),
                                             idx,
                                             self.sampling_strategy,
                                             self.NUM_WALKS_KEY,
                                             self.WALK_LENGTH_KEY,
                                             self.NEIGHBORS_KEY,
                                             self.PROBABILITIES_KEY,
                                             self.FIRST_TRAVEL_KEY,
                                             self.quiet) for
            idx, num_walks
            in enumerate(num_walks_lists, 1))

        self.walks = flatten(walk_results)

class Walks(object):
  
  def __init__(self, abs_digraph, 
               walk_len=15, num_walks=10, p=.5, q=1.5, 
               workers=1, seed=2021):
      """ 
      Generate walks.
      Input:
          abs_digraph nx.DiGraph: absolute signed graph
          walk_len int: length of the walks
          num_walks int: #walks per node
          workers int: #workers for parallelization
          p float: return hyperparameter
          q float: inout hyperparameter
          seed int: random seed state
      """
      self.walk_generator = WalksGenerator(abs_digraph,
                                           dimensions=64,
                                           walk_length=walk_len,
                                           num_walks=num_walks,
                                           weight_key='weight',
                                           quiet=False,
                                           workers=workers, 
                                           p=p, 
                                           q=q, 
                                           seed=seed)
  
  def generate(self, workers=1):
    """ Generate walks. """
    self.walk_generator.generate_walks(workers)
    self.walks = self.walk_generator.walks 
    del self.walk_generator

In [None]:
def create_contexts(walk, i, window, adj_mat, unk):
        """ 
        Creates contexts with multi-step edge sign estimation.
        Input:
          walk list
          i int: current node
        Returns:
          inode int: current node
          left + right list: contexts
        """
        inode = int(walk[i])
        left_sign, right_sign = 1, 1
        left, right = [], []
        last_visited_left, last_visited_right = inode, inode
        count = 1
        while count < window + 1:
            if i - count > -1:
                context_node_left = int(walk[i - count])
                current_sign_left = adj_mat[context_node_left, last_visited_left]
                if current_sign_left not in [-1, 1]:
                    print('!!')
                    print(current_sign_left)
                    print('---')
                left_sign *= current_sign_left
                if left_sign == 1:
                    context_left = str(context_node_left) + '+'
                elif left_sign == -1:
                    context_left = str(context_node_left) + '-'
                last_visited_left = context_node_left
            else:
                context_left = unk
            
            if i + count < len(walk):
                context_node_right = int(walk[i + count])
                current_sign_right = adj_mat[last_visited_right, context_node_right]
                if current_sign_right not in [-1, 1]:
                    print('!!')
                    print(current_sign_right)
                    print('---')
                right_sign *= current_sign_right
                if right_sign == 1:
                    context_right = str(context_node_right) + '+'
                elif right_sign == -1:
                    context_right = str(context_node_right) + '-'
                last_visited_right = context_node_right
            else:
                context_right = unk
            
            count += 1
            left += [context_left]
            right += [context_right]
        left.reverse()
        
        return inode, left + right

def parallel_create_contexts(walks, topic_idx, cpu_num, window, adj_mat, unk):
        """ 
        Converts walks into (node, contexts) tuples.
        Input:
          walks list
          topic_idx int
          cpu_num int: # CPUs
          window int: window size for the contexts
          adj_mat sparse.csr_matrix: signed adjacency matrix
          unk str: padding token
        """
        data = []
        pbar = tqdm(total=len(walks), desc='Generating contexts (CPU: {})'.format(cpu_num))
        for walk in walks:
            pbar.update(1)
            for i in range(len(walk)):
                inode, ocontexts = create_contexts(walk, i, window, adj_mat, unk)
                data.append((inode, ocontexts, topic_idx))
        pbar.close()
        return data

class TrainingSamples(object):
    
    def __init__(self, window=5, adj_mat=None, 
                 idx2node={}, node2idx={}, idx2topic={}, topic2idx={},
                 unk='<UNK>'):
        """ 
        Input:
          window int: window size for the contexts
          adj_mat sparse.csr_matrix: signed adjacency matrix
          idx2node dict: mapping index to node
          node2idx dict: mapping node to index
          unk str: padding token
        """
        self.window = window
        self.unk = unk
        self.adj_mat = adj_mat
        self.node2idx = node2idx
        node2idx[self.unk] = len(node2idx)
        self.idx2node = idx2node
        idx2node[len(idx2node)] = self.unk
        self.node_vocab = set([node for node in self.node2idx])
        self.topic2idx = topic2idx
        self.idx2topic = idx2topic
        self.topic_vocab = set([topic for topic in self.topic2idx])
        self.data = None

    def build_context_vocab(self):
        step = 0
        self.cc = {self.unk: 1}
        print("computing context frequencies...")
        for (inode, ocontexts, topic_idx) in tqdm(self.data):
            for context in ocontexts:
                self.cc[context] = self.cc.get(context, 0) + 1
        print("")
        print("building context vocab...")
        self.cc[self.unk] = 1
        self.idx2context = [self.unk] + sorted(self.cc, key=self.cc.get, reverse=True)
        self.context2idx = {self.idx2context[idx]: idx for idx, _ in enumerate(self.idx2context)}
        self.context_vocab = set([context for context in self.context2idx])
        data = []
        size_data = len(self.data) 
        for _ in tqdm(range(size_data)):
            inode, ocontexts, topic_idx = self.data.pop()
            data.append((inode, [self.context2idx[ocontext] for ocontext in ocontexts], topic_idx))
        self.data = data
        print("building done")
      
    def convert(self, walks, topic_idx, workers):
        
        print("converting corpus in parallel...")
        flatten = lambda l: [item for sublist in l for item in sublist]

        # Split num_walks for each worker
        walks_lists = np.array_split(walks, workers)
        
        context_results = Parallel(n_jobs=workers, temp_folder=None, require=None)(
            delayed(parallel_create_contexts)(wlks,
                                              topic_idx,
                                              idx, 
                                              self.window, 
                                              self.adj_mat, 
                                              self.unk) \
          for idx, wlks in enumerate(walks_lists, 1)
        )
        self.data = flatten(context_results)
        print('conversion done')
        
    def info(self):
        print('#nodes:', len(self.node_vocab))
        print('#contexts:', len(self.context_vocab))
        print('#training samples:',len(self.data))
        print('An example:', self.data[10])

In [None]:
class DataProcessor(object):
  
    def __init__(self, edge_data, node2idx, idx2node, dataset, 
               path_walks, path_contexts, 
               dataset_type='hetero', topic2idx={'unk': 0}, idx2topic={0: 'unk'}):
        self.dataset = dataset
        self.dataset_type = dataset_type
        self.edge_data = edge_data
        self.node2idx = node2idx
        self.idx2node = idx2node
        self.topic2idx = topic2idx
        self.idx2topic = idx2topic
        self.path_walks = path_walks
        self.path_contexts = path_contexts
        self.size = len(self.node2idx)
        print('#nodes:', self.size)
        self.nb_topics = len(self.topic2idx)
          
    def build_training_data(self, p, q, nw, wl, wind, workers=16):
        """ Builds networkX directed graphs from adjacency matrix."""
        print('building training samples...')

        walks_available = False
        contexts_available = False
        my_walks_path = self.path_walks + '_p{p}_q{q}_nw{nw}_wl{wl}.json' \
        .format(tp=self.dataset_type, p=p, q=q, nw=nw, wl=wl)
        my_file = Path(my_walks_path)
        if my_file.is_file():
            print('Walks exist.')
            walks_available = True
            with open(my_walks_path, 'r') as F:
                self.all_walks = json.loads(F.read())
            F.close()
        else:
            print('Walks dont exist.')
        my_contexts_path = self.path_contexts + '_p{p}_q{q}_nw{nw}_wl{wl}_wind{wind}.json' \
        .format(tp=self.dataset_type, p=p, q=q, nw=nw, wl=wl, wind=wind)
      
        my_file = Path(my_contexts_path)
        if my_file.is_file():
            print('Contexts exist.')
            contexts_available = True
            with open(my_contexts_path, 'r') as F:
                self.all_contexts = json.loads(F.read())
            F.close()
        else:
            print('Contexts dont exist.')

        if not (walks_available and contexts_available):
            self.all_walks = []
            self.all_contexts = []
            for topic_idx in tqdm(self.idx2topic):
                print('\ntopic no {}:'.format(topic_idx), self.idx2topic[topic_idx])
                dt = self.edge_data[self.edge_data.topic_idx == topic_idx]
                row = dt.source_idx.values
                col = dt.target_idx.values
                data = dt.weight.values
                adj_mat = sp.sparse.csr_matrix((data, (row, col)),
                                                   shape=(self.size, self.size))
                abs_adj_mat = \
                sp.sparse.csr_matrix(adj_mat.multiply(adj_mat))
                digraph = nx.from_scipy_sparse_matrix(abs_adj_mat,
                                                         create_using=nx.DiGraph)
                start_time = time.time()
                print('generating walks...')
                walks_instance = Walks(digraph, p=p, q=q, num_walks=nw, walk_len=wl)
                walks_instance.generate()
                walks = walks_instance.walks
                print('execution time:', time.time() - start_time)
                print('#walks:', len(walks))
                self.all_walks += [(walk, topic_idx) for walk in walks]
                training_samples = TrainingSamples(window=wind,
                                                 adj_mat=adj_mat,
                                                 idx2node=deepcopy(self.idx2node),
                                                 node2idx=deepcopy(self.node2idx),
                                                 idx2topic=deepcopy(self.idx2topic),
                                                 topic2idx=deepcopy(self.topic2idx),
                                                 unk='<UNK>')
                training_samples.convert(walks, topic_idx, workers=workers)
                self.all_contexts += training_samples.data
                del training_samples
                del walks
                del adj_mat
                del abs_adj_mat
                del digraph
            with open(my_walks_path, 'w') as F:
                F.write(json.dumps(self.all_walks))
            F.close()
            with open(my_contexts_path, 'w') as F:
                F.write(json.dumps(self.all_contexts))
            F.close()
      
        training_samples = TrainingSamples(idx2node=deepcopy(self.idx2node),
                                         node2idx=deepcopy(self.node2idx),
                                         idx2topic=deepcopy(self.idx2topic),
                                         topic2idx=deepcopy(self.topic2idx),
                                         unk='<UNK>')
        training_samples.data = self.all_contexts
        training_samples.build_context_vocab()
        training_samples.info()
        self.training_samples = training_samples
        print('built training samples.')

In [None]:
class Bundler(nn.Module):

    def forward(self, data):
        raise NotImplementedError

    def forward_i(self, data):
        raise NotImplementedError

    def forward_o(self, data):
        raise NotImplementedError

class Embeddings(Bundler):

    def __init__(self, node_vocab_size=5000, topic_vocab_size=100, 
                 context_vocab_size=10001, embedding_size=64, op='addition'):
        """ 
        Input:
          node_vocab_size int: #nodes
          context_vocab_size int: #contexts
          embedding_size int
        """
        super(Embeddings, self).__init__()
        self.node_vocab_size = node_vocab_size
        self.topic_vocab_size = topic_vocab_size
        self.context_vocab_size = context_vocab_size
        self.embedding_size = embedding_size
        self.padding_idx = self.context_vocab_size - 1
        self.op = op
        
        self.ivectors = nn.Embedding(self.node_vocab_size, self.embedding_size)
        self.ivectors.weight = nn.Parameter(FT(self.node_vocab_size, self.embedding_size) \
                                            .uniform_(-0.5 / self.embedding_size, 0.5 / self.embedding_size))
        self.ivectors.weight.requires_grad = True
        print(self.ivectors.weight.data.cpu().numpy().shape)
        
        self.itopics = nn.Embedding(self.topic_vocab_size, self.embedding_size)
        if self.topic_vocab_size > 1:
            self.multitopic = True
            self.itopics.weight = nn.Parameter(FT(self.topic_vocab_size, self.embedding_size) \
                                                .uniform_(-0.5 / self.embedding_size, 0.5 / self.embedding_size))
            self.itopics.weight.requires_grad = True
        elif self.topic_vocab_size == 1:
            self.multitopic = False
            self.itopics.weight = nn.Parameter(FT(self.topic_vocab_size, self.embedding_size) \
                                                .uniform_(0, 0))
            self.itopics.weight.requires_grad = False
        
        self.ovectors = nn.Embedding(self.context_vocab_size, 
                                     self.embedding_size, 
                                     padding_idx=self.padding_idx)
        self.ovectors.weight = nn.Parameter(t.cat([FT(self.context_vocab_size - 1, self.embedding_size) \
                                                   .uniform_(-0.5 / self.embedding_size, 0.5 / self.embedding_size), 
                                                   t.zeros(1, self.embedding_size)]))
        self.ovectors.weight.requires_grad = True

    def forward(self, data):
        return self.forward_i(data)

    def forward_i(self, node, topic):
        u = LT(node)
        u = u.to(device)
        if self.multitopic:
            v = LT(topic)
            v = v.to(device)
            return self.__operation(self.ivectors(u), self.itopics(v))
        else:
            return self.ivectors(u)

    def forward_o(self, context):
        v = LT(context)
        v = v.to(device)
        return self.ovectors(v)
    
    def __operation(self, x, y):
        if self.op == 'addition':
            return x + y
        elif self.op == 'hadamard':
            return x * y

In [None]:
class SGNS(nn.Module):
    """ Skipgram with negative sampling."""
    
    def __init__(self, embedding, n_negs=20, weights=None):
        """ 
        Input:
          embedding: Embeddings object
          n_negs int: #negative samples
          weights np.array: context weights
        """
        super(SGNS, self).__init__()
        self.embedding = embedding
        self.context_vocab_size = self.embedding.context_vocab_size
        self.n_negs = n_negs
        self.weights = None
        if weights is not None:
            cf = np.power(weights, 0.75)
            cf = cf / cf.sum()
            self.weights = FT(cf)

    def forward(self, inode, ocontexts, itopic):
        batch_size = inode.size()[0]
        context_size = ocontexts.size()[1]
        if self.weights is not None:
            ncontexts = t.multinomial(self.weights,
                                      batch_size * context_size * self.n_negs,
                                      replacement=True).view(batch_size, -1)
        else:
            ncontexts = FT(batch_size, context_size * self.n_negs) \
                        .uniform_(0, self.context_vocab_size - 1).long()
        ivectors = self.embedding.forward_i(inode, itopic).unsqueeze(2)
        ovectors = self.embedding.forward_o(ocontexts)
        nvectors = self.embedding.forward_o(ncontexts).neg()
        oloss = t.bmm(ovectors, ivectors).squeeze().sigmoid().log().mean(1)
        nloss = t.bmm(nvectors, ivectors).squeeze().sigmoid().log() \
                .view(-1, context_size, self.n_negs).sum(2).mean(1)
        return -(oloss + nloss).mean()

In [None]:
class PermutedSubsampledCorpus:
    """ For subsampling data if needed."""
    def __init__(self, data, idx2context, us=True, ratio=.8, cs=None):
        if cs is not None:
            self.data = []
            for inode, ocontexts, itopic, osign in data:
                if rd.random() > cs[iword]:
                    self.data.append((inode, ocontexts, itopic, osign))
        else:
            self.data = data
        self.idx2context = idx2context
        self.negative_data = []
        self.positive_data = []
        self.us = us
        if self.us:
            print('spliting negative and positive contexts...')
            for inode, ocontexts, itopic in tqdm(self.data):
                sign = 1 
                max_count = len(ocontexts)
                count = 0 
                while sign == 1 and count < max_count:
                    if self.idx2context[ocontexts[count]][-1] == '-':
                        sign = 0
                    count += 1
                if sign == 0:
                    self.negative_data.append((inode, ocontexts, itopic))
                else:
                    self.positive_data.append((inode, ocontexts, itopic))
            print('split negative and positive contexts.')
            self.ratio = ratio
            self.num_negative = len(self.negative_data)
            self.num_positive = len(self.positive_data)
            self.min_class = 'negative' if self.num_negative < self.num_positive else 'positive'
            self.ratio_dict = {0: self.num_negative, 1: self.num_positive}
            print('#negative contexts:', self.num_negative)
            print('#positive contexts:', self.num_positive)
        
    
    def undersample(self):
        if self.us:
            if self.min_class == 'negative':
                self.data_undersampled = self.negative_data
                self.data_undersampled += rd.sample(self.positive_data, int(self.ratio * self.num_negative))
            else:
                self.data_undersampled = self.positive_data
                self.data_undersampled += rd.sample(self.negative_data, int(self.ratio * self.num_positive))
        else:
            self.data_undersampled = self.data
        
    def __len__(self):
        return len(self.data_undersampled)

    def __getitem__(self, idx):
        inode, ocontexts, itopic = self.data_undersampled[idx]
        return inode, np.array(ocontexts), itopic

In [None]:
def train(training_samples, idx2context, idx2node, idx2topic, cc, us=False, ratio=0.8,
          ss_t=1e-5, wgts=False, e_dim=64, n_negs=20, epoch=100, mb=4096, op='addition'):
    """ 
    Input:
        cc dict: context count
        ss_t float: subsampling threshold
        wgts np.array: context weights
        mb int: minibatch size
    Returns:
        idx2vec np.array: trained node embeddings
    """
    
    global model 
    
    cf = np.array([cc[context] for context in idx2context])
    cf = cf / cf.sum()
    cs = 1 - np.sqrt(ss_t / cf)
    cs = np.clip(cs, 0, 1)
    context_vocab_size = len(idx2context)
    node_vocab_size = len(idx2node)
    topic_vocab_size = len(idx2topic)
    weights = cf if wgts else None
    print('Initialize model...')
    model = Embeddings(node_vocab_size=node_vocab_size,
                       topic_vocab_size=topic_vocab_size, 
                       context_vocab_size=context_vocab_size, 
                       embedding_size=e_dim, 
                       op=op)
    if t.cuda.device_count() > 1:
        print("Used:", t.cuda.device_count(), "GPUs")
        model = nn.DataParallel(model, device_ids=[0, 1, 6, 7], output_device=6)
        sgns = SGNS(embedding=model.module, 
                    n_negs=n_negs, 
                    weights=weights).to(device)
    else:
        sgns = SGNS(embedding=model, 
                    n_negs=n_negs, 
                    weights=weights).to(device)
    print('Initialized model...')
    optim = Adam(sgns.parameters())
    dataset = PermutedSubsampledCorpus(training_samples, idx2context, us, ratio)
    print('Input data ready.')
    for epoch in range(1, epoch + 1):
        dataset.undersample()
        dataloading = DataLoader(dataset, batch_size=mb, shuffle=True)
        total_batches = int(np.ceil(len(dataset) / mb))
        pbar = tqdm(dataloading)
        pbar.set_description("[Epoch {}]".format(epoch))
        for inode, ocontexts, itopic in pbar:
            loss = sgns(inode, ocontexts, itopic)
            optim.zero_grad()
            loss.backward()
            optim.step()
            pbar.set_postfix(loss=loss.item())
    if t.cuda.device_count() > 1:
        idx2vec = model.module.ivectors.weight.data.cpu().numpy()
        idx2topic = model.module.itopics.weight.data.cpu().numpy()
    else:
        idx2vec = model.ivectors.weight.data.cpu().numpy()
        idx2topic = model.itopics.weight.data.cpu().numpy()
    
    return idx2vec, idx2topic

In [None]:
# Training
dataset = 'birdwatch'
dataset_type = 'hetero'
p = 1.5
q = .5
nw = 5 
wl = 40
workers = 3 
wind = 5
wgts = True
sst = 1e-5
edim = 64 
nnegs = 20 
ep = 5 
mb = 4096
op = 'addition'

edge_data = pd.read_csv('cached_data/full_datasets/{}.csv'.format(dataset)
databuilder = utils.DataBuilder(edge_data, dataset_name=dataset)
dataloader = utils.Dataloader(dataset_name=dataset)

topic2id = dataloader.topic2idx
id2topic = dataloader.idx2topic

for _ in range(5):
    print('Training set: {}'.format(_+1))
    data = DataProcessor(edge_data=dataloader.training_data[dataset_type][_+1], 
                         node2idx=dataloader.node2idx,
                         idx2node=dataloader.idx2node,
                         topic2idx=topic2id,
                         idx2topic=id2topic,
                         dataset=dataloader.dataset_name,
                         dataset_type=dataset_type,
                         path_walks='cached_data/SE/{data}_type{tp}_walks_train{t}' \
                         .format(data=dataset, tp=dataset_type, t=_+1),
                         path_contexts='cached_data/SE/{data}_type{tp}_contexts_train{t}' \
                         .format(data=dataset, tp=dataset_type, t=_+1))
    data.build_training_data(p=p, q=q, nw=nw, wl=wl, wind=wind, workers=workers)
    idx2vec, idx2topic  = train(training_samples=data.training_samples.data,
                                idx2context=data.training_samples.idx2context, 
                                idx2node=data.training_samples.idx2node, 
                                idx2topic=data.training_samples.idx2topic,
                                cc=data.training_samples.cc,
                                us=False,
                                ratio=1.,
                                ss_t=sst, 
                                wgts=wgts, 
                                e_dim=edim, 
                                n_negs=nnegs, 
                                epoch=ep, 
                                mb=mb, 
                                op=op)