In [1]:
import tracemalloc
import time
import scanpy as sc
import pandas as pd
import scipy

import cupy as cp
import cupyx
import numpy as np
from tqdm import tqdm

import sys

from icecream import ic

from SEACells.core_copy import SEACells

findfont: Font family ['Raleway'] not found. Falling back to DejaVu Sans.
findfont: Font family ['Lato'] not found. Falling back to DejaVu Sans.


In [2]:
def get_data(ad, num_cells, use_gpu, use_sparse, A_init=None, B_init=None, K_init=None, l1_penalty = 0):
    ## User defined parameters

    ## Core parameters
    # number of SEACells
    n_SEACells = num_cells // 75
    build_kernel_on = "X_pca"  # key in ad.obsm to use for computing metacells
    # This would be replaced by 'X_svd' for ATAC data

    ## Additional parameters
    n_waypoint_eigs = (
        10  # Number of eigenvalues to consider when initializing metacells
    )

    model = SEACells(
        ad,
        use_gpu=use_gpu,
        use_sparse=use_sparse,
        build_kernel_on=build_kernel_on,
        n_SEACells=n_SEACells,
        n_waypoint_eigs=n_waypoint_eigs,
        convergence_epsilon=1e-5,
    )
    model.l1_penalty = l1_penalty
    if K_init is None:
        model.construct_kernel_matrix()
    else:
        model.add_precomputed_kernel_matrix(K_init)
    # model.initialize_archetypes()
    # model.initialize()

    if A_init is not None:
        model.A_ = A_init
    if B_init is not None:
        model.B_ = B_init

    start = time.time()
    tracemalloc.start()

    model.fit(min_iter=10, max_iter=150)

    end = time.time()
    tot_time = end - start

    mem = tracemalloc.get_traced_memory()
    tracemalloc.stop()

    assignments = model.get_hard_assignments()

    # Get the final A and B matrices
    A = model.A_
    B = model.B_
    K = model.kernel_matrix

    #   Get the sparsity dataframe
    sparsity = model.sparsity_ratios

    return assignments, tot_time, mem, A, B, K, sparsity

In [3]:
def gpu_versions(ad, num_cells):
    # assignments2, time2, mem2, A_init, B_init, K_init, sparsity = get_data(
    #     ad, num_cells=num_cells, use_gpu=False, use_sparse=True
    # )
    # Clear the GPU memory
    cp.get_default_memory_pool().free_all_blocks() 
    cp.get_default_pinned_memory_pool().free_all_blocks()
    
    assignments4, time4, mem4, A, B, K, sparsity = get_data(
        ad,
        num_cells=num_cells,
        use_gpu=True,
        use_sparse=True
    )
    # If successful, write the time and memory a file "{num_cells}_cells/v4_{timestamp}.txt"
    # Get the timestamp as a number
    timestamp = time.time()

    # Write the time and memory data
    with open(f"results14-files_in_copy/{num_cells}_cells/v4_{timestamp}.txt", "w") as f:
        f.write(f"Time: {time4}\n")
        f.write(f"Memory: {mem4}\n")

    # If assignments is not None, write it to a file
    if assignments4 is not None:
        assignments4.to_csv(
            f"results14-files_in_copy/{num_cells}_cells/assignments_v4_{timestamp}.csv"
        )

    # Write the A and B matrices
    np.save(f"results14-files_in_copy/{num_cells}_cells/A_v4_{timestamp}.npy", A)
    np.save(f"results14-files_in_copy/{num_cells}_cells/B_v4_{timestamp}.npy", B)

    # Write the sparsity dataframe
    sparsity.to_csv(f"results14-files_in_copy/{num_cells}_cells/sparsity_v4_{timestamp}.csv")

    # Clear the GPU memory
    cp.get_default_memory_pool().free_all_blocks()
    cp.get_default_pinned_memory_pool().free_all_blocks()

In [4]:
def get_results(num_cell):
    ad = sc.read("/home/aparna/DATA/aparnakumar/150000_cells/mouse_marioni_150k.h5ad")
    ad = ad[:num_cell]
    for trial in range(3):
        gpu_versions(ad, num_cell)

        print(f"Done with {num_cell} cells, trial {trial + 1}")

In [5]:
num_cells = 50000

In [6]:
get_results(num_cells)

SPARSE AND GPU
TRYING SEACellsGPU
Welcome to SEACells GPU!
build_graph.SEACellGraph completed
Computing kNN graph using scanpy NN ...
Computing radius for adaptive bandwidth kernel...


HBox(children=(FloatProgress(value=0.0, max=50000.0), HTML(value='')))

In [None]:
# check memory usage with cupy 


print(cp.get_default_memory_pool().used_bytes())
print(cp.get_default_memory_pool().total_bytes())
print(cp.get_default_pinned_memory_pool().n_free_blocks())

502522368
11524222464
0


In [None]:
# clear memory 

mempool = cp.get_default_memory_pool()
pinned_mempool = cp.get_default_pinned_memory_pool()

mempool.free_all_blocks() 
pinned_mempool.free_all_blocks()

In [None]:
# check memory usage with cupy 
print(mempool.used_bytes()) 
print(mempool.total_bytes()) 
print(pinned_mempool.n_free_blocks())

502522368
11524222464
0
