In [12]:
import torch
import torch.nn as nn
from torch.distributions import Normal
import pandas as pd

In [13]:

class LDMprobit(torch.nn.Module):
    def __init__(self, Aij, thresholds, embedding_dim, device, n_epochs, lr, seed=0):
        super(LDMprobit, self).__init__()
        self.Aij = Aij.to(device)
        self.thresholds = thresholds.to(device)
        self.device = device
        self.n_drugs, self.n_effects = Aij.shape
        self.n_ordinal_classes = (len(thresholds) - 1)

        #set seed
        self.seed = seed
        self.__set_seed(seed)

        #Variables for the learning process
        self.n_epochs = n_epochs
        self.lr = lr

        #parameters to be learned
        self.beta = nn.Parameter(torch.randn(self.n_effects, device=device))
        self.w = torch.nn.Parameter(torch.randn(self.n_drugs, embedding_dim))  # Latent embeddings for drugs
        self.v = torch.nn.Parameter(torch.randn(self.n_effects, embedding_dim))  # Latent embeddings for side effects
        self.beta_thilde = nn.Parameter(torch.randn(self.n_ordinal_classes, device=device))
    
    def __set_seed(self, seed):
        if seed is not None:
            torch.manual_seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed)

    def get_embeddings(self):
        return self.w, self.v
    
    def get_thresholds(self):
        softmax_beta = torch.softmax(self.beta_thilde, dim=0)

    def probit(self):
        #n_ordinal_classes = len(self.thresholds) - 1
        #n_drugs, n_effects = self.Aij.shape
        normal_dist = Normal(0, 1) # Noise contaminated by normal distribution
        probit_matrix = torch.zeros((self.n_drugs, self.n_effects, self.n_ordinal_classes), device=self.device)

    
        #Linear term (\beta^T x_{i,j})
        linear_term = torch.matmul(self.Aij, self.beta.unsqueeze(1))

        # Distance term -|w_i - v_j|
        dist = -torch.norm(self.w.unsqueeze(1) - self.v.unsqueeze(0), dim=2)

        # Latent variable \beta^T x_{i,j} + \alpha(u_i - u_j)
        latent_var = linear_term + dist
        
        for y in range(self.n_ordinal_classes):
            z1 = latent_var - self.thresholds[y]
            z2 = latent_var - self.thresholds[y+1]
            probit_matrix[:, :, y] = normal_dist.cdf(z1) - normal_dist.cdf(z2)
        return probit_matrix

    def ordinal_cross_entropy_loss(self):
    # Compute the predicted probabilities using the probit function
        probit_matrix = self.probit() 

        # Initialize loss variable
        loss = 0.0

        # Iterate over all drugs and side effects
        for i in range(self.n_drugs):  # For each drug
            for j in range(self.n_effects):  # For each side effect
                if self.Aij[i, j] != 0:  # Only compute loss for nonzero entries
                    target_class = int(self.Aij[i, j]) - 1  # Convert severity to class (0-based)
                    
                    #One-hot encode target 
                    one_hot_target = torch.zeros(self.n_ordinal_classes, device=self.device)
                    one_hot_target[target_class] = 1  # Set the correct class to 1

                    # Compute the log-likelihood 
                    prob = probit_matrix[i, j]
                    loss -= torch.log(torch.sum(prob * one_hot_target) + 1e-8)  # Negative log-likelihood, addition of small number to avoid log(0)

        return loss
    
    def forward(self):
        return self.probit()
        
    def learn(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)

        for epoch in range(self.n_epochs):
            avg_loss = self.train_one_epoch(optimizer)
            print(f"Epoch {epoch+1}/{self.n_epochs}, Loss: {avg_loss:.4f}")
        return self.ordinal_cross_entropy_loss()

    
    def train_one_epoch(self, optimizer, batch_size=32):
        total_loss = 0
        n_batches = (self.n_drugs + batch_size - 1) // batch_size
        for _ in range(n_batches):
            batch_loss = self.train_one_batch(optimizer)
            total_loss += batch_loss
        return total_loss / n_batches
    

    def train_one_batch(self, optimizer):
        optimizer.zero_grad()
        loss = self.ordinal_cross_entropy_loss()
        loss.backward()
        optimizer.step()
        return loss.item()

    def get_params(self):
        return self.beta, self.w.detach().cpu(), self.v.detach().cpu()

    def save_embeddings():
        raise NotImplementedError

In [22]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embedding_dim = 4
Aij = torch.tensor([[0, 2, 0, 3, 1, 2], 
                    [0, 0, 0, 0, 1, 2],
                    [0, 0, 2, 0, 1, 2],
                    [3, 3, 0, 0, 0, 1],
                    [3, 3, 0, 0, 0, 1],
                    [0, 0, 2, 0, 0, 0]],dtype=torch.float32) #column: side effect, row:drug, value: frequency (ordinal)
thresholds = torch.tensor([0, 1, 2,3], dtype=torch.float32, device=device) 
n_epochs = 100
lr = 0.01
seed = 42
model = LDMprobit(Aij, thresholds, embedding_dim, device, n_epochs, lr, seed)
probit_output = model.probit()  # Compute the probit probability matrix
loss_out = model.ordinal_cross_entropy_loss()  # Compute the ordinal cross-entropy loss
print(loss_out)

tensor(140.2836, grad_fn=<SubBackward0>)


In [23]:
model = LDMprobit(Aij, thresholds, embedding_dim, device, n_epochs, lr, seed).to(device)

# Run a forward pass
probit_output = model.forward()  # Should return a probability matrix
print("\nProbit Output Shape:", probit_output.shape)  # Expect (n_drugs, n_effects, n_ordinal_classes)
print(probit_output)


Probit Output Shape: torch.Size([6, 6, 3])
tensor([[[4.9788e-04, 9.0003e-06, 5.9605e-08],
         [1.5269e-03, 3.7998e-05, 3.5763e-07],
         [1.9052e-02, 1.1198e-03, 2.5451e-05],
         [1.2656e-02, 6.3485e-04, 1.2279e-05],
         [6.7136e-04, 1.3202e-05, 8.9407e-08],
         [2.6510e-02, 1.7850e-03, 4.6521e-05]],

        [[8.6129e-06, 5.9605e-08, 0.0000e+00],
         [3.2783e-07, 0.0000e+00, 0.0000e+00],
         [6.9588e-05, 7.7486e-07, 0.0000e+00],
         [1.5327e-04, 2.0564e-06, 0.0000e+00],
         [6.1095e-06, 2.9802e-08, 0.0000e+00],
         [4.3716e-03, 1.5101e-04, 2.0266e-06]],

        [[7.9435e-04, 1.6361e-05, 1.1921e-07],
         [6.7288e-04, 1.3232e-05, 8.9407e-08],
         [1.5348e-04, 2.0564e-06, 0.0000e+00],
         [3.2181e-03, 1.0064e-04, 1.2219e-06],
         [6.9737e-06, 5.9605e-08, 0.0000e+00],
         [1.2899e-03, 3.0547e-05, 2.6822e-07]],

        [[7.4274e-02, 8.1477e-03, 3.4806e-04],
         [2.1345e-01, 4.7558e-02, 4.1633e-03],
         [

In [29]:
loss = model.ordinal_cross_entropy_loss()
print("\nInitial Loss:", loss.item())


Initial Loss: 176.02639770507812


In [30]:
final_loss = model.learn()
print("\nFinal Loss after training:", final_loss.item())

Epoch 1/100, Loss: 176.0264
Epoch 2/100, Loss: 172.7043
Epoch 3/100, Loss: 168.0842
Epoch 4/100, Loss: 165.0133
Epoch 5/100, Loss: 162.0621
Epoch 6/100, Loss: 159.1644
Epoch 7/100, Loss: 156.3022
Epoch 8/100, Loss: 153.5302
Epoch 9/100, Loss: 150.2718
Epoch 10/100, Loss: 147.6735
Epoch 11/100, Loss: 143.4125
Epoch 12/100, Loss: 140.9678
Epoch 13/100, Loss: 138.3568
Epoch 14/100, Loss: 135.3163
Epoch 15/100, Loss: 132.1927
Epoch 16/100, Loss: 129.5783
Epoch 17/100, Loss: 126.9500
Epoch 18/100, Loss: 124.2507
Epoch 19/100, Loss: 120.3727
Epoch 20/100, Loss: 117.8882
Epoch 21/100, Loss: 114.9585
Epoch 22/100, Loss: 112.2982
Epoch 23/100, Loss: 109.8339
Epoch 24/100, Loss: 107.0345
Epoch 25/100, Loss: 104.4933
Epoch 26/100, Loss: 102.0181
Epoch 27/100, Loss: 99.5922
Epoch 28/100, Loss: 97.2574
Epoch 29/100, Loss: 94.9764
Epoch 30/100, Loss: 91.3880
Epoch 31/100, Loss: 89.2669
Epoch 32/100, Loss: 87.2227
Epoch 33/100, Loss: 85.2425
Epoch 34/100, Loss: 82.7717
Epoch 35/100, Loss: 80.5749
Epo

In [25]:
embeddings = model.get_embeddings()
print("\nDrug Embeddings Shape:", embeddings)  # Expect (n_drugs, embedding_dim)


Drug Embeddings Shape: (Parameter containing:
tensor([[-0.2483, -1.2082, -0.4777,  0.5201],
        [ 1.6423, -0.1596, -0.4974,  0.4396],
        [ 0.3189, -0.4245,  0.3057, -0.7746],
        [ 0.0349,  0.3211,  1.5736, -0.8455],
        [-1.2742,  2.1228, -1.2347, -0.4879],
        [-1.4181,  0.8963,  0.0499,  2.2667]], requires_grad=True), Parameter containing:
tensor([[-0.4880,  1.1914, -0.8140, -0.7360],
        [-0.8371, -0.9224,  1.8113,  0.1606],
        [ 0.1971, -1.1441,  0.3383,  1.6992],
        [ 0.0109, -0.3387, -1.3407, -0.5854],
        [-0.5644,  1.0563, -1.4692,  1.4332],
        [ 0.7440, -0.4816, -1.0495,  0.6039]], requires_grad=True))


In [26]:
beta, w, v = model.get_params()
print(beta)

Parameter containing:
tensor([ 0.3367,  0.1288,  0.2345,  0.2303, -1.1229, -0.1863],
       requires_grad=True)


## Trying on real data

In [2]:
def load_data(path_to_csv, device):
    df = pd.read_csv(path_to_csv, index_col=0)
    Aij = torch.tensor(df.values, dtype=torch.float32).to(device)
    return Aij

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
csv_path = "/Users/christine/Bachelor/src/data/adj_matrix.csv" 
Aij = load_data(csv_path, device)
print(Aij.shape)

torch.Size([968, 3964])


In [19]:
# Define thresholds for ordinal categories
thresholds = torch.tensor([0, 1, 2, 3], dtype=torch.float).to(device)  # Adjust as needed

# Define model hyperparameters
embedding_dim = 5  # Number of dimensions in latent space
n_epochs = 50
lr = 0.01
seed = 42

# Initialize the model
model = LDMprobit(Aij, thresholds, embedding_dim, device, n_epochs, lr, seed)

# Train the model
model.learn()

# Get final learned embeddings
drug_embeddings, side_effect_embeddings = model.get_embeddings()


Epoch 1/50, Loss: 181.0123
Epoch 2/50, Loss: 176.5727
Epoch 3/50, Loss: 172.0259
Epoch 4/50, Loss: 167.6300
Epoch 5/50, Loss: 161.9679
Epoch 6/50, Loss: 157.7904
Epoch 7/50, Loss: 151.7343
Epoch 8/50, Loss: 147.1560
Epoch 9/50, Loss: 142.7737
Epoch 10/50, Loss: 137.9396
Epoch 11/50, Loss: 132.0063
Epoch 12/50, Loss: 127.6929
Epoch 13/50, Loss: 122.8156
Epoch 14/50, Loss: 118.3766
Epoch 15/50, Loss: 113.9239
Epoch 16/50, Loss: 109.6384
Epoch 17/50, Loss: 105.4291
Epoch 18/50, Loss: 101.4095
Epoch 19/50, Loss: 97.4802
Epoch 20/50, Loss: 93.7245
Epoch 21/50, Loss: 90.1176
Epoch 22/50, Loss: 86.6545
Epoch 23/50, Loss: 83.3182
Epoch 24/50, Loss: 80.1134
Epoch 25/50, Loss: 77.0573
Epoch 26/50, Loss: 74.1201
Epoch 27/50, Loss: 71.3184
Epoch 28/50, Loss: 68.6457
Epoch 29/50, Loss: 66.0966
Epoch 30/50, Loss: 63.6686
Epoch 31/50, Loss: 61.3564
Epoch 32/50, Loss: 59.1587
Epoch 33/50, Loss: 57.0710
Epoch 34/50, Loss: 55.0897
Epoch 35/50, Loss: 53.2117
Epoch 36/50, Loss: 51.4322
Epoch 37/50, Loss: 

In [207]:
import networkx as nx
import numpy as np
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

class LDMprobit(torch.nn.Module):
    def __init__(self, Aij, embedding_dim, device, n_epochs, lr, seed=None):
        super(LDMprobit, self).__init__()
        self.Aij = Aij.to(device)
        self.device = device
        self.n_drugs, self.n_effects = Aij.shape
        self.n_ordinal_classes = Aij.max().int().item() +1

        #set seed
        self.seed = seed
        self.__set_seed(seed)

        #Variables for the learning process
        self.n_epochs = n_epochs
        self.lr = lr

        #parameters to be learned (latent representations)
        self.beta = nn.Parameter(torch.randn(self.n_effects, device=device))
        self.w = torch.nn.Parameter(torch.randn(self.n_drugs, embedding_dim))  # Latent embeddings for drugs
        self.v = torch.nn.Parameter(torch.randn(self.n_effects, embedding_dim))  # Latent embeddings for side effects

        # Parameters to be learned (thresholds)
        self.beta_thilde = nn.Parameter(torch.randn(self.n_ordinal_classes, device=device))
        self.a = nn.Parameter(torch.rand(1, device=device))
        self.b = nn.Parameter(torch.rand(1, device=device))
    
    def __set_seed(self, seed):
        if seed is not None:
            torch.manual_seed(seed)
            if torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed)

    def get_embeddings(self):
        return self.w, self.v
    
    # def get_thresholds(self):
    #     softmax_values = torch.softmax(self.beta_thilde, dim=0)  # Ensure positive values
    #     scaled_thresholds = torch.cumsum(softmax_values, dim=0) * torch.abs(self.a)  # Ensure increasing values
    #     positive_thresholds = scaled_thresholds + torch.abs(self.b)  # Shift to ensure positivity
    #     return torch.cat([torch.tensor([-float("inf")], device=self.device), positive_thresholds, torch.tensor([float("inf")], device=self.device)])
    def get_thresholds(self):
        # Ensure thresholds remain positive and increasing
        deltas = torch.softmax(self.beta_thilde, dim = 0)  # Ensure positive increments
        thresholds = torch.cumsum(deltas, dim=0)* self.a - self.b
        return torch.cat([torch.tensor([-float("inf")], device=self.device), thresholds, torch.tensor([float("inf")], device=self.device)])
    
    def probit(self):
        # #n_ordinal_classes = len(self.thresholds) - 1
        # #n_drugs, n_effects = self.Aij.shape
        normal_dist = Normal(0, 1) # Noise contaminated by normal distribution
        probit_matrix = torch.zeros((self.n_ordinal_classes, self.n_drugs, self.n_effects), device=self.device)
        thresholds = self.get_thresholds()
    
        #Linear term (\beta^T x_{i,j})
        linear_term = torch.matmul(self.Aij, self.beta.unsqueeze(1))

        # Distance term -|w_i - v_j|
        dist = -torch.norm(self.w.unsqueeze(1) - self.v.unsqueeze(0), dim=2)

        # Latent variable \beta^T x_{i,j} + \alpha(u_i - u_j)
        latent_var = linear_term + dist
        
        for y in range(self.n_ordinal_classes):
            z1 = latent_var - thresholds[y]
            z2 = latent_var - thresholds[y+1]
            probit_matrix[y, :, :] = normal_dist.cdf(z1) - normal_dist.cdf(z2)
        return probit_matrix

    
    def predict_categories(self):
        probit_matrix = self.probit()  # Call probit to get probabilities
        return torch.argmax(probit_matrix, dim=0), probit_matrix
    
    def ordinal_cross_entropy_loss(self):
    # Compute the predicted probabilities using the probit function
        probit_matrix = self.probit() 

        # Initialize loss variable
        loss = 0.0

        # Iterate over all drugs and side effects
        for i in range(self.n_drugs):  # For each drug
            for j in range(self.n_effects):  # For each side effect
                if True:#self.Aij[i, j] != 0:  # Only compute loss for nonzero entries
                    target_class = int(self.Aij[i, j])  # Convert severity to class (0-based)
                    
                    #One-hot encode target 
                    one_hot_target = torch.zeros(self.n_ordinal_classes, device=self.device)
                    one_hot_target[target_class] = 1  # Set the correct class to 1

                    # Compute the log-likelihood 
                    prob = probit_matrix[:,i,j]
                    loss -= torch.log(torch.sum(prob * one_hot_target) + 1e-8)  # Negative log-likelihood, addition of small number to avoid log(0)

        return loss

    def train(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        final_loss = None  # Store the last loss

        for epoch in range(self.n_epochs):
            optimizer.zero_grad()  # Reset gradients
            loss = self.ordinal_cross_entropy_loss()  # Compute loss
            loss.backward()  # Backpropagation
            optimizer.step()  # Update parameters
            
            final_loss = loss.item()  # Store latest loss value
            
            if epoch % 10 == 0:  # Print every 10 epochs
                print(f"Epoch {epoch}/{self.n_epochs}, Loss: {loss.item():.4f}")

        return final_loss
    def plot_network(self):
        G = nx.Graph()

        # Get embeddings (drug and effect embeddings)
        drug_embeddings = self.w.detach().cpu().numpy()  # (n_drugs, embedding_dim)
        effect_embeddings = self.v.detach().cpu().numpy()  # (n_effects, embedding_dim)

        # Add nodes for drugs and side effects
        for i in range(self.n_drugs):
            G.add_node(f"Drug_{i}", bipartite=0)
        for j in range(self.n_effects):
            G.add_node(f"Effect_{j}", bipartite=1)

        # Calculate the probability matrix using the probit function
        probit_matrix = self.probit()

        # Add edges based on the probit matrix (non-zero probability indicates a link)
        for i in range(self.n_drugs):
            for j in range(self.n_effects):
                prob = probit_matrix[i, j].max().item()  # Use the max probability for this drug-effect pair
                if prob > 0.01:  # Threshold for displaying an edge
                    G.add_edge(f"Drug_{i}", f"Effect_{j}", weight=prob)

        # Create a layout based on the embeddings
        pos = {}
        
        # Position drugs based on their embeddings
        for i in range(self.n_drugs):
            pos[f"Drug_{i}"] = (drug_embeddings[i, 0], drug_embeddings[i, 1])  # 2D position based on first two embedding dims
        
        # Position effects based on their embeddings
        for j in range(self.n_effects):
            pos[f"Effect_{j}"] = (effect_embeddings[j, 0], effect_embeddings[j, 1])

        # Draw the graph
        plt.figure(figsize=(12, 8))
        nx.draw(G, pos, with_labels=True, node_size=500, node_color=["blue" if "Drug" in node else "red" for node in G.nodes], font_size=10, font_weight='bold', edge_color='gray')

        # Display edge weights (probabilities) as labels
        # labels = nx.get_edge_attributes(G, 'weight')
        # nx.draw_networkx_edge_labels(G, pos, edge_labels=labels)

        plt.title('Drug-Side Effect Network based on Embeddings and Probit Output')
        plt.show()

    def plot_links(self):
        G = nx.Graph()

        # Add nodes for drugs and side effects
        for i in range(self.n_drugs):
            G.add_node(f"Drug_{i}", bipartite=0)
        for j in range(self.n_effects):
            G.add_node(f"Effect_{j}", bipartite=1)

        # Calculate the probability matrix using the probit function
        probit_matrix = self.probit()

        # Add edges based on the probit matrix (non-zero probability indicates a link)
        for i in range(self.n_drugs):
            for j in range(self.n_effects):
                prob = probit_matrix[i, j].max().item()  # Use the max probability for this drug-effect pair
                if prob > 0.35:  # Threshold for displaying an edge
                    G.add_edge(f"Drug_{i}", f"Effect_{j}", weight=prob)

        pos = {}
        pos.update((node, (1, index)) for index, node in enumerate(f"Drug_{i}" for i in range(self.n_drugs)))  # Position for drugs
        pos.update((node, (2, index)) for index, node in enumerate(f"Effect_{j}" for j in range(self.n_effects)))  # Position for effects

        # Draw the graph
        plt.figure(figsize=(12, 8))
        nx.draw(G, pos, with_labels=True, node_size=500, node_color='skyblue', font_size=10, font_weight='bold', edge_color='gray')

        # Display edge weights (optional)
        labels = nx.get_edge_attributes(G, 'weight')
        nx.draw_networkx_edge_labels(G, pos, edge_labels=labels)

        plt.title('Drug-Side Effect Network')
        plt.show()

    def get_params(self):
        return self.beta, self.w.detach().cpu().numpy(), self.v.detach().cpu().numpy(), self.beta_thilde.detach().cpu().numpy()
        
    def save_embeddings():
        raise NotImplementedError

In [220]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embedding_dim = 5
Aij = torch.tensor([[0, 2, 0, 3, 1, 2], 
                    [0, 0, 2, 0, 1, 0],
                    [3, 3, 0, 0, 0, 1],
                    [3, 3, 0, 0, 0, 1],
                    [0, 0, 2, 0, 0, 0]],dtype=torch.float32) #column: side effect, row:drug, value: frequency (ordinal)
n_epochs = 500
lr = 0.01
seed = 20
model = LDMprobit(Aij, embedding_dim, device, n_epochs, lr, seed)
probit_output = model.probit()  # Compute the probit probability matrix
loss_out = model.train()  # Compute the ordinal cross-entropy loss
print(loss_out)

Epoch 0/500, Loss: 207.0220
Epoch 10/500, Loss: 193.6110
Epoch 20/500, Loss: 171.1986
Epoch 30/500, Loss: 133.7646
Epoch 40/500, Loss: 93.9879
Epoch 50/500, Loss: 63.5178
Epoch 60/500, Loss: 48.1958
Epoch 70/500, Loss: 41.1482
Epoch 80/500, Loss: 36.7619
Epoch 90/500, Loss: 33.6167
Epoch 100/500, Loss: 31.3641
Epoch 110/500, Loss: 29.6154
Epoch 120/500, Loss: 28.1393
Epoch 130/500, Loss: 26.8318
Epoch 140/500, Loss: 25.6392
Epoch 150/500, Loss: 24.5358
Epoch 160/500, Loss: 23.5083
Epoch 170/500, Loss: 22.5495
Epoch 180/500, Loss: 21.6552
Epoch 190/500, Loss: 20.8219
Epoch 200/500, Loss: 20.0465
Epoch 210/500, Loss: 19.3259
Epoch 220/500, Loss: 18.6565
Epoch 230/500, Loss: 18.0351
Epoch 240/500, Loss: 17.4580
Epoch 250/500, Loss: 16.9216
Epoch 260/500, Loss: 16.4225
Epoch 270/500, Loss: 15.9572
Epoch 280/500, Loss: 15.5225
Epoch 290/500, Loss: 15.1156
Epoch 300/500, Loss: 14.7336
Epoch 310/500, Loss: 14.3743
Epoch 320/500, Loss: 14.0353
Epoch 330/500, Loss: 13.7148
Epoch 340/500, Loss: 

In [221]:
model.get_thresholds()

tensor([   -inf, -1.1659, -0.0777,  1.1145,  2.9408,     inf],
       grad_fn=<CatBackward0>)

In [222]:
print("Goal matrice")
print(Aij)
print("Obtained through loss")
print(model.predict_categories())
print("Embeddings: ")
model.get_embeddings()

Goal matrice
tensor([[0., 2., 0., 3., 1., 2.],
        [0., 0., 2., 0., 1., 0.],
        [3., 3., 0., 0., 0., 1.],
        [3., 3., 0., 0., 0., 1.],
        [0., 0., 2., 0., 0., 0.]])
Obtained through loss
(tensor([[0, 2, 0, 3, 1, 2],
        [0, 0, 2, 0, 1, 0],
        [3, 3, 0, 0, 0, 1],
        [3, 3, 0, 0, 0, 1],
        [0, 0, 2, 0, 0, 0]]), tensor([[[7.7926e-01, 1.1680e-01, 9.4371e-01, 2.3836e-03, 2.6842e-01,
          4.6688e-02],
         [9.9986e-01, 9.9079e-01, 5.7178e-02, 9.7995e-01, 3.4095e-01,
          9.6395e-01],
         [1.3719e-03, 2.3763e-03, 9.9507e-01, 9.0207e-01, 9.7478e-01,
          2.7960e-01],
         [1.4328e-03, 2.1979e-03, 9.9363e-01, 8.9940e-01, 9.7071e-01,
          2.7990e-01],
         [9.9994e-01, 9.9559e-01, 5.7039e-02, 9.9951e-01, 9.5587e-01,
          9.9863e-01]],

        [[1.8915e-01, 3.4223e-01, 5.2550e-02, 3.9065e-02, 4.1263e-01,
          2.3104e-01],
         [1.4058e-04, 8.9287e-03, 2.5465e-01, 1.9209e-02, 4.1028e-01,
          3.4108e-02]

(Parameter containing:
 tensor([[-1.8760, -1.1874,  0.8197,  0.1494, -1.7796],
         [-0.2435,  3.0271, -0.9507, -1.3117, -1.1426],
         [-0.4712, -2.4295, -0.7742,  0.5608,  1.4998],
         [-0.5940, -2.3083, -0.8633,  0.4821,  1.5446],
         [ 2.0860,  2.7021, -2.3055,  0.5721, -1.8318]], requires_grad=True),
 Parameter containing:
 tensor([[-0.3805, -2.4470, -0.4952,  0.5661,  2.5484],
         [-0.8714, -2.0413, -0.7425,  0.4219,  0.3779],
         [ 0.4108,  3.3098, -1.9245, -0.2011, -1.4660],
         [-2.3958, -0.5789,  1.6886,  0.0518, -2.4329],
         [-1.1747,  2.0305,  1.4991, -0.9615, -1.4354],
         [-0.7426, -1.0268,  1.7253, -0.7753,  0.0243]], requires_grad=True))

In [212]:
model.plot_network()

IndexError: index 5 is out of bounds for dimension 1 with size 5

In [11]:
import torch

import math
import time
import sys
import random

    
    def plot_network(self):
        G = nx.Graph()

        # Get embeddings (drug and effect embeddings)
        drug_embeddings = self.w.detach().cpu().numpy()  # (n_drugs, embedding_dim)
        effect_embeddings = self.v.detach().cpu().numpy()  # (n_effects, embedding_dim)

        # Add nodes for drugs and side effects
        for i in range(self.n_drugs):
            G.add_node(f"Drug_{i}", bipartite=0)
        for j in range(self.n_effects):
            G.add_node(f"Effect_{j}", bipartite=1)

        # Calculate the probability matrix using the probit function
        probit_matrix = self.probit()

        # Add edges based on the probit matrix (non-zero probability indicates a link)
        for i in range(self.n_drugs):
            for j in range(self.n_effects):
                prob = probit_matrix[i, j].max().item()  # Use the max probability for this drug-effect pair
                if prob > 0.01:  # Threshold for displaying an edge
                    G.add_edge(f"Drug_{i}", f"Effect_{j}", weight=prob)

        # Create a layout based on the embeddings
        pos = {}
        
        # Position drugs based on their embeddings
        for i in range(self.n_drugs):
            pos[f"Drug_{i}"] = (drug_embeddings[i, 0], drug_embeddings[i, 1])  # 2D position based on first two embedding dims
        
        # Position effects based on their embeddings
        for j in range(self.n_effects):
            pos[f"Effect_{j}"] = (effect_embeddings[j, 0], effect_embeddings[j, 1])

        # Draw the graph
        plt.figure(figsize=(12, 8))
        nx.draw(G, pos, with_labels=True, node_size=500, node_color=["blue" if "Drug" in node else "red" for node in G.nodes], font_size=10, font_weight='bold', edge_color='gray')

        # Display edge weights (probabilities) as labels
        # labels = nx.get_edge_attributes(G, 'weight')
        # nx.draw_networkx_edge_labels(G, pos, edge_labels=labels)

        plt.title('Drug-Side Effect Network based on Embeddings and Probit Output')
        plt.show()

    def plot_links(self):
        G = nx.Graph()

        # Add nodes for drugs and side effects
        for i in range(self.n_drugs):
            G.add_node(f"Drug_{i}", bipartite=0)
        for j in range(self.n_effects):
            G.add_node(f"Effect_{j}", bipartite=1)

        # Calculate the probability matrix using the probit function
        probit_matrix = self.probit()

        # Add edges based on the probit matrix (non-zero probability indicates a link)
        for i in range(self.n_drugs):
            for j in range(self.n_effects):
                prob = probit_matrix[i, j].max().item()  # Use the max probability for this drug-effect pair
                if prob > 0.7:  # Threshold for displaying an edge
                    G.add_edge(f"Drug_{i}", f"Effect_{j}", weight=prob)

        pos = {}
        pos.update((node, (1, index)) for index, node in enumerate(f"Drug_{i}" for i in range(self.n_drugs)))  # Position for drugs
        pos.update((node, (2, index)) for index, node in enumerate(f"Effect_{j}" for j in range(self.n_effects)))  # Position for effects

        # Draw the graph
        plt.figure(figsize=(12, 8))
        nx.draw(G, pos, with_labels=True, node_size=500, node_color='skyblue', font_size=10, font_weight='bold', edge_color='gray')

        # Display edge weights (optional)
        labels = nx.get_edge_attributes(G, 'weight')
        nx.draw_networkx_edge_labels(G, pos, edge_labels=labels)

        plt.title('Drug-Side Effect Network')
        plt.show()