In [None]:
import polars as pl

bitcoinalpha.colunns = ['source', 'target', 'weight', 'timestamp']
print(bitcoinalpha.head)

ModuleNotFoundError: No module named 'torch'

In [None]:
bitcoinalpha.shape
bitcoinalpha.describe()

In [None]:
import torch
from torch_geometric.data import Data
import math
import numpy as np
import networkx as nx

# edge list & weight
edge_index = torch.tensor(bitcoinalpha.select(['source', 'target']).to_numpy().T, dtype=torch.long)
edge_weight = torch.tensor(bitcoinalpha['weight'].to_numpy(), dtype=torch.float)

# nodes & features
num_nodes = max(bitcoinalpha['source'].max(), bitcoinalpha['target'].max()) + 1
x = torch.ones((num_nodes, 1))  # scalar feature of 1

# set data obj
data = Data(x=x, edge_index=edge_index, edge_attr=edge_weight)

In [None]:
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv

class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x
        
class LinkPredictor(torch.nn.Module):
    def forward(self, emb, edge_index):
        src = emb[edge_index[:, 0]]
        dst = emb[edge_index[:, 1]]
        return (src * dst).sum(dim=1)

In [None]:
model = GraphSAGE(in_channels=1, hidden_channels=16, out_channels=8)
out = model(data.x, data.edge_index)
print(out)

In [None]:
from sklearn.model_selection import train_test_split

# edge indicies
pos_edges = bitcoinalpha.filter(pl.col("weight") == 10)
neg_edges = bitcoinalpha.filter(pl.col("weight") == -10)

# 70/30 training split
pos_train, pos_test = train_test_split(pos_edges, test_size=0.3, random_state=128)
neg_train, neg_test = train_test_split(neg_edges, test_size=0.3, random_state=128)

# combine trained & test edges
train_edges = pl.concat([pos_train, neg_train])
test_edges = pl.concat([pos_test, neg_test])

def to_edge_tensor(df):
    return torch.tensor(df.select(['source', 'target']).to_numpy(), dtype=torch.long)

def to_label_tensor(df):
    return torch.tensor((df['weight'] == 10).to_numpy(), dtype=torch.float)  # 1 for trust, 0 for distrust

train_edge_index = to_edge_tensor(train_edges)
train_labels = to_label_tensor(train_edges)

test_edge_index = to_edge_tensor(test_edges)
test_labels = to_label_tensor(test_edges)

In [None]:
model = GraphSAGE(1, 16, 16)
predictor = LinkPredictor()

optimizer = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()), lr=0.01)
loss_fn = torch.nn.BCEWithLogitsLoss()

for epoch in range(1, 150):
    model.train()
    optimizer.zero_grad()
    
    node_emb = model(data.x, data.edge_index)  # node embeddings
    pred = predictor(node_emb, train_edge_index)
    
    loss = loss_fn(pred, train_labels)
    loss.backward()
    optimizer.step()
    
    print(f"Epoch {epoch} | Loss: {loss.item():.5f}")

In [None]:
model.eval()
with torch.no_grad():
    node_emb = model(data.x, data.edge_index)
    pred = predictor(node_emb, test_edge_index)
    pred_label = torch.sigmoid(pred) > 0.5
    
    acc = (pred_label == test_labels.bool()).float().mean()
    print(f"Test Accuracy: {acc:.4f}")

In [None]:
G = nx.DiGraph()
for src, target in test_edge_index.tolist():
  G.add_edge(src, target)

In [None]:
model.eval()
with torch.no_grad():
    node_emb = model(data.x, data.edge_index)
    pred_scores = predictor(node_emb, test_edge_index)
    pred_probs = torch.sigmoid(pred_scores)
    pred_labels = (pred_probs > 0.5).int()  # 1: trust, 0: distrust
    
edge_colors = ['green' if p == 1 else 'red' for p in pred_labels]

In [None]:
import matplotlib.pyplot as plt

# node position
pos = nx.spring_layout(G, seed=1)

plt.figure(figsize=(10, 10))
nx.draw_networkx_nodes(G, pos, node_size=10, node_color='gray')
nx.draw_networkx_edges(G, pos, edge_color=edge_colors, arrows=False, width=1)

plt.title("Bitcoin Alpha Trust Predictions (Green = Trust, Red = Distrust)")
plt.axis("off")
plt.show()