<a href="https://colab.research.google.com/github/jmasonlee/efficiently_testing_etl_pipelines/blob/main/fixing_a_big_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## STEP 0.A: Setup Notebook

In [1]:
!rm -rf efficiently_testing_etl_pipelines
!git clone https://github.com/jmasonlee/efficiently_testing_etl_pipelines.git
!cp -r /content/efficiently_testing_etl_pipelines/src/ .
!cp -r /content/efficiently_testing_etl_pipelines/tests/ .
!rm -rf efficiently_testing_etl_pipelines
!rm -rf tests/diamond_pricing_test*
!rm -rf tests/test_helpers/*verification_helpers.py
!rm -rf tests/conftest.py
!rm -rf sample_data


Cloning into 'efficiently_testing_etl_pipelines'...
remote: Enumerating objects: 664, done.[K
remote: Counting objects: 100% (305/305), done.[K
remote: Compressing objects: 100% (136/136), done.[K
remote: Total 664 (delta 196), reused 259 (delta 162), pack-reused 359[K
Receiving objects: 100% (664/664), 287.11 KiB | 3.99 MiB/s, done.
Resolving deltas: 100% (402/402), done.


# STEP 0.B: Setup Tests

### Install Dependencies

For the exercise, we will need some special dependencies to allow us to run lots of tests in a notebook.

`ipytest` lets us run our tests in a notebook.



In [2]:
!pip install ipytest

Collecting ipytest
  Downloading ipytest-0.13.3-py3-none-any.whl (14 kB)
Collecting jedi>=0.16 (from ipython->ipytest)
  Downloading jedi-0.18.2-py2.py3-none-any.whl (1.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: jedi, ipytest
Successfully installed ipytest-0.13.3 jedi-0.18.2


ipytest is what allows us to run our tests in a notebook. This next cell is not needed if you are writing tests in a separate pytest file.

In [3]:
import ipytest
ipytest.autoconfig()

We are installing `pyspark` because it doesn't come with the base colab environment

In [4]:
!pip install pyspark

Collecting pyspark
  Downloading pyspark-3.4.1.tar.gz (310.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m310.8/310.8 MB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.4.1-py2.py3-none-any.whl size=311285398 sha256=95b0199d2c831f775e2fd026d3048c454d35b05cbac25d03b8e17dd2fb36560f
  Stored in directory: /root/.cache/pip/wheels/0d/77/a3/ff2f74cc9ab41f8f594dabf0579c2a7c6de920d584206e0834
Successfully built pyspark
Installing collected packages: pyspark
Successfully installed pyspark-3.4.1


In [5]:
!pip install chispa

Collecting chispa
  Downloading chispa-0.9.2-py3-none-any.whl (28 kB)
Installing collected packages: chispa
Successfully installed chispa-0.9.2


## Create a local SparkSession

Normally spark runs on a bunch of executors in the cloud. Since we want our tests to be able to run on a single dev machine, we make a fixture that gives us a local spark context.

In [6]:
import pytest
from _pytest.fixtures import FixtureRequest
from pyspark import SparkConf
from pyspark.sql import SparkSession

@pytest.fixture(scope="session")
def spark(request: FixtureRequest):
    conf = (SparkConf()
        .setMaster("local")
        .setAppName("sample_pyspark_testing_starter"))

    spark = SparkSession \
        .builder \
        .config(conf=conf) \
        .getOrCreate()

    request.addfinalizer(lambda: spark.stop())
    return spark

## Create Helpers

This is a helper function that retrieves our test output from the expected.json file

In [7]:
import json
from typing import List, Dict

from pyspark.pandas import DataFrame


def create_df_from_json(json_file, spark):
    return spark.read.option("multiline", "true").json(json_file)


def data_frame_to_json(df: DataFrame) -> List:
    output = [json.loads(item) for item in df.toJSON().collect()]
    output.sort(key=lambda item: item["id"])
    print(output)
    return output

def expected_json(name: str) -> Dict:
    with open(f"tests/fixtures/{name}") as f:
        return json.loads(f.read())



In [8]:
import pyspark
def build_indep_vars(df, independent_vars, categorical_vars=None, keep_intermediate=False, summarizer=True):
    check_input(categorical_vars, df, independent_vars)

    from pyspark.ml import Pipeline
    from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
    idx = 'index'
    vec = 'vector'
    if categorical_vars:
        string_indexer = [StringIndexer(inputCol=x,
                                        outputCol=f"{x}_{idx}")
                          for x in categorical_vars]

        encoder        = [OneHotEncoder(dropLast=True,
                                        inputCol =f'{x}_{idx}',
                                        outputCol=f'{x}_{vec}')
                          for x in categorical_vars]

        independent_vars = ['{}_vector'.format(x) if x in categorical_vars else x for x in independent_vars]
    else:
        string_indexer, encoder = [], []

    assembler = VectorAssembler(inputCols=independent_vars,
                                outputCol='indep_vars')
    pipeline  = Pipeline(stages=string_indexer+encoder+[assembler])
    model = pipeline.fit(df)
    df = model.transform(df)

    if not keep_intermediate:
        fcols = [c for c in df.columns if f'_{idx}' not in c[-3:] and f'_{vec}' not in c[-7:]]
        df = df[fcols]

    return df


def check_input(categorical_vars, df, independent_vars):
    assert (type(
        df) is pyspark.sql.dataframe.DataFrame), 'pypark_glm: A pySpark dataframe is required as the first argument.'
    assert (type(
        independent_vars) is list), 'pyspark_glm: List of independent variable column names must be the third argument.'
    for iv in independent_vars:
        assert (type(iv) is str), 'pyspark_glm: Independent variables must be column name strings.'
        assert (iv in df.columns), 'pyspark_glm: Independent variable name is not a dataframe column.'
    if categorical_vars:
        for cv in categorical_vars:
            assert (type(cv) is str), 'pyspark_glm: Categorical variables must be column name strings.'
            assert (cv in df.columns), 'pyspark_glm: Categorical variable name is not a dataframe column.'
            assert (cv in independent_vars), 'pyspark_glm: Categorical variables must be independent variables.'


In [9]:
from pyspark.sql import DataFrame, Window, Column
from pyspark.sql.functions import log, when, mean, col

def replace_null(orig: Column, average: Column):
    return when(orig.isNull(), average).otherwise(orig)

def transform(df: DataFrame) -> DataFrame:

    df = df.withColumn('lprice', log('price'))
    window = Window.partitionBy('cut', 'clarity').orderBy('price').rowsBetween(-3, 3)
    moving_avg = mean(df['price']).over(window)
    df = df.withColumn('moving_avg', moving_avg)

    df = df.withColumn('price', when(df.price.isNull(), df.moving_avg).otherwise(df.price))
    df = df[['id', 'carat', 'clarity', 'color', 'price']]
    df = build_indep_vars(df, ['carat', 'clarity', 'color'],
                                      categorical_vars=['clarity', 'color'],
                                      keep_intermediate=False,
                                      summarizer=True)
    return df

# STEP 0.C: Run The Test

In [None]:
%%ipytest -qq
from pyspark.sql import SparkSession

def test_will_do_the_right_thing(spark: SparkSession):
    diamonds_df = create_df_from_json("tests/fixtures/diamonds.json", spark)

    actual_df = transform(diamonds_df)
    assert data_frame_to_json(actual_df) == expected_json("expected.json")

[32m.[0m[32m                                                                                            [100%][0m


# Step 1: Setup For the Saff Squeeze

## Instructions
Let's get ready to improve the test.

**our bug**: Diamonds of the same cut and clarity are influencing the calculated price of diamonds with a different color. Only diamonds with the same cut, clarity _**and color**_ should be influencing the calculated price for diamonds with a null price.

**expected behaviour**:
An unpriced diamond with cut=Good, color=D and clarity=VVS1 in a dataset with other diamonds of the same cut, clarity and color all priced at 3333.0, will have it's price set to match the average of all the prices for diamonds of the same cut, clarity and color - or 3333.0.



- [ ] Change the test to check for the behaviour we want.
There is a second json file (`expected_correct.json`) where the expected price for the unpriced diamond has been updated to the correct value. Use that file name as the argument passed to `expected_json`
- [ ] Run the test. It should fail.
- [ ] Duplicate the test. Now you should have two copies of the same test.
One copy will stay the same, so we can make sure that nothing is broken. The second copy is what we will change in the next steps.
- [ ] Rename the test.
Pick a name that tells you what behaviour you are verifying with the test you are using for the Saff Squeeze. I chose `test_null_price_is_replaced_based_on_cut_clarity_and_color`, but names are hard. You might be able to think of a better one :-)
- [ ] Run the tests. They should both fail because our diamond is being returned with a price of 2460.0, when we expect 3333.0

## Exercise

In [None]:
%%ipytest -qq
from pyspark.sql import SparkSession

def test_will_do_the_right_thing(spark: SparkSession):
    diamonds_df = create_df_from_json("tests/fixtures/diamonds.json", spark)

    actual_df = transform(diamonds_df)
    assert data_frame_to_json(actual_df) == expected_json("expected.json")

[32m.[0m[32m                                                                                            [100%][0m


# Step 2: Make The Assert Specific

Right now, our test compares everything in the output dataframe to everything in a large json file. That's a lot of rows to compare and the assert is wrong anyways!

Let's make this test assert on the thing we actually care about - the output price of the diamond!

## Instructions - Chispa

#### With Chispa
- [ ] Add these imports to the top of the cell, below the `%%ipytest -qq` line:  
`from chispa import assert_column_equality`  
`from pySpark.sql.functions import lit`
- [ ] Filter the dataframe for the unique id of the diamond we care about:  
`actual_df=actual_df.filter(actual_df.id == 'DI-26-null-price')`
- [ ] Create a new column in our dataframe that contains our expected price:  
`actual_df=actual_df.withColumn('expected_price', lit(3333.0))`
- [ ] Assert the value in the price column matches the value we want:  
`assert_column_equality(actual_df, 'price', 'expected_price')`


## Exercise - Chispa

In [None]:
%%ipytest -qq
from pyspark.sql import SparkSession


def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    diamonds_df = create_df_from_json("tests/fixtures/diamonds.json", spark)

    actual_df = transform(diamonds_df)


[32m.[0m[32m                                                                                            [100%][0m


### Instructions - Pandas
- [ ] import pandas:  
`import pandas as pd`
- [ ] Filter the dataframe for the unique id of the diamond we care about:  
  `actual_df=actual_df.filter(actual_df.id == 'DI-26-null-price')`
- [ ] Create your expected dataframe using Pandas:  
 `expected = pd.DataFrame(({'id': ["DI-26-null-price"], 'price':[3333.0] }))`
- [ ] Select the column you care about:  
  `actual_df=actual_df.select(['id', 'price'])
- [ ] Assert for dataframe equality using pandas:  
  `pd.testing.assert_frame_equal(actual_df, expected)`

## Exercise - Pandas

In [None]:
%%ipytest -qq
from pyspark.sql import SparkSession


def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    diamonds_df = create_df_from_json("tests/fixtures/diamonds.json", spark)

    actual_df = transform(diamonds_df)


[32m.[0m[32m                                                                                            [100%][0m


### Instructions - JSON properties

- [ ] Filter the dataframe for the unique id of the diamond we care about:  
  `actual_df=actual_df.filter(actual_df.id == 'DI-26-null-price')`
- [ ] Convert your dataframe to JSON:  
`actual_df_json = data_frame_to_json(actual_df)`
- [ ] Assert the price property of the first object matches your expected price:  
`assert actual_df_json[0]['price'] == 3333.0`

## Exercise - JSON properties

In [None]:
%%ipytest -qq
from pyspark.sql import SparkSession


def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    diamonds_df = create_df_from_json("tests/fixtures/diamonds.json", spark)

    actual_df = transform(diamonds_df)



[32m.[0m[32m                                                                                            [100%][0m


# Step 3: Reduce Duplicate Coverage and Fix the Bug

Right now, our test is running the entire transform function. Because there are multiple tests in `diamonds.json`, each test is running the same large block of code over and over again.

## STEP 3.A: Prep

### Instructions
Let's get ready to reduce the duplicate coverage.

#### 1. Put the transform function where you can work with it
- [ ] Run the test. It should be failing.
- [ ] Replace the call to the transform function with the body of that function.
- [ ] Change the last line of the function body to assign to `actual_df` instead of `df`
```
df = build_indep_vars(df, ['carat', 'clarity', 'color'],
                                      categorical_vars=['clarity', 'color'],
                                      keep_intermediate=False,
                                      summarizer=True)
```
becomes
```
actual_df = build_indep_vars(df, ['carat', 'clarity', 'color'],
                                      categorical_vars=['clarity', 'color'],
                                      keep_intermediate=False,
                                      summarizer=True)
```
- [ ] Change the first line of the function body to read from `diamonds_df` instead of `df`
```
 df = df.withColumn('lprice', log('price'))
```
becomes
```
 df = diamonds_df.withColumn('lprice', log('price'))
```
- [ ] Run the test. It should still be failing for the same reasons.

#### 2. Test FOR the bug
- [ ] Change your assert code so that it is testing _for_ the bug.
```
    actual_df=actual_df.withColumn('expected_price', lit(3333.0))
```
becomes
```
  actual_df=actual_df.withColumn('expected_price', lit(2460.0))
```
- [ ] Run the test. It should pass.

#### 3. Make the assert easier to work with
- [ ] Extract your assert code into a one-line helper function:
```
def assert_diamond_has_expected_price(actual_df: DataFrame) -> None:
    actual_df=actual_df.filter(actual_df.id == 'DI-26-null-price')
    actual_df=actual_df.withColumn('expected_price', lit(2460.0))
    assert_column_equality(actual_df, 'price', 'expected_price')
```
- [ ] Run the test. It should pass.

### Exercise

#### The Code

##### The `transform` Function

In [None]:
from pyspark.sql import DataFrame, Window, Column
from pyspark.sql.functions import log, when, mean, col

def transform(df: DataFrame) -> DataFrame:
# The body of the transform function STARTS HERE
    df = df.withColumn('lprice', log('price'))  #<-- In the test, this line should be: df = diamonds_df.withColumn('lprice', log('price'))
    window = Window.partitionBy('cut', 'clarity').orderBy('price').rowsBetween(-3, 3)
    moving_avg = mean(df['price']).over(window)
    df = df.withColumn('moving_avg', moving_avg)

    df = df.withColumn('price', when(df.price.isNull(), df.moving_avg).otherwise(df.price))
    df = df[['id', 'carat', 'clarity', 'color', 'price']]
    df = build_indep_vars(df, ['carat', 'clarity', 'color'], #<-- In the test, this line should be: actual_df = build_indep_vars(df, ['carat', 'clarity', 'color']
                                      categorical_vars=['clarity', 'color'],
                                      keep_intermediate=False,
                                      summarizer=True)
# The body of the transform function ENDS HERE
    return df

#### The Test

In [None]:
%%ipytest -qq
from pyspark.sql import SparkSession

def test_prep_for_linear_regression(spark: SparkSession):
    diamonds_df = create_df_from_json("tests/fixtures/diamonds.json", spark)

    actual_df = transform(diamonds_df) #<-- We will be replacing this line with the body of the transform function

    actual_df=actual_df.filter(actual_df.id == 'DI-26-null-price')
    actual_df=actual_df.withColumn('expected_price', lit(3333.0))
    assert_column_equality(actual_df, 'price', 'expected_price')

[32m.[0m[32m                                                                                            [100%][0m


#### The original test



In [None]:
%%ipytest -qq
from pyspark.sql import SparkSession

def test_will_do_the_right_thing(spark: SparkSession):
    diamonds_df = create_df_from_json("tests/fixtures/diamonds.json", spark)

    actual_df = transform(diamonds_df)
    assert data_frame_to_json(actual_df) == expected_json("expected_correct.json")

[31mF[0m[31m                                                                                            [100%][0m
[31m[1m___________________________________ test_will_do_the_right_thing ___________________________________[0m

spark = <pyspark.sql.session.SparkSession object at 0x7f23801641c0>

    [94mdef[39;49;00m [92mtest_will_do_the_right_thing[39;49;00m(spark: SparkSession):[90m[39;49;00m
        diamonds_df = create_df_from_json([33m"[39;49;00m[33mtests/fixtures/diamonds.json[39;49;00m[33m"[39;49;00m, spark)[90m[39;49;00m
    [90m[39;49;00m
        actual_df = transform(diamonds_df)[90m[39;49;00m
>       [94massert[39;49;00m data_frame_to_json(actual_df) == expected_json([33m"[39;49;00m[33mexpected_correct.json[39;49;00m[33m"[39;49;00m)[90m[39;49;00m
[1m[31mE       AssertionError: assert [{'carat': 0....G', ...}, ...] == [{'carat': 0....G', ...}, ...][0m
[1m[31mE         At index 3 diff: {'id': 'DI-26-null-price', 'carat': 0.21, 'clarity': '

## Step 3.B Squeeze the bottom!



### Instructions
**our bug**: Diamonds of the same cut and clarity are influencing the calculated price of diamonds with a different color. Only diamonds with the same cut, clarity and color should be influencing the calculated price for diamonds with a null price.

**Squeeze the bottom until you find the bug**
- [ ] Move your assert up one line at a time.  
- [ ] After each move, run your test.  
- [ ] If it fails, figure out why it's failing.(You may need to rename columns in the assert)
- [ ] If the test passes, the line wasn't important for the bug you wanted to catch. Delete it.
- [ ] Continue until you find the source of the bug

### Exercise

#### The Code

##### The `transform` Function

In [None]:
from pyspark.sql import DataFrame, Window, Column
from pyspark.sql.functions import log, when, mean, col

def transform(df: DataFrame) -> DataFrame:

    df = df.withColumn('lprice', log('price'))
    window = Window.partitionBy('cut', 'clarity').orderBy('price').rowsBetween(-3, 3)
    moving_avg = mean(df['price']).over(window)
    df = df.withColumn('moving_avg', moving_avg)

    df = df.withColumn('price', when(df.price.isNull(), df.moving_avg).otherwise(df.price))
    df = df[['id', 'carat', 'clarity', 'color', 'price']]
    df = build_indep_vars(df, ['carat', 'clarity', 'color'],
                                      categorical_vars=['clarity', 'color'],
                                      keep_intermediate=False,
                                      summarizer=True)
    return df

#### The Test

In [None]:
%%ipytest -qq
from pyspark.sql import SparkSession, DataFrame, Window, Column
from pyspark.sql.functions import lit, log, when, mean, col
from chispa import assert_column_equality

def assert_diamond_has_expected_price(actual_df: DataFrame) -> None:
    actual_df=actual_df.filter(actual_df.id == 'DI-26-null-price')
    actual_df=actual_df.withColumn('expected_price', lit(2460.0))
    assert_column_equality(actual_df, 'price', 'expected_price')

def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    diamonds_df = create_df_from_json("tests/fixtures/diamonds.json", spark)

    df = diamonds_df.withColumn('lprice', log('price'))
    window = Window.partitionBy('cut', 'clarity').orderBy('price').rowsBetween(-3, 3)
    moving_avg = mean(df['price']).over(window)
    df = df.withColumn('moving_avg', moving_avg)

    df = df.withColumn('price', when(df.price.isNull(), df.moving_avg).otherwise(df.price))
    df = df[['id', 'carat', 'clarity', 'color', 'price']]
    actual_df = build_indep_vars(df, ['carat', 'clarity', 'color'],
                                      categorical_vars=['clarity', 'color'],
                                      keep_intermediate=False,
                                      summarizer=True)

    assert_diamond_has_expected_price(actual_df)







[31mF[0m[31m                                                                                            [100%][0m
[31m[1m____________________ test_null_price_is_replaced_based_on_cut_clarity_and_color ____________________[0m

spark = <pyspark.sql.session.SparkSession object at 0x7f238031aa70>

    [94mdef[39;49;00m [92mtest_null_price_is_replaced_based_on_cut_clarity_and_color[39;49;00m(spark: SparkSession):[90m[39;49;00m
        diamonds_df = create_df_from_json([33m"[39;49;00m[33mtests/fixtures/diamonds.json[39;49;00m[33m"[39;49;00m, spark)[90m[39;49;00m
    [90m[39;49;00m
        df = diamonds_df.withColumn([33m'[39;49;00m[33mlprice[39;49;00m[33m'[39;49;00m, log([33m'[39;49;00m[33mprice[39;49;00m[33m'[39;49;00m))[90m[39;49;00m
        window = Window.partitionBy([33m'[39;49;00m[33mcut[39;49;00m[33m'[39;49;00m, [33m'[39;49;00m[33mclarity[39;49;00m[33m'[39;49;00m).orderBy([33m'[39;49;00m[33mprice[39;49;00m[33m'[39;49;00m).rowsBetw

#### The Original Test

In [None]:
%%ipytest -qq
from pyspark.sql import SparkSession

def test_will_do_the_right_thing(spark: SparkSession):
    diamonds_df = create_df_from_json("tests/fixtures/diamonds.json", spark)

    actual_df = transform(diamonds_df)
    assert data_frame_to_json(actual_df) == expected_json("expected_correct.json")

[31mF[0m[31m                                                                                            [100%][0m
[31m[1m___________________________________ test_will_do_the_right_thing ___________________________________[0m

spark = <pyspark.sql.session.SparkSession object at 0x7f23803e5d80>

    [94mdef[39;49;00m [92mtest_will_do_the_right_thing[39;49;00m(spark: SparkSession):[90m[39;49;00m
        diamonds_df = create_df_from_json([33m"[39;49;00m[33mtests/fixtures/diamonds.json[39;49;00m[33m"[39;49;00m, spark)[90m[39;49;00m
    [90m[39;49;00m
        actual_df = transform(diamonds_df)[90m[39;49;00m
>       [94massert[39;49;00m data_frame_to_json(actual_df) == expected_json([33m"[39;49;00m[33mexpected_correct.json[39;49;00m[33m"[39;49;00m)[90m[39;49;00m
[1m[31mE       AssertionError: assert [{'carat': 0....G', ...}, ...] == [{'carat': 0....G', ...}, ...][0m
[1m[31mE         At index 3 diff: {'id': 'DI-26-null-price', 'carat': 0.21, 'clarity': '

## Step 3.C Let's fix the bug!


### Instructions

**Our Bug**: Diamonds of the same cut and clarity are influencing the calculated price of diamonds with a different color. Only diamonds with the same cut, clarity and color should be influencing the calculated price for diamonds with a null price.

**Our Desired Behaviour**:  
The input diamond with these properties:
- id: `"DI-26-null-price"`
- cut: `"Good"`
- color: `"D"`
- clarity: `"VVS1"`
- price: `null`

Should be output with a price of `3333.0` - the average price of the other diamonds with `cut="Good"`, `clarity="VVS1"` and `color="D"`

**Fix the bug**
#### Test for the behaviour you actually want
- [ ] Update your test so that it checks for the good behaviour.  
Replace the expected price on this line with `3333.0`:  
```
actual_df=actual_df.withColumn('expected_price', lit(2460.0))
```
- [ ] Run your test. It should fail with a `columnsNotEqualError`:  
```
E           chispa.column_comparer.ColumnsNotEqualError:
E           +------------+----------------+
E           | moving_avg | expected_price |
E           +------------+----------------+
E           |   2460.0   |     3333.0     |
E           +------------+----------------+
```

#### Fix the bug
- [ ] Fix the code _in your test_ so that the bug is gone. We need to add `'color'` to `'cut'` and `'clarity'` on this line:
```
window = Window.partitionBy('cut', 'clarity').orderBy('price').rowsBetween(-3, 3)
```
- [ ] Run your test. It should pass.

#### Encapsulate the code necessary for the behaviour
- [ ] The behaviour belongs to a group of lines working together. Extract them into a method.  
These lines can't be separated without changing the behaviour we are testing:  
```
    window = Window.partitionBy('cut', 'clarity', 'color').orderBy('price').rowsBetween(-3, 3)
    moving_avg = mean(df['price']).over(window)
    df = df.withColumn('moving_avg', moving_avg)
    df = df.withColumn('price', when(df.price.isNull(), df.moving_avg).otherwise(df.price))
```  
Use them to make a method:
```
def calculate_avg_price_for_similar_diamonds(df: DataFrame) -> DataFrame:
      window = Window.partitionBy('cut', 'clarity', 'color').orderBy('price').rowsBetween(-3, 3)
      moving_avg = mean(df['price']).over(window)
      df = df.withColumn('moving_avg', moving_avg)
      df = df.withColumn('price', when(df.price.isNull(), df.moving_avg).otherwise(df.price))
      return df
```

Replace those lines with a call to the new method in your test:  
```
actual_df = calculate_avg_price_for_similar_diamonds(actual_df)
```
- [ ] Run your test. It should pass.

#### Move the encapsulated behaviour to the actual code
- [ ] Move the new method out of your test and into the transform code.
- [ ] Replace the lines in your transform code with the new method call.
- [ ] We've changed the original code, so we need to check that everything still works the way we expect. Run the copy of your original large test. It should also pass.
- [ ] Run your test. It should pass.


### Exercise

#### The Code

##### The `transform` Function

In [None]:
from pyspark.sql import DataFrame, Window, Column
from pyspark.sql.functions import log, when, mean, col


def transform(df: DataFrame) -> DataFrame:

    df = df.withColumn('lprice', log('price'))
    #THIS IS WHERE THE BEHAVIOUR WE CARE ABOUT STARTS - REPLACE THIS BLOCK
    window = Window.partitionBy('cut', 'clarity').orderBy('price').rowsBetween(-3, 3)
    moving_avg = mean(df['price']).over(window)
    df = df.withColumn('moving_avg', moving_avg)

    df = df.withColumn('price', when(df.price.isNull(), df.moving_avg).otherwise(df.price))
    #THIS IS WHERE THE BEHAVIOUR WE CARE ABOUT ENDS - REPLACE THIS BLOCK
    df = df[['id', 'carat', 'clarity', 'color', 'price']]
    df = build_indep_vars(df, ['carat', 'clarity', 'color'],
                                      categorical_vars=['clarity', 'color'],
                                      keep_intermediate=False,
                                      summarizer=True)
    return df

#### The Test

In [None]:
%%ipytest -qq
from pyspark.sql import SparkSession, DataFrame, Window, Column
from pyspark.sql.functions import lit, log, when, mean, col
from chispa import assert_column_equality

def assert_diamond_has_expected_price(actual_df: DataFrame) -> None:
    actual_df=actual_df.filter(actual_df.id == 'DI-26-null-price')
    actual_df=actual_df.withColumn('expected_price', lit(2460.0))
    assert_column_equality(actual_df, 'price', 'expected_price')

def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    diamonds_df = create_df_from_json("tests/fixtures/diamonds.json", spark)

    df = diamonds_df.withColumn('lprice', log('price'))

    #THIS IS WHERE THE BEHAVIOUR WE CARE ABOUT STARTS - EXTRACT THIS BLOCK
    window = Window.partitionBy('cut', 'clarity').orderBy('price').rowsBetween(-3, 3)
    moving_avg = mean(df['price']).over(window)
    df = df.withColumn('moving_avg', moving_avg)
    df = df.withColumn('price', when(df.price.isNull(), df.moving_avg).otherwise(df.price))
    #THIS IS WHERE THE BEHAVIOUR WE CARE ABOUT ENDS

    assert_diamond_has_expected_price(df)





[32m.[0m[32m                                                                                            [100%][0m


#### The Original Test

In [None]:
%%ipytest -qq
from pyspark.sql import SparkSession

def test_will_do_the_right_thing(spark: SparkSession):
    diamonds_df = create_df_from_json("tests/fixtures/diamonds.json", spark)

    actual_df = transform(diamonds_df)
    assert data_frame_to_json(actual_df) == expected_json("expected_correct.json")

[31mF[0m[31m                                                                                            [100%][0m
[31m[1m___________________________________ test_will_do_the_right_thing ___________________________________[0m

spark = <pyspark.sql.session.SparkSession object at 0x7f238031bdc0>

    [94mdef[39;49;00m [92mtest_will_do_the_right_thing[39;49;00m(spark: SparkSession):[90m[39;49;00m
        diamonds_df = create_df_from_json([33m"[39;49;00m[33mtests/fixtures/diamonds.json[39;49;00m[33m"[39;49;00m, spark)[90m[39;49;00m
    [90m[39;49;00m
        actual_df = transform(diamonds_df)[90m[39;49;00m
>       [94massert[39;49;00m data_frame_to_json(actual_df) == expected_json([33m"[39;49;00m[33mexpected_correct.json[39;49;00m[33m"[39;49;00m)[90m[39;49;00m
[1m[31mE       AssertionError: assert [{'carat': 0....G', ...}, ...] == [{'carat': 0....G', ...}, ...][0m
[1m[31mE         At index 3 diff: {'id': 'DI-26-null-price', 'carat': 0.21, 'clarity': '

# Step 4: Simplify the Top

We don't want our test to be reading in the entire diamonds.json file in order to test this single behaviour. We are going to simplify things so that we have the absolute minimum number of inputs we need to reproduce the behaviour.

In [48]:
%%ipytest -qq
from pyspark.sql import DataFrame, Window
from pyspark.sql.functions import lit, log, when, mean
from chispa import assert_column_equality

### FOR REFERENCE: THIS IS THE CODE WE ARE TESTING
def calculate_avg_price_for_similar_diamonds(df: DataFrame) -> DataFrame:
    window = Window.partitionBy('cut', 'clarity', 'color').orderBy('price').rowsBetween(-3, 3)
    moving_avg = mean(df['price']).over(window)
    df = df.withColumn('moving_avg', moving_avg)
    df = df.withColumn('price', when(df.price.isNull(), df.moving_avg).otherwise(df.price))
    return df

def check_we_have_all_the_rows_we_need_for_the_behaviour(actual_df):
    actual_df = actual_df.select('id', 'cut', 'clarity', 'color')
    # print("ACTUAL_DF")
    # actual_df.show()
    null_price_df = actual_df.filter(actual_df.id == 'DI-26-null-price')
    # print("NULL_PRICE_DF")
    # null_price_df.show()
    mismatched_properties_df = actual_df.crossJoin(null_price_df.select(
        col('id').alias('null_id'),
        col('cut').alias('null_cut'),
        col('clarity').alias('null_clarity'),
        col('color').alias('null_color')
    ))

    # print("MISMATCHED_PROPERTIES_DF")
    # mismatched_properties_df.show()
    mismatched_properties_df = mismatched_properties_df.filter(mismatched_properties_df.null_id != mismatched_properties_df.id)
    # print("MISMATCHED_PROPERTIES_DF - WITHOUT THE NULL PRICE ROW")
    # mismatched_properties_df.show()
    mismatched_properties_df = mismatched_properties_df.filter(mismatched_properties_df.null_clarity == mismatched_properties_df.clarity)
    # print("MISMATCHED_PROPERTIES_DF - WITH MATCHING CLARITY")
    # mismatched_properties_df.show()
    mismatched_properties_df = mismatched_properties_df.filter(mismatched_properties_df.null_cut == mismatched_properties_df.cut)
    # print("MISMATCHED_PROPERTIES_DF - WITH MATCHING CUT")
    # mismatched_properties_df.show()
    mismatched_properties_df = mismatched_properties_df.filter(mismatched_properties_df.null_color != mismatched_properties_df.color)
    # print("MISMATCHED_PROPERTIES_DF - WITH MISMATCHED COLOR ROWS")
    # mismatched_properties_df.show()
    assert mismatched_properties_df.count() >= 1
    # assert False


def assert_diamond_has_expected_price(actual_df: DataFrame) -> None:
    check_we_have_all_the_rows_we_need_for_the_behaviour(actual_df)
    actual_df=actual_df.filter(actual_df.id == 'DI-26-null-price')
    assert actual_df.count() == 1
    actual_df=actual_df.withColumn('expected_price', lit(3333.0))
    assert_column_equality(actual_df, 'price', 'expected_price')

def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    diamonds_df = create_df_from_json("tests/fixtures/diamonds.json", spark)

    actual_df = diamonds_df.withColumn('lprice', log('price'))  # <-- Delete this line
    actual_df = calculate_avg_price_for_similar_diamonds(actual_df) # <-- Change this line to: actual_df = calculate_avg_price_for_similar_diamonds(diamonds_df)

    assert_diamond_has_expected_price(actual_df)


[32m.[0m[32m                                                                                            [100%][0m



#### 4.A Remove lines that aren't related to the behaviour
- [ ] We can see that we only care about the price column, and this line makes a column called `'lprice'`. It should be safe to delete:
```
actual_df = diamonds_df.withColumn('lprice', log('price'))
```
- [ ] Update the name of the parameter we pass to `calculate_avg_price_for_similar_diamonds` from `actual_df` to `diamonds_df`
- [ ] Run your test. It should pass.



In [49]:
%%ipytest -qq
from pyspark.sql import DataFrame, Window
from pyspark.sql.functions import lit, log, when, mean
from chispa import assert_column_equality

### FOR REFERENCE: THIS IS THE CODE WE ARE TESTING
def calculate_avg_price_for_similar_diamonds(df: DataFrame) -> DataFrame:
    window = Window.partitionBy('cut', 'clarity', 'color').orderBy('price').rowsBetween(-3, 3)
    moving_avg = mean(df['price']).over(window)
    df = df.withColumn('moving_avg', moving_avg)
    df = df.withColumn('price', when(df.price.isNull(), df.moving_avg).otherwise(df.price))
    return df

def check_we_have_all_the_rows_we_need_for_the_behaviour(actual_df):
    actual_df = actual_df.select('id', 'cut', 'clarity', 'color')
    null_price_df = actual_df.filter(actual_df.id == 'DI-26-null-price')
    mismatched_properties_df = actual_df.crossJoin(null_price_df.select(
        col('id').alias('null_id'),
        col('cut').alias('null_cut'),
        col('clarity').alias('null_clarity'),
        col('color').alias('null_color')
    ))

    mismatched_properties_df = mismatched_properties_df.filter(mismatched_properties_df.null_id != mismatched_properties_df.id)
    mismatched_properties_df = mismatched_properties_df.filter(mismatched_properties_df.null_clarity == mismatched_properties_df.clarity)
    mismatched_properties_df = mismatched_properties_df.filter(mismatched_properties_df.null_cut == mismatched_properties_df.cut)
    mismatched_properties_df = mismatched_properties_df.filter(mismatched_properties_df.null_color != mismatched_properties_df.color)
    assert mismatched_properties_df.count() >= 1

def assert_diamond_has_expected_price(actual_df: DataFrame) -> None:
    check_we_have_all_the_rows_we_need_for_the_behaviour(actual_df)
    actual_df=actual_df.filter(actual_df.id == 'DI-26-null-price')
    assert actual_df.count() == 1
    actual_df=actual_df.withColumn('expected_price', lit(3333.0))
    assert_column_equality(actual_df, 'price', 'expected_price')

def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    diamonds_df = create_df_from_json("tests/fixtures/diamonds.json", spark)

    actual_df = diamonds_df.withColumn('lprice', log('price'))  # <-- Delete this line
    actual_df = calculate_avg_price_for_similar_diamonds(actual_df) # <-- Change this line to: actual_df = calculate_avg_price_for_similar_diamonds(diamonds_df)

    assert_diamond_has_expected_price(actual_df)


[32m.[0m[32m                                                                                            [100%][0m


Instead of reading from a large JSON file directly, we can make a dataframe with only the values we need in order to reproduce the behaviour

**our bug**: Diamonds of the same cut and clarity are influencing the calculated price of diamonds with a different color. Only diamonds with the same cut, clarity and color should be influencing the calculated price for diamonds with a null price.

**expected behaviour**: An unpriced diamond with cut=Good, color=D and clarity=VVS1 in a dataset with other diamonds of the same cut, clarity and color all priced at 3333.0, will have it's price set to match the average price of those diamonds. It will ignore prices from diamonds with a different color, cut or clarity

#### 4.B Create the minimum Inputs needed to reproduce the bug
- [ ] Delete the line that reads in from diamonds.json
- [ ] Replace it with a new spark dataframe with the minimum inputs needed to test the behaviour
- [ ] Change the id that we are filtering on in the assert statement to match the id of the diamond without a price in the new spark dataframe.
- [ ] Run your test. It should pass.


In [12]:
%%ipytest -qq
from pyspark.sql import DataFrame
from pyspark.sql.functions import lit, log
from chispa import assert_column_equality

############################################ THE NEW SPARK DATAFRAME ############################################
"""
diamonds_df = spark.createDataFrame([
  {
    "id": 1,
    "carat": 0.23,
    "cut": "Ideal",
    "color": "E",
    "clarity": "SI2",
    "depth": 61.5,
    "table": 55,
    "price": 326,
    "x": 3.95,
    "y": 3.98,
    "z": 2.43
  },
  {
    "id": "minimum_inputs",
    "carat": 0.23,
    "cut": "Good",
    "color": "F",
    "clarity": "SI1",
    "depth": null,
    "table": null,
    "price": null,
    "x": null,
    "y": null,
    "z": null
  },
  {
    "id": "DI-26-null-price",
    "carat": 0.21,
    "cut": "Good",
    "color": "D",
    "clarity": "VVS1",
    "depth": null,
    "table": null,
    "price": null,
    "x": null,
    "y": null,
    "z": null
  },
  {"id": "DI-26","carat": 0.21,"cut": "Good","color": "D","clarity": "VVS1","depth": null,"table": null,"price": 3333,"x": null,"y": null,"z": null},
  {"id": "DI-26","carat": 0.21,"cut": "Good","color": "D","clarity": "VVS1","depth": null,"table": null,"price": 3333,"x": null,"y": null,"z": null},
  {"id": "DI-27","carat": 0.21,"cut": "Very Good","color": "D","clarity": "VVS1","depth": null,"table": null,"price": 2692,"x": null,"y": null,"z": null},
  {"id": "DI-28","carat": 0.21,"cut": "Good","color": "G","clarity": "VVS1","depth": null,"table": null,"price": 1665,"x": null,"y": null,"z": null},
  {"id": "DI-30","carat": 0.32,"cut": "Good","color": "D","clarity": "I1","depth": 60.9,"table": 58,"price": 345,"x": 4.38,"y": 4.42,"z": 2.68}
]
)
"""
#################################################################################################################

def assert_diamond_has_expected_price(actual_df: DataFrame) -> None:
    actual_df=actual_df.filter(actual_df.id == 'DI-26-null-price')  #<- Change this line to: actual_df=actual_df.filter(actual_df.id == 'no-price')
    assert actual_df.count() == 1
    actual_df=actual_df.withColumn('expected_price', lit(3333.0))
    assert_column_equality(actual_df, 'price', 'expected_price')

def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    diamonds_df = create_df_from_json("tests/fixtures/diamonds.json", spark) # <- Delete this line
    ############################ PASTE THE NEW SPARK DATAFRAME HERE ##############################

    actual_df = calculate_avg_price_for_similar_diamonds(diamonds_df)

    assert_diamond_has_expected_price(actual_df)

[32m.[0m[32m                                                                                            [100%][0m


In [15]:
%%ipytest -qq
from pyspark.sql import DataFrame
from pyspark.sql.functions import lit, log
from chispa import assert_column_equality

from pathlib import Path

def assert_diamond_has_expected_price(actual_df: DataFrame) -> None:
    actual_df=actual_df.filter(actual_df.id == 'DI-26-null-price')
    assert actual_df.count() == 1
    actual_df=actual_df.withColumn('expected_price', lit(3333.0))
    assert_column_equality(actual_df, 'price', 'expected_price')

def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    json_text = Path("./tests/fixtures/diamonds.json").read_text()
    diamonds_df = spark.read.option("multiline", "true").json(spark.sparkContext.parallelize([json_text]))

    actual_df = calculate_avg_price_for_similar_diamonds(diamonds_df)

    assert_diamond_has_expected_price(actual_df)

[32m.[0m[32m                                                                                            [100%][0m


In [None]:
%%ipytest -qq
from pyspark.sql import DataFrame
from pyspark.sql.functions import lit, log
from chispa import assert_column_equality

from pathlib import Path

def assert_diamond_has_expected_price(actual_df: DataFrame) -> None:
    actual_df=actual_df.filter(actual_df.id == 'DI-26-null-price')
    assert actual_df.count() == 1
    actual_df=actual_df.withColumn('expected_price', lit(3333.0))
    assert_column_equality(actual_df, 'price', 'expected_price')

def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    json_text = '''
      PASTE CONTENTS OF tests/fixtures/diamonds.json HERE
    '''
    json_text = Path("./tests/fixtures/diamonds.json").read_text()
    diamonds_df = spark.read.option("multiline", "true").json(spark.sparkContext.parallelize([json_text]))

    actual_df = calculate_avg_price_for_similar_diamonds(diamonds_df)

    assert_diamond_has_expected_price(actual_df)

[32m.[0m[32m                                                                                            [100%][0m


In [16]:
%%ipytest -qq
from pyspark.sql import DataFrame
from pyspark.sql.functions import lit, log
from chispa import assert_column_equality

from pathlib import Path

def assert_diamond_has_expected_price(actual_df: DataFrame) -> None:
    actual_df=actual_df.filter(actual_df.id == 'DI-26-null-price')
    assert actual_df.count() == 1
    actual_df=actual_df.withColumn('expected_price', lit(3333.0))
    assert_column_equality(actual_df, 'price', 'expected_price')

def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    json_text = '''
[
  {
    "id": 1,
    "carat": 0.23,
    "cut": "Ideal",
    "color": "E",
    "clarity": "SI2",
    "depth": 61.5,
    "table": 55,
    "price": 326,
    "x": 3.95,
    "y": 3.98,
    "z": 2.43
  },
  {
    "id": "minimum_inputs",
    "carat": 0.23,
    "cut": "Good",
    "color": "F",
    "clarity": "SI1",
    "depth": null,
    "table": null,
    "price": null,
    "x": null,
    "y": null,
    "z": null
  },
  {
    "id": "DI-26-null-price",
    "carat": 0.21,
    "cut": "Good",
    "color": "D",
    "clarity": "VVS1",
    "depth": null,
    "table": null,
    "price": null,
    "x": null,
    "y": null,
    "z": null
  },
  {
    "id": "DI-26",
    "carat": 0.21,
    "cut": "Good",
    "color": "D",
    "clarity": "VVS1",
    "depth": null,
    "table": null,
    "price": 3333,
    "x": null,
    "y": null,
    "z": null
  },
  {
    "id": "DI-26",
    "carat": 0.21,
    "cut": "Good",
    "color": "D",
    "clarity": "VVS1",
    "depth": null,
    "table": null,
    "price": 3333,
    "x": null,
    "y": null,
    "z": null
  },
  {
    "id": "DI-27",
    "carat": 0.21,
    "cut": "Very Good",
    "color": "D",
    "clarity": "VVS1",
    "depth": null,
    "table": null,
    "price": 2692,
    "x": null,
    "y": null,
    "z": null
  },
  {
    "id": "DI-28",
    "carat": 0.21,
    "cut": "Good",
    "color": "G",
    "clarity": "VVS1",
    "depth": null,
    "table": null,
    "price": 1665,
    "x": null,
    "y": null,
    "z": null
  },
  {
    "id": "DI-30",
    "carat": 0.32,
    "cut": "Good",
    "color": "D",
    "clarity": "I1",
    "depth": 60.9,
    "table": 58,
    "price": 345,
    "x": 4.38,
    "y": 4.42,
    "z": 2.68
  }
]

    '''
    #json_text = Path("./tests/fixtures/diamonds.json").read_text()
    diamonds_df = spark.read.option("multiline", "true").json(spark.sparkContext.parallelize([json_text]))

    actual_df = calculate_avg_price_for_similar_diamonds(diamonds_df)

    assert_diamond_has_expected_price(actual_df)

[32m.[0m[32m                                                                                            [100%][0m


In [31]:
%%ipytest -qq
from pyspark.sql import DataFrame
from pyspark.sql.functions import lit, log
from chispa import assert_column_equality

from pathlib import Path

def assert_diamond_has_expected_price(actual_df: DataFrame) -> None:
    actual_df=actual_df.filter(actual_df.id == 'DI-26-null-price')
    assert actual_df.count() == 1
    actual_df=actual_df.withColumn('expected_price', lit(3333.0))
    assert_column_equality(actual_df, 'price', 'expected_price')

def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    json_text = '''
[
  {
    "id": "DI-26-null-price",
    "carat": 0.21,
    "cut": "Good",
    "color": "D",
    "clarity": "VVS1",
    "depth": null,
    "table": null,
    "price": null,
    "x": null,
    "y": null,
    "z": null
  },
  {
    "id": "DI-26",
    "carat": 0.21,
    "cut": "Good",
    "color": "D",
    "clarity": "VVS1",
    "depth": null,
    "table": null,
    "price": 3333,
    "x": null,
    "y": null,
    "z": null
  }
]

    '''
    #json_text = Path("./tests/fixtures/diamonds.json").read_text()
    diamonds_df = spark.read.option("multiline", "true").json(spark.sparkContext.parallelize([json_text]))

    actual_df = calculate_avg_price_for_similar_diamonds(diamonds_df)

    assert_diamond_has_expected_price(actual_df)

[32m.[0m[32m                                                                                            [100%][0m


In [37]:
%%ipytest -qq
from pyspark.sql import DataFrame
from pyspark.sql.functions import lit, log
from chispa import assert_column_equality

from pathlib import Path

def assert_diamond_has_expected_price(actual_df: DataFrame) -> None:
    actual_df=actual_df.filter(actual_df.id == 'DI-26-null-price')
    assert actual_df.count() == 1
    actual_df=actual_df.withColumn('expected_price', lit(3333.0))
    assert_column_equality(actual_df, 'price', 'expected_price')

def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    json_text = '''
[
  {
    "id": "DI-26-null-price",

    "cut": "Good",
    "color": "D",
    "clarity": "VVS1",
    "price": null


  },
  {
    "id": "DI-26",

    "cut": "Good",
    "color": "D",
    "clarity": "VVS1",
    "price": 3333

  }
]

    '''
    #json_text = Path("./tests/fixtures/diamonds.json").read_text()
    diamonds_df = spark.read.option("multiline", "true").json(spark.sparkContext.parallelize([json_text]))

    actual_df = calculate_avg_price_for_similar_diamonds(diamonds_df)

    assert_diamond_has_expected_price(actual_df)

[32m.[0m[32m                                                                                            [100%][0m


# Step 5: Clean up!

This test can still be cleaner and easy to read.

It passes, so it satisfies the first rule of simple design - tests pass

What about the other 3?

### Step 5A: Reveal Intent - Extract the unpriced diamond to a method

- [ ]  This is our unpriced diamond. In the next cell, we will pull it out into a method.

In [55]:
%%ipytest -qq
from pyspark.sql import DataFrame
from pyspark.sql.functions import lit, log
from chispa import assert_column_equality

def assert_diamond_has_expected_price(actual_df: DataFrame) -> None:
    actual_df=actual_df.filter(actual_df.id == 'no-price')
    assert actual_df.count() == 1
    actual_df=actual_df.withColumn('expected_price', lit(3333.0))
    assert_column_equality(actual_df, 'price', 'expected_price')

def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    diamonds_df = spark.createDataFrame([
      ##### THE UNPRICED DIAMOND #####
      {"id": "no-price",                       "price": None,  "color": "D", "cut": "Good", "clarity": "VVS1"},
      ##### END OF THE UNPRICED DIAMOND #####
      {"id": "with-price",                     "price": 3333.0,"color": "D", "cut": "Good", "clarity": "VVS1"},
      {"id": "with-price-wrong-color",         "price": 2000.0,"color": "G", "cut": "Good", "clarity": "VVS1"},
      {"id": "with-price-wrong-clarity",       "price": 2000.0,"color": "D", "cut": "Good", "clarity": "S1"},
      {"id": "with-price-wrong-cut",           "price": 2000.0,"color": "D", "cut": "Very Good", "clarity": "VVS1"},
    ])

    actual_df = calculate_avg_price_for_similar_diamonds(diamonds_df)

    assert_diamond_has_expected_price(actual_df)


[32m.[0m[32m                                                                                            [100%][0m


### Step 5B: Reveal Intent - Extract the unpriced diamond to a method

- [ ] We've created a new method called unpriced_diamond to return the code creating the unpriced diamond. It is commented out.
- [ ] Uncomment it.
- [ ] Run the test. It should pass.

In [56]:
%%ipytest -qq
from pyspark.sql import DataFrame
from pyspark.sql.functions import lit, log
from chispa import assert_column_equality

from typing import Dict

def assert_diamond_has_expected_price(actual_df: DataFrame) -> None:
    actual_df=actual_df.filter(actual_df.id == 'no-price')
    assert actual_df.count() == 1
    actual_df=actual_df.withColumn('expected_price', lit(3333.0))
    assert_column_equality(actual_df, 'price', 'expected_price')

##### UNCOMMENT THIS METHOD
#def unpriced_diamond() -> Dict:
#  return {"id": "no-price", "price": None,  "color": "D", "cut": "Good", "clarity": "VVS1"}

def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    diamonds_df = spark.createDataFrame([
      ##### THE UNPRICED DIAMOND #####
      unpriced_diamond(),
      ##### END OF THE UNPRICED DIAMOND #####
      {"id": "with-price",                     "price": 3333.0,"color": "D", "cut": "Good", "clarity": "VVS1"},
      {"id": "with-price-wrong-color",         "price": 2000.0,"color": "G", "cut": "Good", "clarity": "VVS1"},
      {"id": "with-price-wrong-clarity",       "price": 2000.0,"color": "D", "cut": "Good", "clarity": "S1"},
      {"id": "with-price-wrong-cut",           "price": 2000.0,"color": "D", "cut": "Very Good", "clarity": "VVS1"},
    ])

    actual_df = calculate_avg_price_for_similar_diamonds(diamonds_df)

    assert_diamond_has_expected_price(actual_df)


[32m.[0m[32m                                                                                            [100%][0m


### Step 5C: Reveal Intent - Extract the matching diamond to a method

- [ ] Look at the second object in the JSON array. This is our matching diamond. In the next cell, we will pull it out into a method.

In [57]:
%%ipytest -qq
from pyspark.sql import DataFrame
from pyspark.sql.functions import lit, log
from chispa import assert_column_equality

from typing import Dict

def assert_diamond_has_expected_price(actual_df: DataFrame) -> None:
    actual_df=actual_df.filter(actual_df.id == 'no-price')
    assert actual_df.count() == 1
    actual_df=actual_df.withColumn('expected_price', lit(3333.0))
    assert_column_equality(actual_df, 'price', 'expected_price')

def unpriced_diamond() -> Dict:
  return {"id": "no-price", "price": None,  "color": "D", "cut": "Good", "clarity": "VVS1"}

def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    diamonds_df = spark.createDataFrame([
      unpriced_diamond(),
      ##### THE MATCHING DIAMOND #####
      {"id": "with-price",                     "price": 3333.0,"color": "D", "cut": "Good", "clarity": "VVS1"},
      ##### END OF THE MATCHING DIAMOND #####
      {"id": "with-price-wrong-color",         "price": 2000.0,"color": "G", "cut": "Good", "clarity": "VVS1"},
      {"id": "with-price-wrong-clarity",       "price": 2000.0,"color": "D", "cut": "Good", "clarity": "S1"},
      {"id": "with-price-wrong-cut",           "price": 2000.0,"color": "D", "cut": "Very Good", "clarity": "VVS1"},
    ])

    actual_df = calculate_avg_price_for_similar_diamonds(diamonds_df)

    assert_diamond_has_expected_price(actual_df)


[32m.[0m[32m                                                                                            [100%][0m


### Step 5D: Reveal Intent - Extract the matching diamond to a method

- [ ] We've created a new method called matching_diamond to return the code creating the matching diamond.It is commented out.
- [ ] Uncomment it.
- [ ] Notice how we've set up the method interface so it takes a price. This lets us indicate we care about the price of the matching diamond.
- [ ] The price of the matching diamond should be `3333.0`. Pass it into the method call.
- [ ] Run the test. It should pass.

In [58]:
%%ipytest -qq
from pyspark.sql import DataFrame
from pyspark.sql.functions import lit, log

from typing import Dict

def assert_diamond_has_expected_price(actual_df: DataFrame) -> None:
    actual_df=actual_df.filter(actual_df.id == 'no-price')
    assert actual_df.count() == 1
    actual_df=actual_df.withColumn('expected_price', lit(3333.0))
    assert_column_equality(actual_df, 'price', 'expected_price')

def unpriced_diamond() -> Dict:
  return {"id": "no-price", "price": None,  "color": "D", "cut": "Good", "clarity": "VVS1"}

#def matching_diamond(price: float) -> Dict:
#  return {"id": "with-price", "price": price ,"color": "D", "cut": "Good", "clarity": "VVS1"}

def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    diamonds_df = spark.createDataFrame([
      unpriced_diamond(),
      ##### THE MATCHING DIAMOND #####
      matching_diamond(), # <- CHANGE THIS LINE TO: matching_diamond(price=3333.0),
      ##### END OF THE MATCHING DIAMOND #####
      {"id": "with-price-wrong-color",         "price": 2000.0,"color": "G", "cut": "Good", "clarity": "VVS1"},
      {"id": "with-price-wrong-clarity",       "price": 2000.0,"color": "D", "cut": "Good", "clarity": "S1"},
      {"id": "with-price-wrong-cut",           "price": 2000.0,"color": "D", "cut": "Very Good", "clarity": "VVS1"},
    ])

    actual_df = calculate_avg_price_for_similar_diamonds(diamonds_df)

    assert_diamond_has_expected_price(actual_df)


[32m.[0m[32m                                                                                            [100%][0m


### Step 5E: Reveal Intent - Extract the diamond with a different color to a method

- [ ] Look at the third object in the JSON array. This is the diamond with a different color. In the next cell, we will pull it out into a method.

In [60]:
%%ipytest -qq
from pyspark.sql import DataFrame
from pyspark.sql.functions import lit, log

from typing import Dict

def assert_diamond_has_expected_price(actual_df: DataFrame) -> None:
    actual_df=actual_df.filter(actual_df.id == 'no-price')
    assert actual_df.count() == 1
    actual_df=actual_df.withColumn('expected_price', lit(3333.0))
    assert_column_equality(actual_df, 'price', 'expected_price')

def unpriced_diamond() -> Dict:
  return {"id": "no-price", "price": None,  "color": "D", "cut": "Good", "clarity": "VVS1"}

def matching_diamond(price: float) -> Dict:
  return {"id": "with-price", "price": price ,"color": "D", "cut": "Good", "clarity": "VVS1"}

def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    diamonds_df = spark.createDataFrame([
      unpriced_diamond(),
      matching_diamond(price=3333.0),
      ##### THE DIAMOND WITH A DIFFERENT COLOR #####
      {"id": "with-price-wrong-color",         "price": 2000.0,"color": "G", "cut": "Good", "clarity": "VVS1"},
      ##### END OF THE DIAMOND WITH A DIFFERENT COLOR #####
      {"id": "with-price-wrong-clarity",       "price": 2000.0,"color": "D", "cut": "Good", "clarity": "S1"},
      {"id": "with-price-wrong-cut",           "price": 2000.0,"color": "D", "cut": "Very Good", "clarity": "VVS1"},
    ])

    actual_df = calculate_avg_price_for_similar_diamonds(diamonds_df)

    assert_diamond_has_expected_price(actual_df)


[32m.[0m[32m                                                                                            [100%][0m


### Step 5F: Reveal Intent - Extract the diamond with a different color to a method

- [ ] We've created a new method called diamond_with_different_color to return the diamond with a different color.
- [ ] Notice how we've set up the method interface so we can see we care about the price of the diamond.
- [ ] We want to use the price we pass into the method. Replace the hardcoded `2000.0` with `price`
- [ ] The price of this diamond should be `2000.0`. Pass it into the method call.
- [ ] Run the test. It should pass.

In [61]:
%%ipytest -qq
from pyspark.sql import DataFrame
from pyspark.sql.functions import lit, log

from typing import Dict

def assert_diamond_has_expected_price(actual_df: DataFrame) -> None:
    actual_df=actual_df.filter(actual_df.id == 'no-price')
    assert actual_df.count() == 1
    actual_df=actual_df.withColumn('expected_price', lit(3333.0))
    assert_column_equality(actual_df, 'price', 'expected_price')

def unpriced_diamond() -> Dict:
  return {"id": "no-price", "price": None,  "color": "D", "cut": "Good", "clarity": "VVS1"}

def matching_diamond(price: float) -> Dict:
  return {"id": "with-price", "price": price ,"color": "D", "cut": "Good", "clarity": "VVS1"}

def diamond_with_different_color(price: float) -> Dict:
  return {"id": "with-price-wrong-color", "price": 2000.0, "color": "G", "cut": "Good", "clarity": "VVS1"}  # <- CHANGE "price": 2000.0 TO "price": price

def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    diamonds_df = spark.createDataFrame([
      unpriced_diamond(),
      matching_diamond(price=3333.0),
      ##### THE DIAMOND WITH A DIFFERENT COLOR #####
      diamond_with_different_color(), # <- CHANGE THIS LINE TO: diamond_with_different_color(price=2000.0),
      ##### END OF THE DIAMOND WITH A DIFFERENT COLOR #####
      {"id": "with-price-wrong-clarity",       "price": 2000.0,"color": "D", "cut": "Good", "clarity": "S1"},
      {"id": "with-price-wrong-cut",           "price": 2000.0,"color": "D", "cut": "Very Good", "clarity": "VVS1"},
    ])

    actual_df = calculate_avg_price_for_similar_diamonds(diamonds_df)

    assert_diamond_has_expected_price(actual_df)


[32m.[0m[32m                                                                                            [100%][0m


### Step 5G: Reveal Intent - Extract the diamond with a different clarity to a method

- [ ] Look at the fourth object in the JSON array. This is the diamond with a different clarity. In the next cell, we will pull it out into a method.

In [62]:
%%ipytest -qq
from pyspark.sql import DataFrame
from pyspark.sql.functions import lit, log

from typing import Dict

def assert_diamond_has_expected_price(actual_df: DataFrame) -> None:
    actual_df=actual_df.filter(actual_df.id == 'no-price')
    assert actual_df.count() == 1
    actual_df=actual_df.withColumn('expected_price', lit(3333.0))
    assert_column_equality(actual_df, 'price', 'expected_price')

def unpriced_diamond() -> Dict:
  return {"id": "no-price", "price": None,  "color": "D", "cut": "Good", "clarity": "VVS1"}

def matching_diamond(price: float) -> Dict:
  return {"id": "with-price", "price": price ,"color": "D", "cut": "Good", "clarity": "VVS1"}

def diamond_with_different_color(price: float) -> Dict:
  return {"id": "with-price-wrong-color", "price": price, "color": "G", "cut": "Good", "clarity": "VVS1"}

def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    diamonds_df = spark.createDataFrame([
      unpriced_diamond(),
      matching_diamond(price=3333.0),
      diamond_with_different_color(price=2000.0),
      ##### THE DIAMOND WITH A DIFFERENT CLARITY #####
      {"id": "with-price-wrong-clarity",       "price": 2000.0,"color": "D", "cut": "Good", "clarity": "S1"},
      ##### END OF THE DIAMOND WITH A DIFFERENT CLARITY #####
      {"id": "with-price-wrong-cut",           "price": 2000.0,"color": "D", "cut": "Very Good", "clarity": "VVS1"},
    ])

    actual_df = calculate_avg_price_for_similar_diamonds(diamonds_df)

    assert_diamond_has_expected_price(actual_df)


[32m.[0m[32m                                                                                            [100%][0m


### Step 5H: Reveal Intent - Extract the diamond with a different clarity to a method

- [ ] We've created a new method called diamond_with_different_clarity to return the diamond with a different clarity.
- [ ] The method call that would call this method is commented.
- [ ] Uncomment it.
- [ ] Cut and paste the dict that represents the diamond with a different clarity after the return statement in diamond_with_different_clarity.
- [ ] Change the hardcoded "price": 2000.0 to "price": price
- [ ] Run the test. It should pass.

In [63]:
%%ipytest -qq
from pyspark.sql import DataFrame
from pyspark.sql.functions import lit, log

from typing import Dict

def assert_diamond_has_expected_price(actual_df: DataFrame) -> None:
    actual_df=actual_df.filter(actual_df.id == 'no-price')
    assert actual_df.count() == 1
    actual_df=actual_df.withColumn('expected_price', lit(3333.0))
    assert_column_equality(actual_df, 'price', 'expected_price')

def unpriced_diamond() -> Dict:
  return {"id": "no-price", "price": None,  "color": "D", "cut": "Good", "clarity": "VVS1"}

def matching_diamond(price: float) -> Dict:
  return {"id": "with-price", "price": price ,"color": "D", "cut": "Good", "clarity": "VVS1"}

def diamond_with_different_color(price: float) -> Dict:
  return {"id": "with-price-wrong-color", "price": price, "color": "G", "cut": "Good", "clarity": "VVS1"}

def diamond_with_different_clarity(price: float) -> Dict:
  return ####### PASTE THE DICT HERE AND CHANGE "price": 2000.0 TO "price": price #####

def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    diamonds_df = spark.createDataFrame([
      unpriced_diamond(),
      matching_diamond(price=3333.0),
      diamond_with_different_color(price=2000.0),
      ##### THE DIAMOND WITH A DIFFERENT CLARITY #####
      # diamond_with_different_clarity(price=2000.0), # <- UNCOMMENT THIS LINE
      {"id": "with-price-wrong-clarity", "price": 2000.0,"color": "D", "cut": "Good", "clarity": "S1"}, # <- CUT THIS LINE AND PASTE IT AFTER THE RETURN IN diamond_with_different_clarity
      ##### END OF THE DIAMOND WITH A DIFFERENT CLARITY #####
      {"id": "with-price-wrong-cut",           "price": 2000.0,"color": "D", "cut": "Very Good", "clarity": "VVS1"},
    ])

    actual_df = calculate_avg_price_for_similar_diamonds(diamonds_df)

    assert_diamond_has_expected_price(actual_df)


[32m.[0m[32m                                                                                            [100%][0m


### Step 5I: Reveal Intent - Extract the diamond with a different cut to a method

- [ ] Look at the fifth object in the JSON array. This is the diamond with a different cut. In the next cell, we will pull it out into a method.

In [None]:
%%ipytest -qq
from pyspark.sql import DataFrame
from pyspark.sql.functions import lit, log

from typing import Dict

def assert_diamond_has_expected_price(actual_df: DataFrame) -> None:
    actual_df=actual_df.filter(actual_df.id == 'no-price')
    assert actual_df.count() == 1
    actual_df=actual_df.withColumn('expected_price', lit(3333.0))
    assert_column_equality(actual_df, 'price', 'expected_price')

def unpriced_diamond() -> Dict:
  return {"id": "no-price", "price": None,  "color": "D", "cut": "Good", "clarity": "VVS1"}

def matching_diamond(price: float) -> Dict:
  return {"id": "with-price", "price": price ,"color": "D", "cut": "Good", "clarity": "VVS1"}

def diamond_with_different_color(price: float) -> Dict:
  return {"id": "with-price-wrong-color", "price": price, "color": "G", "cut": "Good", "clarity": "VVS1"}

def diamond_with_different_clarity(price: float) -> Dict:
  return {"id": "with-price-wrong-color", "price": price, "color": "D", "cut": "Good", "clarity": "S1"}

def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    diamonds_df = spark.createDataFrame([
      unpriced_diamond(),
      matching_diamond(price=3333.0),
      diamond_with_different_color(price=2000.0),
      diamond_with_different_clarity(price=2000.0),
      ##### THE DIAMOND WITH A DIFFERENT CUT #####
      {"id": "with-price-wrong-cut",           "price": 2000.0,"color": "D", "cut": "Very Good", "clarity": "VVS1"},
      ##### END OF THE DIAMOND WITH A DIFFERENT CUT #####
    ])

    actual_df = calculate_avg_price_for_similar_diamonds(diamonds_df)

    assert_diamond_has_expected_price(actual_df)


[32m.[0m[32m                                                                                            [100%][0m


### Step 5J: Reveal Intent - Extract the diamond with a different cut to a method

- [ ] We've created a new method called diamond_with_different_cut to return the diamond with a different cut.
- [ ] Notice how we've set up the method interface so we can see we care about the price of the diamond.
- [ ] Replace the Dict in between the comments with a method call to the new method.
- [ ] Run the test. It should pass.

In [None]:
%%ipytest -qq
from pyspark.sql import DataFrame
from pyspark.sql.functions import lit, log

from typing import Dict

def assert_diamond_has_expected_price(actual_df: DataFrame) -> None:
    actual_df=actual_df.filter(actual_df.id == 'no-price')
    assert actual_df.count() == 1
    actual_df=actual_df.withColumn('expected_price', lit(3333.0))
    assert_column_equality(actual_df, 'price', 'expected_price')

def unpriced_diamond() -> Dict:
  return {"id": "no-price", "price": None,  "color": "D", "cut": "Good", "clarity": "VVS1"}

def matching_diamond(price: float) -> Dict:
  return {"id": "with-price", "price": price ,"color": "D", "cut": "Good", "clarity": "VVS1"}

def diamond_with_different_color(price: float) -> Dict:
  return {"id": "with-price-wrong-color", "price": price, "color": "G", "cut": "Good", "clarity": "VVS1"}

def diamond_with_different_clarity(price: float) -> Dict:
  return {"id": "with-price-wrong-color", "price": price, "color": "D", "cut": "Good", "clarity": "S1"}

def diamond_with_different_cut(price: float) -> Dict:
  return  {"id": "with-price-wrong-cut",           "price": price,"color": "D", "cut": "Very Good", "clarity": "VVS1"}

def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    diamonds_df = spark.createDataFrame([
      unpriced_diamond(),
      matching_diamond(price=3333.0),
      diamond_with_different_color(price=2000.0),
      diamond_with_different_clarity(price=2000.0),
      ##### THE DIAMOND WITH A DIFFERENT CUT #####
      {"id": "with-price-wrong-cut", "price": 2000.0,"color": "D", "cut": "Very Good", "clarity": "VVS1"},  # <- DELETE THIS
      # CALL diamond_with_different_cut(price=2000.0) HERE
      ##### END OF THE DIAMOND WITH A DIFFERENT CUT #####
    ])

    actual_df = calculate_avg_price_for_similar_diamonds(diamonds_df)

    assert_diamond_has_expected_price(actual_df)


[32m.[0m[32m                                                                                            [100%][0m


### Step 5K: Reveal Intent - Make it obvious which price we want

- [ ] Pass the price as an argument to your assert helper function
- [ ] Use it in place of the hardcoded 3333.0 when you create the `expected_price` column
- [ ] Run the test. It should pass.

In [52]:
%%ipytest -qq
from pyspark.sql import DataFrame
from pyspark.sql.functions import lit, log

from typing import Dict

def assert_diamond_has_expected_price(actual_df: DataFrame) -> None: # <- CHANGE THIS LINE TO: def assert_diamond_has_expected_price(actual_df: DataFrame, price: float) -> None:
    actual_df=actual_df.filter(actual_df.id == 'no-price')
    assert actual_df.count() == 1
    actual_df=actual_df.withColumn('expected_price', lit(3333.0)) # <- CHANGE THIS LINE TO: actual_df=actual_df.withColumn('expected_price', lit(price))
    assert_column_equality(actual_df, 'price', 'expected_price')

def unpriced_diamond() -> Dict:
  return {"id": "no-price", "price": None,  "color": "D", "cut": "Good", "clarity": "VVS1"}

def matching_diamond(price: float) -> Dict:
  return {"id": "with-price", "price": price ,"color": "D", "cut": "Good", "clarity": "VVS1"}

def diamond_with_different_color(price: float) -> Dict:
  return {"id": "with-price-wrong-color", "price": price, "color": "G", "cut": "Good", "clarity": "VVS1"}

def diamond_with_different_clarity(price: float) -> Dict:
  return {"id": "with-price-wrong-color", "price": price, "color": "D", "cut": "Good", "clarity": "S1"}

def diamond_with_different_cut(price: float) -> Dict:
  return  {"id": "with-price-wrong-cut",  "price": price, "color": "D", "cut": "Very Good", "clarity": "VVS1"}

def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    diamonds_df = spark.createDataFrame([
      unpriced_diamond(),
      matching_diamond(price=3333.0),
      diamond_with_different_color(price=2000.0),
      diamond_with_different_clarity(price=2000.0),
      diamond_with_different_cut(price=2000.0)
    ])


    actual_df = calculate_avg_price_for_similar_diamonds(diamonds_df)

    assert_diamond_has_expected_price(actual_df)  # <- CHANGE THIS LINE TO: assert_diamond_has_expected_price(actual_df, price=3333.0)


[32m.[0m[32m                                                                                            [100%][0m


### Step 5L: Remove Duplication - Use the Mismatched Diamond Function

There is duplication in the `diamond_with_different_color`, `diamond_with_different_clarity` and `diamond_with_different_cut` methods.
- [ ] There is a new method `mismatched_diamond` that minimizes duplication. Have each of these methods call that method.
- [ ] Run the test. It should pass.

In [67]:
%%ipytest -qq
from pyspark.sql import DataFrame
from pyspark.sql.functions import lit, log

from typing import Dict

def assert_diamond_has_expected_price(actual_df: DataFrame, price: float) -> None:
    actual_df=actual_df.filter(actual_df.id == 'no-price')
    assert actual_df.count() == 1
    actual_df=actual_df.withColumn('expected_price', lit(price))
    assert_column_equality(actual_df, 'price', 'expected_price')

def unpriced_diamond() -> Dict:
  return {"id": "no-price", "price": None,  "color": "D", "cut": "Good", "clarity": "VVS1"}

def matching_diamond(price: float) -> Dict:
  return {"id": "with-price", "price": price ,"color": "D", "cut": "Good", "clarity": "VVS1"}

def diamond_with_different_color(price: float) -> Dict:
  # CHANGE THE LINE BELOW TO: return mismatched_diamond(price=price, different_columns={"color": "G"})
  return {"id": "with-price-wrong-color", "price": price, "color": "G", "cut": "Good", "clarity": "VVS1"}

def diamond_with_different_clarity(price: float) -> Dict:
  # CHANGE THE LINE BELOW TO: return mismatched_diamond(price=price, different_columns={"clarity": "S1"})
  return {"id": "with-price-wrong-color", "price": price, "color": "D", "cut": "Good", "clarity": "S1"}

def diamond_with_different_cut(price: float) -> Dict:
  # CHANGE THE LINE BELOW TO: return mismatched_diamond(price=price, different_columns={"cut": "Very Good"})
  return  {"id": "with-price-wrong-cut",  "price": price, "color": "D", "cut": "Very Good", "clarity": "VVS1"}

def mismatched_diamond(price: float, different_columns: Dict) -> Dict:
  diamond = {"id": "with-price", "price": price ,"color": "D", "cut": "Good", "clarity": "VVS1"}
  diamond = diamond | different_columns
  return diamond

def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    diamonds_df = spark.createDataFrame([
      unpriced_diamond(),
      matching_diamond(price=3333.0),
      diamond_with_different_color(price=2000.0),
      diamond_with_different_clarity(price=2000.0),
      diamond_with_different_cut(price=2000.0)
    ])


    actual_df = calculate_avg_price_for_similar_diamonds(diamonds_df)

    assert_diamond_has_expected_price(actual_df, price=3333.0)


[32m.[0m[32m                                                                                            [100%][0m


### Step 5M: Remove Duplication - Inline the duplicate functions

There is duplication in the `diamond_with_different_color`, `diamond_with_different_clarity` and `diamond_with_different_cut` methods.
- [ ] Replace the call to each of these functions in the test with their call to `mismatched_diamond`
- [ ] Delete the duplicated functions

In [68]:
%%ipytest -qq
from pyspark.sql import DataFrame
from pyspark.sql.functions import lit, log

from typing import Dict

def assert_diamond_has_expected_price(actual_df: DataFrame, price: float) -> None:
    actual_df=actual_df.filter(actual_df.id == 'no-price')
    assert actual_df.count() == 1
    actual_df=actual_df.withColumn('expected_price', lit(price))
    assert_column_equality(actual_df, 'price', 'expected_price')

def unpriced_diamond() -> Dict:
  return {"id": "no-price", "price": None,  "color": "D", "cut": "Good", "clarity": "VVS1"}

def matching_diamond(price: float) -> Dict:
  return {"id": "with-price", "price": price ,"color": "D", "cut": "Good", "clarity": "VVS1"}

##### DELETE THESE FUNCTIONS #####
def diamond_with_different_color(price: float) -> Dict:
  return mismatched_diamond(price=price, different_columns={"color": "G"})

def diamond_with_different_clarity(price: float) -> Dict:
  return mismatched_diamond(price=price, different_columns={"clarity": "S1"})

def diamond_with_different_cut(price: float) -> Dict:
  return mismatched_diamond(price=price, different_columns={"cut": "Very Good"})
##################################


def mismatched_diamond(price: float, different_columns: Dict) -> Dict:
  diamond = {"id": "with-price", "price": price ,"color": "D", "cut": "Good", "clarity": "VVS1"}
  diamond = diamond | different_columns
  return diamond

def test_null_price_is_replaced_based_on_cut_clarity_and_color(spark: SparkSession):
    diamonds_df = spark.createDataFrame([
      unpriced_diamond(),
      matching_diamond(price=3333.0),
      # CHANGE THE LINE BELOW TO: mismatched_diamond(price=2000.0, different_columns={"color": "G"})
      diamond_with_different_color(price=2000.0),
      # CHANGE THE LINE BELOW TO: mismatched_diamond(price=2000.0, different_columns={"clarity": "S1"})
      diamond_with_different_clarity(price=2000.0),
      # CHANGE THE LINE BELOW TO: mismatched_diamond(price=2000.0, different_columns={"cut": "Very Good"})
      diamond_with_different_cut(price=2000.0)
    ])


    actual_df = calculate_avg_price_for_similar_diamonds(diamonds_df)

    assert_diamond_has_expected_price(actual_df, price=3333.0)


[32m.[0m[32m                                                                                            [100%][0m
