In [253]:
import tqdm

import deconV.deconV as dv

import pandas as pd
import numpy as np
import scanpy as sc

import dgl
import torch
import torch.nn as nn

from sklearn.model_selection import train_test_split

In [5]:
adata = sc.read("../../data/GSE136148/adata.h5ad")
adata

with open("../../data/reactome/genes.csv", "r") as f:
    genes = f.read().splitlines()

common_genes = list(set(adata.var.index.tolist()) & set(genes))
missing = list(set(genes) - set(adata.var.index.tolist()))
adata = adata[:, common_genes].copy()
adata


AnnData object with n_obs × n_vars = 2230 × 8424
    obs: 'cell_type', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt'
    var: 'n_cells', 'mt', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts'
    uns: 'log1p'
    layers: 'centered', 'counts', 'logcentered', 'ncounts'

In [6]:
mapping = dict([(g, i) for i, g in enumerate(genes)])

In [7]:
idxs = [mapping[g] for g in adata.var.index]
n_genes = len(genes)
len(idxs)

8424

In [141]:
def get_signature(adata, groupby, n_genes):
    n_ct = len(adata.obs["cell_type"].cat.categories)

    X = torch.zeros((n_genes, 5, n_ct))
    idxs = [mapping[g] for g in adata.var.index]

    for ct in tqdm.tqdm(adata.obs["cell_type"].cat.categories):
        ct_idx = adata.obs["cell_type"] == ct
        X[idxs, 0, ct] = torch.tensor(adata[ct_idx, :].X.mean(0))
        X[idxs, 1, ct] = torch.tensor(adata[ct_idx, :].layers["centered"].mean(0))
        X[idxs, 2, ct] = torch.tensor(adata[ct_idx, :].layers["counts"].mean(0))
        X[idxs, 3, ct] = torch.tensor(adata[ct_idx, :].layers["logcentered"].mean(0))
        X[idxs, 4, ct] = torch.tensor(adata[ct_idx, :].layers["ncounts"].mean(0))

    return X.reshape(n_genes, -1)

In [142]:
true_df = pd.read_csv("../../data/synthetic100/bulk_proportions.csv", index_col=0)
true_df.drop(columns="n_cells", inplace=True)
true_df


Unnamed: 0,0,1,2
0,0.511972,0.152481,0.335547
1,0.307759,0.523730,0.168511
2,0.305641,0.194240,0.500119
3,0.266643,0.165227,0.568130
4,0.200115,0.453825,0.346060
...,...,...,...
95,0.746312,0.038374,0.215314
96,0.204091,0.036037,0.759872
97,0.017337,0.942209,0.040453
98,0.137069,0.474475,0.388456


In [266]:
bulk = pd.read_csv("../../data/synthetic100/bulk.csv", index_col=0)
bulk = pd.concat([bulk, pd.DataFrame(np.zeros((len(missing), bulk.shape[1])), index=missing, columns=bulk.columns)])
bulk = bulk.loc[genes]
bulk = bulk.T.reset_index(drop=True)
bulk

Unnamed: 0,BANF1,HMGA1,LIG4,PSIP1,XRCC4,XRCC5,XRCC6,PRPS1,PRPS1L1,PRPS2,...,STK3,STK4,THBS3,THBS4,TMED5,VPS26A,VPS29,VPS35,WLS,ZFYVE16
0,4543918.0,16054073.0,163339.0,3664115.0,340191.0,5320793.0,6807409.0,1577103.0,0.0,1336048.0,...,755661.0,1022285.0,136108.0,7104.0,2046038.0,2557282.0,8122577.0,5103096.0,895479.0,700080.0
1,3472092.0,9597301.0,112444.0,1982997.0,249125.0,3775303.0,4142645.0,819419.0,0.0,846817.0,...,560754.0,917381.0,145114.0,8374.0,1232796.0,1619654.0,5765895.0,2894347.0,439918.0,548281.0
2,5096191.0,20248248.0,172953.0,3750117.0,498850.0,5896177.0,7247867.0,2256510.0,0.0,1441552.0,...,890154.0,1269905.0,239443.0,5888.0,2401091.0,2932931.0,9609277.0,5686886.0,919086.0,946693.0
3,3104498.0,14401305.0,129482.0,2340656.0,302953.0,3612230.0,4355471.0,1652735.0,0.0,868127.0,...,511321.0,747909.0,158023.0,6895.0,1586064.0,1894975.0,5918321.0,3472058.0,632998.0,615687.0
4,5802625.0,19882051.0,216582.0,3454254.0,509444.0,6820993.0,6714497.0,2207612.0,0.0,1459334.0,...,998764.0,1674790.0,284143.0,9165.0,2470133.0,2911018.0,9942873.0,5284853.0,785499.0,1133948.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,1095115.0,3723335.0,27197.0,960064.0,94932.0,1309083.0,1860808.0,326180.0,0.0,349419.0,...,180937.0,219071.0,23031.0,1462.0,551314.0,640520.0,2071134.0,1263977.0,247195.0,142283.0
96,3661369.0,21795983.0,186590.0,3379433.0,412840.0,5220792.0,6076406.0,3040429.0,0.0,1059917.0,...,713126.0,1007321.0,217381.0,12762.0,2315906.0,2545502.0,8177620.0,4949774.0,1022370.0,843320.0
97,2853413.0,5323010.0,59009.0,866883.0,184975.0,2624184.0,2008274.0,470810.0,0.0,627881.0,...,422993.0,803561.0,78450.0,6367.0,615545.0,1038002.0,3842976.0,1537567.0,172398.0,454812.0
98,8083310.0,24360945.0,250666.0,4018386.0,637560.0,8744111.0,7866163.0,2766422.0,0.0,1748087.0,...,1194782.0,2108716.0,336464.0,22263.0,2891537.0,3686242.0,12767782.0,6195397.0,964434.0,1447131.0


In [271]:
X_train, X_val = train_test_split(bulk, test_size=0.2)
Y_train = true_df.loc[X_train.index]
Y_val = true_df.loc[X_val.index]

In [285]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, X, Y) -> None:
        super().__init__()
        self.x = torch.tensor(X.values, dtype=torch.float32)
        self.y = torch.tensor(Y.values, dtype=torch.float32)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        # X = torch.cat((signature.T, self.bulk[idx].reshape(1, -1)), dim=0).T
        return self.x[idx], self.y[idx].reshape(-1)


In [278]:
signature = get_signature(adata, "cell_type", len(genes))

G = dgl.data.utils.load_graphs("../../data/reactome/graph.bin")[0][0]
G.ndata["signature"] = signature
G.ndata["bulk"] = torch.zeros((len(genes), 1))

100%|██████████| 3/3 [00:00<00:00, 15.09it/s]


In [286]:
train_dataset = CustomDataset(X_train, Y_train)
val_dataset = CustomDataset(X_val, Y_val)
train_loader  = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader  = torch.utils.data.DataLoader(val_dataset, batch_size=32, shuffle=True)

In [288]:
X, y = next(iter(val_loader))
X.shape, y.shape

(torch.Size([20, 11457]), torch.Size([20, 3]))

In [299]:
X[0, :].shape

torch.Size([11457])

In [300]:
model(X[0, :])

tensor([[1.2533e+10, 0.0000e+00, 0.0000e+00]], grad_fn=<SegmentReduceBackward>)

In [236]:
torch.cat([G.ndata["signature"], G.ndata["bulk"].reshape(-1, 1)], dim=1)

tensor([[1.0094, 0.7759, 0.4785,  ..., 1.5392, 0.9998, 0.0000],
        [1.9458, 1.1331, 2.2040,  ..., 2.6209, 9.9814, 0.0000],
        [0.0491, 0.0286, 0.0330,  ..., 0.0404, 0.0556, 0.0000],
        ...,
        [1.1726, 0.4728, 0.6989,  ..., 0.7961, 1.5101, 0.0000],
        [0.3004, 0.0455, 0.1233,  ..., 0.0667, 0.2427, 0.0000],
        [0.1687, 0.1870, 0.2301,  ..., 0.2806, 0.4266, 0.0000]])

In [248]:
class DeCoNNV(nn.Module):
    def __init__(self, g, n_ct):
        super().__init__()
        self.g = g
        self.in_features = g.ndata["signature"].shape[1] + 1
        self.n_ct = n_ct

        self.activation = nn.ReLU()

        self.gcl1 = dgl.nn.GraphConv(in_feats=self.in_features, out_feats=64)
        self.gcl2 = dgl.nn.GraphConv(in_feats=64, out_feats=64)
        self.gcl3 = dgl.nn.GraphConv(in_feats=64, out_feats=self.n_ct)

    def forward(self, x):
        x = torch.cat([self.g.ndata["signature"], x.reshape(-1, 1)], dim=1)
        x = self.activation(self.gcl1(self.g, x))
        x = self.activation(self.gcl2(self.g, x))
        x = self.activation(self.gcl3(self.g, x))

        with self.g.local_scope():
            self.g.ndata["h"] = x
            return dgl.sum_nodes(self.g, "h")


In [293]:
model = DeCoNNV(G, n_ct=3)
loss_fn = torch.nn.MSELoss(reduction="sum")
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

L = []

for epoch in tqdm.tqdm(range(100)):
    model.train()
    L = 0.0
    for x, y in train_loader:
        optimizer.zero_grad()

        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        L += loss.item()
        
        loss.backward()
        optimizer.step()    

    

  0%|          | 0/100 [00:00<?, ?it/s]


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 11457 but got size 366624 for tensor number 1 in the list.

In [250]:
model(X)

tensor([[2.4660e+09, 1.3710e+09, 1.5161e+09]], grad_fn=<SegmentReduceBackward>)