In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

In [2]:
df = pd.read_pickle('dataset_df.pkl')
df.head(1)

Unnamed: 0,pmid,pmid_text,pmid_embeddings,sr_preds,sr_pairs,sr_pairs_embeddings,negative_sr_pairs,negative_sr_pairs_embeddings
0,23622459,Mitral annulus calcification and sudden death....,"[-0.746443, 0.33572936, -0.23443037, -0.280785...","[[], [{'subj_id': 'C0018787', 'subj_name': 'He...","[(Degenerative disorder, Heart), (Degenerative...","[([-0.34357956, 0.12238737, -0.69660324, -0.57...","[(Degenerative disorder, Blood), (Degenerative...","[([-0.34357956, 0.12238737, -0.69660324, -0.57..."


In [69]:
def prepare_data(df):
    data = []
    labels = []
    pair_info = []
    
    # Positive pairs
    for pmid_emb, pair_embs, pmid, pairs in zip(df['pmid_embeddings'], df['sr_pairs_embeddings'], df['pmid'], df['sr_pairs']):
        for pair_emb, pair in zip(pair_embs, pairs):
            combined_emb = np.concatenate((pmid_emb, pair_emb[0], pair_emb[1]))
            data.append(combined_emb)
            labels.append(1)
            pair_info.append(tuple([pmid, pair]))
    
    # Negative pairs
    for pmid_emb, neg_pairs, pmid, neg_pairs_list in zip(df['pmid_embeddings'], df['negative_sr_pairs_embeddings'], df['pmid'], df['negative_sr_pairs']):
        for neg_pair_emb, neg_pair in zip(neg_pairs, neg_pairs_list):
            combined_emb = np.concatenate((pmid_emb, neg_pair_emb[0], neg_pair_emb[1]))
            data.append(combined_emb)
            labels.append(0)
            pair_info.append(tuple([pmid, neg_pair]))
    
    return np.array(data), np.array(labels), pair_info

In [70]:
X, y, pair_info = prepare_data(df)

In [71]:
class RelationDataset(Dataset):
    def __init__(self, X, y, pair_info):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)
        self.pair_info = pair_info

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.pair_info[idx]

In [72]:
X_train, X_test, y_train, y_test, pair_info_train, pair_info_test = train_test_split(X, y, pair_info, test_size=0.2, random_state=42)

train_dataset = RelationDataset(X_train, y_train, pair_info_train)
test_dataset = RelationDataset(X_test, y_test, pair_info_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [73]:
class RelationClassifier(nn.Module):
    def __init__(self, input_dim):
        super(RelationClassifier, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

In [74]:
input_dim = X.shape[1]
model = RelationClassifier(input_dim)

In [75]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

cuda


RelationClassifier(
  (model): Sequential(
    (0): Linear(in_features=2304, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=1, bias=True)
    (5): Sigmoid()
  )
)

In [76]:
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [77]:
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for X_batch, y_batch, _ in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        outputs = model(X_batch).squeeze()
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}")

Epoch [1/10], Loss: 0.3665
Epoch [2/10], Loss: 0.1653
Epoch [3/10], Loss: 0.1325
Epoch [4/10], Loss: 0.1053
Epoch [5/10], Loss: 0.1017
Epoch [6/10], Loss: 0.1227
Epoch [7/10], Loss: 0.0837
Epoch [8/10], Loss: 0.0761
Epoch [9/10], Loss: 0.0726
Epoch [10/10], Loss: 0.0788


In [78]:
model.eval()
correct = 0
predicted_pairs = []
with torch.no_grad():
    for X_batch, y_batch, info_batch in test_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        outputs = model(X_batch).squeeze()
        predictions = (outputs > 0.5).float()
        correct += (predictions == y_batch).sum().item()

        for info, output in zip(info_batch, outputs):
            if output > 0.5:
                predicted_pairs.append(info)

In [79]:
accuracy = correct / len(test_dataset)
print(f"Test Accuracy: {accuracy * 100:.2f}%")

Test Accuracy: 92.85%


In [80]:
print("Positive pairs identified")
predicted_pairs

Positive pairs identified


[('3478483',
  '7253702',
  '22484800',
  '26283888',
  '3196525',
  '30120647',
  '35431250',
  '2678040',
  '33963949',
  '35971326',
  '32838261',
  '2491934',
  '17893502',
  '25386903',
  '8643441',
  '30683148',
  '486226',
  '30252781',
  '3028251',
  '19286696',
  '24705495',
  '30683148',
  '32743309',
  '6095477',
  '36334397',
  '23742124',
  '1317734',
  '9675323',
  '26189138',
  '33413233',
  '31606874',
  '27479077'),
 [('Diabetes Mellitus',
   'Capillary blood',
   'Water Quality',
   'Haemophilus influenzae',
   'Dietary Modification',
   'Detection',
   'Agenesis',
   'Cholecystectomy procedure',
   'Cavia',
   'Cancer Stem Cells',
   'Hemoglobin',
   'Hand',
   'Patients',
   'Cryoelectron Microscopy',
   'Kidney Failure, Acute',
   'Memory',
   'Clone Cells',
   'Delirium',
   'Migration, Cell',
   'Immunoglobulin deficiency',
   'Anti-Inflammatory Agents, Non-Steroidal',
   'Chronic Kidney Diseases',
   'Vertebral Artery Dissection',
   'CAT scan of head',
   'Chro