Some nodes can represent tables or can be quite expensive to re-compute on every run. In these cases, it is possible to cache or persist the output of the node as you want.

If you want to enable cache for a specific node you have to crate a `class` that tells Flypipe how to `read`, `write` and check if the cache or persisted data `exists`.


Example:

``` py
from flypipe.cache import Cache

import pandas as pd


class MyCustomPersistance(Cache):
    def __init__(self, csv_path_name: str):
        self.csv_path_name = csv_path_name

    def read(self, spark):
        """
        Reads the persisted/cached data into a dataframe
        """
        return pd.read_csv(self.csv_path_name)

    def write(self, spark, df):
        """
        Cache or persist the data
        """
        df.to_csv(self.csv_path_name, index=False)
        
    def exists(self, spark):
        """
        Check if the data has been cached or persisted.
        """
        return os.path.exists(self.csv_path_name)
```

Having defined your cache/persistance class, you can start marking nodes to be cached, for instance:

``` py
@node(
    ...
    cache = MyCustomPersistance("data.csv")
    ...
)
def t0():
    ...
```

## Cache/Persistance workflow

For every node that has cache set up, Flypipe will do the following:

Node has cache?<br>
    &emsp;Yes -> cache exists (runs the method `exists`)?<br>
          &emsp;&emsp;Yes -> runs `read` method and returns the dataframe<br>
          &emsp;&emsp;No -> runs the node, collects the output dataframe and runs `write` method.

## Example 1: CSV persistence

In [None]:
import os
import pandas as pd
from flypipe.cache import Cache


class SaveAsCSV(Cache):
    def __init__(self, csv_path_name: str):
        self.csv_path_name = csv_path_name

    def read(self, spark):
        print(f"Reading CSV `{self.csv_path_name}`...")
        return pd.read_csv(self.csv_path_name)

    def write(self, spark, df):
        print(f"Writing CSV `{self.csv_path_name}`...")
        df.to_csv(self.csv_path_name, index=False)
        
    def exists(self, spark):
        csv_exists = os.path.exists(self.csv_path_name)
        print(f"CSV `{self.csv_path_name}` exists?", csv_exists)
        return os.path.exists(self.csv_path_name)

## Example 2: Spark persistence

In [None]:
from typing import List
from flypipe.cache import Cache


class SparkTable(Cache):
    def __init__(self, 
                 table_name: str, 
                 schema: str,
                 merge_keys: List[str] = None,
                 partition_columns: List[str] = None,
                 schema_location: str = None):
        
        self.table_name = table_name
        self.schema = schema
        self.merge_keys = merge_keys
        self.partition_columns = partition_columns
        self.schema_location = schema_location
    
    @property
    def table(self):
        return f"{self.schema}.{self.table_name}"
    
    def read(self, spark):
        return spark.table(self.table)

    def write(self, spark, df):
        
        # check if database exists
        if not spark.catalog.databaseExists(self.schema):
            print(f"Creating database `{self.schema}`")
            location = f"LOCATION '{self.schema_location}'" if self.schema_location else ""
            spark.sql(f"CREATE DATABASE IF NOT EXISTS {self.schema} {location}")
            
        # check if table exists
        if not spark.catalog.tableExists(self.table_name, self.schema):
            print(f"Creating table `{self.table}`")
            df = df.write.format("delta").mode("overwrite")
            
            if self.partition_columns:
                df = df.partitionBy(*self.partition_columns)
                
            df.saveAsTable(self.table)            
        else:
            # table already exists, merge into
            print(f"Merging into table `{self.table}`")
            df.createOrReplaceTempView("updates")
            keys = " AND ".join([f"s.{col} = t.{col}" for col in self.merge_keys])
            
            merge_query = f"""
                MERGE INTO {self.table} t
                USING updates s
                ON {keys}
                WHEN MATCHED THEN UPDATE SET *
                WHEN NOT MATCHED THEN INSERT *
            """
            df._jdf.sparkSession().sql(merge_query)
        
    def exists(self, spark):
        table_exists = spark.catalog.tableExists(self.table_name, self.schema)
        print(f"Table {self.table} exists?", table_exists)
        return table_exists

## Execution Graph

In [None]:
import pandas as pd
from flypipe import node
from flypipe.cache import Cache 
import pyspark.sql.functions as F

@node(
    type="pandas",
    cache = SaveAsCSV("/tmp/data.csv")
)
def csv_cache():
    return pd.DataFrame(data={"id": [1, 2], "sales":[100.0, 34.1]})

@node(
    type="pyspark",
    cache = SparkTable("my_table", "tmp", merge_keys=["id"], schema_location="/tmp"),
    dependencies=[csv_cache.select("id", "sales").alias("df")]
)
def spark_cache(df):
    return df.withColumn("above_50", F.col("sales") > 50.0)


@node(
    type="pyspark",
    dependencies=[spark_cache.select("id", "sales", "above_50").alias("df")]
)
def t0(df):
    return df

### 1st run

When no cache exists all cache nodes will be active

In [None]:
displayHTML(t0.html(spark))

In [None]:
t0.run(spark)

After the 1st run, all caches will be saved

### Subsequent runs

As caches have been saved, the nodes will be inactive as the caches will be loaded on the fly

In [None]:
displayHTML(t0.html(spark))

In [None]:
t0.run(spark)

Flypipe will only load necessary caches, for instance, loading the cache of node `csv_cache` was skipped, as only cache of `spark_cache` was necessary to run `t0`.  

## Merging Data

Often we need to merge the data, insert new rows and update rows if data already exists.
This behaviour will happen accordingly to `merge_keys`.

``` py
@node(
    ...
    cache = SparkTable(
        "my_table", 
        "tmp", 
        merge_keys=["id"], # <--
        schema_location="/tmp"),
    ...
)
def spark_cache(df):
    ...
```

Rows with non-existent ids in `tmp.my_table` will be added to `tmp.my_table`, rows with existent ids, will be updated.

Independent of the write mode (insert or update) the rows needs to me re-transformed by the graph, so nodes that were previously `skipped`, shall be `active` by changing the `CacheMode` type.

In [None]:
from flypipe.cache import CacheMode

displayHTML(
    t0.html(
        spark,
        cache={
            csv_cache: CacheMode.MERGE,
            spark_cache: CacheMode.MERGE
        }
))

Currently `tmp.my_table` data is:

In [None]:
%%sql
select * from tmp.my_table

In [None]:
import pandas as pd
from flypipe.cache import CacheMode

t0.run(
    spark,
    cache={
        csv_cache: CacheMode.MERGE,
        spark_cache: CacheMode.MERGE
    },
    inputs={
        csv_cache: pd.DataFrame(data={"id": [1, 3], "sales":[17.25, 547.39]})
    }
)

Checking the data in `tmp.my_table` we can see that:

* sales changed from `100.0` to `17.25` and `above_50` to `false`.
* row of id 2 remaining unchanged
* added 1 row id 3

In [None]:
%%sql
select * from tmp.my_table