In [None]:
import random
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import anndata
import time
from tqdm import trange
import sys
sys.path.append("../method/fastscbatch/")
import fast_scBatch

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
device = torch.device("cpu")

In [None]:
class Network(nn.Module):
	def __init__(self) -> None:
		super(Network, self).__init__()
		self.alpha = nn.Parameter(torch.eye(n))
		self.beta = nn.Parameter(torch.zeros(p, n))
	def forward(self, X):
		return torch.matmul(X, self.alpha) + self.beta
def loss(Y, std):
	corr = torch.corrcoef(Y.T)
	return torch.norm(corr - std, p='fro')
n_set = [240, 360, 600, 840, 960,
		1440, 2160, 3000, 4020, 5100,
		6000, 7020]

In [None]:
time1 = []
for i in trange(1, 13):
	cell = anndata.read_h5ad(f"./sample/sample{i}.h5ad")
	batch = cell.obs[["Batch"]].copy()
	cells = cell.to_df().T
	corr = pd.read_csv(f"../method/fastqn/sample/sample{i}.csv", index_col=0)
	time_s = time.time()
	corr.columns = cells.columns
	corr.index = cells.columns
	p, n = cells.shape
	cells = cells.values
	corr = corr.values
	X = torch.from_numpy(cells).float().to(device)
	D = torch.from_numpy(corr).float().to(device)

	model = Network().to(device)
	optimizer = torch.optim.Adam(model.parameters(), lr=0.0002)
	EPOCHS = 300
	losses = []
	for epoch in range(EPOCHS):
		optimizer.zero_grad()
		Y = model(X)
		loss_val = loss(Y, D)
		losses.append(loss_val.item())
		loss_val.backward()
		optimizer.step()
	model.eval()
	Y = model(X)
	adata = anndata.AnnData(X=Y.cpu().detach().numpy().T, obs=cell.obs, var=cell.var)
	time_t = time.time()
	time1.append(time_t - time_s)

In [None]:
time2 = []
for i in trange(1, 13):
	cell = anndata.read(f"./sample/sample{i}.h5ad")
	batch = cell.obs[["Batch"]].copy()
	cells = cell.to_df().T
	corr = pd.read_csv(f"../method/fastqn/sample/sample{i}.csv", index_col=0)
	time_s = time.time()
	corr.columns = cells.columns
	corr.index = cells.columns
	p, n = cells.shape
	res = fast_scBatch.solver(cells, corr, batch, p=0.3, k=20,
		lr=(0.0002, 0.0001, 0.0001), EPOCHS=(0, 0, 300), verbose=False)
	adata = anndata.AnnData(X=res.T, obs=cell.obs, var=cell.var)
	time_t = time.time()
	time2.append(time_t - time_s)

In [None]:
plt.figure()
plt.plot(n_set, time1, label="scBatch")
plt.plot(n_set, time2, label="fast-scBatch")
plt.xlabel("Number of cells")
plt.legend()
ax = plt.gca()
ax.set_ylim(0, 300)
plt.savefig("simu_efficiency.png")