In [1]:
import torch
import numpy as np
import pandas as pd

from datasets import Dataset, load_from_disk, concatenate_datasets, DatasetDict

  from .autonotebook import tqdm as notebook_tqdm


In [60]:
# Load the datasets from disk
sft_dataset = load_from_disk("/users/zlyu12/Desktop/c2s-RL/new_sft_dataset_0114")
test_dataset = load_from_disk("/users/zlyu12/Desktop/c2s-RL/new_test_hf_dataset_new")

In [29]:
test_dataset['Question']

['Does the [GABAergic neuron] cell gene expression contain a key gene known to be involved in GABA production, suggesting it is an inhibitory neuron?',
 'Which tubulin genes appear highly ranked in the [GABAergic neuron] cell gene expression, possibly indicating active cytoskeletal dynamics?',
 'Can the [glutamatergic neuron] cell gene expression provide evidence for excitatory neurotransmission based on a key genetic marker?',
 'Which calmodulin-related genes in the [glutamatergic neuron] cell gene expression might be important for calcium-mediated signaling?',
 'Do both the [GABAergic neuron] and [glutamatergic neuron] cell gene expressions share similar mitochondrial transcripts that suggest strong energy requirements?',
 'Which transcription factor from the [GABAergic neuron] cell gene expression might be linked to neuronal differentiation or subtype specification in this inhibitory cell?',
 'Is there evidence of MAPT or MAP1B expression in the [GABAergic neuron] cell gene expressi

In [61]:
def filter_train_sft_data(example):
    # Filter out examples with standalone "yes" or "no"
    if "yes " in example["Answer"].lower() or " yes " in example["Answer"].lower() or "yes," in example["Answer"].lower() or "yes." in example["Answer"].lower() or \
       "no " in example["Answer"].lower() or " no " in example["Answer"].lower() or "no," in example["Answer"].lower() or "no." in example["Answer"].lower():
        return False
    for keyword in example["Keyword"].split(", "):
        if keyword in example["Context"].lower() or keyword in example["Question"].lower():
            return False
    if example["Question"] in test_dataset['Question']:
        return False
    return True


# Apply the filter and create new filtered dataset
filtered_train_sft_dataset = sft_dataset["train_sft"].filter(filter_train_sft_data)
filtered_val_sft_dataset = sft_dataset["val_sft"].filter(filter_train_sft_data)

# Print statistics
print(f"Original dataset size: {len(sft_dataset['train_sft'])}")
print(f"Filtered dataset size: {len(filtered_train_sft_dataset)}")


Filter: 100%|██████████| 6299/6299 [00:00<00:00, 17703.11 examples/s]
Filter: 100%|██████████| 38/38 [00:00<00:00, 2028.35 examples/s]

Original dataset size: 6299
Filtered dataset size: 1036





In [62]:
# Create DatasetDict with train and val splits
filtered_dataset = DatasetDict({
    'train_sft': filtered_train_sft_dataset,
    'val_sft': filtered_val_sft_dataset
})

print(f"Final train dataset size: {len(filtered_dataset['train_sft'])}")
print(f"Final validation dataset size: {len(filtered_dataset['val_sft'])}")

# Save the filtered dataset to disk
filtered_dataset.save_to_disk("/users/zlyu12/Desktop/c2s-RL/filtered_train_dataset_3-5-25")


Final train dataset size: 1036
Final validation dataset size: 35


Saving the dataset (1/1 shards): 100%|██████████| 1036/1036 [00:00<00:00, 22545.69 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 35/35 [00:00<00:00, 3069.99 examples/s]


In [63]:
def filter_test_sft_data(example):
    if "yes " in example["Answer"].lower() or " yes " in example["Answer"].lower() or "yes," in example["Answer"].lower() or "yes." in example["Answer"].lower() or \
       "no " in example["Answer"].lower() or " no " in example["Answer"].lower() or "no," in example["Answer"].lower() or "no." in example["Answer"].lower():
        return False
    for keyword in example["Keyword"].split(", "):
        if keyword in example["Context"].lower() or keyword in example["Question"].lower():
            return False
    return True

filtered_test_dataset = test_dataset.filter(filter_test_sft_data)

print(f"Original test dataset size: {len(test_dataset)}")
print(f"Filtered test dataset size: {len(filtered_test_dataset)}")

filtered_test_dataset.save_to_disk("/users/zlyu12/Desktop/c2s-RL/filtered_test_dataset_3-5-25")


Filter: 100%|██████████| 170/170 [00:00<00:00, 8328.25 examples/s]


Original test dataset size: 170
Filtered test dataset size: 109


Saving the dataset (1/1 shards): 100%|██████████| 109/109 [00:00<00:00, 6381.35 examples/s]


In [44]:
filtered_test_dataset[15]

{'Context': 'Neurodevelopmental processes often involve stress response or chaperone proteins',
 'Summary_Dataset': 'Cell Type: GABAergic neuron, Tissue: diencephalon, Gene Expression: MALAT1 TMSB4X TUBA1A MT-ATP6 MT-CO3 EEF1A1 MT-CO2 RPL13 PTMA RPS8 RPSA RPS24 TPT1 RPS27A H3F3A HSP90AB1 FABP7 RPS10 RPS11 RPLP1 RPL11 RPL23 RPL27A RPL18A RPL8 RPL19 RPS9 ACTG1 RPS20 FAU RPL41 RPS14 FTL1 RPS3A1 RPL15 H3F3B STMN1 MT-CO1 RPS4X RPL28 VIM RPL30 RPL24 RPS23 RPL18 RPL9 RPL6 RPS16 RPS19 RPS13 RPL32 ACTB TUBB2B PPIA RPS3 RPS18 RPLP0 CKB RPS2 RPS15A RPS29 RPS26 RPS5 MT-CYTB MARCKS RPL37 RPL3 RPL37A RPL35A TUBB5 RPLP2 RPL17 RPS7 RPL34 PRDX2 NACA RPL7 RPL14 MT-ND2 RACK1 HSPA8 UBB RPL5 CHCHD2 RPS21 PPP1R14B RPL21 NPM1 HSP90AA1 RPS6 RPS12 TCF7L2 TMSB10 EIF1 SOX4 YBX1 RPL10A RPL38 RPS28 RPL36 H2AZ1 NNAT H2AZ2 HMGB1 GAPDH COX4I1 RPS15 RPL7A PEBP1 IGFBPL1 MT-ND4 RPS25 CFL1 HES6 RPL22 MT-ND1 RPS27 RPL4 RPS17 RAN XIST BASP1 EIF4A1 EIF4G2 HNRNPA2B1 RPL10 SUMO2 FABP5 SOX11 HNRNPA1 DDX5 SRSF3 RPL39 PTN SERBP1

In [46]:
filtered_dataset["train_sft"][15]

{'Context': 'Gene expression patterns in immune cells can offer insights into cellular adaptations in response to tumor signals, potentially informing therapeutic targets or intervention strategies.',
 'Summary_Dataset': 'Cell Type: endothelial cell, Tissue: basal ganglion, Gene Expression: MALAT1 MT-CO3 SPARC MT-CO2 MT-ATP6 MT-ND3 B2M TMSB10 EEF1A1 RPS28 MT-CYB IGFBP7 MT-CO1 A2M TMSB4X COL4A1 RPLP1 TPT1 CLDN5 VWF MT-ND4 PTMA RPL41 MT2A RPS8 NEAT1 RPL13 RPS23 COL4A2 IGFBP3 MT-ND1 APP RPS6 RPL28 IFI27 RPS19 RPL37A RPS27 RPS29 VIM RPS12 RPL10 RPL37 IFITM3 ITGB1 MT-ND2 RPS7 RPL13A RPL34 ACTB HLA-A RPS15A RPS2 ACTG1 RPLP2 RPL3 RPS14 SERF2 ITM2B RPL15 RPS24 RPL27A RPS4X ITM2A RPL26 SAT1 RPS27A RPL32 HLA-E RPL39 RPL19 RPS15 RPL12 PECAM1 HSPG2 RPL36 FLT1 MT-ND5 RPS16 CD59 RPL18A RPL21 RPS3A FTL HSP90AA1 HSP90B1 RPL11 RPS11 ATP5F1E TCF4 RPL30 RPL35A PPIA RPL23 ADAMTS1 CALM1 GNAS DDX5 RPL17 RPL24\nCell Type: microglial cell, Tissue: occipital lobe, Gene Expression: MALAT1 PLXDC2 NEAT1 DOCK4 LRM