# 03.1 - Extending Kedro

## Custom datasets

You have a large collection of datasets maintained by the core team in the [`kedro-datasets`](https://pypi.org/package/kedro-datasets) package. Some of them have been used in this bootcamp already, for example `pandas.ExcelDataset`, `spark.SparkDataset`, or `databricks.ManagedTableDataset`.

And yet, sometimes you will want to implement your own custom dataset, either because of custom data formats or Python libraries or because you want extra features not available in the official dataset.

For example, let's create a simple Delta Table dataset backed by Spark, compatible with the existing Delta Tables you have declared in your Kedro catalog.

Usage will look like this:

```yaml
companies:
  type: rocketfuel.datasets.SimpleDeltaTableDataset
  catalog: ${_uc_catalog}
  database: ${_uc_schema}
  table: companies
  write_mode: overwrite
```

Notice that the `type` is set to `rocketfuel.datasets.SimpleDeltaTableDataset`. That is the class that Kedro will try to `import`.

As such, start by creating a `rockefuel/src/rocketfuel/datasets.py` file, containing a a subclass of `kedro.io.AbstractDataset`, with an `__init__` method reflecting the properties of the YAML entry:

In [None]:
%%writefile ../src/rocketfuel/datasets.py
from kedro.io import AbstractDataset
from pyspark.sql import DataFrame


class SimpleDeltaTableDataset(AbstractDataset[DataFrame, DataFrame]):
    def __init__(
        self, catalog: str, database: str, table: str, write_mode: str = "overwrite"
    ):
        self._catalog = catalog
        self._schema = database
        self._table = table

        if write_mode != "overwrite":
            raise NotImplementedError("Only overwrite mode is supported")
        self._write_mode = write_mode

Next, implement the required abstract methods:

- `load`
- `save`
- `_describe`

The `load` method will be invoked when the dataset is an _input_ of a node in a pipeline, and is expected to return the data that the node will use. The `save` method, on the other hand, will be invoked when the dataset is an _output_ of a node, and is expected to receive the data that the node produced.

```python
pipeline([
    node(
        func=_noop,
        inputs="companies_raw",  # .load() will be called, return value will be passed to noop(df)
        outputs="companies",  # the return value of noop(df) will be used to call .save(data)
        name="companies_load_node",
    ),
])
```

In [None]:
%%writefile ../src/rocketfuel/datasets.py
import typing as t

from kedro.io import AbstractDataset
from pyspark.sql import DataFrame, SparkSession


class SimpleDeltaTableDataset(AbstractDataset[DataFrame, DataFrame]):
    def __init__(
        self, catalog: str, database: str, table: str, write_mode: str = "overwrite"
    ):
        self._catalog = catalog
        self._schema = database
        self._table = table

        if write_mode != "overwrite":
            raise NotImplementedError("Only overwrite mode is supported")
        self._write_mode = write_mode

        self._full_table_location = (
            f"`{self._catalog}`.`{self._schema}`.`{self._table}`"
        )

    def load(self) -> DataFrame:
        spark = SparkSession.builder.getOrCreate()

        data = spark.table(self._full_table_location)
        return data

    def save(self, data: DataFrame) -> None:
        writer = (
            data.write.format("delta")
            .mode("overwrite")
            .option("overwriteSchema", "true")
        )
        writer.saveAsTable(self._full_table_location)

    def _describe(self) -> dict[str, t.Any]:
        return {
            "catalog": self._catalog,
            "database": self._schema,
            "table": self._table,
            "write_mode": self._write_mode,
        }

Next, adapt the `companies` dataset in the catalog to use the new dataset:

```diff
 companies:
-  type: databricks.ManagedTableDataset
+  type: rocketfuel.datasets.SimpleDeltaTableDataset
   catalog: ${_uc_catalog}
   database: ${_uc_schema}
   table: companies
   write_mode: overwrite
```

And finally, bootstrap the project to verify that it works:

In [None]:
%load_ext kedro.ipython

In [None]:
%reload_kedro --env databricks

In [None]:
catalog._get_dataset("companies")

In [None]:
display(catalog.load("companies"))

### Exercise NN

...

## Hooks

![Kedro lifecycle](kedro_run_lifecycle.png)