In [1]:
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import scvi
import hdf5plugin
import os

In [2]:
def filter_doublets(directory="./data/3_filtered_h5ads"):
    # List all files in the specified directory
    for file in os.listdir(directory):
        if file.endswith(".h5ad"):
            # Construct the full file path
            file_path = os.path.join(directory, file)
            # Print the file name
            print(f"Processing file: {file_path}")

            adata = sc.read_h5ad(file_path)
            sc.pp.highly_variable_genes(adata, n_top_genes=3000, subset=True, flavor="seurat", span=0.8)
            scvi.model.SCVI.setup_anndata(adata)
            vae = scvi.model.SCVI(adata)
            vae.train()
            solo = scvi.external.SOLO.from_scvi_model(vae)
            solo.train()
            df = solo.predict()
            df["prediction"] = solo.predict(soft=False)

            doublet_dic = dict(zip(df.index, df.prediction))

            def filter_doublet(x):
                try:
                    return doublet_dic[x]
                except:
                    return 'filtered'

            adata = sc.read_h5ad(file_path)
            adata.obs["doublet"] = adata.obs.index.map(filter_doublet)
            adata = adata[adata.obs.doublet == 'singlet']
            adata.write_h5ad(
                        f"./data/4_fwd_h5ads/fwd_{file}",
                        compression=hdf5plugin.FILTERS["zstd"]
                    )

In [3]:
filter_doublets()

Processing file: ./data/3_filtered_h5ads/filtered_luad_gse131907.h5ad


  self.validate_field(adata)
  _verify_and_correct_data_format(adata, self.attr_name, self.attr_key)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/kostas/miniconda3/envs/scrna/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 1/65:   0%|                                                                                          | 0/65 [00:00<?, ?it/s]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 2/65:   2%|▍                             | 1/65 [00:54<57:55, 54.30s/it, v_num=1, train_loss_step=431, train_loss_epoch=481]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 3/65:   3%|▉                             | 2/65 [01:48<56:40, 53.98s/it, v_num=1, train_loss_step=483, train_loss_epoch=434]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 4/65:   5%|█▍                            | 3/65 [02:40<55:08, 53.36s/it, v_num=1, train_loss_step=437, train_loss_epoch=422]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 5/65:   6%|█▊                            | 4/65 [03:34<54:15, 53.37s/it, v_num=1, train_loss_step=428, train_loss_epoch=415]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 6/65:   8%|██▎                           | 5/65 [04:27<53:12, 53.22s/it, v_num=1, train_loss_step=436, train_loss_epoch=410]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 7/65:   9%|██▊                           | 6/65 [05:20<52:28, 53.37s/it, v_num=1, train_loss_step=413, train_loss_epoch=407]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 8/65:  11%|███▏                          | 7/65 [06:13<51:23, 53.17s/it, v_num=1, train_loss_step=397, train_loss_epoch=405]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 9/65:  12%|███▋                          | 8/65 [07:06<50:31, 53.18s/it, v_num=1, train_loss_step=402, train_loss_epoch=404]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 10/65:  14%|████                         | 9/65 [08:07<51:53, 55.61s/it, v_num=1, train_loss_step=416, train_loss_epoch=402]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 11/65:  15%|████▎                       | 10/65 [09:07<52:17, 57.05s/it, v_num=1, train_loss_step=417, train_loss_epoch=401]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 12/65:  17%|████▋                       | 11/65 [10:00<50:09, 55.73s/it, v_num=1, train_loss_step=392, train_loss_epoch=401]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 13/65:  18%|█████▏                      | 12/65 [10:54<48:39, 55.09s/it, v_num=1, train_loss_step=407, train_loss_epoch=400]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 14/65:  20%|█████▌                      | 13/65 [11:46<47:00, 54.23s/it, v_num=1, train_loss_step=421, train_loss_epoch=400]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 15/65:  22%|██████                      | 14/65 [12:39<45:39, 53.71s/it, v_num=1, train_loss_step=433, train_loss_epoch=399]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 16/65:  23%|██████▍                     | 15/65 [13:31<44:34, 53.49s/it, v_num=1, train_loss_step=425, train_loss_epoch=399]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 17/65:  25%|██████▉                     | 16/65 [14:24<43:20, 53.08s/it, v_num=1, train_loss_step=393, train_loss_epoch=399]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 18/65:  26%|███████▎                    | 17/65 [15:15<42:06, 52.64s/it, v_num=1, train_loss_step=439, train_loss_epoch=398]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 19/65:  28%|███████▊                    | 18/65 [16:08<41:10, 52.57s/it, v_num=1, train_loss_step=398, train_loss_epoch=398]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 20/65:  29%|████████▏                   | 19/65 [17:00<40:14, 52.49s/it, v_num=1, train_loss_step=389, train_loss_epoch=398]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 21/65:  31%|████████▌                   | 20/65 [17:53<39:31, 52.69s/it, v_num=1, train_loss_step=409, train_loss_epoch=398]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 22/65:  32%|█████████                   | 21/65 [18:47<38:50, 52.96s/it, v_num=1, train_loss_step=386, train_loss_epoch=398]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 23/65:  34%|█████████▍                  | 22/65 [19:39<37:51, 52.83s/it, v_num=1, train_loss_step=416, train_loss_epoch=398]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 24/65:  35%|█████████▉                  | 23/65 [20:31<36:50, 52.63s/it, v_num=1, train_loss_step=409, train_loss_epoch=398]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 25/65:  37%|██████████▎                 | 24/65 [21:26<36:24, 53.27s/it, v_num=1, train_loss_step=389, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 26/65:  38%|██████████▊                 | 25/65 [22:18<35:16, 52.92s/it, v_num=1, train_loss_step=429, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 27/65:  40%|███████████▏                | 26/65 [23:11<34:21, 52.86s/it, v_num=1, train_loss_step=383, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 28/65:  42%|███████████▋                | 27/65 [24:04<33:34, 53.03s/it, v_num=1, train_loss_step=398, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 29/65:  43%|████████████                | 28/65 [24:58<32:53, 53.35s/it, v_num=1, train_loss_step=380, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 30/65:  45%|████████████▍               | 29/65 [25:51<31:54, 53.17s/it, v_num=1, train_loss_step=392, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 31/65:  46%|████████████▉               | 30/65 [26:45<31:09, 53.41s/it, v_num=1, train_loss_step=378, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 32/65:  48%|█████████████▎              | 31/65 [27:39<30:17, 53.47s/it, v_num=1, train_loss_step=419, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 33/65:  49%|█████████████▊              | 32/65 [28:32<29:17, 53.26s/it, v_num=1, train_loss_step=401, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 34/65:  51%|██████████████▏             | 33/65 [29:25<28:30, 53.45s/it, v_num=1, train_loss_step=409, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 35/65:  52%|██████████████▋             | 34/65 [30:19<27:34, 53.36s/it, v_num=1, train_loss_step=418, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 36/65:  54%|███████████████             | 35/65 [31:11<26:34, 53.16s/it, v_num=1, train_loss_step=416, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 37/65:  55%|███████████████▌            | 36/65 [32:03<25:31, 52.79s/it, v_num=1, train_loss_step=394, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 38/65:  57%|███████████████▉            | 37/65 [32:56<24:38, 52.80s/it, v_num=1, train_loss_step=395, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 39/65:  58%|████████████████▎           | 38/65 [33:48<23:38, 52.55s/it, v_num=1, train_loss_step=406, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 40/65:  60%|████████████████▊           | 39/65 [34:40<22:40, 52.31s/it, v_num=1, train_loss_step=405, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 41/65:  62%|█████████████████▏          | 40/65 [35:31<21:40, 52.03s/it, v_num=1, train_loss_step=408, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 42/65:  63%|█████████████████▋          | 41/65 [36:24<20:52, 52.20s/it, v_num=1, train_loss_step=390, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 43/65:  65%|██████████████████          | 42/65 [37:16<20:02, 52.29s/it, v_num=1, train_loss_step=415, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 44/65:  66%|██████████████████▌         | 43/65 [38:10<19:22, 52.82s/it, v_num=1, train_loss_step=385, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 45/65:  68%|██████████████████▉         | 44/65 [39:04<18:31, 52.95s/it, v_num=1, train_loss_step=401, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 46/65:  69%|███████████████████▍        | 45/65 [39:56<17:35, 52.76s/it, v_num=1, train_loss_step=384, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 47/65:  71%|███████████████████▊        | 46/65 [40:48<16:41, 52.71s/it, v_num=1, train_loss_step=369, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 48/65:  72%|████████████████████▏       | 47/65 [41:41<15:45, 52.55s/it, v_num=1, train_loss_step=380, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 49/65:  74%|████████████████████▋       | 48/65 [42:32<14:48, 52.27s/it, v_num=1, train_loss_step=391, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 50/65:  75%|█████████████████████       | 49/65 [43:26<14:05, 52.86s/it, v_num=1, train_loss_step=404, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 51/65:  77%|█████████████████████▌      | 50/65 [44:20<13:14, 52.97s/it, v_num=1, train_loss_step=387, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 52/65:  78%|█████████████████████▉      | 51/65 [45:12<12:17, 52.66s/it, v_num=1, train_loss_step=424, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 53/65:  80%|██████████████████████▍     | 52/65 [46:03<11:20, 52.32s/it, v_num=1, train_loss_step=391, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 54/65:  82%|██████████████████████▊     | 53/65 [46:59<10:39, 53.30s/it, v_num=1, train_loss_step=396, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 55/65:  83%|███████████████████████▎    | 54/65 [47:56<09:59, 54.54s/it, v_num=1, train_loss_step=378, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 56/65:  85%|███████████████████████▋    | 55/65 [48:50<09:03, 54.37s/it, v_num=1, train_loss_step=414, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 57/65:  86%|████████████████████████    | 56/65 [49:43<08:04, 53.89s/it, v_num=1, train_loss_step=406, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 58/65:  88%|████████████████████████▌   | 57/65 [50:36<07:09, 53.69s/it, v_num=1, train_loss_step=401, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)


Epoch 59/65:  89%|████████████████████████▉   | 58/65 [51:29<06:14, 53.46s/it, v_num=1, train_loss_step=422, train_loss_epoch=397]

  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
  reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1)
/home/kostas/miniconda3/envs/scrna/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


[34mINFO    [0m Creating doublets, preparing SOLO model.                                                                  


  self.validate_field(adata)
  _verify_and_correct_data_format(adata, self.attr_name, self.attr_key)


KeyboardInterrupt: 