In [None]:
import os
import dask
from dask import delayed
from dask.distributed import Client
from dask_jobqueue import SLURMCluster

# Suppose you have a function that runs GRNBoost for one tumor_type
def run_grnboost_for(tumor_type, adata, tf_names, network_dir):
    # 1) Subset the data
    adata_sub = adata[adata.obs['subtype'] == tumor_type,:].copy()

    # 2) Convert to matrix, etc.
    mat, genes, cells = fetch_adata(adata_sub)
    n_genes = len(genes)
    # Possibly print or log
    print(f"{tumor_type}: {mat.shape=} {n_genes=} ...")

    # 3) Actually run GRNBoost (assuming it can be done in a purely "delayed" manner)
    network = grnboost2(
        expression_data=mat,
        gene_names=genes,
        tf_names=tf_names,
        # We'll specify the cluster once outside
    )
    # 4) Save results
    network_file = os.path.join(network_dir, f"{tumor_type}_network.tsv")
    network.to_csv(network_file, sep='\t', header=False, index=False)

    return f"{tumor_type} done."

# -------------- MAIN SCRIPT --------------
if __name__ == "__main__":

    # Start one Dask cluster for everything
    cluster = SLURMCluster(queue="short", cores=16, processes=1, memory="16GB", walltime="05:00:00",
                           scheduler_options={"dashboard_address":":40748","host":"0.0.0.0"})
    cluster.scale(8)  # or cluster.adapt(...)

    client = Client(cluster)
    print(client)
    print(cluster)

    # We'll schedule each tumor_type in parallel via Dask delayed
    tumor_types = ['ER','HER2','TNBC']
    tasks = []
    for tumor_type in tumor_types:
        task = delayed(run_grnboost_for)(tumor_type, adata, tf_names, network_dir)
        tasks.append(task)

    # "Compute" them in parallel
    results = dask.compute(*tasks)
    # results is a tuple of returns from run_grnboost_for

    print("All done:", results)

    # Optionally close
    client.close()
    cluster.close()
