<a href="https://colab.research.google.com/github/krajkumar6/MachineLearning/blob/master/3_creating_DGL_graph.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Creating a heterogenous DGL graph with the following node types and relationships from raw data that has been prepared in the previous notebooks

*Nodes*

1.   Users
2.   News
1.   Category
2.   SubCategory

*Edges/Relationships*

1.   Users->read->News
2.   News>belongs->Category

1.   SubCategory>belongs>Category
















In [1]:
# Install dgl library
!pip install dgl -f https://data.dgl.ai/wheels/repo.html

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.dgl.ai/wheels/repo.html


In [2]:
!export DGLBACKEND='pytorch'

In [3]:
# Import all libraries
import pandas as pd
import numpy as np
import torch as th
import dgl

DGL backend not selected or invalid.  Assuming PyTorch for now.


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)


In [4]:
# mounting gdrive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
# importing nodes data from csv files
path = '/content/drive/MyDrive/Raj/Work related/Colab Notebooks/PaaS demo/MIND-small/train/'
user = pd.read_csv(path+'users_emb.csv')
news_title = pd.read_csv(path+'Title_emb.csv')
news_abs = pd.read_csv(path+'Abs_emb.csv')
cat = pd.read_csv(path+'cat_enc.csv')
sub_cat = pd.read_csv(path+'sub_cat_enc.csv')

user = user.iloc[:,1:]
print(user.shape)
cat=cat.iloc[:,1:]
print(cat.shape)
sub_cat= sub_cat.iloc[:,1:]
print(sub_cat.shape)
news_title = news_title.iloc[:,1:]
print(news_title.shape)
news_abs = news_abs.iloc[:,1:]
print(news_abs.shape)

(50000, 384)
(17, 17)
(264, 264)
(51282, 384)
(51282, 384)


In [6]:
news_title.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,374,375,376,377,378,379,380,381,382,383
0,-0.009286,0.042442,0.058988,0.012106,0.033419,0.019646,0.027369,-0.061527,-0.036167,-0.039047,...,-0.006674,-0.003653,0.048061,-0.001848,-0.037265,-0.010018,0.051743,-0.131551,0.050825,-0.030139
1,0.018617,0.038452,0.026627,0.083813,-0.007007,-0.030662,0.020416,0.000688,-0.092173,0.017821,...,0.003076,0.005209,-0.005404,-0.031541,-0.033012,0.087852,0.073781,-0.039684,-0.008007,0.000351
2,-0.026763,0.067353,0.047776,0.065173,0.001775,-0.004801,-0.026565,0.03151,-0.043916,0.001157,...,0.088919,-0.00192,0.030752,0.028289,-0.030664,0.027943,-0.060686,-0.068997,-0.010749,-0.007543
3,0.067175,0.093692,0.010179,0.030391,-0.01182,0.076098,0.070887,0.035816,0.03656,0.003015,...,0.043614,-0.012557,-0.003307,-0.04403,-0.058982,-0.079833,0.060218,0.004372,-0.063992,-0.013305
4,0.001295,0.104388,-0.016191,0.016463,0.023895,-0.089162,0.073511,0.030496,-0.076674,-0.051121,...,-0.051557,-0.053719,0.118986,0.02045,-0.029104,0.016139,0.042121,-0.010622,0.048454,0.075503


In [7]:
# importing relationship data from csv diles
user_art_rel = pd.read_csv(path+'user_art_rel1.csv')
print(user_art_rel.shape)
# importing negative graph consisting of user-article sets where user saw an article but did not click
user_art_neg_rel = pd.read_csv(path+'user_art_neg_samp_rel1.csv')
print(user_art_neg_rel.shape)
# art_cat_rel = pd.read_csv(path+'art_cat_rel.csv')
# print(art_cat_rel.shape)
# cat_sub_cat_rel = pd.read_csv(path+'cat_subcat_rel.csv')
# print(cat_sub_cat_rel.shape)

(1318793, 2)
(1157708, 2)


### Creating DGL dataset and (DGL graph)

In [8]:
from dgl.data import DGLDataset
from dgl.data.utils import save_graphs,load_graphs

class MIND_mini_DGLdataset(DGLDataset):

  def __init__(self,
               url = None,
               raw_dir = None,
               save_dir = None,
               force_reload = False,
               verbose = False):
    super(MIND_mini_DGLdataset,self).__init__(name = 'MIND_mini',
                                           url = url,
                                          raw_dir = raw_dir,
                                          save_dir = save_dir,
                                          force_reload = force_reload,
                                          verbose = verbose
                                          )
  
  def process(self):
    # process raw data to graphs, labels, splitting masks
    
    # converting nodes data to torch tensors
    print('DGL Dataset process begins')

    user_th = th.from_numpy(user.to_numpy().astype(np.float32)) 
    news_title_th = th.from_numpy(news_title.to_numpy().astype(np.float32))
    news_abs_th = th.from_numpy(news_abs.to_numpy().astype(np.float32))
    # cat_th = th.from_numpy(cat.to_numpy())
    # sub_cat_th = th.from_numpy(sub_cat.to_numpy())
    
    # print(news_title_th.shape)
    # print(news_abs_th.shape)

    # converting edges data to torch tensors
    user_art_rel_u = th.from_numpy(user_art_rel['UserID_int'].to_numpy())
    user_art_rel_v = th.from_numpy(user_art_rel['NewsID_int'].to_numpy())

    user_art_neg_rel_u = th.from_numpy(user_art_neg_rel['UserID_int'].to_numpy())
    user_art_neg_rel_v = th.from_numpy(user_art_neg_rel['NewsID_int'].to_numpy())
    # art_cat_rel_u = th.from_numpy(art_cat_rel['NewsID_int'].to_numpy())
    # art_cat_rel_v = th.from_numpy(art_cat_rel['CatID_int'].to_numpy())
    # cat_sub_cat_rel_u = th.from_numpy(cat_sub_cat_rel['CatID_int'].to_numpy())
    # cat_sub_cat_rel_v = th.from_numpy(cat_sub_cat_rel['SubcatID_int'].to_numpy())

    # creating dgl graph for relationship - user-clicks-article
    graph_data = {
        ('user','clicks','article'):(user_art_rel_u,user_art_rel_v),
        ('article','clicked_by','user'):(user_art_rel_v,user_art_rel_u)

        # ('user','no-click','article'):(user_art_neg_rel_u,user_art_neg_rel_v),
        # ('article','not_clicked_by','user'):(user_art_neg_rel_v,user_art_neg_rel_u)
        
        # ('article','belongs','category'):(art_cat_rel_u,art_cat_rel_v),
        # ('category','contains','article'):(art_cat_rel_v,art_cat_rel_u),

        # ('category','parent','subcategory'):(cat_sub_cat_rel_u,cat_sub_cat_rel_v),
        # ('subcategory','child','category'):(cat_sub_cat_rel_v,cat_sub_cat_rel_u)
    }
    graph_data_dict ={'user':50000,
                      'article':51282}

    self.hg = dgl.heterograph(graph_data,graph_data_dict)

    # setting features for nodes and edges
    self.hg.nodes['user'].data['feat'] = user_th
    # self.hg.nodes['category'].data['feat'] = cat_th
    # self.hg.nodes['subcategory'].data['feat'] = sub_cat_th
    self.hg.nodes['article'].data['title'] = news_title_th
    self.hg.nodes['article'].data['abs'] = news_abs_th
    print(self.hg.num_nodes('article'))
    

  def len(self):
    # number of graphs
    return 1

  def __getitem__(self,idx):
    # get one example graph by index - idx
    assert idx == 0,"There is just one graph in this dataset"
    return (self.hg)
  
  def save(self):
    dgl.save_graphs(path + 'MIND_small.bin',[self.hg])
    print(f'DGL dataset {self.name} is stored in binary format in {path}')

  # def load(self):
  #   self.graphs,label_dict = dgl.load_graphs(self.path + 'MIND_small.bin')
  #   print(f'DGL dataset {self.name} retrieved to list of graphs - self.graphs[]')

In [9]:
dataset = MIND_mini_DGLdataset()

DGL Dataset process begins
51282
DGL dataset MIND_mini is stored in binary format in /content/drive/MyDrive/Raj/Work related/Colab Notebooks/PaaS demo/MIND-small/train/


In [10]:
hg = dataset[0]

In [None]:
print(f'hg.ntypes {hg.ntypes}')
print(f'hg.etypes {hg.etypes}')
print(f'hg.canonical_etypes {hg.canonical_etypes}')

hg.ntypes ['article', 'user']
hg.etypes ['clicked_by', 'not_clicked_by', 'clicks', 'no-click']
hg.canonical_etypes [('article', 'clicked_by', 'user'), ('article', 'not_clicked_by', 'user'), ('user', 'clicks', 'article'), ('user', 'no-click', 'article')]


In [11]:
hg

Graph(num_nodes={'article': 51282, 'user': 50000},
      num_edges={('article', 'clicked_by', 'user'): 1318793, ('user', 'clicks', 'article'): 1318793},
      metagraph=[('article', 'user', 'clicked_by'), ('user', 'article', 'clicks')])