In [79]:
import torch 
import anndata as ad
import pandas as pd
import numpy as np
import tqdm
from sklearn.metrics import r2_score

seed = 16
use_gpu = True

np.random.seed(seed)
torch.manual_seed(seed)

# df = pd.read_csv('df.csv', index_col=0)
# data = df.values
adata = ad.read_h5ad('../output/preprocess/bulk_adata_f.h5ad')

In [81]:
data = adata.layers['X_norm_pearson']

In [None]:
grn_net_df = pd.read_csv("https://github.com/pablormier/omnipath-static/raw/main/op/collectri-26.09.2023.zip")


In [2]:
# adata = ad.read_h5ad('..//output/preprocess/bulk_adata_f.h5ad')

# adata = adata[:, adata.var_names.isin(df_collectri['target'].unique())]
# gene_names = adata.var_names
# adata.X = adata.layers['X_norm_pearson']
# df = pd.DataFrame(adata.X, index=pd.MultiIndex.from_frame(adata.obs[['cell_type', 'sm_name', 'plate_name']]), columns=adata.var_names)
# data = df.values

gene_names = df.columns
gene_name_dict = {gene_name: i for i, gene_name in enumerate(gene_names)}

# Iterate over the edges (regulatory relationships)
edge_idx = set()
for gene_a, gene_b in zip(df_collectri['source'], df_collectri['target']):
    if (gene_a not in gene_name_dict) or (gene_b not in gene_name_dict):
        continue  # Consider only gene names that are present in the training data
    i = gene_name_dict[gene_a]  # Index of first gene
    j = gene_name_dict[gene_b]  # Index of second gene
    edge_idx.add((i, j))
edge_idx = np.asarray(list(edge_idx), dtype=int)

# Convert list of edges into an adjacency matrix
grn = np.zeros((len(gene_names), len(gene_names)))
grn[edge_idx[:, 0], edge_idx[:, 1]] = 1
# # Remove rows and columns with no annotation
grn_idx = np.where(np.logical_or(grn.sum(axis=0) > 0, grn.sum(axis=1) > 0))[0]
grn = grn[grn_idx, :][:, grn_idx]
n_genes = len(grn_idx)

In [3]:
grn_net = np.zeros((len(gene_names), len(gene_names)))
for source, target, weight in zip(df_collectri['source'], df_collectri['target'], df_collectri['weight']):
    if (source not in gene_name_dict) or (target not in gene_name_dict):
        continue  # Consider only gene names that are present in the training data
    i = gene_name_dict[source]  # Index of first gene
    j = gene_name_dict[target]  # Index of second gene
    grn_net[i,j] = weight
# # Remove rows and columns with no annotation
grn_idx = np.where(np.logical_or(grn_net.sum(axis=0) > 0, grn_net.sum(axis=1) > 0))[0]
grn_net = grn_net[grn_idx, :][:, grn_idx]
n_genes = len(grn_idx)

In [4]:
torch.det(torch.tensor(grn_net)+torch.eye(grn_net.shape[0])*.5)
# grn_net[grn_net!=0]

tensor(0., dtype=torch.float64)

In [70]:
import torch.nn as nn 
from sklearn.metrics import roc_auc_score

class NN(torch.nn.Module):
    def __init__(self, n_genes:int, n_nodes_latent:int=10000):
        torch.nn.Module.__init__(self)
        self.n_genes = n_genes
        # self.A  = nn.Parameter(.5*torch.eye(n_genes))
        self.A = torch.tensor(grn_net, dtype=torch.float32, device='cuda', requires_grad=False)
        
        self.encoder = nn.Sequential(
            nn.Linear(n_genes, 120),
            nn.LeakyReLU(.2),
            nn.Linear(120, 120),
            nn.LeakyReLU(.2),
            nn.Linear(120, n_genes))

        self.mu_scaler = nn.Linear(n_genes, n_genes)
        self.logvar_scaler = nn.Linear(n_genes, n_genes)

        self.decoder = nn.Sequential(
            nn.Linear(n_genes, 120),
            nn.LeakyReLU(.2),
            nn.Linear(120, 120),
            nn.LeakyReLU(.2),
            nn.Linear(120, n_genes))
        

    def reparametrize(self, mu, log_var):
        std = torch.sqrt(torch.exp(log_var))
        eps = torch.randn_like(log_var)
        # return mu + std*eps
        return mu + log_var*eps

    def forward(self, x: torch.Tensor):
        x = self.encoder(x)
        mu = self.mu_scaler(x)
        log_var = self.logvar_scaler(x)

        # A = torch.eye(self.n_genes).to(device='cuda') - self.A.t()
        A = self.A
        mu = torch.matmul(mu, A)
        # print(mu[0])
        log_var = torch.matmul(log_var, A)

        z = self.reparametrize(mu, log_var)
        # A_inv = torch.linalg.inv(A)
        # z = torch.matmul(z, A_inv)

        x = self.decoder(z)
        # print(x[0])
        return x, mu, log_var

In [71]:
#TODO: dataloader = DataLoader(train_data, batch_size=self.opt.batch_size, shuffle=True, num_workers=1)
#TODO: how to account for batch effects        
train_data = data
batch_size = 10
n_epoch = 200
# torch data
train_data = torch.FloatTensor(train_data)
if use_gpu:
    train_data = train_data.cuda()

model = NN(n_genes=n_genes)
if use_gpu:
    model = model.cuda()
#TODO: loss function. cross entropy loss? other types  
criterion = lambda Y_true, Y_pred: torch.sum(torch.square(Y_true-Y_pred))
# optimizer
# optimizer = torch.optim.SGD(model.parameters(), lr=.00001, momentum=.9)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, eps=1e-7)

# scheduler = torch.lr_scheduler.StepLR(optimizer, step_size=10, gamma=.1) 
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, 
            patience=5, threshold_mode='rel', threshold=0.0001, cooldown=5, min_lr=1e-10, eps=1e-8)

pbar = tqdm.tqdm(range(n_epoch))
for i_epoch in pbar:
    # shuffle the data
    train_data = train_data[torch.randperm(train_data.size(dim=0))]
    # train for a epoch
    model.train()
    rel_loss_store = []
    Y_pred_stack = []
    Y_true_stack = []

    for X in np.array_split(train_data, batch_size, axis=0):
        X = X[:, grn_idx]

        optimizer.zero_grad()

        Y_true = X #TODO: add noise to the data
        Y_pred, mu, log_var = model.forward(X)

        Y_true_stack.append(Y_true.cpu().detach().numpy())
        Y_pred_stack.append(Y_pred.cpu().detach().numpy())
        # calculate loss and backpropagate 
        loss_x = criterion(Y_true, Y_pred)
        loss_KL =  - 0.5 * torch.sum(1.0 + log_var - mu.pow(2) - log_var.exp())

        beta = 1
        loss = loss_x + beta*loss_KL
        loss.backward()
        optimizer.step()
        # print(loss_x)

        # baseline pred
        Y_pred_mean = torch.mean(Y_true, axis=0)
        loss_baseline = criterion(Y_true, Y_pred_mean)
        rel_loss = loss_x/loss_baseline
        rel_loss_store.append(rel_loss.item())
    if i_epoch%10==0:
        # AUROC
        mask = ~np.eye(n_genes, dtype=bool)
        grn_pred = np.abs(model.A.cpu().data.numpy().T)
        print('AUROC', roc_auc_score(np.abs(grn_net[mask]), grn_pred[mask]))
        
    mean_rel_loss = np.mean(rel_loss_store)
    scheduler.step(mean_rel_loss)

    y_pred = np.concatenate(Y_pred_stack, axis=0)
    y_true = np.concatenate(Y_true_stack, axis=0)

    r2 = r2_score(y_true, y_pred, multioutput='variance_weighted')
    
    pbar.set_description(f'Rel loss: {mean_rel_loss:.3f}, R2:{r2:.3f}')


Rel loss: 1.162, R2:-0.156:   0%|          | 1/200 [00:04<15:04,  4.55s/it]

AUROC 0.524682849175731


Rel loss: 0.717, R2:0.287:   6%|▌         | 11/200 [00:12<05:17,  1.68s/it]

AUROC 0.524682849175731


Rel loss: 0.466, R2:0.536:  10%|█         | 21/200 [00:20<04:58,  1.67s/it]

AUROC 0.524682849175731


Rel loss: 0.392, R2:0.610:  16%|█▌        | 31/200 [00:29<04:54,  1.74s/it]

AUROC 0.524682849175731


Rel loss: 0.377, R2:0.624:  20%|██        | 41/200 [00:37<04:31,  1.71s/it]

AUROC 0.524682849175731


Rel loss: 0.363, R2:0.639:  26%|██▌       | 51/200 [00:47<04:36,  1.86s/it]

AUROC 0.524682849175731


Rel loss: 0.371, R2:0.630:  30%|███       | 61/200 [00:55<03:58,  1.72s/it]

AUROC 0.524682849175731


Rel loss: 0.403, R2:0.599:  36%|███▌      | 71/200 [01:04<03:44,  1.74s/it]

AUROC 0.524682849175731


Rel loss: 0.403, R2:0.599:  40%|████      | 80/200 [01:10<01:46,  1.13it/s]


KeyboardInterrupt: 

In [75]:
model.A
print('AUROC', roc_auc_score(np.abs(model.A.cpu().data.numpy().T[mask]), model.A.cpu().data.numpy()[mask]))

AUROC 0.5146390800748502


In [30]:
grn_net.shape

(3494, 3494)

In [25]:
model.A

Parameter containing:
tensor([[ 4.9868e-01, -7.2643e-04, -4.8510e-03,  ...,  2.6554e-03,
         -1.0545e-03, -8.5110e-03],
        [ 2.2719e-03,  5.1852e-01,  4.2199e-03,  ...,  2.1642e-03,
         -1.4096e-04, -1.9116e-03],
        [-2.2845e-04,  1.4502e-03,  5.1587e-01,  ..., -2.0177e-03,
         -6.1622e-04, -9.3648e-04],
        ...,
        [-1.7440e-03, -2.0144e-03,  1.7241e-03,  ...,  5.0684e-01,
         -5.1308e-04, -5.8981e-03],
        [ 1.2055e-03, -5.6728e-03, -1.6025e-03,  ..., -3.2979e-04,
          5.2474e-01, -9.3714e-04],
        [-2.4000e-03, -1.1062e-03, -2.2459e-03,  ..., -1.1234e-03,
          1.1746e-03,  5.1869e-01]], device='cuda:0', requires_grad=True)