In [None]:
import torch
from performer_pytorch import PerformerLM
import scanpy as sc
import anndata as ad
from utils import *
import pickle as pkl
from tqdm.notebook import tqdm
import numpy as np

In [2]:
CLASS = 7
SEED = 2021
EPOCHS = 1
SEQ_LEN = 16906 + 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
data_path = 'data/adata_preprocessed_4000.h5ad'
adata = sc.read_h5ad(data_path)
data = adata.X
print(type(data))
data.shape

<class 'scipy.sparse._csr.csr_matrix'>


(36763, 16906)

In [None]:
#get position of genes for which we have correspondance in our data
non_zero_columns = (adata.X.toarray() != 0).any(axis=0)
non_zero_columns = np.append(non_zero_columns, True)

In [5]:
model = PerformerLM(
    num_tokens = CLASS,
    dim = 200,
    depth = 6,
    max_seq_len = SEQ_LEN,
    heads = 10,
    local_attn_heads = 0,
    g2v_position_emb = True
)

path = 'model/scbert_pretrain.pth'
ckpt = torch.load(path, map_location=torch.device('cpu'))
model.load_state_dict(ckpt['model_state_dict'])
for param in model.parameters():
    param.requires_grad = False
model = model.to(device)

The boolean parameter 'some' has been replaced with a string parameter 'mode'.
Q, R = torch.qr(A, some)
should be replaced with
Q, R = torch.linalg.qr(A, 'reduced' if some else 'complete') (Triggered internally at /pytorch/aten/src/ATen/native/BatchLinearAlgebra.cpp:2485.)
  q, r = torch.qr(unstructured_block.cpu(), some = True)


In [None]:
def scbert_embed(data):
    batch_size = data.shape[0]
    model.eval()
    batch = []
    epoch = []
    with torch.no_grad():
        for index in tqdm(range(batch_size)):
            full_seq = data[index].toarray()[0]
            full_seq[full_seq > (CLASS - 2)] = CLASS - 2
            full_seq = torch.from_numpy(full_seq).long()
            full_seq = torch.cat((full_seq, torch.tensor([0]))).to(device)
            full_seq = full_seq.unsqueeze(0)
            cell_embedding = model(full_seq, return_encodings = True, output_attentions = False)
            cell_embedding = torch.reshape(cell_embedding, [16907, 200])
            cell_embedding = cell_embedding[non_zero_columns, :]
            batch.append(cell_embedding)
            if index % 500 == 0:
                #regularly empty GPU of data
                epoch.extend([b.cpu() for b in batch])
                batch = []
        epoch.extend([b.cpu() for b in batch])
    embeddings = np.stack([t.numpy().astype(np.float16) for t in epoch])
    return embeddings

def get_cuts(N, k):
    return [i * N // k for i in range(1, k)]

In [8]:
number_of_batch = 10
batch_cut = get_cuts(36763, number_of_batch)
batch_cut

[3676, 7352, 11028, 14705, 18381, 22057, 25734, 29410, 33086]

In [9]:
for i in range(number_of_batch):
    print(i)
    if i == 0:
        batch_data = data[:batch_cut[i]]
        print(f'batch : {0}-{batch_cut[i]}')
    elif i == len(batch_cut):
        batch_data = data[batch_cut[i-1]:]
        print(f'batch : {batch_cut[i-1]}-{data.shape[0]}')
    else:
        batch_data = data[batch_cut[i-1]:batch_cut[i]]
        print(f'batch : {batch_cut[i-1]}-{batch_cut[i]}')
    embeddings = scbert_embed(batch_data)
    print(embeddings.shape)
    path = 'data/embeddings_' + str(i) + '.npy'
    np.save(path, embeddings)
      

0
batch : 0-3676


  0%|          | 0/3676 [00:00<?, ?it/s]

(3676, 2455, 200)
1
batch : 3676-7352


  0%|          | 0/3676 [00:00<?, ?it/s]

(3676, 2455, 200)
2
batch : 7352-11028


  0%|          | 0/3676 [00:00<?, ?it/s]

KeyboardInterrupt: 