In [None]:
# if your device is cpu, install this version
pip install  dgl -f https://data.dgl.ai/wheels/torch-2.3/repo.html

In [None]:
# if cuda is available, install this version
pip install  dgl -f https://data.dgl.ai/wheels/torch-2.3/cu121/repo.html

In [None]:
import dgl
import torch
import torch.nn as nn
import dgl.nn as dglnn
import torch.nn.functional as F
import dgl.function as fn
from dgl.nn import RelGraphConv
import random
from tqdm import tqdm
import numpy as np

In [None]:
ds = dgl.data.CSVDataset('data/sdoh_drug')
g = ds[0]

### Description of the graph

In [None]:
g

In [None]:
sdoh_checks = [[3, 14], [4, 16], [7, 18], [1, 11], [2, 12]] # economic, education, environment, community
sdoh_lists = ['economic', 'education', 'environment', 'community_absent', 'community_present']

### Define some functions, mainly for seperating test and train data

In [None]:
def get_unconnected(g, sdoh_check, etype, nodes = 'disease', print_out=True):
  target_nodes = g.nodes(nodes)
  connected_target_1,_ = g.in_edges(sdoh_check[0], etype=etype) # known the dst, get the src.
  connected_target_2,_ = g.in_edges(sdoh_check[1], etype=etype)
  connected_target = set(connected_target_1.tolist()) | set(connected_target_2.tolist())
  unconnected_target = [node for node in target_nodes.cpu().numpy() if node not in connected_target]
  if print_out:
      print(f'the number of unconnected {nodes} nodes is: ',len(unconnected_target))
  return unconnected_target

def choose_dst(g, unconnected_target, n, etype, dst='phenotype', print_out=True):
    edges_dst = []
    edges_src = []
    for nodes in unconnected_target:
        connected_dst, src = g.in_edges(nodes, etype=etype)
        edges_dst = edges_dst + connected_dst.tolist()
        edges_src = edges_src + src.tolist()

    n_sample = len(edges_dst) // n
    if print_out:
        print('number of edges to be selected:', n_sample)

    mask_dst = random.sample(range(len(edges_dst)), n_sample)

    edges_to_remove = [edges_dst[i] for i in mask_dst]
    sel_unsrc = [edges_src[i] for i in mask_dst]

    return sel_unsrc, edges_to_remove

def test_graph_construct(g, sdoh_check, n, etype_sdoh, etype, nodes='disease', dst='phenotype', print_out=True, device='cpu'):
  inverse_etype_sdoh = etype_sdoh.split('_')[1] + '_' + etype_sdoh.split('_')[0]
  unconnected_target = get_unconnected(g, sdoh_check, etype_sdoh, nodes, print_out=print_out)
  choosed_src, choosed_dst = choose_dst(g, unconnected_target, n, etype, dst, print_out=print_out)
  num_nodes_dict = {ntype: g.number_of_nodes(ntype) for ntype in g.ntypes}
  g0 = g.clone()
  g0.add_edges(choosed_src, [sdoh_check[0]] * len(choosed_src), etype = etype_sdoh)
  g0.add_edges([sdoh_check[0]] * len(choosed_src), choosed_src, etype = inverse_etype_sdoh)
  g1 = g.clone()
  g1.add_edges(choosed_src, [sdoh_check[1]] * len(choosed_src), etype = etype_sdoh)
  g1.add_edges([sdoh_check[1]] * len(choosed_src), choosed_src, etype = inverse_etype_sdoh)
  g0 = g0.to(device)
  g1 = g1.to(device)
  choosed_src = torch.tensor(choosed_src).to(device)
  choosed_dst = torch.tensor(choosed_dst).to(device)

  return g0, g1, choosed_src, choosed_dst


def fair_loss(pred0, pred1):
    sum_squared_diff = torch.sum((pred0 - pred1) ** 2)

    return sum_squared_diff

def compute_mrr(pos_score, neg_score):
  num_edges = pos_score.shape[0]
  neg_score = neg_score.view(num_edges, -1).detach().cpu().numpy()
  pos_score = pos_score.detach().cpu().numpy()
  mrr = []
  for i in range(len(pos_score)):
    rank = np.sum(neg_score[i] > pos_score[i]) + 1
    mrr.append(1/rank)
  return mrr

def compute_loss(pos_score, neg_score):
    # Margin loss
    n_edges = pos_score.shape[0]
    return (1 - pos_score + neg_score.view(n_edges, -1)).clamp(min=0).mean()

def construct_negative_graph(graph, k, etype, device='cpu'):
    utype, _, vtype = etype
    src, dst = graph.edges(etype=etype)
    neg_src = src.repeat_interleave(k)
    neg_dst = torch.randint(0, graph.num_nodes(vtype), (len(src) * k,)).to(device)
    return dgl.heterograph(
        {etype: (neg_src, neg_dst)},
        num_nodes_dict={ntype: graph.num_nodes(ntype) for ntype in graph.ntypes}).to(device)

### The class for heterogeneous-GCN model

In [None]:
class HeteroDotProductPredictor(nn.Module):
    def forward(self, graph, h, etype):
        # h contains the node representations for each node type computed from
        # the GNN defined in the previous section (Section 5.1).
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), etype=etype)
            return graph.edges[etype].data['score']


class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()

        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs, edge_weights):
        #inputs are features of nodes
        h = self.conv1(graph, inputs, mod_kwargs={
            'sdoh_drug': {'edge_weight': edge_weights['sdoh_drug']},
            'drug_sdoh': {'edge_weight': edge_weights['drug_sdoh']}
        })
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h


class Model(nn.Module):
    def __init__(self, in_features, hid_feats, out_features, rel_names):
        super().__init__()
        self.sage = RGCN(in_features, hid_feats, out_features, rel_names)
        self.pred = HeteroDotProductPredictor()

    def forward(self, g, neg_g, x, e, etype):
        h = self.sage(g, x, e)
        return self.pred(g, h, etype), self.pred(neg_g, h, etype)


class Bias_Predictor(nn.Module):
    def forward(self, graph, h, src, dst, etype):
        # h contains the node representations for each node type computed from
        # the GNN defined in the previous section (Section 5.1).
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'), edges=(src, dst), etype=etype)
            return graph.edges[etype].data['score']

class RGCN_noweight(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()

        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h


class Model_noweight(nn.Module):
    def __init__(self, in_features, hid_feats, out_features, rel_names):
        super().__init__()
        self.sage = RGCN_noweight(in_features, hid_feats, out_features, rel_names)
        self.pred = HeteroDotProductPredictor()

    def forward(self, g, neg_g, x, etype):
        h = self.sage(g, x)
        return self.pred(g, h, etype), self.pred(neg_g, h, etype)

### The class for fairness model

In [None]:
class Model_fair(nn.Module):
    def __init__(self, in_features, hid_feats, out_features, rel_names, num_edges, emb_model):
        super().__init__()
        self.rgcn = RGCN(in_features, hid_feats, out_features, rel_names)
        self.rgcn.load_state_dict(torch.load(emb_model))
        # Freeze the parameters of the RGCN model to prevent them from being updated
        for param in self.rgcn.parameters():
            param.requires_grad = False

        self.influence_weights = nn.Parameter(torch.rand(num_edges))
        self.influenced_by_weights = nn.Parameter(torch.rand(num_edges))
        self.get_bias = Bias_Predictor()
        self.num_edges = num_edges


    def forward(self, g0, g1, x, unconnected_drug, choosed_disease, etype):
        num_edges_test = g0.num_edges('sdoh_drug')
        num_edges = self.num_edges

        # Compute the mean as a scalar and create tensors using it, ensuring gradient tracking
        edge_inf_mean = torch.mean(self.influence_weights)
        edge_infby_mean = torch.mean(self.influenced_by_weights)

    # Construct the influence weights tensor
        weight_inf = torch.cat(
            [self.influence_weights, edge_inf_mean.expand(num_edges_test - num_edges)],
            dim=0
        )

        # Construct the influenced_by weights tensor
        weight_infby = torch.cat(
            [self.influenced_by_weights, edge_infby_mean.expand(num_edges_test - num_edges)],
            dim=0
        )
        weight_inf.requires_grad_(True)
        weight_infby.requires_grad_(True)

        e = {'sdoh_drug': weight_inf, 'drug_sdoh': weight_infby}

        h0 = self.rgcn(g0, x, e)
        h1 = self.rgcn(g1, x, e)

        return self.get_bias(g0, h0, unconnected_drug, choosed_disease, etype), self.get_bias(g1, h1, unconnected_drug, choosed_disease, etype), e

### Detect the bias and then debias

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
num_nodes_dict = {ntype: g.number_of_nodes(ntype) for ntype in g.ntypes}
n_hetero_features = 20
g = g.to(device)
g.nodes['disease'].data['feature'] = torch.randn(num_nodes_dict['disease'], n_hetero_features).to(device)
g.nodes['drug'].data['feature'] = torch.randn(num_nodes_dict['drug'], n_hetero_features).to(device)
g.nodes['sdoh'].data['feature'] = torch.randn(num_nodes_dict['sdoh'], n_hetero_features).to(device)
g.nodes['phenotype'].data['feature'] = torch.randn(num_nodes_dict['phenotype'], n_hetero_features).to(device)
edge_weights = {'sdoh_drug': torch.ones(g.num_edges('sdoh_drug')).to(device),
                'drug_sdoh': torch.ones(g.num_edges('drug_sdoh')).to(device)}


for i in range(len(sdoh_checks)):
    test0, test1, test_drug, test_disease = test_graph_construct(g, sdoh_checks[i], 5, 'drug_sdoh', 'disease_drug', nodes='drug', dst='disease', device=device)
    e_reid_ass = g.edge_ids(test_drug,test_disease, etype='drug_disease')
    e_reid_assd = g.edge_ids(test_disease, test_drug, etype='disease_drug')

    train_g = dgl.remove_edges(g, e_reid_ass, etype='drug_disease')
    train_g = dgl.remove_edges(train_g, e_reid_assd, etype='disease_drug')


    k = 20
    model = Model(20, 50, 20, g.etypes).to(device)
    disease_feats = train_g.nodes['disease'].data['feature']
    drug_feats = train_g.nodes['drug'].data['feature']
    sdoh_feats = train_g.nodes['sdoh'].data['feature']
    phenotype_feats = train_g.nodes['phenotype'].data['feature']

    node_features = {'disease': disease_feats, 'drug': drug_feats, 'sdoh': sdoh_feats, 'phenotype': phenotype_feats} # g, train_g, and all the text set share the same node_features, since the nodes are all the same.
    opt = torch.optim.Adam(model.parameters())
    for epoch in tqdm(range(50)):
        negative_graph = construct_negative_graph(train_g, k, ('drug', 'drug_disease', 'disease'), device=device)
        pos_score, neg_score = model(train_g, negative_graph, node_features, edge_weights, ('drug', 'drug_disease', 'disease'))
        loss = compute_loss(pos_score, neg_score)
        opt.zero_grad()
        loss.backward()
        opt.step()
        if (epoch + 1) % 10 == 0:
                print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

    torch.save(model.sage.state_dict(), f'model/emb_model-{sdoh_lists[i]}.pth')

    model.eval()
    edge_weights_test = {'sdoh_drug': torch.ones(test0.num_edges('sdoh_drug')).to(device),
                    'drug_sdoh': torch.ones(test0.num_edges('drug_sdoh')).to(device)}
    with torch.no_grad():
      h0 = model.sage(test0, node_features, edge_weights_test)
      h1 = model.sage(test1, node_features, edge_weights_test)
      pred0 = Bias_Predictor()(test0, h0, test_drug, test_disease, ('drug', 'drug_disease', 'disease'))
      pred1 = Bias_Predictor()(test1, h1, test_drug, test_disease, ('drug', 'drug_disease', 'disease'))

      sum_of_differences = torch.sum(torch.abs(pred0 - pred1))
    print(f'the bias of {sdoh_lists[i]} is:', sum_of_differences.detach().cpu().numpy())

    print('-----------------------------------------------------')
    print('begin debias')
    print('-----------------------------------------------------')
    debias_model = Model_fair(20, 50, 20, train_g.etypes, train_g.num_edges('drug_sdoh'), f'model/emb_model-{sdoh_lists[i]}.pth').to(device)
    opt = torch.optim.Adam(debias_model.parameters())
    for epoch in tqdm(range(100)):
        g0, g1, unconnected_drug, choosed_disease = test_graph_construct(train_g, sdoh_checks[i], 1, 'drug_sdoh', 'disease_drug', nodes='drug', dst='disease', print_out=False, device=device)
        pred0, pred1, e = debias_model(g0, g1, node_features, unconnected_drug, choosed_disease, ('drug', 'drug_disease', 'disease'))
        loss = fair_loss(pred0, pred1)
        opt.zero_grad()
        loss.backward()
        opt.step()
        if (epoch + 1) % 10 == 0:
                print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

    debias_model.eval()
    with torch.no_grad():
      e_inf = e['sdoh_drug'][:train_g.num_edges('sdoh_drug')]
      e_infby = e['drug_sdoh'][:train_g.num_edges('drug_sdoh')]
      edge_inf_mean = torch.mean(e_inf)
      edge_infby_mean = torch.mean(e_infby)
      weight_inf = torch.cat(
                [e_inf, edge_inf_mean.expand(len(test_drug))],
                dim=0)
      weight_infby = torch.cat(
                [e_infby, edge_infby_mean.expand(len(test_drug))],
                dim=0)
      edge_weights_test_debias = {'sdoh_drug': weight_inf, 'drug_sdoh': weight_infby}

      h0 = debias_model.rgcn(test0, node_features, edge_weights_test_debias)
      h1 = debias_model.rgcn(test1, node_features, edge_weights_test_debias)

      pred0 = Bias_Predictor()(test0, h0, test_drug, test_disease, ('drug', 'drug_disease', 'disease'))
      pred1 = Bias_Predictor()(test1, h1, test_drug, test_disease, ('drug', 'drug_disease', 'disease'))
      sum_of_differences_af = torch.sum(torch.abs(pred0 - pred1))
    print(f'the af-bias of {sdoh_lists[i]} is:', sum_of_differences_af.detach().cpu().numpy())
    print('the bias improvement is:', sum_of_differences.detach().cpu().numpy() - sum_of_differences_af.detach().cpu().numpy())
    print('-----------------------------------------------------')

### Check the change of accuracy

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
num_nodes_dict = {ntype: g.number_of_nodes(ntype) for ntype in g.ntypes}
n_hetero_features = 20
g = g.to(device)
g.nodes['disease'].data['feature'] = torch.randn(num_nodes_dict['disease'], n_hetero_features).to(device)
g.nodes['drug'].data['feature'] = torch.randn(num_nodes_dict['drug'], n_hetero_features).to(device)
g.nodes['sdoh'].data['feature'] = torch.randn(num_nodes_dict['sdoh'], n_hetero_features).to(device)
g.nodes['phenotype'].data['feature'] = torch.randn(num_nodes_dict['phenotype'], n_hetero_features).to(device)
edge_weights = {'sdoh_drug': torch.ones(g.num_edges('sdoh_drug')).to(device),
                'drug_sdoh': torch.ones(g.num_edges('drug_sdoh')).to(device)}

drug_ids, disease_ids = g.edges(etype='drug_disease')
num_edges = drug_ids.shape[0]
# Number of edges to sample (1/10 of the total edges)
num_test_edges = num_edges // 10
# Randomly sample indices for test edges
sample_indices = torch.randperm(num_edges)[:num_test_edges]
# Get the corresponding drug and disease IDs for the sampled test edges
test_drug_ids = drug_ids[sample_indices]
test_disease_ids = disease_ids[sample_indices]
e_reid_ass = g.edge_ids(test_drug_ids,test_disease_ids, etype='drug_disease')
e_reid_assd = g.edge_ids(test_disease_ids, test_drug_ids, etype='disease_drug')
train_g = dgl.remove_edges(g, e_reid_ass, etype='drug_disease')
train_g = dgl.remove_edges(train_g, e_reid_assd, etype='disease_drug')
test_nodes_dict = {'drug':num_nodes_dict['drug'], 'disease':num_nodes_dict['disease']}
test_g = dgl.heterograph({('drug', 'drug_disease', 'disease'): (test_drug_ids, test_disease_ids)}, num_nodes_dict=test_nodes_dict)

k = 20
model = Model(20, 50, 20, g.etypes).to(device)
disease_feats = train_g.nodes['disease'].data['feature']
drug_feats = train_g.nodes['drug'].data['feature']
sdoh_feats = train_g.nodes['sdoh'].data['feature']
phenotype_feats = train_g.nodes['phenotype'].data['feature']

node_features = {'disease': disease_feats, 'drug': drug_feats, 'sdoh': sdoh_feats, 'phenotype': phenotype_feats} # g, train_g, and all the text set share the same node_features, since the nodes are all the same.
opt = torch.optim.Adam(model.parameters())
for epoch in tqdm(range(100)):
    negative_graph = construct_negative_graph(train_g, k, ('drug', 'drug_disease', 'disease'), device=device)
    pos_score, neg_score = model(train_g, negative_graph, node_features, edge_weights, ('drug', 'drug_disease', 'disease'))
    loss = compute_loss(pos_score, neg_score)
    opt.zero_grad()
    loss.backward()
    opt.step()
    # if (epoch + 1) % 10 == 0:
    #       print(f'Epoch {epoch + 1}, Loss: {loss.item()}')
torch.save(model.sage.state_dict(), f'model/emb_model-dd.pth')
model.eval()
with torch.no_grad():
    trained_features = model.sage(train_g, node_features, edge_weights)

test_features = {'drug': trained_features['drug'], 'disease': trained_features['disease']}
neg_graph = construct_negative_graph(test_g, k, ('drug', 'drug_disease', 'disease'), device=device)
with torch.no_grad():
    pos_score = model.pred(test_g, test_features, ('drug', 'drug_disease', 'disease'))
    neg_score = model.pred(neg_graph, test_features, ('drug', 'drug_disease', 'disease'))

mrr = compute_mrr(pos_score, neg_score)
print(f'the mrr of drug-disease is:', np.mean(mrr))

for i in range(len(sdoh_checks)):
    debias_model = Model_fair(20, 50, 20, train_g.etypes, train_g.num_edges('drug_sdoh'), f'model/emb_model-dd.pth').to(device)
    opt = torch.optim.Adam(debias_model.parameters())
    for epoch in tqdm(range(100)):
        g0, g1, unconnected_drug, choosed_disease = test_graph_construct(train_g, sdoh_checks[i], 1, 'drug_sdoh', 'disease_drug', nodes='drug', dst='disease', print_out=False, device=device)
        pred0, pred1, e = debias_model(g0, g1, node_features, unconnected_drug, choosed_disease, ('drug', 'drug_disease', 'disease'))
        loss = fair_loss(pred0, pred1)
        opt.zero_grad()
        loss.backward()
        opt.step()
        # if (epoch + 1) % 10 == 0:
        #     print(f'Epoch {epoch + 1}, Loss: {loss.item()}')

    debias_model.eval()
    with torch.no_grad():
      e_inf = e['sdoh_drug'][:train_g.num_edges('sdoh_drug')]
      e_infby = e['drug_sdoh'][:train_g.num_edges('drug_sdoh')]
      edge_weights_test_debias = {'sdoh_drug': e_inf, 'drug_sdoh': e_infby}
      h0 = debias_model.rgcn(train_g, node_features, edge_weights_test_debias)

    test_features = {'drug': h0['drug'], 'disease': h0['disease']}

    with torch.no_grad():
      pos_score = model.pred(test_g, test_features, ('drug', 'drug_disease', 'disease'))
      neg_score = model.pred(neg_graph, test_features, ('drug', 'drug_disease', 'disease'))

    de_mrr = compute_mrr(pos_score, neg_score)
    print(f'after debias the mrr of drug-disease is:', np.mean(de_mrr))
    print('the change of mrr is:', np.mean(mrr)-np.mean(de_mrr))