In [11]:
import torch
from performer_pytorch import PerformerLM
import scanpy as sc
import anndata as ad
from utils import *
import pickle as pkl
from tqdm.notebook import tqdm

In [12]:
CLASS = 7
SEED = 2021
EPOCHS = 1
SEQ_LEN = 16906 + 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [13]:
data_path = 'data/test.h5ad'
adata = sc.read_h5ad(data_path)
data = adata.layers['log1p_norm']
print(type(data))
data.shape

<class 'scipy.sparse._csr.csr_matrix'>


(2000, 2000)

In [4]:
non_zero_columns = (adata.X.toarray() != 0).any(axis=0)
non_zero_columns = np.append(non_zero_columns, True)

In [14]:
#without gene2vec
model = PerformerLM(
    num_tokens = CLASS,
    dim = 200,
    depth = 6,
    max_seq_len = SEQ_LEN,
    heads = 10,
    local_attn_heads = 0,
    g2v_position_emb = False
)

path = 'model/scbert_pretrain.pth'
ckpt = torch.load(path, map_location='cpu')
model_state_dict = ckpt['model_state_dict']
# Filter out unexpected keys
model_state_dict = {k: v for k, v in model_state_dict.items() 
                    if k in model.state_dict()}
# Now load the filtered state dict
model.load_state_dict(model_state_dict, strict=False)  # strict=False ignores missing keys

for param in model.parameters():
    param.requires_grad = False
model = model.to(device)

In [26]:
def scbert_embed(data):
    batch_size = data.shape[0]
    model.eval()
    batch = []
    epoch = []
    with torch.no_grad():
        for index in tqdm(range(batch_size)):
            full_seq = data[index].toarray()[0]
            full_seq[full_seq > (CLASS - 2)] = CLASS - 2
            full_seq = torch.from_numpy(full_seq).long()
            full_seq = torch.cat((full_seq, torch.tensor([0]))).to(device)
            full_seq = full_seq.unsqueeze(0)
            cell_embedding = model(full_seq, return_encodings = True, output_attentions = False)
            cell_embedding = torch.reshape(cell_embedding, [2001, 200])
            #cell_embedding = cell_embedding[non_zero_columns, :]
            batch.append(cell_embedding)
            if index % 500 == 0:
                #regularly empty GPU of data
                epoch.extend([b.cpu() for b in batch])
                batch = []
        epoch.extend([b.cpu() for b in batch])
    embeddings = np.stack([t.numpy().astype(np.float16) for t in epoch])
    return embeddings

def get_cuts(N, k):
    return [i * N // k for i in range(1, k)]

In [27]:
number_of_batch = 1
batch_cut = get_cuts(2000, number_of_batch)
batch_cut = [2000]

In [28]:
for i in range(number_of_batch):
    print(i)
    if i == 0:
        batch_data = data[:batch_cut[i]]
        print(f'batch : {0}-{batch_cut[i]}')
    elif i == len(batch_cut):
        batch_data = data[batch_cut[i-1]:]
        print(f'batch : {batch_cut[i-1]}-{data.shape[0]}')
    else:
        batch_data = data[batch_cut[i-1]:batch_cut[i]]
        print(f'batch : {batch_cut[i-1]}-{batch_cut[i]}')
    embeddings = scbert_embed(batch_data)
    print(embeddings.shape)
    path = 'data/embeddings_' + str(i) + '.npy'
    np.save(path, embeddings)
      

0
batch : 0-2000


  0%|          | 0/2000 [00:00<?, ?it/s]

(2000, 2001, 200)
