# Graph Neural Network for Emergence of TEK - Patent Paper Pairs with Node Classification

In [1]:
import pandas as pd
import re
import numpy as np
import torch
import h5py
import ast
import torch
import multiprocessing as mp
import os.path as osp
import gcld3
from sqlalchemy import create_engine, URL, text, MetaData, Table
from tqdm import tqdm
from tqdm.auto import tqdm
tqdm.pandas()
from rapidfuzz import fuzz, process, distance
from rapidfuzz.distance import Levenshtein
# from concurrent.futures import ProcessPoolExecutor, as_completed
import torch.nn.functional as F
from torch_geometric.data import HeteroData, Dataset, Data
from torch_geometric.nn import SAGEConv, GATConv, HeteroConv, MessagePassing
from torch_geometric.loader import NeighborLoader, DataLoader, LinkNeighborLoader
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.transforms import RandomLinkSplit
from sentence_transformers import SentenceTransformer

In [2]:
model = SentenceTransformer('distilbert/distilbert-base-uncased')
detector = gcld3.NNetLanguageIdentifier(min_num_bytes=0, max_num_bytes=1000)
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'

No sentence-transformers model found with name /home/thiesen/.cache/torch/sentence_transformers/distilbert_distilbert-base-uncased. Creating a new one with MEAN pooling.


# Data Preprocessing

## OpenAlex Works from Reliance on Science

In [3]:
df_rel_pcs = pd.read_csv("/mnt/hdd01/Reliance on Science/Raw Files/_pcs_oa.csv")
df_rel_ppp = pd.read_csv("/mnt/hdd01/Reliance on Science/Raw Files/_patent_paper_pairs.tsv", sep="\t")

In [4]:
df_rel_ppp = df_rel_ppp.astype(str)
df_rel_ppp['patent'] = df_rel_ppp['patent'].apply(lambda x: x.lower())
df_rel_ppp = df_rel_ppp[df_rel_ppp['patent'].apply(lambda x: "us" in x)] # ONLY US patents in the original dataset - inference on US and EP patents!!
df_rel_ppp['patent_id'] = df_rel_ppp['patent'].apply(lambda x: x.split("-", 1)[1].rsplit("-", 1)[0])
df_rel_ppp['patent_paper_pair'] = 1
df_rel_ppp = df_rel_ppp.rename(columns={'magid': 'oaid'})

In [5]:
df_rel_pcs = df_rel_pcs.astype(str)
df_rel_pcs['patent'] = df_rel_pcs['patent'].apply(lambda x: x.lower())
df_rel_pcs = df_rel_pcs[df_rel_pcs['patent'].apply(lambda x: "us" in x)]
df_rel_pcs['patent_id'] = df_rel_pcs['patent'].apply(lambda x: x.split("-", 1)[1].rsplit("-", 1)[0])

In [6]:
df_rel_pcs_filtered = df_rel_pcs[~df_rel_pcs['oaid'].isin(df_rel_ppp['oaid'])]
df_rel_pcs_filtered = df_rel_pcs_filtered.groupby('oaid').filter(lambda x: len(x) >= 5) # 5 is the minimum number of mentions in the dataset, to ensure well connected graph
df_rel_pcs_filtered = df_rel_pcs_filtered.groupby('patent_id').filter(lambda x: len(x) >= 5) # 5 is the minimum number of mentions in the dataset, to ensure well connected graph
df_rel_pcs_sample = df_rel_pcs_filtered.sample(n=df_rel_ppp['oaid'].nunique(), random_state=42)
df_rel_pcs_sample = df_rel_pcs_sample.reset_index(drop=True)
df_rel_pcs_sample['patent_paper_pair'] = 0

In [7]:
oaid_list = np.unique(np.concatenate([df_rel_ppp['oaid'].unique(), df_rel_pcs_sample['oaid'].unique()]))
patent_list = np.unique(np.concatenate([df_rel_ppp['patent_id'].unique(), df_rel_pcs_sample['patent_id'].unique()]))

In [8]:
df_rel_info = pd.concat([df_rel_ppp, df_rel_pcs_sample], ignore_index=True)

### Extract Works from Postgres OpenAlex

In [None]:
url_object = URL.create(
    drivername="postgresql+psycopg2",
    username=user,
    password=password,
    host=host,
    port=port,
    database=db,
)
engine = create_engine(url_object)

In [None]:
with engine.begin() as connection:
    connection.execute(text("""
        CREATE TABLE temp_oaid_ppp (
            oaid VARCHAR PRIMARY KEY
        )
    """))
    oaid_prefixed = ['https://openalex.org/W' + str(oaid) for oaid in oaid_list]
    for oaid in tqdm(oaid_prefixed):
        connection.execute(text("INSERT INTO temp_oaid_ppp (oaid) VALUES (:oaid)"), {'oaid': oaid})

In [None]:
df_rel_postgres = pd.read_sql_query("""
    SELECT w.id, w.title, w.abstract_inverted_index 
    FROM openalex.works AS w
    JOIN temp_oaid_ppp AS t ON w.id = t.oaid
""", con=engine)
df_rel_postgres = df_rel_postgres.drop_duplicates(subset=['id'])

In [None]:
df_rel_postgres = df_rel_postgres[df_rel_postgres['abstract_inverted_index'].apply(lambda x: x is not None and x['InvertedIndex'] != {})]

def reconstruct_abstract(row):
    # Extract the InvertedIndex for the current row
    inverted_index = row['abstract_inverted_index']['InvertedIndex']
    
    # Create a mapping of positions to words
    position_to_word = {}
    for word, positions in inverted_index.items():
        for position in positions:
            position_to_word[position] = word
    
    # Sort positions and reconstruct the abstract
    sorted_positions = sorted(position_to_word.keys())
    full_text_abstract = " ".join(position_to_word[pos] for pos in sorted_positions)
    
    # Fix punctuation spacing
    full_text_abstract = re.sub(r'\s+([.,;?!:])', r'\1', full_text_abstract)
    
    return full_text_abstract

df_rel_postgres['abstract'] = df_rel_postgres.apply(reconstruct_abstract, axis=1)

In [None]:
df_rel_postgres.to_csv("/mnt/hdd01/Reliance on Science/ppp_oa_works.csv", index=False)

In [9]:
df_rel_postgres = pd.read_csv("/mnt/hdd01/Reliance on Science/ppp_oa_works.csv")

### Extract Authors from Postgres OpenAlex

In [None]:
df_rel_authors_postgres = pd.read_sql_query("""
    SELECT a.work_id, a.author_id
    FROM openalex.works_authorships as a
    JOIN temp_oaid_ppp AS t ON a.work_id = t.oaid
""", con=engine)
df_rel_authors_postgres = df_rel_authors_postgres.drop_duplicates(subset=['author_id', 'work_id'])

In [None]:
df_rel_authors_postgres_grouped = df_rel_authors_postgres.groupby('author_id')['work_id'].apply(list).reset_index()

In [None]:
with engine.begin() as connection:
    connection.execute(text("""
        CREATE TEMPORARY TABLE temp_author_id_ppp (
            author_id VARCHAR PRIMARY KEY
        )
    """))
    for author_id in tqdm(df_rel_authors_postgres_grouped['author_id']):
        connection.execute(text("INSERT INTO temp_author_id_ppp (author_id) VALUES (:author_id)"), {'author_id': author_id})

In [None]:
df_rel_authors_info_postgres = pd.read_sql_query("""
    SELECT a.id, a.display_name, a.display_name_alternatives
    FROM openalex.authors as a
    JOIN temp_author_id_ppp AS t ON a.id = t.author_id
""", con=engine)
df_rel_authors_info_postgres = df_rel_authors_info_postgres.drop_duplicates(subset=['display_name'])

In [None]:
df_rel_authors_complete = pd.merge(df_rel_authors_postgres_grouped, df_rel_authors_info_postgres, left_on='author_id', right_on='id', how='inner')
df_rel_authors_complete['oaid'] = df_rel_authors_complete['work_id'].apply(lambda x: [i.replace("https://openalex.org/W", "") for i in x])

In [None]:
df_rel_authors_complete.to_csv("/mnt/hdd01/Reliance on Science/ppp_oa_authors.csv", index=False)

In [10]:
df_rel_authors_complete = pd.read_csv("/mnt/hdd01/Reliance on Science/ppp_oa_authors.csv")

### Extract Paper Citations from Postgres OpenAlex

In [None]:
df_rel_citations_postgres = pd.read_sql_query("""
    SELECT w.work_id, w.referenced_work_id
    FROM openalex.works_referenced_works as w
    JOIN temp_oaid_ppp AS t1 ON w.work_id = t1.oaid
    JOIN temp_oaid_ppp AS t2 ON w.referenced_work_id = t2.oaid
""", con=engine)
df_rel_citations_postgres = df_rel_citations_postgres.drop_duplicates(subset=['work_id', 'referenced_work_id'])

In [None]:
df_rel_citations_postgres.to_csv("/mnt/hdd01/Reliance on Science/ppp_oa_citations.csv", index=False)

In [11]:
df_rel_citations_postgres = pd.read_csv("/mnt/hdd01/Reliance on Science/ppp_oa_citations.csv")

## Extract Patents from PATSTAT Postgres

In [21]:
url_object = URL.create(
    drivername="postgresql+psycopg2",
    username=user,
    password=password,
    host=host,
    port=port,
    database=db,
)
engine = create_engine(url_object)

In [None]:
with engine.begin() as connection:
    connection.execute(text("""
        CREATE TABLE temp_patentid_ppp (
            patentid VARCHAR PRIMARY KEY,
            appln_id VARCHAR
        )
    """))
    for patent_id in tqdm(patent_list):
        connection.execute(text("INSERT INTO temp_patentid_ppp (patentid) VALUES (:patentid)"), {'patentid': patent_id})

### Extract appln_id for further processing

In [24]:
df_patstat_applnid = pd.read_sql_query("""
    SELECT t.publn_nr, t.appln_id
    FROM tls211_pat_publn as t
    JOIN temp_patentid_ppp AS tp ON t.publn_nr = tp.patentid
    WHERE t.publn_auth = 'US'
""", con=engine)
df_patstat_applnid = df_patstat_applnid.drop_duplicates(subset=['publn_nr', 'appln_id'])

In [None]:
with engine.begin() as connection:
    connection.execute(text("""
        UPDATE temp_patentid_ppp
        SET appln_id = t.appln_id
        FROM (
            SELECT t.publn_nr, t.appln_id
            FROM tls211_pat_publn as t
            WHERE t.publn_auth = 'US'
        ) AS t
        WHERE temp_patentid_ppp.patentid = t.publn_nr
    """))

### Extract Cleantech Patents for later subtraction

In [22]:
df_patstat_cleantech = pd.read_sql_query("""
    SELECT t.appln_id, t.cpc_class_symbol
    FROM tls224_appln_cpc as t
    JOIN temp_patentid_ppp AS tp ON t.appln_id = tp.appln_id
    WHERE t.cpc_class_symbol LIKE '%%Y02%%'
""", con=engine)
df_patstat_cleantech = df_patstat_cleantech.drop_duplicates(subset=['appln_id', 'cpc_class_symbol'])

In [25]:
df_patstat_cleantech = pd.merge(df_patstat_cleantech, df_patstat_applnid, on='appln_id')
df_patstat_cleantech = df_patstat_cleantech.rename(columns={'publn_nr': 'patent_id'})

In [27]:
df_rel_cleantech = pd.merge(df_rel_info, df_patstat_cleantech, on='patent_id', how='inner')

### Extract Title from PATSTAT

In [None]:
df_patstat_title = pd.read_sql_query("""
    SELECT tp.appln_id, tp.patentid, tat.appln_title
    FROM temp_patentid_ppp AS tp
    JOIN tls202_appln_title AS tat ON tp.appln_id = tat.appln_id
    WHERE tat.appln_title_lg = 'en'
""", con=engine)
df_patstat_title = df_patstat_title.drop_duplicates(subset=['appln_id', 'patentid', 'appln_title'])

### Extract Abstract from PATSTAT

In [None]:
df_patstat_abstract = pd.read_sql_query("""
    SELECT tp.appln_id, tp.patentid, tab.appln_abstract
    FROM temp_patentid_ppp AS tp
    JOIN tls203_appln_abstr AS tab ON tp.appln_id = tab.appln_id
    WHERE tab.appln_abstract_lg = 'en'
""", con=engine)
df_patstat_abstract = df_patstat_abstract.drop_duplicates(subset=['appln_id', 'patentid', 'appln_abstract'])

### Extract Authors from PATSTAT

In [None]:
df_patstat_person_id = pd.read_sql_query("""
    SELECT tp.appln_id, pa.person_id
    FROM temp_patentid_ppp AS tp
    JOIN tls207_pers_appln AS pa ON tp.appln_id = pa.appln_id
""", con=engine)
df_patstat_person_id = df_patstat_person_id.drop_duplicates(subset=['appln_id', 'person_id'])

In [None]:
df_patstat_person_id.to_sql('temp_ppp_person_id', con=engine, if_exists='replace', index=False)

In [None]:
df_patstat_person_details = pd.read_sql_query("""
    SELECT tpi.appln_id, tpi.person_id, p.person_name, p.person_name_orig_lg, p.person_address, p.doc_std_name, p.psn_name, p.han_name
    FROM temp_ppp_person_id AS tpi
    JOIN tls206_person AS p ON tpi.person_id = p.person_id
""", con=engine)
df_patstat_person_details = df_patstat_person_details.drop_duplicates(subset=['appln_id', 'person_id', 'person_name', 'person_name_orig_lg', 'person_address', 'doc_std_name', 'psn_name', 'han_name'])

In [None]:
df_patstat_title.to_csv("/mnt/hdd01/Reliance on Science/ppp_patstat_title.csv", index=False)
df_patstat_abstract.to_csv("/mnt/hdd01/Reliance on Science/ppp_patstat_abstract.csv", index=False)
df_patstat_person_details.to_csv("/mnt/hdd01/Reliance on Science/ppp_patstat_person_details.csv", index=False)

In [12]:
df_patstat_title = pd.read_csv("/mnt/hdd01/Reliance on Science/ppp_patstat_title.csv")
df_patstat_abstract = pd.read_csv("/mnt/hdd01/Reliance on Science/ppp_patstat_abstract.csv")
df_patstat_person_details = pd.read_csv("/mnt/hdd01/Reliance on Science/ppp_patstat_person_details.csv")

### Extract Citations from PATSTAT

In [None]:
df_patstat_citations = pd.read_sql_query("""
    SELECT c.pat_publn_id::text, c.cited_pat_publn_id::text, p.appln_id::text
    FROM tls212_citation AS c
    JOIN tls211_pat_publn AS p ON c.pat_publn_id = p.pat_publn_id
    WHERE p.appln_id IN (SELECT appln_id FROM temp_patentid_ppp)
""", con=engine)
df_patstat_citations = df_patstat_citations.drop_duplicates(subset=['pat_publn_id', 'cited_pat_publn_id', 'appln_id'])
df_patstat_citations = df_patstat_citations[df_patstat_citations['cited_pat_publn_id'].isin(df_patstat_citations['pat_publn_id'])]

In [None]:
df_patstat_citations = df_patstat_citations.rename(columns={"appln_id": "pat_appln_id"})
df_patstat_citations = pd.merge(df_patstat_citations, df_patstat_citations[['pat_publn_id', 'pat_appln_id']].rename(columns={'pat_appln_id': 'cited_pat_appln_id'}), left_on='cited_pat_publn_id', right_on='pat_publn_id', how='inner')
df_patstat_citations = df_patstat_citations[['pat_publn_id_x', 'cited_pat_publn_id', 'pat_appln_id', 'cited_pat_appln_id']]
df_patstat_citations = df_patstat_citations.rename(columns={'pat_publn_id_x': 'pat_publn_id'})

In [None]:
df_patstat_citations.to_csv("/mnt/hdd01/Reliance on Science/ppp_patstat_citations.csv", index=False)

In [13]:
df_patstat_citations = pd.read_csv("/mnt/hdd01/Reliance on Science/ppp_patstat_citations.csv")

## Fuzzy matching of Authors, Inventor for Patents with Authors for Papers (PCS and PPP)

### Preprocessing

In [None]:
df_rel_authors_complete['display_name_alternatives'] = df_rel_authors_complete['display_name_alternatives'].apply(lambda x: ast.literal_eval(x))
df_rel_authors_complete['display_name_alternatives'] = df_rel_authors_complete.apply(lambda row: row['display_name_alternatives'] + [row['display_name']] if isinstance(row['display_name_alternatives'], list) else [row['display_name']], axis=1)
df_rel_authors_complete_exploded = df_rel_authors_complete.explode('display_name_alternatives')
df_rel_authors_complete_exploded['oaid'] = df_rel_authors_complete_exploded['work_id'].apply(lambda x: [i.replace("https://openalex.org/W", "") for i in eval(x)])
df_rel_authors_complete_exploded['display_name_alternatives'] = df_rel_authors_complete_exploded['display_name_alternatives'].apply(lambda x: x.lower())

In [None]:
df_patstat_person_details = df_patstat_person_details.drop_duplicates(subset=['appln_id', 'person_id'])
df_patstat_person_details_melted = df_patstat_person_details.melt(id_vars=['appln_id', 'person_id'], 
                                                                  value_vars=['person_name', 'person_name_orig_lg', 'doc_std_name', 'psn_name', 'han_name'],
                                                                  var_name='name_type', 
                                                                  value_name='name')
df_patstat_person_details_melted['name'] = df_patstat_person_details_melted['name'].apply(lambda x: x.lower())
df_patstat_person_details_melted = pd.merge(df_patstat_person_details_melted, df_patstat_title[['appln_id', 'patentid']], on='appln_id', how='inner')

In [None]:
df_rel_info_grouped = df_rel_info.groupby('patent_id').agg({'oaid': list}).reset_index()

In [None]:
df_patstat_person_details_melted['patentid'] = df_patstat_person_details_melted['patentid'].astype(str)
df_rel_info_grouped['patent_id'] = df_rel_info_grouped['patent_id'].astype(str)
df_patstat_person_details_melted = pd.merge(df_patstat_person_details_melted, df_rel_info_grouped[['oaid', 'patent_id']], left_on='patentid', right_on='patent_id', how='inner', validate='m:m')
# df_patstat_person_details_melted = df_patstat_person_details_melted.dropna(subset=['oaid'])
df_patstat_person_details_melted = df_patstat_person_details_melted.dropna(subset=['patentid'])
df_patstat_person_details_exploded = df_patstat_person_details_melted.explode('oaid')
df_patstat_person_details_exploded = df_patstat_person_details_exploded.dropna(subset=['oaid'])

### Matching depending on pcs and ppp relationships

In [None]:
df_rel_authors_complete_exploded_exploded = df_rel_authors_complete_exploded.explode('oaid')

In [None]:
df_merged = pd.merge(df_rel_authors_complete_exploded_exploded, df_patstat_person_details_exploded, on='oaid', how='inner', validate='m:m')
df_merged = df_merged[['patent_id', 'appln_id', 'person_id', 'name_type', 'name', 'oaid', 'author_id', 'display_name', 'display_name_alternatives']]

In [None]:
def match_names(row):
    full_name = row['name']
    match = distance.Levenshtein.normalized_similarity(full_name, row['display_name_alternatives'])
    return match

df_merged['best_match'] = df_merged.progress_apply(match_names, axis=1)
# df_merged_test['best_match'] = df_merged_test.progress_apply(match_names, axis=1)

# df_merged = df_merged.sort_values('best_match', ascending=False)
df_merged_filtered = df_merged[df_merged['best_match'] >= 0.75]
df_merged_filtered = df_merged_filtered.loc[df_merged_filtered.groupby(['patent_id', 'appln_id', 'person_id', 'oaid', 'author_id', 'display_name'])['best_match'].idxmax()]

### Construct final authors dataframe

In [None]:
df_patent_authors_filtered = df_patstat_person_details[~df_patstat_person_details[['appln_id', 'person_id']].apply(tuple, 1).isin(df_merged_filtered[['appln_id', 'person_id']].apply(tuple, 1))]
df_patent_authors_filtered = df_patent_authors_filtered.drop_duplicates(subset=['appln_id', 'person_id'])

In [None]:
df_patent_authors_filtered['appln_id'] = df_patent_authors_filtered['appln_id'].astype(str)
df_patstat_applnid['appln_id'] = df_patstat_applnid['appln_id'].astype(str)
df_patent_authors_filtered = pd.merge(df_patent_authors_filtered, df_patstat_applnid, on='appln_id', how='inner', validate='m:m')
df_patent_authors_filtered = df_patent_authors_filtered.rename(columns={'publn_nr': 'patent_id'})

In [None]:
df_rel_authors_complete_filtered = df_rel_authors_complete_exploded_exploded[~df_rel_authors_complete_exploded_exploded[['oaid', 'author_id']].apply(tuple, 1).isin(df_merged_filtered[['oaid', 'author_id']].apply(tuple, 1))]
df_rel_authors_complete_filtered = df_rel_authors_complete_filtered.drop_duplicates(subset=['oaid', 'author_id'])
df_rel_authors_complete_filtered = df_rel_authors_complete_filtered[['oaid', 'author_id', 'display_name', 'display_name_alternatives']]

In [None]:
df_authors = pd.concat([df_merged_filtered, df_patent_authors_filtered, df_rel_authors_complete_filtered], ignore_index=True)

In [None]:
df_authors.drop_duplicates(subset=['appln_id', 'person_id', 'oaid', 'author_id', 'patent_id'], inplace=True)

In [None]:
df_authors.to_csv("/mnt/hdd01/Reliance on Science/ppp_oa_patentsview_authors.csv", index=False)

In [14]:
df_authors = pd.read_csv("/mnt/hdd01/Reliance on Science/ppp_oa_patentsview_authors.csv", dtype=str)

# Graph Preparation - Embedding of all properties; Edge Indices for all properties; Create H5PY files

### Embedding of Node Properties

In [None]:
df_patstat_text = pd.merge(df_patstat_title, df_patstat_abstract, on=['appln_id', 'patentid'], how='inner')
df_patstat_text['embedding'] = model.encode(df_patstat_text['appln_title'] + ' [SEP] ' + df_patstat_text['appln_abstract'].apply(lambda x: " ".join(x)), device=device, show_progress_bar=True).tolist()

In [None]:
df_rel_postgres['title'] = df_rel_postgres['title'].astype(str)
df_rel_postgres['abstract'] = df_rel_postgres['abstract'].astype(str)
df_rel_postgres['embedding'] = model.encode(df_rel_postgres['title'] + ' [SEP] ' + df_rel_postgres['abstract'].apply(lambda x: " ".join(x.split())), device=device, show_progress_bar=True).tolist()

In [None]:
df_rel_postgres.to_csv("/mnt/hdd01/Reliance on Science/ppp_oa_works_embeddings.csv", index=False)
df_patstat_text.to_csv("/mnt/hdd01/Reliance on Science/ppp_patstat_text_embeddings.csv", index=False)

In [15]:
df_rel_postgres = pd.read_csv("/mnt/hdd01/Reliance on Science/ppp_oa_works_embeddings.csv")
df_patstat_text = pd.read_csv("/mnt/hdd01/Reliance on Science/ppp_patstat_text_embeddings.csv")

In [16]:
d = model.get_sentence_embedding_dimension()  
df_authors['embedding'] = df_authors.apply(lambda _: np.random.rand(d), axis=1)
df_authors['embedding'] = df_authors['embedding'].apply(lambda x: x / np.linalg.norm(x))

In [None]:
df_authors.to_csv("/mnt/hdd01/Reliance on Science/ppp_oa_patstat_authors_embeddings.csv", index=False)  

In [17]:
df_rel_postgres['oaid'] = df_rel_postgres['id'].apply(lambda x: x.replace("https://openalex.org/W", ""))
df_rel_info = df_rel_info[df_rel_info['oaid'].isin(df_rel_postgres['oaid'])]
df_authors['id'] = df_authors.index
df_rel_info = df_rel_info.sort_values('patent_paper_pair', ascending=False).drop_duplicates(subset=['oaid'])

In [18]:
df_rel_citations_postgres['work_id'] = df_rel_citations_postgres['work_id'].apply(lambda x: x.replace("https://openalex.org/W", ""))
df_rel_citations_postgres['referenced_work_id'] = df_rel_citations_postgres['referenced_work_id'].apply(lambda x: x.replace("https://openalex.org/W", ""))

### Subtract all Cleantech Entities (REL and PATSTAT)

In [None]:
# df_rel_cleantech = df_rel_postgres[df_rel_postgres['oaid'].isin(df_rel_cleantech['oaid'])]
# df_rel_info_cleantech = df_rel_info[df_rel_info['oaid'].isin(df_rel_cleantech['oaid'])]
# df_authors_cleantech = df_authors[df_authors['oaid'].isin(df_rel_cleantech['oaid'])]
# df_rel_citations_postgres_cleantech = df_rel_citations_postgres[df_rel_citations_postgres['work_id'].isin(df_rel_cleantech['oaid'])]
# df_rel_citations_postgres_cleantech_ref = df_rel_citations_postgres[df_rel_citations_postgres['referenced_work_id'].isin(df_rel_cleantech['oaid'])]

In [29]:
df_rel_postgres = df_rel_postgres[~df_rel_postgres['oaid'].isin(df_rel_cleantech['oaid'])]
df_rel_info = df_rel_info[~df_rel_info['oaid'].isin(df_rel_cleantech['oaid'])]
df_authors = df_authors[~df_authors['oaid'].isin(df_rel_cleantech['oaid'])]
df_rel_citations_postgres = df_rel_citations_postgres[~df_rel_citations_postgres['work_id'].isin(df_rel_cleantech['oaid'])]
df_rel_citations_postgres = df_rel_citations_postgres[~df_rel_citations_postgres['referenced_work_id'].isin(df_rel_cleantech['oaid'])]

In [30]:
df_patstat_citations[['pat_appln_id', 'cited_pat_appln_id']] = df_patstat_citations[['pat_appln_id', 'cited_pat_appln_id']].astype(str)
df_patstat_citations = df_patstat_citations.dropna(subset=['pat_appln_id', 'cited_pat_appln_id'])
df_patstat_citations = pd.merge(df_patstat_citations, df_patstat_applnid, left_on='pat_appln_id', right_on='appln_id', how='inner')
df_patstat_citations = df_patstat_citations.rename(columns={'publn_nr': 'patent_id'})
df_patstat_citations = pd.merge(df_patstat_citations, df_patstat_applnid, left_on='cited_pat_appln_id', right_on='appln_id', how='inner')
df_patstat_citations = df_patstat_citations.rename(columns={'publn_nr': 'citation_patent_id'})
df_patstat_citations = df_patstat_citations.drop_duplicates(subset=['patent_id', 'citation_patent_id'])

In [31]:
df_patstat_applnid = df_patstat_applnid[~df_patstat_applnid['publn_nr'].isin(df_patstat_cleantech['patent_id'])]
df_patstat_text = df_patstat_text[~df_patstat_text['patentid'].isin(df_patstat_cleantech['patent_id'])]
df_authors = df_authors[~df_authors['patent_id'].isin(df_patstat_cleantech['patent_id'])]
df_patstat_citations = df_patstat_citations[~df_patstat_citations['patent_id'].isin(df_patstat_cleantech['patent_id'])]
df_patstat_citations = df_patstat_citations[~df_patstat_citations['citation_patent_id'].isin(df_patstat_cleantech['patent_id'])]

In [32]:
df_rel_postgres = df_rel_postgres.reset_index(drop=True)
df_rel_info = df_rel_info.reset_index(drop=True)
df_authors = df_authors.reset_index(drop=True)
df_rel_citations_postgres = df_rel_citations_postgres.reset_index(drop=True)

df_patstat_citations = df_patstat_citations.reset_index(drop=True)
df_patstat_text = df_patstat_text.reset_index(drop=True)
df_patstat_applnid = df_patstat_applnid.reset_index(drop=True)

### Edge Indices for all relationships

In [33]:
df_patstat_applnid['appln_id'] = df_patstat_applnid['appln_id'].astype(str)
df_patstat_text['patentid'] = df_patstat_text['patentid'].astype(str)
df_rel_info[['patent_id', 'oaid']] = df_rel_info[['patent_id', 'oaid']].astype(str)
df_rel_postgres['oaid'] = df_rel_postgres['oaid'].astype(str)
df_rel_citations_postgres[['work_id', 'referenced_work_id']] = df_rel_citations_postgres[['work_id', 'referenced_work_id']].astype(str)
df_authors[['id', 'oaid', 'patent_id']] = df_authors[['id', 'oaid', 'patent_id']].astype(str)

In [34]:
df_patstat_text = df_patstat_text.reset_index(drop=True)
patent_id_to_index = pd.Series(df_patstat_text.index, index=df_patstat_text['patentid']).to_dict()
df_patent_edge_index = df_patstat_citations.copy()
df_patent_edge_index = df_patent_edge_index[['patent_id', 'citation_patent_id']]
df_patent_edge_index['patent_id'] = df_patent_edge_index['patent_id'].map(patent_id_to_index)
df_patent_edge_index['citation_patent_id'] = df_patent_edge_index['citation_patent_id'].map(patent_id_to_index)
df_patent_edge_index = df_patent_edge_index.drop_duplicates(subset=['patent_id', 'citation_patent_id']).reset_index(drop=True)

In [35]:
df_rel_postgres = df_rel_postgres.reset_index(drop=True)
paper_id_to_index = pd.Series(df_rel_postgres.index, index=df_rel_postgres['oaid']).to_dict()
df_paper_edge_index = df_rel_citations_postgres.copy()
df_paper_edge_index = df_paper_edge_index[['work_id', 'referenced_work_id']]
df_paper_edge_index['work_id'] = df_paper_edge_index['work_id'].map(paper_id_to_index)
df_paper_edge_index['referenced_work_id'] = df_paper_edge_index['referenced_work_id'].map(paper_id_to_index)
df_paper_edge_index = df_paper_edge_index.drop_duplicates(subset=['work_id', 'referenced_work_id']).reset_index(drop=True)

In [None]:
# df_patent_paper_edge_index = df_rel_info.copy()
# df_patent_paper_edge_index = df_patent_paper_edge_index[['patent_id', 'oaid']]
# df_patent_paper_edge_index['patent_id'] = df_patent_paper_edge_index['patent_id'].map(patent_id_to_index)
# df_patent_paper_edge_index['oaid'] = df_patent_paper_edge_index['oaid'].map(paper_id_to_index)
# df_patent_paper_edge_index = df_patent_paper_edge_index.drop_duplicates(subset=['patent_id', 'oaid']).reset_index(drop=True)

In [37]:
df_patent_paper_edge_index_pcs = df_rel_info.copy()
df_patent_paper_edge_index_pcs = df_patent_paper_edge_index_pcs[['patent_id', 'oaid']]
df_patent_paper_edge_index_pcs['patent_id'] = df_patent_paper_edge_index_pcs['patent_id'].map(patent_id_to_index)
df_patent_paper_edge_index_pcs['oaid'] = df_patent_paper_edge_index_pcs['oaid'].map(paper_id_to_index)
df_patent_paper_edge_index_pcs = df_patent_paper_edge_index_pcs.drop_duplicates(subset=['patent_id', 'oaid']).reset_index(drop=True)

In [39]:
df_patent_paper_edge_index_ppp = df_rel_info[df_rel_info['patent_paper_pair'] == 1].copy()
df_patent_paper_edge_index_ppp = df_patent_paper_edge_index_ppp[['patent_id', 'oaid']]
df_patent_paper_edge_index_ppp['patent_id'] = df_patent_paper_edge_index_ppp['patent_id'].map(patent_id_to_index)
df_patent_paper_edge_index_ppp['oaid'] = df_patent_paper_edge_index_ppp['oaid'].map(paper_id_to_index)
df_patent_paper_edge_index_ppp = df_patent_paper_edge_index_ppp.drop_duplicates(subset=['patent_id', 'oaid']).reset_index(drop=True)

In [40]:
# df_authors = df_authors.drop_duplicates(subset=['patent_id', 'person_id'])
# df_authors = df_authors.drop_duplicates(subset=['author_id', 'oaid'])
df_authors = df_authors.reset_index(drop=True)
df_authors['id'] = df_authors.index
author_id_to_index = pd.Series(df_authors.index, index=df_authors['id']).to_dict()
df_author_patent_edge_index = df_authors.copy()
df_author_patent_edge_index = df_author_patent_edge_index[['id', 'patent_id']]
df_author_patent_edge_index['id'] = df_author_patent_edge_index['id'].map(author_id_to_index)
df_author_patent_edge_index['patent_id'] = df_author_patent_edge_index['patent_id'].map(patent_id_to_index)
df_patent_author_edge_index = df_author_patent_edge_index[['patent_id', 'id']]

In [41]:
df_author_paper_edge_index = df_authors.copy()
# df_author_paper_edge_index = df_author_paper_edge_index.astype(str)
df_author_paper_edge_index = df_author_paper_edge_index[['id', 'oaid']]
df_author_paper_edge_index['id'] = df_author_paper_edge_index['id'].map(author_id_to_index)
df_author_paper_edge_index['oaid'] = df_author_paper_edge_index['oaid'].map(paper_id_to_index)
df_paper_author_edge_index = df_author_paper_edge_index[['oaid', 'id']]

In [42]:
df_author_patent_edge_index = df_author_patent_edge_index.drop_duplicates(subset=['id', 'patent_id'])
df_author_paper_edge_index = df_author_paper_edge_index.drop_duplicates(subset=['id', 'oaid'])
df_patent_author_edge_index = df_patent_author_edge_index.drop_duplicates(subset=['patent_id', 'id'])
df_paper_author_edge_index = df_paper_author_edge_index.drop_duplicates(subset=['oaid', 'id'])

### Create H5PY files

In [43]:
# Delete all rows where strings are "nan"
df_author_patent_edge_index = df_author_patent_edge_index[df_author_patent_edge_index['patent_id'] != "nan"]
df_author_patent_edge_index = df_author_patent_edge_index[df_author_patent_edge_index['id'] != "nan"]
df_author_paper_edge_index = df_author_paper_edge_index[df_author_paper_edge_index['oaid'] != "nan"]
df_author_paper_edge_index = df_author_paper_edge_index[df_author_paper_edge_index['id'] != "nan"]
df_patent_author_edge_index = df_patent_author_edge_index[df_patent_author_edge_index['id'] != "nan"]
df_patent_author_edge_index = df_patent_author_edge_index[df_patent_author_edge_index['patent_id'] != "nan"]
df_paper_author_edge_index = df_paper_author_edge_index[df_paper_author_edge_index['id'] != "nan"]
df_paper_author_edge_index = df_paper_author_edge_index[df_paper_author_edge_index['oaid'] != "nan"]

df_patent_edge_index = df_patent_edge_index[df_patent_edge_index['patent_id'] != "nan"]
df_patent_edge_index = df_patent_edge_index[df_patent_edge_index['citation_patent_id'] != "nan"]
df_paper_edge_index = df_paper_edge_index[df_paper_edge_index['work_id'] != "nan"]
df_paper_edge_index = df_paper_edge_index[df_paper_edge_index['referenced_work_id'] != "nan"]
# df_patent_paper_edge_index = df_patent_paper_edge_index[df_patent_paper_edge_index['patent_id'] != "nan"]
# df_patent_paper_edge_index = df_patent_paper_edge_index[df_patent_paper_edge_index['oaid'] != "nan"]
df_patent_paper_edge_index_pcs = df_patent_paper_edge_index_pcs[df_patent_paper_edge_index_pcs['patent_id'] != "nan"]
df_patent_paper_edge_index_pcs = df_patent_paper_edge_index_pcs[df_patent_paper_edge_index_pcs['oaid'] != "nan"]
df_patent_paper_edge_index_ppp = df_patent_paper_edge_index_ppp[df_patent_paper_edge_index_ppp['patent_id'] != "nan"]
df_patent_paper_edge_index_ppp = df_patent_paper_edge_index_ppp[df_patent_paper_edge_index_ppp['oaid'] != "nan"]

In [44]:
# Delete all rows with nan values
df_author_patent_edge_index = df_author_patent_edge_index.dropna(subset=['patent_id', 'id'])
df_author_paper_edge_index = df_author_paper_edge_index.dropna(subset=['oaid', 'id'])
df_patent_author_edge_index = df_patent_author_edge_index.dropna(subset=['patent_id', 'id'])
df_paper_author_edge_index = df_paper_author_edge_index.dropna(subset=['oaid', 'id'])

df_patent_edge_index = df_patent_edge_index.dropna(subset=['patent_id', 'citation_patent_id'])
df_paper_edge_index = df_paper_edge_index.dropna(subset=['work_id', 'referenced_work_id'])
# df_patent_paper_edge_index = df_patent_paper_edge_index.dropna(subset=['patent_id', 'oaid'])
df_patent_paper_edge_index_pcs = df_patent_paper_edge_index_pcs.dropna(subset=['patent_id', 'oaid'])
df_patent_paper_edge_index_ppp = df_patent_paper_edge_index_ppp.dropna(subset=['patent_id', 'oaid'])

In [45]:
df_patent_edge_index = df_patent_edge_index.astype(int)
df_paper_edge_index = df_paper_edge_index.astype(int)
# df_patent_paper_edge_index = df_patent_paper_edge_index.astype(int)
df_patent_paper_edge_index_pcs = df_patent_paper_edge_index_pcs.astype(int)
df_patent_paper_edge_index_ppp = df_patent_paper_edge_index_ppp.astype(int)
df_author_patent_edge_index = df_author_patent_edge_index.astype(int)
df_author_paper_edge_index = df_author_paper_edge_index.astype(int)
df_patent_author_edge_index = df_patent_author_edge_index.astype(int)
df_paper_author_edge_index = df_paper_author_edge_index.astype(int)

df_authors['id'] = df_authors['id'].astype(int)
df_patstat_text['patentid'] = df_patstat_text['patentid'].astype(int)
df_rel_postgres['oaid'] = df_rel_postgres['oaid'].astype(int)

In [46]:
df_author_patent_edge_index = df_author_patent_edge_index[df_author_patent_edge_index['id'] < len(df_authors) - 1]
df_author_patent_edge_index = df_author_patent_edge_index[df_author_patent_edge_index['patent_id'] < len(df_patstat_text) - 1]

df_patent_author_edge_index = df_patent_author_edge_index[df_patent_author_edge_index['id'] < len(df_authors) - 1]
df_patent_author_edge_index = df_patent_author_edge_index[df_patent_author_edge_index['patent_id'] < len(df_patstat_text) - 1]

df_author_paper_edge_index = df_author_paper_edge_index[df_author_paper_edge_index['id'] < len(df_authors) - 1]
df_author_paper_edge_index = df_author_paper_edge_index[df_author_paper_edge_index['oaid'] < len(df_rel_postgres) - 1]

df_paper_author_edge_index = df_paper_author_edge_index[df_paper_author_edge_index['id'] < len(df_authors) - 1]
df_paper_author_edge_index = df_paper_author_edge_index[df_paper_author_edge_index['oaid'] < len(df_rel_postgres) - 1]

df_patent_edge_index = df_patent_edge_index[df_patent_edge_index['patent_id'] < len(df_patstat_text) - 1]
df_patent_edge_index = df_patent_edge_index[df_patent_edge_index['citation_patent_id'] < len(df_patstat_text) - 1]

df_paper_edge_index = df_paper_edge_index[df_paper_edge_index['work_id'] < len(df_rel_postgres) - 1]
df_paper_edge_index = df_paper_edge_index[df_paper_edge_index['referenced_work_id'] < len(df_rel_postgres) - 1]

# df_patent_paper_edge_index = df_patent_paper_edge_index[df_patent_paper_edge_index['patent_id'] < len(df_patstat_text) - 1]
# df_patent_paper_edge_index = df_patent_paper_edge_index[df_patent_paper_edge_index['oaid'] < len(df_rel_postgres) - 1]

df_patent_paper_edge_index_pcs = df_patent_paper_edge_index_pcs[df_patent_paper_edge_index_pcs['patent_id'] < len(df_patstat_text) - 1]
df_patent_paper_edge_index_pcs = df_patent_paper_edge_index_pcs[df_patent_paper_edge_index_pcs['oaid'] < len(df_rel_postgres) - 1]

df_patent_paper_edge_index_ppp = df_patent_paper_edge_index_ppp[df_patent_paper_edge_index_ppp['patent_id'] < len(df_patstat_text) - 1]
df_patent_paper_edge_index_ppp = df_patent_paper_edge_index_ppp[df_patent_paper_edge_index_ppp['oaid'] < len(df_rel_postgres) - 1]

In [47]:
df_rel_postgres['oaid'] = df_rel_postgres['oaid'].astype(int)
df_rel_info['oaid'] = df_rel_info['oaid'].astype(int)
df_rel_postgres = pd.merge(df_rel_postgres, df_rel_info[['oaid', 'patent_paper_pair']], on='oaid', how='inner')
df_rel_postgres = df_rel_postgres.dropna(subset=['abstract'])

In [None]:
len(df_rel_postgres), len(df_patstat_text), len(df_authors), len(df_patent_edge_index), len(df_paper_edge_index), len(df_patent_paper_edge_index_pcs), len(df_patent_paper_edge_index_ppp), len(df_author_patent_edge_index), len(df_author_paper_edge_index), len(df_patent_author_edge_index), len(df_paper_author_edge_index)

In [None]:
def string_to_array(str_repr):
    return np.fromstring(str_repr.strip('[]'), sep=',')

# df_patstat_text['embedding'] = df_patstat_text['embedding'].apply(string_to_array)
# df_rel_postgres['embedding'] = df_rel_postgres['embedding'].apply(string_to_array)
# df_authors['embedding'] = df_authors['embedding'].apply(string_to_array)

# Delete all NaN values from edge indices
df_patent_edge_index = df_patent_edge_index.dropna()
df_paper_edge_index = df_paper_edge_index.dropna()
# df_patent_paper_edge_index = df_patent_paper_edge_index.dropna()
df_patent_paper_edge_index_pcs = df_patent_paper_edge_index_pcs.dropna()
df_patent_paper_edge_index_ppp = df_patent_paper_edge_index_ppp.dropna()
df_author_patent_edge_index = df_author_patent_edge_index.dropna()
df_paper_author_edge_index = df_paper_author_edge_index.dropna()
df_author_paper_edge_index = df_author_paper_edge_index.dropna()

# Open an HDF5 file
with h5py.File('/mnt/hdd01/patentsview/Graph Neural Network for EDV-TEK PPP/raw/torch_tek_dataset_distilbert_emergence.h5', 'w') as f:
    # Save node data
    f.create_dataset('g_patent/x', data=np.stack(df_patstat_text["embedding"].values))
    f.create_dataset('g_paper/x', data=np.stack(df_rel_postgres["embedding"].values))
    # f.create_dataset('g_paper/y', data=np.stack(df_rel_postgres["patent_paper_pair"].values))
    f.create_dataset('g_author/x', data=np.stack(df_authors["embedding"].values))
    
    # Save edge indices
    f.create_dataset('patent_edge_index', data=df_patent_edge_index.values, dtype=np.int64)
    f.create_dataset('paper_edge_index', data=df_paper_edge_index.values, dtype=np.int64)
    f.create_dataset('patent_paper_edge_index', data=df_patent_paper_edge_index_pcs.values, dtype=np.int64)
    f.create_dataset('patent_paper_pair_edge_index', data=df_patent_paper_edge_index_ppp.values, dtype=np.int64)
    f.create_dataset('author_patent_edge_index', data=df_author_patent_edge_index, dtype=np.int64)
    f.create_dataset('patent_author_edge_index', data=df_patent_author_edge_index, dtype=np.int64)
    f.create_dataset('author_paper_edge_index', data=df_author_paper_edge_index, dtype=np.int64)
    f.create_dataset('paper_author_edge_index', data=df_paper_author_edge_index, dtype=np.int64)

# Construct Heterogeneous Graph Model

In [3]:
class PPPHeteroDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(PPPHeteroDataset, self).__init__(root, transform, pre_transform)
        self.data = None
        # processed_path = osp.join(self.processed_dir, self.processed_file_names)
        # if osp.exists(processed_path):
        #     self.data = torch.load(processed_path)
        # else:
        self.process()

    @property
    def num_classes(self):
        return 2

    @property
    def raw_dir(self):
        return '/mnt/hdd01/patentsview/Graph Neural Network for EDV-TEK PPP/raw/'
    
    @property
    def processed_dir(self):
        return '/mnt/hdd01/patentsview/Graph Neural Network for EDV-TEK PPP/processed/'

    @property
    def raw_file_names(self):
        return [
            'torch_tek_dataset_distilbert_emergence.h5'
        ]

    @property
    def processed_file_names(self):
        return 'gnn_tek_data_distilbert_emergence.pt'

    def download(self):
        pass

    def process(self):
        # Initialize HeteroData object
        data = HeteroData()
    
        # Open an HDF5 file
        with h5py.File(osp.join(self.raw_dir, 'torch_tek_dataset_distilbert_emergence.h5'), 'r') as f:
            # Load and process node features
            data['patent'].x = torch.tensor(f['g_patent/x'][:], dtype=torch.float)
            data['paper'].x = torch.tensor(f['g_paper/x'][:], dtype=torch.float)
            # data['paper'].y = torch.tensor(f['g_paper/y'][:], dtype=torch.long)

            data['author'].x = torch.tensor(f['g_author/x'][:], dtype=torch.float)
            
            # Load and process edge indices
            data['patent', 'cites', 'patent'].edge_index = torch.tensor(f['patent_edge_index'][:], dtype=torch.long).t().contiguous()
            data['paper', 'cites', 'paper'].edge_index = torch.tensor(f['paper_edge_index'][:], dtype=torch.long).t().contiguous()
            data['patent', 'cites', 'paper'].edge_index = torch.tensor(f['patent_paper_edge_index'][:], dtype=torch.long).t().contiguous()

            data['author', 'author_of_patent', 'patent'].edge_index = torch.tensor(f['author_patent_edge_index'][:], dtype=torch.long).t().contiguous()
            data['author', 'author_of_paper', 'paper'].edge_index = torch.tensor(f['author_paper_edge_index'][:], dtype=torch.long).t().contiguous()
            data['patent', 'has_author_patent', 'author'].edge_index = torch.tensor(f['patent_author_edge_index'][:], dtype=torch.long).t().contiguous()
            data['paper', 'has_author_paper', 'author'].edge_index = torch.tensor(f['paper_author_edge_index'][:], dtype=torch.long).t().contiguous()

        if self.pre_transform is not None:
            data = self.pre_transform(data)

        # Create train_mask, val_mask, and test_mask
        data['paper'].train_mask = torch.zeros(data['paper'].num_nodes, dtype=torch.bool)
        # data['paper'].val_mask = torch.zeros(data['paper'].num_nodes, dtype=torch.bool)
        data['paper'].test_mask = torch.zeros(data['paper'].num_nodes, dtype=torch.bool)
        data['paper'].train_mask[:int(0.8*data['paper'].num_nodes)] = 1
        # data['paper'].val_mask[int(0.8*data['paper'].num_nodes):int(0.9*data['paper'].num_nodes)] = 1]
        data['paper'].test_mask[int(0.8*data['paper'].num_nodes):] = 1

        # Diagnostic print statements
        print("Data keys after processing:", data.keys())
        print("Node types and their feature shapes:")
        for node_type, node_data in data.node_items():
            print(f"Node type: {node_type}")
            for key, item in node_data.items():
                if key == 'x' or key == 'y':
                    print(f"Features ({key}) shape:", item.size())

        print("Edge types and their index shapes:")
        for edge_type, edge_data in data.edge_items():
            print(f"Edge type: {edge_type}")
            if 'edge_index' in edge_data:
                print("Edge index shape:", edge_data['edge_index'].size())
            else:
                print(f"{edge_type} has no edge index.")
        

        self.data = data  # Save the processed data to self.data
        torch.save(data, osp.join(self.processed_dir, self.processed_file_names))

    def len(self):
        return 1

    def get(self, idx):
        return self.data

In [4]:
ppp_dataset = PPPHeteroDataset(root='/mnt/hdd01/patentsview/Graph Neural Network for EDV-TEK PPP/raw/')

Data keys after processing: ['train_mask', 'test_mask', 'edge_index', 'x']
Node types and their feature shapes:
Node type: patent
Features (x) shape: torch.Size([143164, 768])
Node type: paper
Features (x) shape: torch.Size([120996, 768])
Node type: author
Features (x) shape: torch.Size([890200, 768])
Edge types and their index shapes:
Edge type: ('patent', 'cites', 'patent')
Edge index shape: torch.Size([2, 55069])
Edge type: ('paper', 'cites', 'paper')
Edge index shape: torch.Size([2, 278168])
Edge type: ('patent', 'cites', 'paper')
Edge index shape: torch.Size([2, 119357])
Edge type: ('author', 'author_of_patent', 'patent')
Edge index shape: torch.Size([2, 618115])
Edge type: ('author', 'author_of_paper', 'paper')
Edge index shape: torch.Size([2, 252207])
Edge type: ('patent', 'has_author_patent', 'author')
Edge index shape: torch.Size([2, 618115])
Edge type: ('paper', 'has_author_paper', 'author')
Edge index shape: torch.Size([2, 252207])


In [5]:
ppp_dataset_0 = ppp_dataset[0]

# Graph Neural Network Model

## Message Passing Algorithm

In [64]:
class FullMessagePassing(MessagePassing):
    def __init__(self):
        super(FullMessagePassing, self).__init__(aggr='mean', flow='source_to_target', node_dim=0) # Aggregation method: "mean", "add", "max", "min"

    def forward(self, data):
        # The `data` is the full HeteroData object
        # Iterate over all types of edges defined in the data and perform message passing for each type
        for edge_type in data.edge_types:
            src, rel, dst = edge_type
            if (src, rel, dst) in data.edge_index_dict:
                try:
                    edge_index = data[src, rel, dst].edge_index
                    # data[dst].x = self.propagate(edge_index, x=data[src].x, size=None)
                    data[dst].x = self.propagate(edge_index, x=data[src].x)
                    # result = data[dst].x
                    # data[dst].x = result / result.norm(dim=1, keepdim=True)
                except Exception as e:
                    print(f"Error processing edge type {src, rel, dst}: {str(e)}")
                    continue
        return data

    def message(self, x_j):
        return x_j

    def update(self, aggr_out):
        return aggr_out

In [72]:
model = FullMessagePassing().to(device)
ppp_dataset_0 = ppp_dataset[0].to(device)

In [73]:
for epoch in tqdm(range(1, 10)):
    ppp_dataset_0 = model(ppp_dataset_0)

  0%|          | 0/9 [00:00<?, ?it/s]

In [77]:
# Print first 5 entries of node 'patent' in ppp_dataset_0
ppp_dataset_0['patent'].x[:5, :]

tensor([[ 0.0408,  0.4221,  0.4951,  ..., -0.3793, -0.3343,  0.6341],
        [ 0.1250,  0.4755,  0.5594,  ..., -0.4141, -0.2955,  0.5524],
        [ 0.1467,  0.5069,  0.5324,  ..., -0.3620, -0.3164,  0.5036],
        [ 0.0889,  0.4809,  0.4878,  ..., -0.3810, -0.2824,  0.4345],
        [ 0.1171,  0.4442,  0.5310,  ..., -0.4051, -0.2892,  0.4538]])

## Construct Graph Neural Network Model

In [37]:
class HeteroGCN(MessagePassing):
    def __init__(self, hidden_channels, num_node_features_dict, num_classes):
        super(HeteroGCN, self).__init__(aggr='mean')
        torch.manual_seed(42) # For reproducible results
        
        self.conv1 = HeteroConv({
            ('patent', 'cites', 'patent'): SAGEConv(num_node_features_dict['patent'], hidden_channels, add_self_loops=True),
            ('paper', 'cites', 'paper'): SAGEConv(num_node_features_dict['paper'], hidden_channels, add_self_loops=True),
            ('patent', 'cites', 'paper'): SAGEConv(num_node_features_dict['patent'], hidden_channels, add_self_loops=True),
            ('author', 'author_of_patent', 'patent'): SAGEConv(num_node_features_dict['author'], hidden_channels, add_self_loops=True),
            ('author', 'author_of_paper', 'paper'): SAGEConv(num_node_features_dict['author'], hidden_channels, add_self_loops=True),
            ('patent', 'has_author_patent', 'author'): SAGEConv(num_node_features_dict['patent'], hidden_channels, add_self_loops=True),
            ('paper', 'has_author_paper', 'author'): SAGEConv(num_node_features_dict['paper'], hidden_channels, add_self_loops=True)
        }, aggr='mean')

        self.conv2 = HeteroConv({
            ('patent', 'cites', 'patent'): SAGEConv(hidden_channels, hidden_channels, add_self_loops=True),
            ('paper', 'cites', 'paper'): SAGEConv(hidden_channels, hidden_channels, add_self_loops=True),
            ('patent', 'cites', 'paper'): SAGEConv(hidden_channels, hidden_channels, add_self_loops=True),
            ('author', 'author_of_patent', 'patent'): SAGEConv(hidden_channels, hidden_channels, add_self_loops=True),
            ('author', 'author_of_paper', 'paper'): SAGEConv(hidden_channels, hidden_channels, add_self_loops=True),
            ('patent', 'has_author_patent', 'author'): SAGEConv(hidden_channels, hidden_channels, add_self_loops=True),
            ('paper', 'has_author_paper', 'author'): SAGEConv(hidden_channels, hidden_channels, add_self_loops=True)
        }, aggr='mean')

        self.lin = torch.nn.Linear(hidden_channels, num_classes)

        self.predictor = torch.nn.Sequential(
            torch.nn.Linear(hidden_channels * 2, hidden_channels), 
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, 1),
            torch.nn.Sigmoid()
        )

    def forward(self, data):
        x_dict, edge_index_dict = data.x_dict, data.edge_index_dict
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        x_dict = self.conv2(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}

        x_dict = {key: x.mean(dim=0, keepdim=True) for key, x in x_dict.items()} # Global pooling - TEST THIS!!

        edge_data = torch.cat([x_dict['patent'], x_dict['paper']], dim=1) # Example for 'patent', 'paper' edge
        return self.predictor(edge_data)

## Model Instantiation

In [38]:
num_node_features_dict = {'patent': 768, 'paper': 768, 'author': 768}
model = HeteroGCN(hidden_channels=64, num_node_features_dict=num_node_features_dict, num_classes=2)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCELoss()

In [39]:
transform = RandomLinkSplit(num_val=0, num_test=0.2, is_undirected=False, edge_types=[('patent', 'cites', 'paper')], add_negative_train_samples=True)
train_data, val_data, test_data = transform(ppp_dataset_0)

In [40]:
train_data['patent', 'cites', 'paper'].edge_index

tensor([[ 17362,  91313,  73932,  ..., 109586,  51197, 132319],
        [ 95899,  13539, 107852,  ...,  31284,  18699,   7637]])

## Neighbor Loader

In [41]:
model = model.to(device)
train_data = train_data.to(device)
val_data = val_data.to(device)
test_data = test_data.to(device)
# edge_label_index = ['patent', 'cites', 'paper']
# train_loader = LinkNeighborLoader(train_data, batch_size=32, num_neighbors=10, edge_label_index=edge_label_index)
# test_loader = LinkNeighborLoader(test_data, batch_size=32, num_neighbors=10, edge_label_index=edge_label_index)

## Train and Test Loop

In [18]:
def train():
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        
        out = model(batch)  
        loss = criterion(out.squeeze(), batch.y.float())
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()

    return total_loss / len(train_loader)

In [None]:
def test():
    model.eval()
    correct = 0
    total = 0

    for batch in test_loader:
        batch = batch.to(device)
        with torch.no_grad():
            out = model(batch)
            pred = (out.squeeze() > 0.5).int()  # Convert probabilities to binary predictions
            correct += int((pred == batch.y).sum())
            total += batch.y.size(0)

    test_acc = correct / total
    return test_acc

In [50]:
def train(data):
    model.train()
    total_loss = 0
    data = data.to(device)
    optimizer.zero_grad()

    # Explicitly get edge labels and edge indices for your specific edge type
    edge_label = data['patent', 'cites', 'paper'].edge_label
    edge_index = data['patent', 'cites', 'paper'].edge_index

    out = model(data)  
    # loss = criterion(out, edge_label.float())
    # loss.backward()
    # optimizer.step()

    # total_loss += loss.item()

    # return total_loss
    return

In [51]:
def test(data):
    model.eval()
    correct = 0
    total = 0

    data = data.to(device)

    with torch.no_grad():
        edge_label = data['patent', 'cites', 'paper'].edge_label
        edge_index = data['patent', 'cites', 'paper'].edge_label_index

        out = model(data)
        pred = (out.squeeze() > 0.5).int()

        correct += int((pred == edge_label).sum())  # Count correct predictions
        total += edge_label.size(0)  # Total number of edge labels to evaluate

    test_acc = correct / total  # Test accuracy
    return test_acc

In [None]:
num_epochs = 100
for epoch in range(1, num_epochs + 1):
    loss = train(train_data)
    test_acc = test(test_data)
    # print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Test Acc: {test_acc:.4f}')
    print(f'Epoch: {epoch:03d}, Test Acc: {test_acc:.4f}')