In [32]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, average_precision_score
import torch.nn.functional as F

In [65]:
df = pd.read_pickle('../data/1000_pmid_dataset.pkl')
df.head(1)

Unnamed: 0,pmid,pmid_text,pmid_embeddings,sr_preds,sr_pairs,sr_pairs_embeddings,negative_sr_pairs,negative_sr_pairs_embeddings
0,23622459,Mitral annulus calcification and sudden death....,"[-0.746443, 0.33572936, -0.23443037, -0.280785...","[[], [{'subj_id': 'C0018787', 'subj_name': 'He...","[(Degenerative disorder, Heart), (Degenerative...","[([-0.34357956, 0.12238737, -0.69660324, -0.57...","[(Degenerative disorder, Blood), (Degenerative...","[([-0.34357956, 0.12238737, -0.69660324, -0.57..."


In [66]:
def prepare_data(df):
    data = []
    labels = []
    pair_info = []
    
    # Positive pairs
    for pmid_emb, pair_embs in zip(df['pmid_embeddings'], df['sr_pairs_embeddings']):
        for pair_emb in pair_embs:
            combined_emb = np.concatenate((pmid_emb, pair_emb[0], pair_emb[1]))
            data.append(combined_emb)
            labels.append(1)
    
    # Negative pairs
    for pmid_emb, neg_pairs_embs in zip(df['pmid_embeddings'], df['negative_sr_pairs_embeddings']):
        for neg_pair_emb in neg_pairs_embs:
            combined_emb = np.concatenate((pmid_emb, neg_pair_emb[0], neg_pair_emb[1]))
            data.append(combined_emb)
            labels.append(0)
    
    return np.array(data), np.array(labels)

In [67]:
X, y = prepare_data(df)

In [68]:
class RelationDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [69]:
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.2, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

train_dataset = RelationDataset(X_train, y_train)
test_dataset = RelationDataset(X_test, y_test)
val_dataset = RelationDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [70]:
class RelationClassifier(nn.Module):
    def __init__(self, input_dim):
        super(RelationClassifier, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [71]:
input_dim = X.shape[1]
model = RelationClassifier(input_dim)

In [72]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

cuda


RelationClassifier(
  (model): Sequential(
    (0): Linear(in_features=2304, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

In [73]:
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [74]:
def evaluate(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for X_batch, y_batch in dataloader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)
            
            output= model(X_batch).squeeze()
            loss = criterion(output, y_batch)
            total_loss += loss.item()

            preds = output.squeeze().cpu().numpy()
            labels = y_batch.cpu().numpy()

            all_preds.extend(preds)
            all_labels.extend(labels)

    avg_loss = total_loss / len(dataloader)
    roc_auc = roc_auc_score(all_labels, all_preds)
    pr_auc = average_precision_score(all_labels, all_preds)

    return avg_loss, roc_auc, pr_auc

In [75]:
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        outputs = model(X_batch).squeeze()
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}")
        
    val_loss, val_roc_auc, val_pr_auc = evaluate(model, val_loader, criterion)

    print(f"Epoch {epoch+1}/{num_epochs} - Validation Loss: {val_loss:.4f}, ROC AUC: {val_roc_auc:.4f}, PR AUC: {val_pr_auc:.4f}\n")

Epoch [1/10], Loss: 0.3762
Epoch 1/10 - Validation Loss: 0.1996, ROC AUC: 0.9743, PR AUC: 0.9785

Epoch [2/10], Loss: 0.1661
Epoch 2/10 - Validation Loss: 0.1465, ROC AUC: 0.9883, PR AUC: 0.9888

Epoch [3/10], Loss: 0.1310
Epoch 3/10 - Validation Loss: 0.1461, ROC AUC: 0.9874, PR AUC: 0.9879

Epoch [4/10], Loss: 0.1092
Epoch 4/10 - Validation Loss: 0.1687, ROC AUC: 0.9875, PR AUC: 0.9880

Epoch [5/10], Loss: 0.1002
Epoch 5/10 - Validation Loss: 0.2126, ROC AUC: 0.9884, PR AUC: 0.9888

Epoch [6/10], Loss: 0.0876
Epoch 6/10 - Validation Loss: 0.1718, ROC AUC: 0.9879, PR AUC: 0.9881

Epoch [7/10], Loss: 0.0794
Epoch 7/10 - Validation Loss: 0.1477, ROC AUC: 0.9889, PR AUC: 0.9890

Epoch [8/10], Loss: 0.0793
Epoch 8/10 - Validation Loss: 0.1657, ROC AUC: 0.9864, PR AUC: 0.9869

Epoch [9/10], Loss: 0.0814
Epoch 9/10 - Validation Loss: 0.1470, ROC AUC: 0.9885, PR AUC: 0.9888

Epoch [10/10], Loss: 0.0772
Epoch 10/10 - Validation Loss: 0.1487, ROC AUC: 0.9899, PR AUC: 0.9903



In [62]:
model.eval()
correct = 0
predicted_pairs = []
with torch.no_grad():
    for X_batch, y_batch in test_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        outputs = model(X_batch).squeeze()
        predictions = (outputs > 0.5).float()
        correct += (predictions == y_batch).sum().item()

In [63]:
accuracy = correct / len(test_dataset)
print(f"Test Accuracy: {accuracy * 100:.2f}%")

Test Accuracy: 94.15%


In [76]:
test_loss, test_roc_auc, test_pr_auc = evaluate(model, test_loader, criterion)
print(f"Test Loss: {test_loss}\nTest ROC AUC: {test_roc_auc}\nTest PR AUC: {test_pr_auc}")

Test Loss: 0.15265478007495403
Test ROC AUC: 0.990186578507372
Test PR AUC: 0.9926108201099503


In [64]:
torch.save(model.state_dict(), '../model/1000_pmid_model.pth')