# Goal here is to correct the label before the finetuning of ESM2
***
## I. Generate the DF_info and the DF_embeddings 
## II. Tryout multiple algoritms : 
> NN : Graph Neural Networks  <br>
***
https://theaisummer.com/gnn-architectures/
***

### I.

In [42]:
import torch
from torch_geometric.data import Data

from torch import nn 
from torch.utils.data import Dataset , DataLoader
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder , label_binarize , OneHotEncoder
from sklearn.metrics import average_precision_score
import os 
import pandas as pd
import numpy as np
from tqdm import tqdm

from collections import Counter
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning) 


> Open the Dataframe

In [18]:
#path_work = "/home/conchae/prediction_depolymerase_tropism/prophage_prediction/depolymerase_decipher/ficheros_28032023"
path_work = "/media/concha-eloko/Linux/PPT_clean"

# Open the DF
DF_info = pd.read_csv(f"{path_work}/DF_Dpo.final.2705.tsv", sep = "\t" ,  header = 0 )
# Open the embeddings
DF_embeddings = pd.read_csv(f"{path_work}/Dpo.2705.embeddings.ultimate.csv", sep = ",", header= None )
DF_embeddings.rename(columns={0: 'index'}, inplace=True)

# Filter the DF :
DF_info_filtered = DF_info[~DF_info["KL_type_LCA"].str.contains("\\|")]
DF_info_ToReLabel = DF_info[DF_info["KL_type_LCA"].str.contains("\\|")]
all_data = pd.merge(DF_info_filtered , DF_embeddings , on = "index")

# Mind the over representation of outbreaks :
all_data = all_data.drop_duplicates(subset = ["Infected_ancestor","index","prophage_id"] , keep = "first").reset_index(drop=True)


In [45]:
DF_embeddings

Unnamed: 0,index,1,2,3,4,5,6,7,8,9,...,1271,1272,1273,1274,1275,1276,1277,1278,1279,1280
0,ppt__2930,-0.000061,-0.017329,0.012884,0.037123,-0.123747,0.004186,-0.061367,-0.056718,-0.037215,...,0.098806,0.012989,-0.001155,0.139749,-0.030987,0.059306,0.107041,-0.041463,-0.085581,0.114973
1,ppt__3300,0.004044,0.040011,-0.001234,-0.095745,-0.058056,-0.002394,0.007648,-0.059740,0.060850,...,-0.020369,0.016287,0.062586,-0.024336,0.019276,0.069623,0.035261,-0.118962,0.035672,0.085582
2,ppt__1182,0.018767,0.068116,-0.009109,-0.012598,-0.107001,0.011569,-0.030943,-0.045359,0.048923,...,0.014524,-0.024645,0.071878,0.018206,0.042790,0.088410,0.031970,-0.124592,0.070040,0.065348
3,ppt__3540,-0.028261,-0.047253,-0.027340,-0.052824,-0.089644,-0.023079,0.094861,0.026104,0.024001,...,0.051728,0.005634,-0.077874,0.030336,-0.037648,0.050625,0.046142,-0.158841,-0.007670,0.034556
4,ppt__942,0.014863,0.028030,0.014927,-0.025997,-0.096138,0.016290,0.015008,-0.066254,0.077959,...,0.008521,-0.019820,0.123201,-0.040306,0.030893,0.051362,0.047316,-0.102698,0.044830,0.084530
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3603,anubis__304,0.006264,0.006471,-0.031665,0.078502,-0.131247,0.077167,0.043005,-0.183636,-0.022181,...,0.044299,-0.061847,0.017696,0.054798,-0.035830,-0.030202,0.039051,-0.127020,-0.113630,0.211258
3604,anubis__1273,-0.019114,0.063302,0.006635,-0.060343,-0.034054,-0.003895,0.033920,-0.080352,0.073579,...,-0.004504,-0.007906,0.075141,-0.052423,0.027127,0.073984,0.030664,-0.096409,0.011906,0.124885
3605,anubis__1311,0.051261,0.067942,0.005061,-0.019131,-0.060296,0.000984,0.037515,-0.033887,0.091774,...,0.044678,0.052609,0.112994,-0.000592,0.027122,0.086020,0.013660,-0.055491,0.021665,0.049301
3606,anubis__1525,-0.010655,0.083864,0.009084,-0.042220,-0.066479,0.008724,0.010109,-0.078033,0.065285,...,0.020752,0.024543,0.071302,0.035980,0.012171,0.054399,0.032167,-0.151018,0.042541,0.035221


In [52]:
len(DF_embeddings[DF_embeddings["index"] == "ppt__942"].values[0][1:1281])

1280

***
## Build the Graph Data

> Indexation process (shall add the N phages to predict)

In [28]:
indexation = all_data["Infected_ancestor"].unique().tolist() + all_data["Phage"].unique().tolist() + all_data["index"].unique().tolist() + [f"Dpo_to_predict_{n}" for n in DF_info["index"].unique().tolist()]

dico_ID = {item:index for index, item in enumerate(indexation)}


> Make edge file 

In [61]:
edge_index = []

# Node A (bacteria) - Node B1 (prophage) :
for _, row in all_data.iterrows() :
    edge_index.append([dico_ID[row["Phage"]], dico_ID[row["Infected_ancestor"]]])
    
# Node B1 - Node B2 (depolymerase) :
for phage in all_data.Phage.unique() :
    all_data_phage = all_data[all_data["Phage"] == phage]
    for _, row in all_data_phage.iterrows() :
        edge_index.append([dico_ID[row["index"]], dico_ID[row["Phage"]]])

# Transform into tensor : 
edge_index_tensor = torch.tensor(edge_index , dtype=torch.long)

# Write file : 
numpy_array = edge_index_tensor.numpy()
df = pd.DataFrame(numpy_array)
df.to_csv(f"{path_work}/edge_index.csv", index=False, header=False)

> Make the node feature file : 

In [39]:
LE  = LabelEncoder()
di = LE.fit_transform(all_data["KL_type_LCA"])
label_mapping = dict(zip(LE.classes_, LE.transform(LE.classes_)))
label_mapping

{'KL1': 0,
 'KL10': 1,
 'KL101': 2,
 'KL102': 3,
 'KL103': 4,
 'KL104': 5,
 'KL105': 6,
 'KL106': 7,
 'KL107': 8,
 'KL108': 9,
 'KL109': 10,
 'KL11': 11,
 'KL110': 12,
 'KL111': 13,
 'KL112': 14,
 'KL113': 15,
 'KL114': 16,
 'KL115': 17,
 'KL116': 18,
 'KL117': 19,
 'KL118': 20,
 'KL119': 21,
 'KL12': 22,
 'KL121': 23,
 'KL122': 24,
 'KL123': 25,
 'KL124': 26,
 'KL125': 27,
 'KL126': 28,
 'KL127': 29,
 'KL128': 30,
 'KL13': 31,
 'KL130': 32,
 'KL131': 33,
 'KL132': 34,
 'KL134': 35,
 'KL136': 36,
 'KL137': 37,
 'KL139': 38,
 'KL14': 39,
 'KL140': 40,
 'KL141': 41,
 'KL142': 42,
 'KL143': 43,
 'KL144': 44,
 'KL145': 45,
 'KL146': 46,
 'KL147': 47,
 'KL148': 48,
 'KL149': 49,
 'KL15': 50,
 'KL150': 51,
 'KL151': 52,
 'KL152': 53,
 'KL153': 54,
 'KL154': 55,
 'KL155': 56,
 'KL157': 57,
 'KL158': 58,
 'KL159': 59,
 'KL16': 60,
 'KL162': 61,
 'KL163': 62,
 'KL164': 63,
 'KL166': 64,
 'KL169': 65,
 'KL17': 66,
 'KL170': 67,
 'KL18': 68,
 'KL19': 69,
 'KL2': 70,
 'KL20': 71,
 'KL21': 72,
 'KL

In [59]:
node_feature = []

for index, item in tqdm(enumerate(indexation)) :
    features = [index]
    if item in all_data["Infected_ancestor"].unique() : 
        KL_type = all_data[all_data["Infected_ancestor"] == item]["KL_type_LCA"].values[0]
        features = features + [label_mapping[KL_type]] + [-1]*1280
    elif item in all_data["Phage"].unique() : 
        features = features + [-1]*1281
    elif item in all_data["index"].unique() : 
        features = features + [-1] + DF_embeddings[DF_embeddings["index"] == item].values[0][1:1281].tolist()
    elif item in [f"Dpo_to_predict_{n}" for n in DF_info["index"].unique().tolist()] : 
        features = features + [-1]*1281
    node_feature.append(features)
    
# Transform into tensor : 
node_feature_tensor = torch.tensor(node_feature , dtype=torch.float)

# Write file : 
numpy_array = node_feature_tensor.numpy()
df = pd.DataFrame(numpy_array)
df.to_csv(f"{path_work}/node_features.csv", index=False, header=False)

26766it [02:01, 219.57it/s]


> Make the Y file : 

In [56]:
y_file = [1] * len(edge_index)

# Transform into tensor : 
y_tensor = torch.tensor(y_file , dtype=torch.float)

# Write file : 
numpy_array = y_tensor.numpy()
df = pd.DataFrame(numpy_array)
df.to_csv(f"{path_work}/y_file.csv", index=False, header=False)

***
## Create the Data instance 

In [69]:
import torch
from torch_geometric.data import Data , DataLoader

edge_index = edge_index_tensor
x = node_feature_tensor
y = y_tensor

data = Data(x=x, edge_index=edge_index.t().contiguous(), y=y)

# print out the data instance
print(data)

data.validate(raise_on_error=True)
data.is_undirected()


Data(x=[26766, 1282], edge_index=[2, 19354], y=[19354])


False

In [70]:
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt

def visualize_graph(G, color):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                     node_color=color, cmap="Set2")
    plt.show()
    
data_sub = data[:50]
    
G = to_networkx(data_sub, to_undirected=False)
visualize_graph(G, color=data.y)

TypeError: unhashable type: 'slice'

***
# Transductive Learning

In [None]:
from torch_geometric.nn import GCNConv

class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        return x
    
class Classifier(torch.nn.Module):
    def forward(self, x, edge_index):
        edge_feat_start = x[edge_index[0]]
        edge_feat_end = x[edge_index[1]]
        return (edge_feat_start * edge_feat_end).sum(dim=-1)

class Model(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.gnn = GNN(in_channels, hidden_channels)
        self.classifier = Classifier()

    def forward(self, data):
        x = self.gnn(data.x, data.edge_index)
        pred = self.classifier(x, data.edge_index)
        return pred

model = Model(in_channels=1280, hidden_channels=64)

In [None]:
from torch.nn import Linear, ReLU
from torch_geometric.nn import Sequential, GCNConv

model = Sequential('x, edge_index', [
    (GCNConv(in_channels, 64), 'x, edge_index -> x'),
    ReLU(inplace=True),
    (GCNConv(64, 64), 'x, edge_index -> x'),
    ReLU(inplace=True),
    Linear(64, out_channels),
])

>Without negative loss sampling

In [None]:
# Define a DataLoader and loss function
data_loader = DataLoader(data_list, batch_size=32)
criterion = torch.nn.BCEWithLogitsLoss()

# Define an optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Loop over epochs
for epoch in range(100):
    # Training phase
    model.train()
    
    total_loss = 0
    for batch in data_loader:
        optimizer.zero_grad()
        out = model(batch)
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(data_loader)

    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")

> With negative loss sampling

In [None]:
from torch_geometric.utils import negative_sampling

# Define a DataLoader and loss function
data_loader = DataLoader(data_list, batch_size=32)
criterion = torch.nn.BCEWithLogitsLoss()
# Define an optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(100):
    model.train()
    total_loss = 0
    for batch in data_loader:
        optimizer.zero_grad()
        # Compute embeddings for nodes in batch
        node_embeddings = model(batch)
        # Compute positive score
        pos_edge_index = batch.edge_index
        pos_out = model.classifier(node_embeddings[pos_edge_index[0]], node_embeddings[pos_edge_index[1]])
        # Generate negative edges and compute negative score
        num_nodes = node_embeddings.size(0)
        neg_edge_index = negative_sampling(edge_index=pos_edge_index, num_nodes=num_nodes, num_neg_samples=pos_edge_index.size(1))
        neg_out = model.classifier(node_embeddings[neg_edge_index[0]], node_embeddings[neg_edge_index[1]])
        # Compute the loss
        pos_loss = criterion(pos_out, batch.y)
        neg_loss = criterion(neg_out, torch.zeros(neg_out.size()).to(device))
        loss = pos_loss + neg_loss
        
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(data_loader)

    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")

In [None]:
import torch
from torch_geometric.utils import negative_sampling
from sklearn.metrics import f1_score, precision_score, roc_auc_score

def train():
    model.train()

    # Negative sampling for each batch
    neg_edge_index = negative_sampling(edge_index, num_nodes=data.num_nodes)

    optimizer.zero_grad()

    # Get the model's predictions for positive and negative samples
    pos_out = model(data.x, edge_index)
    neg_out = model(data.x, neg_edge_index)

    pos_loss = F.binary_cross_entropy_with_logits(pos_out, torch.ones(pos_out.shape))
    neg_loss = F.binary_cross_entropy_with_logits(neg_out, torch.zeros(neg_out.shape))

    loss = pos_loss + neg_loss
    loss.backward()

    optimizer.step()

    # Returning the loss values for potential logging
    return pos_loss.item(), neg_loss.item()

def test():
    model.eval()

    # Negative sampling for the full dataset
    neg_edge_index = negative_sampling(edge_index, num_nodes=data.num_nodes)

    with torch.no_grad():
        # Get the model's predictions for positive and negative samples
        pos_out = model(data.x, edge_index)
        neg_out = model(data.x, neg_edge_index)

        pos_target = torch.ones(pos_out.shape[0])
        neg_target = torch.zeros(neg_out.shape[0])

        # Concatenate the predictions and their respective targets
        predictions = torch.cat([pos_out, neg_out]).numpy()
        targets = torch.cat([pos_target, neg_target]).numpy()

        # Calculate the metrics
        f1 = f1_score(targets, predictions.round())
        precision = precision_score(targets, predictions.round())
        auroc = roc_auc_score(targets, predictions)

    return f1, precision, auroc

# Validation strategy: split the edges and their respective y into train and validation sets.
edge_index_train, edge_index_val, y_train, y_val = train_test_split(edge_index.T, y, test_size=0.2)
edge_index_train = edge_index_train.T
edge_index_val = edge_index_val.T

# To monitor overfitting, you can compare the performance on the validation set.
best_val_auroc = 0

for epoch in range(1, 201):
    pos_loss, neg_loss = train()
    val_f1, val_precision, val_auroc = test(edge_index_val, y_val)

    # Save the model if it has the best validation performance so far
    if val_auroc > best_val_auroc:
        best_val_auroc = val_auroc
        torch.save(model.state_dict(), "best_model.pth")

    print(f"Epoch: {epoch}, Pos loss: {pos_loss}, Neg loss: {neg_loss}, Val F1: {val_f1}, Val Precision: {val_precision}, Val AUROC: {val_auroc}")


### The GNN architecture

> GCN layers

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import DataLoader
from torch_geometric.nn import NegativeSamplingLoss

class Net(torch.nn.Module):
    def __init__(self, num_node_features, hidden_channels):
        super(Net, self).__init__()
        self.conv1 = GCNConv(num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)

        # create edge embeddings by taking the dot product of the node embeddings
        start, end = edge_index
        edge_embeddings = torch.einsum("ef,ef->e", x[start], x[end])

        return edge_embeddings

# Create an instance of the model
model = Net(num_node_features=data.num_node_features, hidden_channels=64)

> GAT layers

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_mean_pool
from torch_geometric.data import DataLoader

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=8, dropout=0.6)
        self.conv2 = GATConv(8*hidden_channels, hidden_channels, heads=8, concat=False, dropout=0.6)
        self.lin1 = torch.nn.Linear(2*hidden_channels, hidden_channels)
        self.lin2 = torch.nn.Linear(hidden_channels, 1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        edge_attr = data.edge_attr

        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)

        # Get the embeddings of the start and end nodes of each edge
        start, end = edge_index[:, edge_attr==1]
        edge_embeddings = torch.cat([x[start], x[end]], dim=1)

        # Pass through a couple fully connected layers
        edge_embeddings = F.elu(self.lin1(edge_embeddings))
        edge_scores = torch.sigmoid(self.lin2(edge_embeddings))

        return edge_scores.view(-1)

Other Method Somehow

In [None]:
from torch_geometric.nn import SAGEConv, to_hetero

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.conv1 = SAGEConv(hidden_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        # Define a 2-layer GNN computation graph.
        # Use a *single* `ReLU` non-linearity in-between.
        # TODO:
        raise NotImplementedError

# Our final classifier applies the dot-product between source and destination
# node embeddings to derive edge-level predictions:
class Classifier(torch.nn.Module):
    def forward(self, x_user: Tensor, x_movie: Tensor, edge_label_index: Tensor) -> Tensor:
        # Convert node embeddings to edge-level representations:
        edge_feat_user = x_user[edge_label_index[0]]
        edge_feat_movie = x_movie[edge_label_index[1]]

        # Apply dot-product to get a prediction per supervision edge:
        return (edge_feat_user * edge_feat_movie).sum(dim=-1)


class Model(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        # Since the dataset does not come with rich features, we also learn two
        # embedding matrices for users and movies:
        self.movie_lin = torch.nn.Linear(20, hidden_channels)
        self.user_emb = torch.nn.Embedding(data["user"].num_nodes, hidden_channels)
        self.movie_emb = torch.nn.Embedding(data["movie"].num_nodes, hidden_channels)

        # Instantiate homogeneous GNN:
        self.gnn = GNN(hidden_channels)

        # Convert GNN model into a heterogeneous variant:
        self.gnn = to_hetero(self.gnn, metadata=data.metadata())

        self.classifier = Classifier()

    def forward(self, data: HeteroData) -> Tensor:
        x_dict = {
          "user": self.user_emb(data["user"].node_id),
          "movie": self.movie_lin(data["movie"].x) + self.movie_emb(data["movie"].node_id),
        } 

        # `x_dict` holds feature matrices of all node types
        # `edge_index_dict` holds all edge indices of all edge types
        x_dict = self.gnn(x_dict, data.edge_index_dict)

        pred = self.classifier(
            x_dict["user"],
            x_dict["movie"],
            data["user", "rates", "movie"].edge_label_index,
        )

        return pred

        
model = Model(hidden_channels=64)

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
from torch_geometric.nn import NegativeSamplingLoss

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.loss_fn = NegativeSamplingLoss()

    def forward(self, x, edge_index, batch=None):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        return x

class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=2, dropout=0.6)
        self.conv2 = GATConv(2*hidden_channels, hidden_channels, heads=2, dropout=0.6)
        self.conv3 = GATConv(2*hidden_channels, hidden_channels, heads=2, concat=False, dropout=0.6)
        self.loss_fn = NegativeSamplingLoss()

    def forward(self, x, edge_index):
        x = F.elu(self.conv1(x, edge_index))
        x = F.elu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        return x

In [None]:
from torch_geometric.data import DataLoader
from torch_geometric.utils import negative_sampling
from torch.nn import BCEWithLogitsLoss
import numpy as np

# Initialize your model
model = GAT(num_node_features, hidden_channels)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = BCEWithLogitsLoss()

data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

for epoch in range(epochs):
    for batch in data_loader:
        model.train()
        optimizer.zero_grad()
        # Forward pass
        out = model(batch.x, batch.edge_index)
        # Compute scores for positive pairs
        pos_out = torch.sum(out[batch.edge_index[0]] * out[batch.edge_index[1]], dim=-1)
        # Negative sampling: for each src node in edge_index, we sample one
        # negative target node and add it to the edge_index tensor.
        neg_edge_index = negative_sampling(batch.edge_index, num_nodes=batch.num_nodes)
        neg_out = torch.sum(out[neg_edge_index[0]] * out[neg_edge_index[1]], dim=-1)
        # The final output tensor should have this form: [pos1, neg1, pos2, neg2, pos3, neg3, ...]
        # where pos{i} is the score for the i-th positive pair, and neg{i} is the score for the i-th negative pair.
        total_out = torch.stack([pos_out, neg_out], dim=1).view(-1)
        # The target tensor should have this form: [1, 0, 1, 0, 1, 0, ...]
        total_target = torch.tensor(np.repeat([1, 0], pos_out.shape[0]), dtype=torch.float)
        # Compute loss and backpropagate
        loss = criterion(total_out, total_target)
        loss.backward()
        optimizer.step()

        print('Epoch: {:03d}, Loss: {:.4f}'.format(epoch, loss.item()))

***
### The training : 

In [None]:
from torch_geometric.data import Data
from torch_geometric.utils import negative_sampling

# Prepare the graph data
x = ...  # Node features tensor (shape: [num_nodes, num_features])
edge_index = ...  # Edge indices tensor (shape: [2, num_edges])

data = Data(x=x, edge_index=edge_index)

# Create positive and negative edge samples
pos_edge_index = data.edge_index
neg_edge_index = negative_sampling(edge_index, num_nodes=data.num_nodes, num_neg_samples=data.num_edges)

# Train the model
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.BCEWithLogitsLoss()

for epoch in range(200):
    model.train()
    optimizer.zero_grad()
    node_embeddings = model(data.x, data.edge_index)

    pos_edge_preds = (node_embeddings[pos_edge_index[0]] * node_embeddings[pos_edge_index[1]]).sum(dim=1)
    neg_edge_preds = (node_embeddings[neg_edge_index[0]] * node_embeddings[neg_edge_index[1]]).sum(dim=1)

    edge_preds = torch.cat([pos_edge_preds, neg_edge_preds], dim=0)
    edge_labels = torch.cat([torch.ones(pos_edge_index.size(1)), torch.zeros(neg_edge_index.size(1))], dim=0)

    loss = criterion(edge_preds, edge_labels)
    loss.backward()
    optimizer.step()

# Predict edge probabilities
model.eval()
with torch.no_grad():
    node_embeddings = model(data.x, data.edge_index)
    edge_probs = torch.sigmoid((node_embeddings[edge_index[0]] * node_embeddings[edge_index[1]]).sum(dim=1))


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GAT(in_channels, hidden_channels).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)

train_dataset = ...
val_dataset = ...
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

def train():
    model.train()

    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = F.binary_cross_entropy(out, data.y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * data.num_graphs
    return total_loss / len(train_dataset)


def validate():
    model.eval()

    total_loss = 0
    for data in val_loader:
        data = data.to(device)
        with torch.no_grad():
            pred = model(data)
        loss = F.binary_cross_entropy(pred, data.y)
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(val_dataset)

for epoch in range(1, 101):
    loss = train()
    val_loss = validate()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val Loss: {val_loss:.4f}')


***
## Transductive predictions : 


In [None]:
def predict():
    model.eval()

    data = full_dataset.to(device)  # The whole graph
    with torch.no_grad():
        preds = model(data)  # predictions for all edges

    # Get predictions only for the desired edges
    desired_edge_preds = preds[data.y == -1] 

    return desired_edge_preds

# After training the model
predictions = predict()

A1: Inductive Learning

In [None]:
from torch_geometric.data import DataLoader

# Define your model and optimizer
model = GCN(num_node_features=1280, num_classes=3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Create DataLoader instances for your training and test datasets
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

# Training loop
for epoch in range(100):  # 100 epochs
    for data in train_loader:
        optimizer.zero_grad()
        out = model(data)
        loss = torch.nn.functional.cross_entropy(out, data.y)
        loss.backward()
        optimizer.step()
        
model.eval()
for data in test_loader:
    with torch.no_grad():
        predictions = model(data)

A2: Transductive Learning

In [None]:
model = GCN(num_node_features=1280, num_classes=3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(100):  # 100 epochs
    optimizer.zero_grad()
    out = model(data)
    loss = torch.nn.functional.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    
model.eval()
with torch.no_grad():
    predictions = model(data)
unlabeled_predictions = predictions[data.test_mask]

***

In [None]:
dataset = dataset.shuffle()
train_dataset = dataset[:540]
test_dataset = dataset[540:]
loader = DataLoader(dataset, batch_size=32, shuffle=True)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    
    
import torch_geometric.transforms as T
from torch_geometric.data import HeteroData

    # Maybe would be nice to 
# We also need to make sure to add the reverse edges from movies to users
# in order to let a GNN be able to pass messages in both directions.
# We can leverage the `T.ToUndirected()` transform for this from PyG: