In [1]:
import torch
from performer_pytorch import PerformerLM
import scanpy as sc
import anndata as ad
from utils import *
import pickle as pkl
from tqdm.notebook import tqdm
import zarr

In [2]:
CLASS = 7
SEED = 2021
EPOCHS = 1
SEQ_LEN = 16906 + 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
data_path = 'data/low_feature.h5ad'
adata = sc.read_h5ad(data_path)
data = adata.X
print(type(data))
data.shape

<class 'scipy.sparse._csr.csr_matrix'>


(36712, 400)

In [7]:
model = PerformerLM(
    num_tokens = CLASS,
    dim = 200,
    depth = 6,
    max_seq_len = SEQ_LEN,
    heads = 10,
    local_attn_heads = 0,
    g2v_position_emb = True
)

path = 'model/scbert_pretrain.pth'
ckpt = torch.load(path, map_location=torch.device('cpu'))
model.load_state_dict(ckpt['model_state_dict'])
for param in model.parameters():
    param.requires_grad = False
model = model.to(device)

In [8]:
def scbert_embed(data):
    batch_size = data.shape[0]
    model.eval()
    batch = []
    epoch = []
    with torch.no_grad():
        for index in tqdm(range(batch_size)):
            full_seq = data[index].toarray()[0]
            full_seq[full_seq > (CLASS - 2)] = CLASS - 2
            full_seq = torch.from_numpy(full_seq).long()
            full_seq = torch.cat((full_seq, torch.tensor([0]))).to(device)
            full_seq = full_seq.unsqueeze(0)
            cell_embedding = model(full_seq, return_encodings = True)
            batch.append(cell_embedding)
            if index % 500 == 0:
                #regularly empty GPU of data
                epoch.extend([b.cpu() for b in batch])
                batch = []
        epoch.extend([b.cpu() for b in batch])
    embeddings = np.stack([t.numpy().astype(np.float16) for t in epoch])
    return embeddings

def get_cuts(N, k):
    return [i * N // k for i in range(1, k)]

In [None]:
number_of_batch = 3
batch_cut = get_cuts(36712, number_of_batch)
for i in range(number_of_batch):
    print(i)
    if i == 0:
        batch_data = data[:batch_cut[i]]
        print(f'batch : {0}-{batch_cut[i]}')
    elif i == len(batch_cut):
        batch_data = data[batch_cut[i-1]:]
        print(f'batch : {batch_cut[i-1]}-{data.shape[0]}')
    else:
        batch_data = data[batch_cut[i-1]:batch_cut[i]]
        print(f'batch : {batch_cut[i-1]}-{batch_cut[i]}')
    embeddings = scbert_embed(batch_data)
    print(embeddings.shape)
    path = 'data/embeddings_' + str(i) + '.npy'
    np.save(path, embeddings)
      

0
batch : 0-12237


  0%|          | 0/12237 [00:00<?, ?it/s]

(12237, 1, 401, 200)
1
batch : 12237-24474


  0%|          | 0/12237 [00:00<?, ?it/s]

(12237, 1, 401, 200)
2
batch : 24474-36712


  0%|          | 0/12238 [00:00<?, ?it/s]