In [2]:
"""Step to run Locus to Gene either for inference or for training."""
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING

import pyspark.sql.functions as f
from hydra import compose, initialize_config_dir

if TYPE_CHECKING:
    from pyspark.sql import DataFrame

from otg.common.session import ETLSession
from otg.config import LocusToGeneConfig

from otg.common.spark_helpers import get_record_with_maximum_value
from otg.dataset.study_locus import StudyLocus
from otg.dataset.study_locus_overlap import StudyLocusOverlap
from otg.dataset.v2g import V2G
from otg.method.locus_to_gene import LocusToGeneTrainer


In [3]:
etl=ETLSession("local[*]", "ot_genetics_local", "overwrite")
cfg = LocusToGeneConfig(
    run_mode="train",
    study_locus_path="/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/mock_study_locus",
    study_locus_overlap_path="/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/mock_study_locus_overlap",
    variant_gene_path="/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/mock_v2g",
    colocalisation_path="/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/mock_colocalisation",
    study_index_path="/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/mock_study_index",
    gold_standard_curation_path="/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/curation.json",
    gene_interactions_path="/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/interaction",
)

23/03/05 02:07:16 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/03/05 02:07:18 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [5]:
@dataclass
class LocusToGeneStep:
    """Locus to gene step."""

    cfg: LocusToGeneConfig

    def run(self: LocusToGeneStep) -> None:
        """Run Locus to Gene step."""
        self.etl.logger.info(f"Executing {self.id} step")

        if self.run_mode == "train":
            gold_standards = get_gold_standards(
                etl=self.etl,
                study_locus_path=self.study_locus_path,
                v2g_path=self.variant_gene_path,
                study_locus_overlap_path=self.study_locus_overlap_path,
                gold_standard_curation=self.gold_standard_curation_path,
                interactions_path=self.gene_interactions_path,
            )
            # print(gold_standards.printSchema())
            # gold_standards.write.parquet(
            #     "/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/staging/gs"
            # )
            fm = L2GFeatureMatrix  # FIXME: debug credset
            # data = gold_standards.join(
            #     fm, on="studyLocusId", how="inner"
            # ).train_test_split(frac=0.1, seed=42)
            # # TODO: data normalization and standardisation of features

            # LocusToGeneTrainer.train(
            #     train_set=data["train"],
            #     test_set=data["test"],
            #     **self.hyperparameters,
            #     # TODO: Add push to hub, and push to W&B
            # )


def get_gold_standards(
    etl: ETLSession,
    gold_standard_curation: str,
    v2g_path: str,
    study_locus_path: str,
    study_locus_overlap_path: str,
    interactions_path: str,
) -> DataFrame:
    """Process gold standard curation to use as training data."""
    # FIXME: assign function to class - something is wrong instantiating the classes, used to work
    overlaps_df = StudyLocusOverlap.from_parquet(
        etl, study_locus_overlap_path
    ).df.select("left_studyLocusId", "right_studyLocusId")
    interactions = process_gene_interactions(etl, interactions_path)
    return (
        etl.spark.read.json(gold_standard_curation)
        .select(
            f.col("association_info.otg_id").alias("studyId"),
            f.col("gold_standard_info.gene_id").alias("geneId"),
            f.concat_ws(
                "_",
                f.col("sentinel_variant.locus_GRCh38.chromosome"),
                f.col("sentinel_variant.locus_GRCh38.position"),
                f.col("sentinel_variant.alleles.reference"),
                f.col("sentinel_variant.alleles.alternative"),
            ).alias("variantId"),
        )
        .filter(f.col("gold_standard_info.highest_confidence").isin(["High", "Medium"]))
        # Bring studyLocusId - TODO: what if I don't have one?
        .join(
            StudyLocus.from_parquet(etl, study_locus_path)._df.select(
                "studyId", "variantId", "studyLocusId"
            ),
            on=["studyId", "variantId"],
            how="inner",
        )
        # Assign Positive or Negative Status based on confidence
        .join(
            V2G.from_parquet(etl, v2g_path)._df.select(
                "variantId", "geneId", "distance"
            ),
            on=["variantId", "geneId"],
            how="inner",
        )
        .withColumn(
            "gsStatus",
            f.when(f.col("distance") <= 500_000, "Positive").otherwise("Negative"),
        )
        # Remove redundant loci
        .alias("left")
        .join(
            overlaps_df.alias("right"),
            (f.col("left.variantId") == f.col("right.left_studyLocusId"))
            | (f.col("left.variantId") == f.col("right.right_studyLocusId")),
            how="left",
        )
        .distinct()
        # Remove redundant genes
        .join(
            interactions.alias("interactions"),
            (f.col("left.geneId") == f.col("interactions.geneIdA"))
            | (f.col("left.geneId") == f.col("interactions.geneIdB")),
            how="left",
        )
        .withColumn("interacting", (f.col("score") > 0.7))
        # filter out genes where geneIdA has gsStatus Negative but geneIdA and gene IdB are interacting
        .filter(
            ~(
                (f.col("gsStatus") == "Negative")
                & (f.col("interacting"))
                & (
                    (f.col("left.geneId") == f.col("interactions.geneIdA"))
                    | (f.col("left.geneId") == f.col("interactions.geneIdB"))
                )
            )
        )
    )


def process_gene_interactions(etl: ETLSession, interactions_path: str) -> DataFrame:
    """Extract top scoring gene-gene interaction from the Platform."""
    # FIXME: assign function to class
    return get_record_with_maximum_value(
        etl.spark.read.parquet(interactions_path),
        ["targetA", "targetB"],
        "scoring",
    ).selectExpr(
        "targetA as geneIdA",
        "targetB as geneIdB",
        "scoring as score",
    )


In [11]:
from otg.common.schemas import parse_spark_schema
from otg.dataset.study_index import StudyIndex
from otg.dataset.study_locus import StudyLocus
from otg.dataset.colocalisation import Colocalisation

from dataclasses import field

spark = etl.spark
from functools import partial
@dataclass
class L2GFeatureMatrix:
    """Dataset with features for Locus to Gene prediction."""

    _schema: StructType = parse_spark_schema("l2g_feature_matrix.json")
    _df: DataFrame = field(
        default_factory=partial(spark.createDataFrame, [], schema=_schema)
    )

    def run(self):
        print(self._schema)
        print(self._df)

    # def __post_init__(self: L2GFeatureMatrix) -> None:
    #     """Post init."""
    #     if self._df is None:
    #         self._df = spark.createDataFrame([], schema=self._schema)

    # def get_distance_features(
    #     self: L2GFeatureMatrix, study_locus: StudyLocus, distances: V2G
    # ) -> L2GFeatureMatrix:
    #     """Get distance features."""
    #     distance = study_locus._get_tss_distance_features(distances)
    #     return L2GFeatureMatrix(
    #         _df=self._df.unionByName(distance, allowMissingColumns=True)
    #     )
    
    # def get_coloc_features(
    #     self: L2GFeatureMatrix, colocalisation: Colocalisation, study_locus: StudyLocus, studies: StudyIndex
    # ):
    #     """Get coloc features."""
    #     ss_coloc = colocalisation.get_max_llr_per_study_locus(study_locus, studies)
    #     return L2GFeatureMatrix(
    #         _df=self._df.unionByName(ss_coloc, allowMissingColumns=True)
    #     )
    

    # def get_all_features()


# feature_matrix = (
#     L2GFeatureMatrix().get_distance_features(
#         StudyLocus.from_parquet(etl, cfg.study_locus_path),
#         V2G.from_parquet(etl, cfg.variant_gene_path),
#     )
#     .get_coloc_features(
#         Colocalisation.from_parquet(etl, cfg.colocalisation_path),
#         StudyLocus.from_parquet(etl, cfg.study_locus_path),
#         StudyIndex.from_parquet(etl, cfg.study_index_path),

#     )
# )

In [12]:
L2GFeatureMatrix().run()

StructType(List(StructField(studyLocusId,StringType,false),StructField(feature,StringType,true),StructField(geneId,StringType,true),StructField(value,DoubleType,true)))
DataFrame[studyLocusId: string, feature: string, geneId: string, value: double]


In [40]:
from typing import Optional, TYPE_CHECKING
from omegaconf import DictConfig, MISSING
@dataclass
class LocusToGeneConfig:
    """Config for Locus to Gene classifier."""

    run_mode: str
    study_locus_path: str = MISSING
    variant_gene_path: str = MISSING
    colocalisation_path: str = MISSING
    study_index_path: str = MISSING
    study_locus_overlap_path: str = MISSING
    gold_standard_curation_path: str = MISSING
    gene_interactions_path: str = MISSING
    hyperparameters: dict = MISSING
    l2g_model_path: Optional[str] = None
    id: str = "locus_to_gene"
    __target__: str = MISSING
    etl: ETLSession = MISSING

cfg = LocusToGeneConfig(
    etl=ETLSession("local[*]", "ot_genetics_local", "overwrite"),
    run_mode="train",
    study_locus_path="/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/mock_study_locus",
    study_locus_overlap_path="/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/mock_study_locus_overlap",
    variant_gene_path="/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/mock_v2g",
    colocalisation_path="/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/mock_colocalisation",
    study_index_path="/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/mock_study_index",
    gold_standard_curation_path="/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/curation.json",
    gene_interactions_path="/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/interaction",
)

In [6]:
LocusToGeneStep.run(cfg)

root
 |-- variantId: string (nullable = false)
 |-- geneId: string (nullable = true)
 |-- studyId: string (nullable = true)
 |-- studyLocusId: string (nullable = true)
 |-- distance: long (nullable = true)
 |-- gsStatus: string (nullable = false)
 |-- left_studyLocusId: string (nullable = true)
 |-- right_studyLocusId: string (nullable = true)
 |-- geneIdA: string (nullable = true)
 |-- geneIdB: string (nullable = true)
 |-- score: double (nullable = true)
 |-- interacting: boolean (nullable = true)

None


In [None]:
# get_gold_standards is not computing correctly - debugging

# overlaps_df = StudyLocusOverlap.from_parquet(
#         cfg.etl, cfg.study_locus_overlap_path
#     ).df.select("left_studyLocusId", "right_studyLocusId")
# interactions_df = process_gene_interactions(cfg.etl, cfg.gene_interactions_path)

gs = (
    cfg.etl.spark.read.json(cfg.gold_standard_curation_path)
    .select(
            f.col("association_info.otg_id").alias("studyId"),
            f.col("gold_standard_info.gene_id").alias("geneId"),
            f.concat_ws(
                "_",
                f.col("sentinel_variant.locus_GRCh38.chromosome"),
                f.col("sentinel_variant.locus_GRCh38.position"),
                f.col("sentinel_variant.alleles.reference"),
                f.col("sentinel_variant.alleles.alternative"),
            ).alias("variantId"),
        )
    .filter(f.col("gold_standard_info.highest_confidence").isin(["High", "Medium"]))
    .join(
            StudyLocus.from_parquet(cfg.etl, cfg.study_locus_path).df.select(
                "studyId", "variantId", "studyLocusId"
            ),
            on=["studyId", "variantId"],
            how="inner",
        )
    .join(
            V2G.from_parquet(cfg.etl, cfg.variant_gene_path).df.select(
                "variantId", "geneId", "distance"
            ),
            on=["variantId", "geneId"],
            how="inner",
        )
        .withColumn(
            "gsStatus",
            f.when(f.col("distance") <= 500_000, "Positive").otherwise("Negative"),
        )
        .alias("left")
        .join(
            overlaps_df.alias("right"),
            (f.col("left.variantId") == f.col("right.left_studyLocusId"))
            | (f.col("left.variantId") == f.col("right.right_studyLocusId")),
            how="left",
        )
        .distinct()
        # here is the error
        # .join(
        #     interactions_df.alias("interactions"),
        #     (f.col("left.geneId") == f.col("interactions.geneIdA"))
        #     | (f.col("left.geneId") == f.col("interactions.geneIdB")),
        #     how="left",
        # )
)

gs.show()

+---------+------+-------+------------+--------+--------+-----------------+------------------+
|variantId|geneId|studyId|studyLocusId|distance|gsStatus|left_studyLocusId|right_studyLocusId|
+---------+------+-------+------------+--------+--------+-----------------+------------------+
+---------+------+-------+------------+--------+--------+-----------------+------------------+



In [None]:
sample = interactions_df.limit(5).selectExpr("geneIdA as geneId")

(
    sample.alias("left")
    .join(
            interactions_df.alias("interactions"),
            (f.col("left.geneId") == f.col("interactions.geneIdA"))
            | (f.col("left.geneId") == f.col("interactions.geneIdB")),
            how="left",
    )
    .withColumn("interacting", (f.col("score") > 0.5))
    .withColumn("gsStatus", f.lit("Negative"))
    .filter(
            ~(
                (f.col("gsStatus") == "Negative")
                & (f.col("interacting"))
                & (
                    (f.col("left.geneId") == f.col("interactions.geneIdA"))
                    | (f.col("left.geneId") == f.col("interactions.geneIdB"))
                )
            )
        )
    .show()
)

# the condition works as expected when the data is not null



+---------------+---------------+---------------+-----+-----------+--------+
|         geneId|        geneIdA|        geneIdB|score|interacting|gsStatus|
+---------------+---------------+---------------+-----+-----------+--------+
|ENSG00000000003|ENSG00000000003|ENSG00000124422| 0.15|      false|Negative|
|ENSG00000000003|ENSG00000000003|ENSG00000124422| 0.15|      false|Negative|
|ENSG00000000003|ENSG00000000003|ENSG00000124422| 0.15|      false|Negative|
|ENSG00000000003|ENSG00000000003|ENSG00000124422| 0.15|      false|Negative|
|ENSG00000000003|ENSG00000000003|ENSG00000154146|0.173|      false|Negative|
|ENSG00000000003|ENSG00000000003|ENSG00000154146|0.173|      false|Negative|
|ENSG00000000003|ENSG00000000003|ENSG00000154146|0.173|      false|Negative|
|ENSG00000000003|ENSG00000000003|ENSG00000154146|0.173|      false|Negative|
|ENSG00000000003|ENSG00000000003|ENSG00000162407|0.242|      false|Negative|
|ENSG00000000003|ENSG00000000003|ENSG00000162407|0.242|      false|Negative|

                                                                                

In [None]:
# created a processed gs file ad hoc from the old one

gs_processed = spark.read.parquet("gs://genetics-portal-dev-staging/l2g/221107/gold_standards/featurematrix_w_goldstandards.training_only.221107.parquet")

gs_processed_slim = (
    gs_processed.filter(f.col("gs_confidence") != "Low")
    .select(
        f.col("study_id").alias("studyId"),
        f.concat_ws(
            "_", f.col("chrom"), f.col("pos"), f.col("ref"), f.col("alt")
        ).alias("variantId"),
        f.col("gene_id").alias("geneId"),
        f.col("gold_standard_status").alias("gsStatus"),
    )
)

# studyLocusId must be created separately
assocs = assocs = gs_processed_slim.select("studyId", "variantId").distinct().withColumn("studyLocusId", f.monotonically_increasing_id())
gs_processed_slim = gs_processed_slim.join(assocs, on=["studyId", "variantId"], how="inner")




In [None]:
# Distance feature matrix

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

from otg.common.schemas import parse_spark_schema
from otg.dataset.dataset import Dataset
from otg.dataset.study_locus import StudyLocus
from otg.dataset.v2g import V2G

if TYPE_CHECKING:
    from pyspark.sql.types import StructType

    from otg.common.session import ETLSession

@dataclass
class L2GFeature:
    """Property of a study locus pair."""

    study_id: str  # TODO: think about moving this to a trait id - so that we can extract the best study for that trait to train on
    locus_id: str
    gene_id: str
    feature_name: str
    feature_value: float


@dataclass
class L2GFeatureMatrix(Dataset):
    """Dataset with features for Locus to Gene prediction."""

    _schema: StructType = parse_spark_schema("l2g_feature_matrix.json")

    @classmethod
    def from_parquet(
        cls: type[L2GFeatureMatrix], etl: ETLSession, path: str
    ) -> Dataset:
        """Initialise L2GFeatureMatrix from parquet file.

        Args:
            etl (ETLSession): ETL session
            path (str): Path to parquet file

        Returns:
            Dataset: Locus to gene feature matrix
        """
        return super().from_parquet(etl, path), cls.schema
    
    def get_distance_features(
            self: L2GFeatureMatrix, study_locus: StudyLocus, distances: V2G
        ) -> L2GFeatureMatrix:
            """Get distance features."""
            distance = study_locus._get_tss_distance_features(distances)
            # return L2GFeatureMatrix(
            #     _df=self._df.unionByName(distance, allowMissingColumns=True)
            # )   
            self._df = self.df.unionByName(distance, allowMissingColumns=True)

    @classmethod
    def get_all_features(
            cls: type[L2GFeatureMatrix], study_locus: StudyLocus, distances: V2G
        ) -> L2GFeatureMatrix:
            """Get all features."""
            distance_features = L2GFeatureMatrix()
            return self


In [None]:
fm = L2GFeatureMatrix(_df=cfg.etl.spark.createDataFrame([], schema=_schema), path=None).get_distance_features(
    study_locus=StudyLocus.from_parquet(
        cfg.etl,
        "/Users/irenelopez/MEGAsync/EBI/repos/genetics_etl_python/mock_data/processed_gs",
    ),
    distances=V2G.from_parquet(cfg.etl, cfg.variant_gene_path),
)


AttributeError: 'NoneType' object has no attribute 'unionByName'

23/02/27 17:57:59 WARN HeartbeatReceiver: Removing executor driver with no recent heartbeats: 547751 ms exceeds timeout 120000 ms
23/02/27 17:57:59 WARN SparkContext: Killing executors is not supported by current scheduler.


In [None]:
L2GFeatureMatrix.__dict__

mappingproxy({'__module__': '__main__',
              '__annotations__': {'schema': 'StructType'},
              '__doc__': 'Dataset with features for Locus to Gene prediction.',
              'schema': StructType(List(StructField(studyLocusId,StringType,false),StructField(feature,StringType,true),StructField(geneId,StringType,true),StructField(value,DoubleType,true))),
              'from_parquet': <classmethod at 0x12c3d28e0>,
              'get_distance_features': <function __main__.L2GFeatureMatrix.get_distance_features(self: 'L2GFeatureMatrix', study_locus: 'StudyLocus', distances: 'V2G') -> 'L2GFeatureMatrix'>,
              '__dict__': <attribute '__dict__' of 'L2GFeatureMatrix' objects>,
              '__weakref__': <attribute '__weakref__' of 'L2GFeatureMatrix' objects>,
              '__dataclass_params__': _DataclassParams(init=True,repr=True,eq=True,order=False,unsafe_hash=False,frozen=False),
              '__dataclass_fields__': {'schema': Field(name='schema',type='Struct

In [57]:
def validate_schema(expected_schema, observed_schema) -> None:
    """Validate DataFrame schema against expected class schema.

    Raises:
        ValueError: DataFrame schema is not valid
    """
    # Do not look at the nullable flag
    from collections import namedtuple

    schema_tuple = namedtuple("schema", ["name", "type"])
    expected_schema = [schema_tuple(e.name, e.dataType) for e in expected_schema.fields]
    observed_schema = [schema_tuple(e.name, e.dataType) for e in observed_schema.fields]

    # Unexpected fields in the observed schema
    missing_struct_fields = [x for x in observed_schema if x not in expected_schema]
    error_message = f"The {missing_struct_fields} StructFields are not included in DataFrame schema: {expected_schema}"
    if missing_struct_fields:
        raise ValueError(error_message)

    # Required fields not in dataset
    required_fields = [x for x in expected_schema if not x.nullable]
    missing_required_fields = [
        x for x in required_fields if x not in observed_schema
    ]
    error_message = f"The {missing_required_fields} StructFields are required but missing from the DataFrame schema: {expected_schema}"
    if missing_required_fields:
        raise ValueError(error_message)

In [60]:
v2g = etl.spark.read.parquet(cfg.variant_gene_path)
observed_schema = v2g.schema
expected_schema = V2G.schema


In [65]:
observed_schema.fields[0].name

'geneId'

In [59]:
schema_tuple = namedtuple("schema", ["name", "type"])
expected_schema = [schema_tuple(e.name, e.dataType) for e in expected_schema.fields]
observed_schema = [schema_tuple(e.name, e.dataType) for e in observed_schema.fields]

validate_schema(expected_schema, observed_schema)


AttributeError: 'list' object has no attribute 'fields'

In [56]:
missing_struct_fields = [x.name for x in observed_schema if x not in expected_schema]

missing_struct_fields

[]

In [70]:
observed_schema.dataTypes()

AttributeError: 'StructType' object has no attribute 'dataTypes'

In [55]:
expected_schema

[schema(name='geneId', type=StringType),
 schema(name='variantId', type=StringType),
 schema(name='distance', type=LongType),
 schema(name='chromosome', type=StringType),
 schema(name='datatypeId', type=StringType),
 schema(name='datasourceId', type=StringType),
 schema(name='score', type=DoubleType),
 schema(name='resourceScore', type=DoubleType),
 schema(name='pmid', type=StringType),
 schema(name='biofeature', type=StringType),
 schema(name='position', type=IntegerType),
 schema(name='label', type=StringType),
 schema(name='variantFunctionalConsequenceId', type=StringType),
 schema(name='isHighQualityPlof', type=BooleanType)]

In [None]:
from dataclasses import dataclass
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate

@dataclass
class Foo:
    a: int = 123

class MyTarget:
    def __init__(self, foo, bar):
        self.foo = foo
        self.bar = bar

cfg = OmegaConf.create(
    {
        "_target_": "__main__.MyTarget",
        "foo": Foo(),
        "bar": {"b": 456},
    }
)

obj_none = instantiate(cfg, _convert_="none")
assert isinstance(obj_none, MyTarget)
assert isinstance(obj_none.foo, DictConfig)
assert isinstance(obj_none.bar, DictConfig)

obj_partial = instantiate(cfg, _convert_="partial")
assert isinstance(obj_partial, MyTarget)
assert isinstance(obj_partial.foo, DictConfig)
assert isinstance(obj_partial.bar, dict)

obj_object = instantiate(cfg, _convert_="object")
assert isinstance(obj_object, MyTarget)
assert isinstance(obj_object.foo, Foo)
assert isinstance(obj_object.bar, dict)

obj_all = instantiate(cfg, _convert_="all")
assert isinstance(obj_none, MyTarget)
assert isinstance(obj_all.foo, dict)
assert isinstance(obj_all.bar, dict)

In [2]:
from dataclasses import dataclass
@dataclass
class Foo:
    a: int = 123

class MyTarget(Foo):
    def run(self):
        print(self.a)

MyTarget().run()

123


In [85]:
from dataclasses import dataclass, field
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate
from typing import Any, List

@dataclass
class Foo:
    a: int = 123

class MyTarget:
    def __init__(self, foo, bar):
        self.foo = foo
        self.bar = bar

@dataclass
class StructConfig:
    defaults: List[Any] = field(
        default_factory=lambda: [
            "_self_",
            {"foo": "Foo", "bar": {"b": 456}},
        ]
    )
    _target_ : str = "__main__.MyTarget"



obj_none = instantiate(StructConfig, _convert_="none")
# assert isinstance(obj_none, MyTarget)
# assert isinstance(obj_none.foo, DictConfig)
# assert isinstance(obj_none.bar, DictConfig)

obj_none

InstantiationException: Error in call to target '__main__.MyTarget':
TypeError("__init__() got an unexpected keyword argument 'defaults'")

In [79]:
obj_partial = instantiate(cfg, _convert_="partial")
assert isinstance(obj_partial, MyTarget)
assert isinstance(obj_partial.foo, DictConfig)
assert isinstance(obj_partial.bar, dict)

obj_partial

<__main__.MyTarget at 0x12b69d0d0>

In [78]:
obj_all = instantiate(cfg, _convert_="all")
assert isinstance(obj_none, MyTarget)
assert isinstance(obj_all.foo, dict)
assert isinstance(obj_all.bar, dict)

obj_all

<__main__.MyTarget at 0x12b6a7850>

In [None]:
_df: DataFrame = field(default_factory=lambda: spark.createDataFrame([], StructType([])))