In [1]:
import pyspark.sql.functions as f
from clinical_mining.utils.spark import SparkSession

spark = SparkSession()


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/07/07 09:48:34 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/07/07 09:48:34 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [2]:
df = spark.session.read.parquet("/Users/irenelopez/EBI/repos/clinical_mining/data/outputs/clinical_trials/20250703")


In [8]:
df.filter(f.col("nct_id") == 'NCT06710847').show(truncate=False)

+-------+-----------+--------------------------+----------------------------+-----------+--------------+-------------+---------------+-----------+--------------+---------+----+--------------+----+------+----+--------+
|drug_id|disease_id |disease_name              |drug_name                   |nct_id     |overall_status|phase        |completion_date|why_stopped|number_of_arms|purpose  |pmid|reference_type|id  |source|url |approval|
+-------+-----------+--------------------------+----------------------------+-----------+--------------+-------------+---------------+-----------+--------------+---------+----+--------------+----+------+----+--------+
|NULL   |EFO_0004230|endometrial neoplasms     |immune checkpoint inhibitors|NCT06710847|RECRUITING    |PHASE1/PHASE2|2028-06-16     |NULL       |3             |TREATMENT|NULL|NULL          |NULL|NULL  |NULL|NULL    |
|NULL   |NULL       |microsatellite instability|immune checkpoint inhibitors|NCT06710847|RECRUITING    |PHASE1/PHASE2|2028-06-16

## Checks

In [3]:
df.groupBy("source").count().show()

                                                                                

+--------------+-------+
|        source|  count|
+--------------+-------+
|          USAN|   1675|
|           EMA|   1240|
|           ATC|   3298|
|           INN|    434|
|      DailyMed|  44349|
|           FDA|    860|
|ClinicalTrials| 177024|
|          NULL|1868896|
+--------------+-------+



In [19]:
df.select("drug_id", "disease_id", f.size("approval").alias("approvals")).distinct().groupBy("approvals").count().show()

+---------+------+
|approvals| count|
+---------+------+
|       -1|130041|
|        1|  6017|
|        3|    76|
|        2|   688|
+---------+------+



In [18]:
df.printSchema()

root
 |-- drug_id: string (nullable = true)
 |-- disease_id: string (nullable = true)
 |-- disease_name: string (nullable = true)
 |-- drug_name: string (nullable = true)
 |-- nct_id: string (nullable = true)
 |-- overall_status: string (nullable = true)
 |-- phase: string (nullable = true)
 |-- completion_date: date (nullable = true)
 |-- why_stopped: string (nullable = true)
 |-- number_of_arms: integer (nullable = true)
 |-- purpose: string (nullable = true)
 |-- pmid: string (nullable = true)
 |-- reference_type: string (nullable = true)
 |-- id: string (nullable = true)
 |-- source: string (nullable = true)
 |-- url: string (nullable = true)
 |-- approval: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- source: string (nullable = true)
 |    |    |-- date: date (nullable = true)



In [22]:
df.filter(f.col("source").isNull()).show(1, False, True)

-RECORD 0------------------------------------------
 drug_id         | NULL                            
 disease_id      | EFO_0006859                     
 disease_name    | head and neck cancer            
 drug_name       | cobalt                          
 nct_id          | NCT00002735                     
 overall_status  | TERMINATED                      
 phase           | PHASE2                          
 completion_date | 2005-03-31                      
 why_stopped     | Terminated due to poor accrual. 
 number_of_arms  | 1                               
 purpose         | TREATMENT                       
 pmid            | 15625363                        
 reference_type  | RESULT                          
 id              | NULL                            
 source          | NULL                            
 url             | NULL                            
 approval        | NULL                            
only showing top 1 row



In [17]:
df.select("drug_id", "disease_id", "approval").distinct().groupBy(f.col("approval")).count().orderBy(f.col("count").desc()).show(truncate=False)

+--------------------------------------------+------+
|approval                                    |count |
+--------------------------------------------+------+
|NULL                                        |130041|
|[{DailyMed, NULL}]                          |4998  |
|[{EMA, NULL}]                               |631   |
|[{FDA, NULL}]                               |388   |
|[{EMA, NULL}, {DailyMed, NULL}]             |362   |
|[{DailyMed, NULL}, {FDA, NULL}]             |229   |
|[{EMA, NULL}, {FDA, NULL}]                  |97    |
|[{EMA, NULL}, {DailyMed, NULL}, {FDA, NULL}]|76    |
+--------------------------------------------+------+



In [None]:
# No pairs in the non approved/approved pile
(
    df.filter(f.size(f.col("approval")) > 0).select("drug_id", "disease_id").distinct()
    .join(
        df.filter(f.col("approval").isNull()).select("drug_id", "disease_id").distinct(),
        on=["drug_id", "disease_id"],
    )
    .show()
)

                                                                                

+-------+----------+
|drug_id|disease_id|
+-------+----------+
+-------+----------+



### Comparison

In [None]:
# Are all CTs from ChEMBL in my set? NO, 5406 are missing

df.filter(f.col("source") == "ClinicalTrials").select("nct_id").distinct().join(
    df.filter(f.col("source").isNull()).select("nct_id").distinct(), "nct_id", "left_anti"
).count()

5406

In [10]:
df.filter(f.col("source") == "ClinicalTrials").select("nct_id").distinct().count()

86927

In [12]:
evd = spark.session.read.parquet("/Users/irenelopez/EBI/repos/clinical_mining/data/inputs/sourceId=chembl").filter(f.col("studyId").isNotNull())

evd.select("studyId").distinct().count()

76925

In [14]:
df.filter(f.col("source") == "ClinicalTrials").select("nct_id").distinct().unionByName(evd.selectExpr("studyId as nct_id").distinct(), allowMissingColumns=True).distinct().count()

94485

In [15]:
me = df.filter(f.col("source").isNull())
chembl = df.filter(f.col("source") == "ClinicalTrials")

In [21]:
me.join(chembl, on=["disease_id", "drug_id"]).select("disease_id", "drug_id").distinct().count()

                                                                                

20632

25/07/07 13:36:32 WARN HeartbeatReceiver: Removing executor driver with no recent heartbeats: 169061 ms exceeds timeout 120000 ms
25/07/07 13:36:32 WARN SparkContext: Killing executors is not supported by current scheduler.
25/07/07 13:36:35 WARN Executor: Issue communicating with driver in heartbeater
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:56)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:310)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:101)
	at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:85)
	at org.apache.spark.storage.BlockManagerMaster.registerBlockManager(BlockManagerMaster.scala:80)
	at org.apache.spark.storage.BlockManager.reregister(BlockManager.scala:642)
	at org.apache.spark.executor.Executor.reportHeartBeat(Executor.scala:1223)
	at o

In [20]:
chembl.select("disease_id", "drug_id").distinct().count()

52209

## Temporary

In [3]:
ind = spark.session.read.json("/Users/irenelopez/EBI/repos/clinical_mining/data/inputs/chembl_drug_indication.jsonl")

ind.show(1, truncate=False, vertical=True)
ind.count()

-RECORD 0---------------------------------------------------------------------------------------------------------------------------------------------------------------
 _metadata          | {[CHEMBL1562610, CHEMBL509]}                                                                                                                      
 efo_id             | EFO:0002609                                                                                                                                       
 indication_refs    | [{96f19af4-de8f-4fd7-90d8-55fe6ebdd81d, DailyMed, https://dailymed.nlm.nih.gov/dailymed/drugInfo.cfm?setid=96f19af4-de8f-4fd7-90d8-55fe6ebdd81d}] 
 max_phase_for_ind  | 4.0                                                                                                                                               
 molecule_chembl_id | CHEMBL1562610                                                                                                                        

55442

In [4]:
ind.printSchema()

root
 |-- _metadata: struct (nullable = true)
 |    |-- all_molecule_chembl_ids: array (nullable = true)
 |    |    |-- element: string (containsNull = true)
 |-- efo_id: string (nullable = true)
 |-- indication_refs: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- ref_id: string (nullable = true)
 |    |    |-- ref_type: string (nullable = true)
 |    |    |-- ref_url: string (nullable = true)
 |-- max_phase_for_ind: string (nullable = true)
 |-- molecule_chembl_id: string (nullable = true)



In [25]:
trans_inf = ind.withColumn("indication_ref", f.explode("indication_refs"))

In [20]:
trans_inf.groupBy("indication_ref.ref_type").count().orderBy("count").show()

+--------------+-----+
|      ref_type|count|
+--------------+-----+
|           INN|  427|
|           FDA|  608|
|           EMA| 1020|
|          USAN| 1425|
|           ATC| 3162|
|      DailyMed|32970|
|ClinicalTrials|48820|
+--------------+-----+



In [28]:
sources_for_approval = ["FDA", "EMA", "DailyMed"]

trans_ind = (
    ind.select(
        f.explode("_metadata.all_molecule_chembl_ids").alias("drug_id"),
        f.translate("efo_id", ":", "_").alias("disease_id"),
        "indication_refs",
    )
    .withColumn("indication_ref", f.explode("indication_refs"))
    # A drug/disease pair can be supported by multiple trials in a single record (e.g. CHEMBL108/EFO_0004263)
    .withColumn("id", f.explode(f.split(f.col("indication_ref.ref_id"), ",")))
    .withColumn(
        "nct_id",
        f.when(
            f.col("indication_ref.ref_type") == "ClinicalTrials",
            f.col("id"),
        ),
    )
    .withColumn("source", f.col("indication_ref.ref_type"))
    .withColumn(
        "approval",
        f.when(
            f.size(
                f.array_intersect(
                    "indication_refs.ref_type", f.lit(sources_for_approval)
                )
            )
            > 0,
            f.struct,
        ).otherwise(f.lit(False)),
    )
    .withColumn(
        "url",
        f.when(
            f.col("source") == "ClinicalTrials",
            f.concat(f.lit("https://clinicaltrials.gov/search?term="), f.col("nct_id")),
        ).otherwise(f.col("indication_ref.ref_url")),
    )
    .drop("indication_refs", "indication_ref")
    .distinct()
)

trans_ind.show(truncate=False)
trans_ind.select("drug_id", "disease_id").distinct().count()
print(trans_ind.groupBy("approved").count().show())

+-------------+-------------+-----------+-----------+--------------+--------+--------------------------------------------------+
|drug_id      |disease_id   |id         |nct_id     |source        |approved|url                                               |
+-------------+-------------+-----------+-----------+--------------+--------+--------------------------------------------------+
|CHEMBL1351   |MONDO_0002087|NCT00017303|NCT00017303|ClinicalTrials|false   |https://clinicaltrials.gov/search?term=NCT00017303|
|CHEMBL1351   |MONDO_0002158|NCT01493505|NCT01493505|ClinicalTrials|false   |https://clinicaltrials.gov/search?term=NCT01493505|
|CHEMBL384467 |EFO_0000365  |NCT03069950|NCT03069950|ClinicalTrials|false   |https://clinicaltrials.gov/search?term=NCT03069950|
|CHEMBL553    |EFO_0000707  |NCT00673049|NCT00673049|ClinicalTrials|false   |https://clinicaltrials.gov/search?term=NCT00673049|
|CHEMBL635    |EFO_1000158  |NCT00278278|NCT00278278|ClinicalTrials|false   |https://clinicaltria

In [20]:
sources = trans_ind.select("source").distinct().collect()
sources

[Row(source='USAN'),
 Row(source='EMA'),
 Row(source='ATC'),
 Row(source='INN'),
 Row(source='DailyMed'),
 Row(source='FDA'),
 Row(source='ClinicalTrials')]

In [24]:
for source in sources:
    source = source["source"]
    print(f"#### {source}")
    trans_ind.filter(f.col("source") == source).show(5, truncate=False)

#### USAN
+-------------+-----------+-------------+------+------+--------+----------------------------------------------------------------------------+
|drug_id      |disease_id |id           |nct_id|source|approved|url                                                                         |
+-------------+-----------+-------------+------+------+--------+----------------------------------------------------------------------------+
|CHEMBL3545154|EFO_0000616|POZIOTINIB   |NULL  |USAN  |false   |https://searchusan.ama-assn.org/finder/usan/search/POZIOTINIB/relevant/1/   |
|CHEMBL4081711|EFO_0000685|ZIMLOVISERTIB|NULL  |USAN  |false   |https://searchusan.ama-assn.org/finder/usan/search/ZIMLOVISERTIB/relevant/1/|
|CHEMBL3828074|EFO_0000768|ZIRITAXESTAT |NULL  |USAN  |false   |https://searchusan.ama-assn.org/finder/usan/search/ZIRITAXESTAT/relevant/1/ |
|CHEMBL2109427|EFO_0000616|ZOLBETUXIMAB |NULL  |USAN  |false   |https://searchusan.ama-assn.org/finder/usan/search/ZOLBETUXIMAB/relevant/1

## Data analysis

In [10]:
ingestable = (
    df.groupBy(f.col("nct_id"))
    .agg(f.collect_set("disease_id").alias("mapped_diseases"), f.collect_set("drug_id").alias("mapped_drugs"))
    .filter(f.size("mapped_diseases") > 0)
    .filter(f.size("mapped_drugs") > 0)
    .withColumn("diseaseId", f.explode("mapped_diseases"))
    .withColumn("drugId", f.explode("mapped_drugs"))
    .select("nct_id", "diseaseId", "drugId").distinct()
    .persist()
)


In [18]:
def expand_disease_index(disease):
    """Expand disease index to include ancestors to account for differences in granularity in the mapping."""
    return (
        disease.select(
            f.col("id").alias("diseaseId"),
            f.explode("ancestors").alias("propagatedDiseaseId"),
        )
        .union(
            disease.select(
                f.col("id").alias("diseaseId"), f.col("id").alias("propagatedDiseaseId")
            )
        )
        .distinct()
    )

disease_ancestors = expand_disease_index(spark.session.read.parquet("/Users/irenelopez/EBI/repos/clinical_mining/data/disease"))

In [20]:
disease_ancestors.show(2)

+----------+-------------------+
| diseaseId|propagatedDiseaseId|
+----------+-------------------+
|DOID_13406|        EFO_0000684|
|GO_0006695|         GO_0008152|
+----------+-------------------+
only showing top 2 rows



In [14]:
ingestable.select("diseaseId", "drugId").distinct().count()

                                                                                

91477

In [21]:
indirect_ingestable = (
    ingestable.join(disease_ancestors, on="diseaseId", how="left")
    .drop("diseaseId")
    .withColumnRenamed("propagatedDiseaseId", "diseaseId")
    .distinct()
)

indirect_ingestable.select("diseaseId", "drugId").distinct().count()


                                                                                

339797

In [16]:
chembl.select("diseaseId", "drugId").distinct().count()

43025

In [None]:
chembl = spark.session.read.parquet("/Users/irenelopez/EBI/repos/clinical_mining/data/inputs/sourceId=chembl").select("diseaseId", "drugId").distinct()

In [22]:
indirect_ingestable.join(chembl, on=["drugId", "diseaseId"], how="inner").select("diseaseId", "drugId").distinct().count()

                                                                                

19390

In [24]:
indirect_ingestable.join(chembl, on=["drugId", "diseaseId"], how="left_anti").select("diseaseId", "drugId").distinct().count()

                                                                                

320407

In [25]:
ingestable.join(chembl, on=["drugId", "diseaseId"], how="left_anti").select("diseaseId", "drugId").distinct().count()

                                                                                

75719

In [26]:
ingestable.join(chembl, on=["drugId", "diseaseId"], how="left_anti").show()

+-------------+-------------+-----------+
|       drugId|    diseaseId|     nct_id|
+-------------+-------------+-----------+
| CHEMBL560511|  EFO_0005611|NCT00000221|
|    CHEMBL115|  EFO_0000764|NCT00000861|
|    CHEMBL129|  EFO_0000764|NCT00002108|
|    CHEMBL141|  EFO_0000764|NCT00002108|
|    CHEMBL178|  EFO_1001257|NCT00003190|
|  CHEMBL44657|  EFO_1001257|NCT00003190|
|CHEMBL1200645|  EFO_1001257|NCT00003190|
|CHEMBL1980825|  EFO_0000574|NCT00003502|
|   CHEMBL1351|  EFO_0003893|NCT00003998|
|   CHEMBL1351|MONDO_0021117|NCT00005059|
|CHEMBL5483015|  EFO_0000565|NCT00005795|
|CHEMBL1079742|  EFO_0003860|NCT00026338|
|    CHEMBL709|  EFO_0002460|NCT00029822|
|    CHEMBL709|  EFO_0000536|NCT00029822|
|CHEMBL1201566|  EFO_0003869|NCT00036023|
|CHEMBL1201566|MONDO_0021117|NCT00036023|
|CHEMBL1201566|  EFO_0001378|NCT00036023|
|CHEMBL5483015|MONDO_0018177|NCT00045565|
|CHEMBL5483015|  EFO_0005543|NCT00045565|
|CHEMBL5483015|  EFO_1001465|NCT00045565|
+-------------+-------------+-----

In [27]:
chembl.filter(f.col("studyId") == "NCT00000221").show()

+---------+------+
|diseaseId|drugId|
+---------+------+
+---------+------+



In [23]:
19390/43025

0.4506682161533992

In [6]:
df.groupBy("nct_id").agg(f.first("purpose").alias("purpose")).groupBy("purpose").count().orderBy("count", ascending=False).show(truncate=False)

[Stage 13:>                                                       (0 + 10) / 10]

+------------------------+------+
|purpose                 |count |
+------------------------+------+
|TREATMENT               |147216|
|PREVENTION              |11833 |
|BASIC_SCIENCE           |7572  |
|OTHER                   |5400  |
|DIAGNOSTIC              |3898  |
|SUPPORTIVE_CARE         |3133  |
|NULL                    |2906  |
|HEALTH_SERVICES_RESEARCH|732   |
|SCREENING               |388   |
|ECT                     |42    |
|DEVICE_FEASIBILITY      |33    |
+------------------------+------+



                                                                                

## Arms distribution

In [3]:
df.select("nct_id", "number_of_arms").distinct().groupBy(
    "number_of_arms"
).count().orderBy(f.col("count").desc()).show()

df.select("nct_id", "number_of_arms").distinct().withColumn(
    "number_of_arms",
    f.when(f.col("number_of_arms") > 2, f.lit("2+")).otherwise(f.col("number_of_arms")),
).groupBy("number_of_arms").count().orderBy(f.col("count").desc()).show()

                                                                                

+--------------+-----+
|number_of_arms|count|
+--------------+-----+
|             2|82222|
|             1|45379|
|             3|19690|
|          NULL|16019|
|             4|10803|
|             5| 3181|
|             6| 2614|
|             7|  863|
|             8|  788|
|             9|  417|
|            10|  328|
|            12|  197|
|            11|  186|
|            13|   91|
|            14|   88|
|            15|   64|
|            16|   60|
|            17|   40|
|            18|   31|
|            19|   18|
+--------------+-----+
only showing top 20 rows

+--------------+-----+
|number_of_arms|count|
+--------------+-----+
|             2|82222|
|             1|45379|
|            2+|39533|
|          NULL|16019|
+--------------+-----+



In [25]:
df.filter(f.col("number_of_arms").isNull()).groupBy("nct_id").agg(
    f.collect_set("disease_name").alias("disease_names")
).withColumn("disease_count", f.size("disease_names")).groupBy(
    "disease_count"
).count().orderBy("disease_count").show(truncate=False)

+-------------+-----+
|disease_count|count|
+-------------+-----+
|1            |7730 |
|2            |4266 |
|3            |2278 |
|4            |967  |
|5            |411  |
|6            |214  |
|7            |95   |
|8            |60   |
|9            |21   |
|10           |18   |
|11           |18   |
|12           |7    |
|13           |2    |
|14           |6    |
|15           |4    |
|16           |1    |
|18           |1    |
|24           |1    |
+-------------+-----+



### No arms

In [28]:
no_arms = df.filter(f.col("number_of_arms").isNull()).filter(
    f.col("completion_date").isNotNull()
).groupBy("nct_id").agg(
    f.collect_set("disease_name").alias("disease_names"),
    f.collect_set("drug_name").alias("drug_names"),
)

print(no_arms.count())
print(no_arms.filter((f.size("disease_names") > 1) & (f.size("drug_names") > 1)).count())


no_arms.filter((f.size("disease_names") > 1) & (f.size("drug_names") > 1)).show(50, truncate=False)

12405
2221
+-----------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|nct_id     |disease_names                                                                                                                                                          |drug_names                                                                                                                                                                      |
+-----------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------------------------------------------------------

### Single arm

In [29]:
single_arm = df.filter(f.col("number_of_arms") == 1).groupBy("nct_id").agg(
    f.collect_set("disease_name").alias("disease_names"),
    f.collect_set("drug_name").alias("drug_names"),
)

print(single_arm.count())
print(single_arm.filter((f.size("disease_names") > 1) & (f.size("drug_names") > 1)).count())

single_arm.filter((f.size("disease_names") > 1) & (f.size("drug_names") > 1)).show(50, truncate=False)

45379
8492
+-----------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

### Double armed studies

In [32]:

double_arms = df.filter(f.col("number_of_arms") == 2).groupBy("nct_id").agg(
    f.collect_set("disease_name").alias("disease_names"),
    f.collect_set("drug_name").alias("drug_names"),
)

print(double_arms.count())
print(double_arms.filter((f.size("disease_names") > 1) & (f.size("drug_names") > 1)).count())

double_arms.filter((f.size("disease_names") > 1) & (f.size("drug_names") > 1)).show(70, truncate=False)

82222
17805
+-----------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|nct_id     |disease_names                                                                                                                                                  

25/06/26 16:02:44 WARN HeartbeatReceiver: Removing executor driver with no recent heartbeats: 1825142 ms exceeds timeout 120000 ms
25/06/26 16:02:44 WARN SparkContext: Killing executors is not supported by current scheduler.
25/06/26 16:02:47 ERROR Inbox: Ignoring error
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:56)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:310)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRefByURI(RpcEnv.scala:102)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRef(RpcEnv.scala:110)
	at org.apache.spark.util.RpcUtils$.makeDriverRef(RpcUtils.scala:36)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.driverEndpoint$lzycompute(BlockManagerMasterEndpoint.scala:124)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.org$apache$spark$storage$BlockManagerMasterEndpoint$

In [31]:
df.filter(f.col("nct_id") == "NCT04496024").show(truncate=False)

+-----------------------------------------------+------------+-----------+--------------+-----+---------------+-----------+--------------+-------------+-----------+
|disease_name                                   |drug_name   |nct_id     |overall_status|phase|completion_date|why_stopped|number_of_arms|drug_id      |disease_id |
+-----------------------------------------------+------------+-----------+--------------+-----+---------------+-----------+--------------+-------------+-----------+
|drug-related side effects and adverse reactions|ofloxacin   |NCT04496024|UNKNOWN       |NA   |2023-08-31     |NULL       |NULL          |CHEMBL4      |NULL       |
|infections                                     |ofloxacin   |NCT04496024|UNKNOWN       |NA   |2023-08-31     |NULL       |NULL          |CHEMBL4      |NULL       |
|arthritis, infectious                          |ofloxacin   |NCT04496024|UNKNOWN       |NA   |2023-08-31     |NULL       |NULL          |CHEMBL4      |EFO_1001351|
|drug-rela

## Multiple arms

In [31]:
multiple_arms = df.filter(f.col("number_of_arms") > 2).groupBy("nct_id").agg(
    f.collect_set("disease_name").alias("disease_names"),
    f.collect_set("drug_name").alias("drug_names"),
)

print(multiple_arms.count())
print(multiple_arms.filter((f.size("disease_names") > 1) & (f.size("drug_names") > 1)).count())

multiple_arms.filter((f.size("disease_names") > 1) & (f.size("drug_names") > 1)).show(50, truncate=False)

39533
9521
+-----------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|nct_id     |disease_names                                                                                                                                                                                                                                                                                             |drug_names                                                                                                                                      

In [5]:
from clinical_mining.utils.db import AACTConnector
import pyspark.sql.functions as f

db = AACTConnector(
    db_url="aact-db.ctti-clinicaltrials.org:5432/aact",
    user="irenelopez",
    password="Ephemeral2023",
    schema="ctgov",
)


    

## Looking at literature refs

In [2]:
df = db.spark.session.read.parquet("/Users/irenelopez/EBI/repos/clinical_mining/data/aact_trials_20250509")

df.show()


+-------------------+------------+-----------+--------------+------+---------------+--------------------+--------------+--------+--------------+----------+-----------+
|       disease_name|   drug_name|     nct_id|overall_status| phase|completion_date|         why_stopped|number_of_arms|    pmid|reference_type|   drug_id| disease_id|
+-------------------+------------+-----------+--------------+------+---------------+--------------------+--------------+--------+--------------+----------+-----------+
|prostatic neoplasms|bicalutamide|NCT00243646|    TERMINATED|PHASE3|     2009-07-31|        poor accrual|             2| 9169810|    BACKGROUND| CHEMBL409|       NULL|
|prostatic neoplasms|bicalutamide|NCT00243646|    TERMINATED|PHASE3|     2009-07-31|        poor accrual|             2|12126818|    BACKGROUND| CHEMBL409|       NULL|
|prostatic neoplasms|bicalutamide|NCT00243646|    TERMINATED|PHASE3|     2009-07-31|        poor accrual|             2|15337535|    BACKGROUND| CHEMBL409|     

In [22]:
agg_df = df.groupBy("nct_id").agg(
    f.collect_set("reference_type").alias("reference_types"),
    f.collect_set("pmid").alias("pmids"),
)

agg_df.groupBy("reference_types").count().orderBy(f.col("count").desc()).show(
    100, truncate=False
)

+-----------------------------+------+
|reference_types              |count |
+-----------------------------+------+
|[]                           |124294|
|[DERIVED]                    |26037 |
|[BACKGROUND]                 |16167 |
|[RESULT]                     |8407  |
|[DERIVED, BACKGROUND]        |3871  |
|[BACKGROUND, RESULT]         |2018  |
|[DERIVED, RESULT]            |1881  |
|[DERIVED, BACKGROUND, RESULT]|653   |
+-----------------------------+------+



In [24]:
agg_df.select(
    f.when(f.size("reference_types") >= 1, f.lit(True))
    .otherwise(f.lit(False))
    .alias("has_reference")
).groupBy("has_reference").count().show()

+-------------+------+
|has_reference| count|
+-------------+------+
|         true| 59034|
|        false|124294|
+-------------+------+



## Get abstract info

In [4]:
result_pubmeds = [
    e.pmid
    for e in df.filter(f.col("reference_type").isin(["RESULT", "DERIVED"]))
    .filter(f.col("pmid").isNotNull())
    .select("pmid")
    .distinct()
    .collect()
]
len(result_pubmeds)

74201

In [5]:
# por que hay nulos en pmid si no hay reference_type?
df.filter(f.col("reference_type").isin(["RESULT", "DERIVED"])).filter(
    f.col("pmid").isNull()
).show()

+--------------------+-------------+-----------+--------------+------+---------------+-----------+--------------+----------+----+--------------+-------------+-----------+
|        disease_name|    drug_name|     nct_id|overall_status| phase|completion_date|why_stopped|number_of_arms|   purpose|pmid|reference_type|      drug_id| disease_id|
+--------------------+-------------+-----------+--------------+------+---------------+-----------+--------------+----------+----+--------------+-------------+-----------+
|     liver neoplasms|  gemcitabine|NCT00006010|     COMPLETED|PHASE2|     2008-04-30|       NULL|             1| TREATMENT|NULL|        RESULT|    CHEMBL888|       NULL|
|     liver neoplasms|    docetaxel|NCT00006010|     COMPLETED|PHASE2|     2008-04-30|       NULL|             1| TREATMENT|NULL|        RESULT|CHEMBL3545252|       NULL|
|          infections|    ritonavir|NCT01061151|     COMPLETED|PHASE3|     2016-09-30|       NULL|             8|PREVENTION|NULL|        RESULT| 

In [6]:
from Bio import Entrez

Entrez.email = "irene.lopez@ebi.ac.uk"

handle = Entrez.efetch(
    db="pubmed",
    id=",".join(result_pubmeds[:100]),
    rettype="medline",
    retmode="xml",
    retmax=100,
)
records = Entrez.read(handle)
results = {}
for article in records["PubmedArticle"]:
    try:
        citation = article["MedlineCitation"]["Article"]
        abstract = " ".join(citation["Abstract"]["AbstractText"])
        results[str(article["MedlineCitation"]["PMID"])] = abstract
    except Exception:
        print(f"Error processing article: {str(article['MedlineCitation']['PMID'])}")
        results[article["MedlineCitation"]["PMID"]] = None

Error processing article: 12672864


In [7]:
from Bio import Entrez
import time
from typing import List, Dict
import logging

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def fetch_abstracts(pmids: List[str], batch_size: int = 10000, sleep_time: int = 3) -> Dict[str, str]:
    """
    Fetch abstracts from PubMed for a list of PMIDs.
    
    Args:
        pmids: List of PubMed IDs
        batch_size: Number of PMIDs to fetch in each batch (max 10000)
        sleep_time: Time to sleep between batches in seconds
    
    Returns:
        Dictionary mapping PMID to abstract text
    """
    Entrez.email = "irene.lopez@ebi.ac.uk"
    results = {}
    
    # Process PMIDs in batches
    for i in range(0, len(pmids), batch_size):
        batch = pmids[i:i + batch_size]
        batch_str = ",".join(batch)
        
        try:
            logger.info(f"Fetching batch {i//batch_size + 1} of {len(pmids)//batch_size + 1}")
            handle = Entrez.efetch(
                db="pubmed",
                id=batch_str,
                rettype="medline",
                retmode="xml",
                retmax=batch_size
            )
            records = Entrez.read(handle)
            
            for article in records["PubmedArticle"]:
                try:
                    citation = article["MedlineCitation"]["Article"]
                    abstract = " ".join(citation["Abstract"]["AbstractText"])
                    results[str(article["MedlineCitation"]["PMID"])] = abstract
                except KeyError:
                    logger.warning(f"No abstract found for PMID: {article['MedlineCitation']['PMID']}")
                    results[article["MedlineCitation"]["PMID"]] = None
                except Exception as e:
                    logger.error(f"Error processing article {article['MedlineCitation']['PMID']}: {str(e)}")
                    results[article["MedlineCitation"]["PMID"]] = None
            
            # Sleep between batches to avoid rate limiting
            if i + batch_size < len(pmids):
                logger.info(f"Sleeping for {sleep_time} seconds before next batch")
                time.sleep(sleep_time)
                
        except Exception as e:
            logger.error(f"Error fetching batch {i//batch_size + 1}: {str(e)}")
            # If batch fails, still return what we have
            continue
    
    return results

In [8]:
abstracts = fetch_abstracts(result_pubmeds)

INFO:__main__:Fetching batch 1 of 8
INFO:__main__:Sleeping for 3 seconds before next batch
INFO:__main__:Fetching batch 2 of 8
INFO:__main__:Sleeping for 3 seconds before next batch
INFO:__main__:Fetching batch 3 of 8
INFO:__main__:Sleeping for 3 seconds before next batch
INFO:__main__:Fetching batch 4 of 8
INFO:__main__:Sleeping for 3 seconds before next batch
INFO:__main__:Fetching batch 5 of 8
INFO:__main__:Sleeping for 3 seconds before next batch
INFO:__main__:Fetching batch 6 of 8
INFO:__main__:Sleeping for 3 seconds before next batch
INFO:__main__:Fetching batch 7 of 8
INFO:__main__:Sleeping for 3 seconds before next batch
INFO:__main__:Fetching batch 8 of 8


In [10]:
from pyspark.sql.types import StructType, StructField, StringType

data = [(str(pmid), str(abstract)) for pmid, abstract in abstracts.items()]

abstracts_df = spark.session.createDataFrame(data, schema=StructType([
    StructField("pmid", StringType(), True),
    StructField("abstract", StringType(), True)
]))

abstracts_df.show(5)

25/05/14 14:37:08 WARN TaskSetManager: Stage 9 contains a task of very large size (12850 KiB). The maximum recommended task size is 1000 KiB.
[Stage 9:>                                                          (0 + 1) / 1]

+--------+--------------------+
|    pmid|            abstract|
+--------+--------------------+
|22253412|We assessed the s...|
|24409418|We compared outco...|
|26867177|Highly pathogenic...|
|19965662|Certain malignant...|
|26024112|To describe the m...|
+--------+--------------------+
only showing top 5 rows



25/05/14 14:37:13 WARN PythonRunner: Detected deadlock while completing task 0.0 in stage 9 (TID 25): Attempting to kill Python Worker
                                                                                

In [11]:
abstracts_df.filter(f.col("abstract").isNull()).show(5)

25/05/14 14:37:18 WARN TaskSetManager: Stage 10 contains a task of very large size (12850 KiB). The maximum recommended task size is 1000 KiB.
25/05/14 14:37:18 WARN TaskSetManager: Stage 11 contains a task of very large size (12708 KiB). The maximum recommended task size is 1000 KiB.


+----+--------+
|pmid|abstract|
+----+--------+
+----+--------+



25/05/14 14:37:19 WARN TaskSetManager: Stage 12 contains a task of very large size (12699 KiB). The maximum recommended task size is 1000 KiB.


In [14]:
abstracts_df.select(f.col("abstract").isNotNull().alias("has_abstract")).groupBy("has_abstract").count().show()

25/05/14 14:38:04 WARN TaskSetManager: Stage 15 contains a task of very large size (12850 KiB). The maximum recommended task size is 1000 KiB.


+------------+-----+
|has_abstract|count|
+------------+-----+
|        true|74162|
+------------+-----+



In [None]:
abstracts_df.write.json("/Users/irenelopez/EBI/repos/clinical_mining/data/abstracts")

In [4]:
abstracts = spark.session.read.json("/Users/irenelopez/EBI/repos/clinical_mining/data/abstracts")

abstracts.show(5)

+--------------------+--------+
|            abstract|    pmid|
+--------------------+--------+
|Increases in andr...|32923850|
|Oral lichen planu...|22322481|
|SENIOR compared t...|29895556|
|In multiple myelo...|33476575|
|The combination o...|23319574|
+--------------------+--------+
only showing top 5 rows



In [17]:
abstracts = (
    abstracts_df.join(df.select("nct_id", "phase", "pmid"), "pmid", "inner")
    .select("nct_id", "pmid", "phase", "abstract")
    .distinct()
    .persist()
)

abstracts.show(5)

25/05/14 14:38:39 WARN TaskSetManager: Stage 25 contains a task of very large size (12850 KiB). The maximum recommended task size is 1000 KiB.

+-----------+--------+------+--------------------+
|     nct_id|    pmid| phase|            abstract|
+-----------+--------+------+--------------------+
|NCT01243242|23933905|PHASE2|To compare the ef...|
|NCT00129740|30885996|PHASE2|Cardiovascular or...|
|NCT01891344|35170751|PHASE2|Ovarian cancer is...|
|NCT00443599|24671945|    NA|Our previous rand...|
|NCT02388061|33408145|PHASE2|To investigate th...|
+-----------+--------+------+--------------------+
only showing top 5 rows



                                                                                

In [19]:
abstracts.write.parquet("/Users/irenelopez/EBI/repos/clinical_mining/data/abstracts", mode="overwrite")

                                                                                

In [20]:
from pyspark.sql import Window

def sample_100_per_phase(df):
    # Create a window partitioned by phase and ordered randomly
    window_spec = Window.partitionBy("phase").orderBy(f.rand())
    
    # Add row numbers within each phase partition
    df_with_row_numbers = df.withColumn("row_num", f.row_number().over(window_spec))
    
    # Filter to keep only the first 100 rows for each phase
    sampled_df = df_with_row_numbers.filter(f.col("row_num") <= 100).drop("row_num")
    
    return sampled_df



In [21]:
sampled_df = sample_100_per_phase(abstracts)
sampled_df.show(5)

+-----------+--------+------+--------------------+
|     nct_id|    pmid| phase|            abstract|
+-----------+--------+------+--------------------+
|NCT01880593|22840761|PHASE2|Ketamine is repor...|
|NCT01139281|10459293|PHASE2|Cisplatin (CDDP) ...|
|NCT00017225|11857322|PHASE2|The biologic beha...|
|NCT00409656|26313245|PHASE2|Penetrating kerat...|
|NCT02845297|36099049|PHASE2|BackgroundImmune ...|
+-----------+--------+------+--------------------+
only showing top 5 rows



In [22]:
sampled_df.groupBy("phase").count().show()

+-------------+-----+
|        phase|count|
+-------------+-----+
| EARLY_PHASE1|  100|
|           NA|  100|
|       PHASE1|  100|
|PHASE1/PHASE2|  100|
|       PHASE2|  100|
|PHASE2/PHASE3|  100|
|       PHASE3|  100|
|       PHASE4|  100|
|         NULL|    2|
+-------------+-----+



In [24]:
sampled_df.coalesce(1).write.csv(
    "/Users/irenelopez/EBI/repos/clinical_mining/data/sampled_abstracts",
    header=True,
    sep="\t",
    mode="overwrite",
)

25/05/14 17:08:37 WARN HeartbeatReceiver: Removing executor driver with no recent heartbeats: 907499 ms exceeds timeout 120000 ms
25/05/14 17:08:37 WARN SparkContext: Killing executors is not supported by current scheduler.
25/05/14 17:15:40 ERROR Inbox: Ignoring error
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:56)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:310)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRefByURI(RpcEnv.scala:102)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRef(RpcEnv.scala:110)
	at org.apache.spark.util.RpcUtils$.makeDriverRef(RpcUtils.scala:36)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.driverEndpoint$lzycompute(BlockManagerMasterEndpoint.scala:124)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.org$apache$spark$storage$BlockManagerMasterEndpoint$$

In [None]:
df.filter(f.col("why_stopped").isNotNull()).show()

root
 |-- disease_name: string (nullable = true)
 |-- drug_name: string (nullable = true)
 |-- nct_id: string (nullable = true)
 |-- overall_status: string (nullable = true)
 |-- phase: string (nullable = true)
 |-- completion_date: date (nullable = true)
 |-- why_stopped: string (nullable = true)
 |-- number_of_arms: integer (nullable = true)
 |-- purpose: string (nullable = true)
 |-- pmid: string (nullable = true)
 |-- reference_type: string (nullable = true)
 |-- drug_id: string (nullable = true)
 |-- disease_id: string (nullable = true)

