In [2]:
import numpy as np
import pandas as pd
import networkx as nx
import torch
from torch_geometric.nn import Node2Vec
from torch_geometric.data import Data

# Load data

In [3]:
bp_db = pd.read_csv("../data/sub_bp_db.csv", index_col=0)
bp_db.head()

Unnamed: 0,ENSEMBL,GO
0,ENSG00000000003,GO:0039532
1,ENSG00000000005,GO:0001937
2,ENSG00000000005,GO:0016525
3,ENSG00000000419,GO:0006488
4,ENSG00000000419,GO:0006506


In [4]:
counts1 = pd.read_csv("../data/counts1.csv", index_col=0)
counts1.head()

Unnamed: 0,ENSG00000000003,ENSG00000000005,ENSG00000000419,ENSG00000000457,ENSG00000000460,ENSG00000000938,ENSG00000000971,ENSG00000001036,ENSG00000001084,ENSG00000001167,...,ENSGR0000167393,ENSGR0000169084,ENSGR0000169093,ENSGR0000178605,ENSGR0000182378,ENSGR0000185291,ENSGR0000198223,ENSGR0000214717,ENSGR0000223511,ENSGR0000223773
089357B,14,7,103,241,72,2057,30,60,207,367,...,1,0,0,0,0,0,0,0,0,0
089366A,11,2,194,511,110,3325,36,111,186,530,...,0,0,0,0,0,0,1,0,0,1
089412B,8,0,312,450,106,3751,45,160,325,653,...,0,0,0,0,0,0,1,0,0,0
089425B,9,0,135,496,133,2758,26,93,182,620,...,0,0,0,0,0,0,0,0,0,0
089687A,4,0,89,267,49,2181,24,75,122,263,...,0,0,0,0,0,0,1,0,0,0


In [5]:
pheno1 = pd.read_csv("../data/pheno1.csv", index_col=0)
pheno1.drop(["diagnosis"], axis=1, inplace=True)
pheno1["condition"] = pheno1["condition"].apply(lambda x: 0 if x == "Control" else 1)
pheno1.head()

Unnamed: 0,age,sex,lithium,condition
089357B,18,F,0,0
089366A,19,F,0,0
089412B,23,F,0,0
089425B,47,F,0,0
089687A,52,F,0,0


# Train Node Embeddings from Graph

In [6]:
bp_graph = nx.read_gml("../data/sub_graph.gml")

In [7]:
bp_db_go = sorted(set(bp_graph.nodes))

map_int_go = {int(idx): go for idx, go in enumerate(bp_db_go)}
map_go_int = {go: idx for idx, go in map_int_go.items()}

_graph = nx.relabel_nodes(bp_graph, map_go_int, copy=False)
edge_index = torch.tensor(list(_graph.edges()), dtype=torch.long).t().contiguous()

data = Data(x=None, edge_index=edge_index)

In [8]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = Node2Vec(data.edge_index, embedding_dim=8, walk_length=20,
                 context_size=10, walks_per_node=10,
                 num_negative_samples=1, p=1, q=1, sparse=True).to(device)

loader = model.loader(batch_size=4, shuffle=True)
optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.001)

In [9]:
def train():
    model.train()
    total_loss = 0
    for pos_rw, neg_rw in loader:
        optimizer.zero_grad()
        loss = model.loss(pos_rw.to(device), neg_rw.to(device))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

In [10]:
for epoch in range(1, 50):
    loss = train()
    #acc = test()
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')

Epoch: 01, Loss: 1.3162
Epoch: 02, Loss: 1.1122
Epoch: 03, Loss: 1.0225
Epoch: 04, Loss: 0.9686
Epoch: 05, Loss: 0.9325
Epoch: 06, Loss: 0.9060
Epoch: 07, Loss: 0.8861
Epoch: 08, Loss: 0.8699
Epoch: 09, Loss: 0.8570
Epoch: 10, Loss: 0.8467
Epoch: 11, Loss: 0.8373
Epoch: 12, Loss: 0.8297
Epoch: 13, Loss: 0.8228
Epoch: 14, Loss: 0.8174
Epoch: 15, Loss: 0.8121
Epoch: 16, Loss: 0.8074
Epoch: 17, Loss: 0.8036
Epoch: 18, Loss: 0.7999
Epoch: 19, Loss: 0.7969
Epoch: 20, Loss: 0.7936
Epoch: 21, Loss: 0.7914
Epoch: 22, Loss: 0.7888
Epoch: 23, Loss: 0.7866
Epoch: 24, Loss: 0.7849
Epoch: 25, Loss: 0.7828
Epoch: 26, Loss: 0.7812
Epoch: 27, Loss: 0.7798
Epoch: 28, Loss: 0.7785
Epoch: 29, Loss: 0.7772
Epoch: 30, Loss: 0.7761
Epoch: 31, Loss: 0.7752
Epoch: 32, Loss: 0.7741
Epoch: 33, Loss: 0.7733
Epoch: 34, Loss: 0.7725
Epoch: 35, Loss: 0.7717
Epoch: 36, Loss: 0.7712
Epoch: 37, Loss: 0.7706
Epoch: 38, Loss: 0.7700
Epoch: 39, Loss: 0.7695
Epoch: 40, Loss: 0.7687
Epoch: 41, Loss: 0.7684
Epoch: 42, Loss:

In [11]:
go_embedding = model()  # go_index -> go_embedding
go_embedding.shape

torch.Size([15518, 8])

In [12]:
_embedding = go_embedding.cpu().detach().numpy()

map_int_gene = {int(idx): gene for idx, gene in enumerate(counts1.columns)}
map_gene_int = {gene: idx for idx, gene in map_int_gene.items()}

# convert to gene_index -> go_embedding
embedding_gene = np.zeros((len(counts1.columns), _embedding.shape[1]))
for idx, gene in map_int_gene.items():
    if gene in map_go_int:
        embedding_gene[idx] = _embedding[map_go_int[gene]]

embedding_gene = torch.tensor(embedding_gene, dtype=torch.float32, device=device)
embedding_gene.shape

torch.Size([52645, 8])

# process the data

In [13]:
# add age, sex, lithium of pheno1 to counts1
tmp_pheno1 = pheno1[["age", "sex", "lithium"]].apply(lambda x: x.replace("M", 0).replace("F", 1))  # chagne sex to 0, 1
counts1_merge = pd.merge(counts1, tmp_pheno1, left_index=True, right_index=True)

counts1_merge = (counts1_merge - counts1_merge.mean()) / counts1_merge.std()
counts1_merge.head()

Unnamed: 0,ENSG00000000003,ENSG00000000005,ENSG00000000419,ENSG00000000457,ENSG00000000460,ENSG00000000938,ENSG00000000971,ENSG00000001036,ENSG00000001084,ENSG00000001167,...,ENSGR0000178605,ENSGR0000182378,ENSGR0000185291,ENSGR0000198223,ENSGR0000214717,ENSGR0000223511,ENSGR0000223773,age,sex,lithium
089357B,2.110648,4.704691,-0.899272,-1.436269,-1.13286,-1.111674,-0.8044,-1.458287,-0.62945,-1.0402,...,-0.152399,-0.453101,-0.212458,-0.502577,-0.45051,-0.222561,-0.226593,-2.043569,0.879916,-0.720677
089366A,1.348549,1.129061,-0.091576,0.088378,-0.357019,-0.170067,-0.695273,-0.660205,-0.785903,-0.267949,...,-0.152399,-0.453101,-0.212458,-0.146685,-0.45051,-0.222561,2.492525,-1.972198,0.879916,-0.720677
089412B,0.58645,-0.30119,0.955766,-0.256079,-0.438687,0.146277,-0.531583,0.10658,0.249663,0.314793,...,-0.152399,-0.453101,-0.212458,-0.146685,-0.45051,-0.222561,-0.226593,-1.686712,0.879916,-0.720677
089425B,0.840483,-0.30119,-0.615247,0.003676,0.112569,-0.591117,-0.877151,-0.941881,-0.815703,0.158447,...,-0.152399,-0.453101,-0.212458,-0.502577,-0.45051,-0.222561,-0.226593,0.026202,0.879916,-0.720677
089687A,-0.429682,-0.30119,-1.023533,-1.289451,-1.602448,-1.019593,-0.913527,-1.223557,-1.26271,-1.532924,...,-0.152399,-0.453101,-0.212458,-0.146685,-0.45051,-0.222561,-0.226593,0.383059,0.879916,-0.720677


In [14]:
bp_db_genes = set(bp_db.ENSEMBL)

In [15]:
dataset = []
input_dim = None

for row in counts1_merge.iterrows():
    idx = row[0]
    values = row[1]
    _data = {
        "gene_with_go_idx": [],
        "gene_with_go_value": [],
        "gene_without_go_value": [],
        "other_info": []
    }
    for k, v in values.items():
        if k in bp_db_genes:
            _data["gene_with_go_idx"].append(map_gene_int[k])
            _data["gene_with_go_value"].append(v)
        elif k in map_gene_int:
            _data["gene_without_go_value"].append(v)
        else:
            _data["other_info"].append(v)
    dataset.append(_data)
    if input_dim is None:
        input_dim = len(_data["gene_with_go_idx"]) * 2 + len(_data["gene_without_go_value"]) + len(_data["other_info"])

In [16]:
from torch.utils.data import Dataset, DataLoader


class MyDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        sample = self.X[idx]
        gene_with_go_idx = torch.tensor(sample['gene_with_go_idx'], dtype=torch.long, device=device)
        gene_with_go_value = torch.tensor(sample['gene_with_go_value'], dtype=torch.float, device=device)
        gene_without_go_value = torch.tensor(sample['gene_without_go_value'], dtype=torch.float, device=device)
        other_info = torch.tensor(sample['other_info'], dtype=torch.float, device=device)

        return {
            'gene_with_go_idx': gene_with_go_idx,
            'gene_with_go_value': gene_with_go_value,
            'gene_without_go_value': gene_without_go_value,
            'other_info': other_info
        }, torch.tensor(self.y[idx], dtype=torch.long, device=device)

In [17]:
from sklearn.model_selection import train_test_split

X = dataset
y = pheno1["condition"].values

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print(len(X_train), len(X_test))

355 89


In [18]:
train_dataset = MyDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

test_dataset = MyDataset(X_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

In [19]:
import torch.nn as nn
import torch.nn.functional as F


class MyModel(nn.Module):
    def __init__(self, node_embedding, input_dim, output_dim):
        super().__init__()

        self.embedding = node_embedding
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, output_dim)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()

    def forward(self, X):
        gene_with_go_idx = X['gene_with_go_idx']  # (B, N_1)
        gene_with_go_value = X['gene_with_go_value']  # (B, N_1)
        gene_without_go_value = X['gene_without_go_value']  # (B, N_2)
        other_info = X['other_info']  # (B, N_3)

        gene_with_go_embedding = self.embedding[gene_with_go_idx]  # (B, N_1, 8)
        gene_with_go_embedding = gene_with_go_embedding * gene_with_go_value.unsqueeze(-1)  # (B, N_1, 8)
        gene_with_go_embedding = gene_with_go_embedding.mean(dim=2)  # (B, N_1)

        output = torch.cat([gene_with_go_embedding, gene_with_go_value, gene_without_go_value, other_info],
                           dim=1)  # (B, N_1 * 2 + N_2 + N_3)

        output = self.fc1(output)  # (B, 256)
        output = self.relu1(output)
        output = self.fc2(output)  # (B, 128)
        output = self.relu2(output)
        output = self.fc3(output)  # (B, 2)
        return output


model = MyModel(node_embedding=embedding_gene, input_dim= input_dim, output_dim=2)

In [20]:
import torch.optim as optim

# Assume model is an instance of our RNN class
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
model = model.to(device)
criterion = criterion.to(device)


def train(model, iterator, optimizer, criterion):
    epoch_loss = 0
    total_correct = 0
    total_instances = 0

    model.train()

    for X, y in iterator:
        optimizer.zero_grad()

        # Assuming your model's forward method automatically handles padding, then no need to pack sequence here
        predictions = model(X)

        loss = criterion(predictions, y)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        # Compute the number of correct predictions
        _, predicted_classes = predictions.max(dim=1)
        correct_predictions = (predicted_classes == y).float()  # Convert to float for summation
        total_correct += correct_predictions.sum().item()
        total_instances += y.size(0)

    epoch_acc = total_correct / total_instances

    return epoch_loss / len(iterator), epoch_acc


train_loss, train_acc = train(model, train_loader, optimizer, criterion)

In [21]:
def evaluate(model, iterator, criterion):
    epoch_loss = 0
    total_correct = 0
    total_instances = 0

    model.eval()

    with torch.no_grad():
        for X, y in iterator:
            predictions = model(X)

            loss = criterion(predictions, y)
            epoch_loss += loss.item()

            # Compute the number of correct predictions
            _, predicted_classes = predictions.max(dim=1)
            correct_predictions = (predicted_classes == y).float()  # Convert to float for summation
            total_correct += correct_predictions.sum().item()
            total_instances += y.size(0)

    epoch_acc = total_correct / total_instances
    return epoch_loss / len(iterator), epoch_acc

In [22]:
import time


def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = elapsed_time - (elapsed_mins * 60)
    return elapsed_mins, elapsed_secs, elapsed_time

In [23]:
N_EPOCHS = 200

best_valid_loss = float('inf')

best_valid_acc = 0

elapsed_times = []

for epoch in range(N_EPOCHS):

    start_time = time.time()

    train_loss, train_acc = train(model, train_loader, optimizer, criterion)
    valid_loss, valid_acc = evaluate(model, test_loader, criterion)

    end_time = time.time()

    epoch_mins, epoch_secs, elapsed_time = epoch_time(start_time, end_time)
    elapsed_times.append(elapsed_time)
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        best_valid_acc = valid_acc
        #torch.save(model.state_dict(), '../data/final_model.pt')

    if epoch % 1 == 0:
        print(f'Epoch: {epoch + 1:02} | Epoch Time: {epoch_mins}m {epoch_secs:.3f}s')
        print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc * 100:.2f}%')
        print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc * 100:.2f}%')

avg_elapsed_time = sum(elapsed_times) / len(elapsed_times)
print(f'Avg Epoch Time: {avg_elapsed_time:.3f}s')
print(f'Best Val. Loss: {best_valid_loss:.3f} | Best Val. Acc: {best_valid_acc * 100:.2f}%')

Epoch: 01 | Epoch Time: 0m 2.818s
	Train Loss: 0.472 | Train Acc: 82.54%
	 Val. Loss: 0.769 |  Val. Acc: 66.29%
Epoch: 02 | Epoch Time: 0m 2.800s
	Train Loss: 0.363 | Train Acc: 87.61%
	 Val. Loss: 0.920 |  Val. Acc: 68.54%
Epoch: 03 | Epoch Time: 0m 2.242s
	Train Loss: 0.139 | Train Acc: 94.65%
	 Val. Loss: 0.641 |  Val. Acc: 74.16%
Epoch: 04 | Epoch Time: 0m 1.806s
	Train Loss: 0.068 | Train Acc: 98.03%
	 Val. Loss: 0.533 |  Val. Acc: 75.28%
Epoch: 05 | Epoch Time: 0m 1.897s
	Train Loss: 0.013 | Train Acc: 100.00%
	 Val. Loss: 0.488 |  Val. Acc: 77.53%
Epoch: 06 | Epoch Time: 0m 1.971s
	Train Loss: 0.004 | Train Acc: 100.00%
	 Val. Loss: 0.540 |  Val. Acc: 75.28%
Epoch: 07 | Epoch Time: 0m 1.974s
	Train Loss: 0.003 | Train Acc: 100.00%
	 Val. Loss: 0.586 |  Val. Acc: 75.28%
Epoch: 08 | Epoch Time: 0m 1.959s
	Train Loss: 0.002 | Train Acc: 100.00%
	 Val. Loss: 0.617 |  Val. Acc: 76.40%
Epoch: 09 | Epoch Time: 0m 1.966s
	Train Loss: 0.002 | Train Acc: 100.00%
	 Val. Loss: 0.620 |  Val.

KeyboardInterrupt: 