In [1]:
import os
import json
from tqdm import tqdm
import pandas as pd
from sklearn.model_selection import train_test_split

In [2]:
def get_doc_texts():
    docs = {}

    with open('./dataset/docs.json', 'r') as docs_file:
        progress_bar = tqdm(total=os.path.getsize('./dataset/docs.json'), unit='B', unit_scale=True, unit_divisor=1024, desc='Loading docs') 

        for line in docs_file:
            doc = json.loads(line)
            docs[doc['id']] = doc['text']
            
            progress_bar.update(len(line))

    return docs

def get_dataset():
    doc_texts = get_doc_texts()

    docs_df = pd.read_csv("./dataset/docs.csv").astype('string')

    batch = []
    for index, row in docs_df.iterrows():
        if not row.doc1_id in doc_texts or not row.doc2_id in doc_texts:
            continue

        batch.append({
            "doc1": {
                "id": row.doc1_id,
                "text": doc_texts[row.doc1_id]
            },
            "doc2": {
                "id": row.doc2_id,
                "text": doc_texts[row.doc2_id]
            },
            "score": float(row.score)
        })

    train, test = train_test_split(batch, test_size=0.2, random_state=42)
    return train, test


In [3]:
train, test = get_dataset()

Loading docs: 100%|█████████████████████████| 43.7M/43.7M [00:00<00:00, 352MB/s]


In [4]:
pos_nodes = {}
neg_nodes = {}

for edge in train:
    if edge['score'] > 0.5:
        pos_nodes[edge['doc1']['id']] = edge['doc1']['text']
        pos_nodes[edge['doc2']['id']] = edge['doc2']['text']
    else:
        neg_nodes[edge['doc1']['id']] = edge['doc1']['text']
        neg_nodes[edge['doc2']['id']] = edge['doc2']['text']

In [5]:
pos_node_id_to_index_map = {node: i for i, node in enumerate(pos_nodes.keys())}
neg_node_id_to_index_map = {node: i for i, node in enumerate(neg_nodes.keys())}

print("pos num_nodes: ", len(pos_nodes))
print("neg num_nodes: ", len(neg_nodes))

pos num_nodes:  3458
neg num_nodes:  2004


In [6]:
import torch
from torch_geometric.data import Data

In [7]:
pos_edges = []
neg_edges = []

for edge in train:

    if edge['score'] > 0.5:
        n1 = pos_node_id_to_index_map[edge['doc1']['id']]
        n2 = pos_node_id_to_index_map[edge['doc2']['id']]
        pos_edges.append([n1, n2])
        pos_edges.append([n2, n1])
    else:
        n1 = neg_node_id_to_index_map[edge['doc1']['id']]
        n2 = neg_node_id_to_index_map[edge['doc2']['id']]
        neg_edges.append([n1, n2])
        neg_edges.append([n2, n1])

pos_edges_tensor = torch.tensor(pos_edges, dtype=torch.long).t().contiguous()
neg_edges_tensor = torch.tensor(neg_edges, dtype=torch.long).t().contiguous()

In [8]:
pos_graph_data = Data(edge_index=pos_edges_tensor, num_nodes=len(pos_nodes))

print('is unidirected: ', pos_graph_data.is_undirected())
pos_graph_data

is unidirected:  True


Data(edge_index=[2, 14130], num_nodes=3458)

In [9]:
neg_graph_data = Data(edge_index=neg_edges_tensor, num_nodes=len(neg_nodes))

print('is unidirected: ', neg_graph_data.is_undirected())
neg_graph_data

is unidirected:  True


Data(edge_index=[2, 19274], num_nodes=2004)

In [10]:
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')

pos_x = []
neg_x = []


pos_x = encoded_input = tokenizer(list(pos_nodes.values()), truncation=True, padding=True, return_tensors='pt')

neg_x = encoded_input = tokenizer(list(neg_nodes.values()), truncation=True, padding=True, return_tensors='pt')

pos_graph_data.x = pos_x['input_ids'].to(torch.float)
neg_graph_data.x = neg_x['input_ids'].to(torch.float)

In [11]:
pos_graph_data

Data(edge_index=[2, 14130], num_nodes=3458, x=[3458, 512])

In [12]:
neg_graph_data

Data(edge_index=[2, 19274], num_nodes=2004, x=[2004, 512])

In [13]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GraphConv
from torch_geometric.utils import negative_sampling

class GCN(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, 1000)
        self.conv2 = GCNConv(1000, 3000)
        self.conv3 = GCNConv(3000, 1000)
        self.conv4 = GCNConv(1000, num_classes)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        
        x = self.conv4(x, edge_index)
        
        return torch.sigmoid(x)

model = GCN(num_features=512, num_classes=1)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCELoss()
loss_list = []

for epoch in range(200):
    model.train()
    optimizer.zero_grad()

    # Forward pass for positive samples
    pred_pos = model(pos_graph_data.x, pos_graph_data.edge_index).squeeze()

    # Predictions for negative samples
    pred_neg = model(neg_graph_data.x, neg_graph_data.edge_index).squeeze()

    # Combine positive and negative predictions and create labels accordingly
    pred_all = torch.cat([pred_pos, pred_neg], dim=0)
    labels = torch.cat([torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)

    # Calculate loss, perform backpropagation, and update model parameters
    loss = criterion(pred_all, labels)
    loss.backward()
    optimizer.step()
    loss_list.append(loss.item())
    print(f'Epoch {epoch+1}: Loss = {loss.item()}')


Epoch 1: Loss = 62.97243881225586
Epoch 2: Loss = 63.310142517089844
Epoch 3: Loss = 63.310142517089844
Epoch 4: Loss = 63.310142517089844
Epoch 5: Loss = 63.310142517089844
Epoch 6: Loss = 63.310142517089844
Epoch 7: Loss = 63.310142517089844
Epoch 8: Loss = 63.310142517089844
Epoch 9: Loss = 63.310142517089844
Epoch 10: Loss = 63.310142517089844
Epoch 11: Loss = 63.310142517089844
Epoch 12: Loss = 63.310142517089844
Epoch 13: Loss = 63.310142517089844
Epoch 14: Loss = 63.310142517089844
Epoch 15: Loss = 63.310142517089844
Epoch 16: Loss = 63.310142517089844
Epoch 17: Loss = 63.310142517089844
Epoch 18: Loss = 63.310142517089844
Epoch 19: Loss = 63.310142517089844
Epoch 20: Loss = 63.310142517089844
Epoch 21: Loss = 63.310142517089844
Epoch 22: Loss = 63.310142517089844
Epoch 23: Loss = 63.310142517089844
Epoch 24: Loss = 63.310142517089844
Epoch 25: Loss = 63.310142517089844
Epoch 26: Loss = 63.310142517089844
Epoch 27: Loss = 63.310142517089844
Epoch 28: Loss = 63.310142517089844
Ep