In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import average_precision_score
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

In [2]:
label_df = pd.read_csv("/home/kongge/projects/Phd/data/all_labels.csv")
prs_df = pd.read_csv('/home/kongge/projects/Phd/data/Only_PGS.csv')
data = pd.merge(prs_df, label_df, left_on='FID', right_on='eid')

In [3]:
data.shape

(336979, 187)

In [4]:
feature_columns = prs_df.columns.drop('FID')
label_columns = label_df.columns.drop('eid')
X = data[feature_columns].values
y = data[label_columns].values

In [5]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [6]:
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32)
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)

In [7]:
class MLP(nn.Module):
    def __init__(self, input_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, output_size)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x

In [17]:
input_size = X_train.shape[1]
output_size = y_train.shape[1]
model = MLP(input_size, output_size)

In [24]:
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

In [25]:
num_epochs = 200
model.train()
for epoch in range(num_epochs):
    running_loss = 0.0
    for X_batch, y_batch in train_loader:
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

Epoch [1/200], Loss: 0.0568
Epoch [2/200], Loss: 0.0568
Epoch [3/200], Loss: 0.0568
Epoch [4/200], Loss: 0.0568
Epoch [5/200], Loss: 0.0568
Epoch [6/200], Loss: 0.0568
Epoch [7/200], Loss: 0.0568
Epoch [8/200], Loss: 0.0567
Epoch [9/200], Loss: 0.0567
Epoch [10/200], Loss: 0.0568
Epoch [11/200], Loss: 0.0567
Epoch [12/200], Loss: 0.0567
Epoch [13/200], Loss: 0.0567
Epoch [14/200], Loss: 0.0567
Epoch [15/200], Loss: 0.0568
Epoch [16/200], Loss: 0.0567
Epoch [17/200], Loss: 0.0567
Epoch [18/200], Loss: 0.0567
Epoch [19/200], Loss: 0.0567
Epoch [20/200], Loss: 0.0567
Epoch [21/200], Loss: 0.0567
Epoch [22/200], Loss: 0.0567
Epoch [23/200], Loss: 0.0567
Epoch [24/200], Loss: 0.0567
Epoch [25/200], Loss: 0.0567
Epoch [26/200], Loss: 0.0567
Epoch [27/200], Loss: 0.0567
Epoch [28/200], Loss: 0.0567
Epoch [29/200], Loss: 0.0567
Epoch [30/200], Loss: 0.0567
Epoch [31/200], Loss: 0.0567
Epoch [32/200], Loss: 0.0567
Epoch [33/200], Loss: 0.0567
Epoch [34/200], Loss: 0.0567
Epoch [35/200], Loss: 0

In [29]:
model.eval()  # 切换到评估模式
with torch.no_grad():
    y_pred = model(X_test_tensor).numpy()
    y_true = y_test_tensor.numpy()

    aupr_results = {}

    for i in range(y_true.shape[1]):
        aupr = average_precision_score(y_true[:, i], y_pred[:, i])
        disease_name = label_columns[i]  # 获取疾病名称
        aupr_results[disease_name] = aupr

    mean_aupr = np.mean(list(aupr_results.values()))

  recall = tps / tps[-1]
  recall = tps / tps[-1]
  recall = tps / tps[-1]


In [31]:
aupr_results = {k: v for k, v in aupr_results.items() if not np.isnan(v)}
sorted_aupr_results = sorted(aupr_results.items(), key=lambda item: item[1], reverse=True)

In [32]:
sorted_aupr_results

[('Class_I10', 0.5912832565184755),
 ('Class_J45', 0.31490851466463526),
 ('Class_E03', 0.2942917891416554),
 ('Class_I25', 0.2282831738292497),
 ('Class_F17', 0.21319730023049605),
 ('Class_I48', 0.2029449849522778),
 ('Class_E11', 0.1832767216951004),
 ('Class_I20', 0.1656175298760879),
 ('Class_I21', 0.1163326634024214),
 ('Class_E14', 0.1122783392267296),
 ('Class_D64', 0.09374645342883481),
 ('Class_D50', 0.08475763876397645),
 ('Class_I50', 0.08152256336170131),
 ('Class_A09', 0.0667528275677719),
 ('Class_I63', 0.054549659203597015),
 ('Class_C50', 0.05247698403093434),
 ('Class_F10', 0.04446388752997917),
 ('Class_M05', 0.04403705871575638),
 ('Class_E05', 0.03267414788706921),
 ('Class_I08', 0.029732646551248),
 ('Class_F00', 0.02833188793902882),
 ('Class_F05', 0.027605111426185654),
 ('Class_C18', 0.024829858121088583),
 ('Class_E10', 0.024008670085211785),
 ('Class_F03', 0.023921524685292975),
 ('Class_E16', 0.02082497780809134),
 ('Class_A04', 0.020208493892878855),
 ('Cla

In [35]:
y_test_df = pd.DataFrame(y_test, columns=label_columns)
test_counts = y_test_df.sum(axis=0)
test_counts_dict = test_counts.to_dict()

In [37]:
sorted_test_counts_dict = sorted(test_counts_dict.items(), key=lambda item: item[1], reverse=True)

In [38]:
sorted_test_counts_dict

[('Class_I10', 27119),
 ('Class_J45', 9971),
 ('Class_I25', 7991),
 ('Class_I48', 6100),
 ('Class_F17', 5897),
 ('Class_E11', 5825),
 ('Class_E03', 5442),
 ('Class_I20', 5112),
 ('Class_D64', 4756),
 ('Class_D50', 4236),
 ('Class_A09', 3608),
 ('Class_I21', 3424),
 ('Class_E14', 3273),
 ('Class_I50', 2923),
 ('Class_C50', 2585),
 ('Class_I63', 2417),
 ('Class_F10', 1859),
 ('Class_M05', 1774),
 ('Class_C18', 1436),
 ('Class_I08', 1260),
 ('Class_F05', 1228),
 ('Class_E05', 1160),
 ('Class_A04', 1113),
 ('Class_A08', 775),
 ('Class_E04', 729),
 ('Class_F03', 719),
 ('Class_E16', 666),
 ('Class_E10', 652),
 ('Class_F00', 623),
 ('Class_G20', 618),
 ('Class_D51', 512),
 ('Class_E07', 416),
 ('Class_G03', 357),
 ('Class_A15', 320),
 ('Class_F01', 314),
 ('Class_I07', 307),
 ('Class_I12', 265),
 ('Class_C56', 264),
 ('Class_D63', 263),
 ('Class_D61', 240),
 ('Class_F06', 229),
 ('Class_I00', 220),
 ('Class_E06', 207),
 ('Class_D52', 158),
 ('Class_F20', 157),
 ('Class_F02', 148),
 ('Class_I

In [None]:
# 患病人数越多的类，aupr的指数越高。