# Alice White python function 

```Python
def stratified_random_sample(sampling_frame, strata, unit, i):
    random.seed(i)
    sampling_frame = sampling_frame.withColumn('random_number',f.rand())

    window_spec = Window.partitionBy(strata).orderBy(f.col('random_number'))

    sampling_frame = (sampling_frame.withColumn(f'{unit}_sampled{i}',
                                               f.when((f.col('random_number_rank') <= f.col('required_ss')), 1)
                                               .otherwise(9)))
    sampling_frame = sampling_frame.drop('random_number', 'random_number_rank')

    return sampling_frame
```


In [1]:
import os
import yaml
from pyspark.sql import SparkSession, functions as F

spark = SparkSession.builder.master("local[2]").appName("sampling").getOrCreate()

with open("../../../config.yaml") as f:
    config = yaml.safe_load(f)
    
#rescue_path = 
rescue_path = "../../data/animal_rescue.csv"
rescue = spark.read.csv(rescue_path, header=True, inferSchema=True)
rescue = rescue.withColumnRenamed('AnimalGroupParent','animal_type')

In [2]:
rescue.show(5)

+--------------+----------------+-------+-------+---------------+---------+--------------+---------------------+-----------------------+--------------------+-----------+------------------+--------------------+-----------------+--------------------------+--------------------+---------+--------------------+-----------+----------+-------------+----------------+---------+----------+---------------+----------------+
|IncidentNumber|  DateTimeOfCall|CalYear|FinYear| TypeOfIncident|PumpCount|PumpHoursTotal|HourlyNotionalCost(£)|IncidentNotionalCost(£)|    FinalDescription|animal_type|      OriginofCall|        PropertyType| PropertyCategory|SpecialServiceTypeCategory|  SpecialServiceType| WardCode|                Ward|BoroughCode|   Borough|StnGroundName|PostcodeDistrict|Easting_m|Northing_m|Easting_rounded|Northing_rounded|
+--------------+----------------+-------+-------+---------------+---------+--------------+---------------------+-----------------------+--------------------+-----------+-

In [3]:
fraction = 0.1
row_count = round(rescue.count() * fraction)
row_count

590

In [4]:
rescue.withColumn("rand_no", F.rand()).orderBy("rand_no").limit(row_count).drop("rand_no").count()

590

# Find a strata:
What is a reasonable strata to sample by?
Alex used animal type-This could work
How could we generalise this so that it can work for any number of strata, dictionary?

Check the github code for stratified sampling in pyspark, one input variable is a dictionary.

In [23]:
#rescue.groupBy('animal_type').agg(F.count('animal_type').alias('count')).sort('count',ascending = False).show()

# strata_dict = rescue.select('animal_type').distinct().collect()
restructured = rescue.withColumn('animal_type',F.when(F.col('animal_type') == 'Cat','Cat')
                                                .when(F.col('animal_type') == 'Bird','Bird')
                                                .when(F.col('animal_type') == 'Dog','Dog')
                                                .when(F.col('animal_type') == 'Fox','Fox')
                                                .otherwise('Other'))
restructured.groupBy('animal_type').agg(F.count('animal_type').alias('count')).sort('count',ascending = False).show()

restructured = restructured.withColumn('sample_size',F.when(F.col('animal_type') == 'Cat',20)
                                                .when(F.col('animal_type') == 'Bird',15)
                                                .when(F.col('animal_type') == 'Dog',10)
                                                .when(F.col('animal_type') == 'Fox',2)
                                                .otherwise(5))
# restructured.show()


+-----------+-----+
|animal_type|count|
+-----------+-----+
|        Cat| 2909|
|       Bird| 1100|
|        Dog| 1008|
|      Other|  643|
|        Fox|  238|
+-----------+-----+



In [24]:
from pyspark.sql.window import Window
restructured = restructured.withColumn('random_number',F.rand())
window_spec = Window.partitionBy('animal_type').orderBy('random_number')
restructured = restructured.withColumn("strata_rank",F.rank().over(Window.partitionBy('animal_type').orderBy('random_number')))

restructured = (restructured.withColumn('sampled',F.when((F.col('strata_rank') <= F.col('sample_size')), 1)
                                    .otherwise(0)))

sample = restructured.filter(F.col('sampled') == 1)

# restructured.filter(F.col('strata_rank') <= F.col('required_ss'))
# sampling_frame = (restructured.withColumn('sampled',
#                                                f.when((f.col('random_number_rank') <= f.col('required_ss')), 1)
#                                                .otherwise(9)))

In [26]:
sample.groupBy('animal_type','sample_size').agg(F.count('animal_type').alias('count')).sort('count',ascending = False).show()


+-----------+-----------+-----+
|animal_type|sample_size|count|
+-----------+-----------+-----+
|        Cat|         20|   20|
|       Bird|         15|   15|
|        Dog|         10|   10|
|      Other|          5|    5|
|        Fox|          2|    2|
+-----------+-----------+-----+



# Applying map()

In [31]:
#rescue.groupBy('animal_type').agg(F.count('animal_type').alias('count')).sort('count',ascending = False).show()

# strata_dict = rescue.select('animal_type').distinct().collect()
strata_dict = {'Cat':100, 'Dog':20,'Bird':15, 'Fox': 49,'Other':0}
print(strata_dict)
restructured = rescue.withColumn('animal_type',F.when(F.col('animal_type') == 'Cat','Cat')
                                                .when(F.col('animal_type') == 'Bird','Bird')
                                                .when(F.col('animal_type') == 'Dog','Dog')
                                                .when(F.col('animal_type') == 'Fox','Fox')
                                                .otherwise('Other'))

restructured.groupBy('animal_type').agg(F.count('animal_type').alias('count')).sort('count',ascending = False).show()

# restructured = restructured.withColumn('sample_size',F.when(F.col('animal_type') == 'Cat',20)
#                                                 .when(F.col('animal_type') == 'Bird',15)
#                                                 .when(F.col('animal_type') == 'Dog',10)
#                                                 .when(F.col('animal_type') == 'Fox',2)
#                                                 .otherwise(5))
# restructured.show()

# Chain is an external package from itertools
from itertools import chain
mapping_expr = F.create_map([F.lit(x) for x in chain(*strata_dict.items())]) 
restructured = restructured.withColumn("sample_size", 
              mapping_expr[F.col("animal_type")])


{'Cat': 100, 'Dog': 20, 'Bird': 15, 'Fox': 49, 'Other': 0}
+-----------+-----+
|animal_type|count|
+-----------+-----+
|        Cat| 2909|
|       Bird| 1100|
|        Dog| 1008|
|      Other|  643|
|        Fox|  238|
+-----------+-----+



In [32]:
from pyspark.sql.window import Window
restructured = restructured.withColumn('random_number',F.rand())
window_spec = Window.partitionBy('animal_type').orderBy('random_number')
restructured = restructured.withColumn("strata_rank",F.rank().over(Window.partitionBy('animal_type').orderBy('random_number')))

restructured = (restructured.withColumn('sampled',F.when((F.col('strata_rank') <= F.col('sample_size')), 1)
                                    .otherwise(0)))

sample = restructured.filter(F.col('sampled') == 1)

In [33]:
sample.groupBy('animal_type','sample_size').agg(F.count('animal_type').alias('count')).sort('count',ascending = False).show()


+-----------+-----------+-----+
|animal_type|sample_size|count|
+-----------+-----------+-----+
|        Cat|        100|  100|
|        Fox|         49|   49|
|        Dog|         20|   20|
|       Bird|         15|   15|
+-----------+-----------+-----+



## Returning an exact sample per strata
If you are wishing to use a form of stratified sampling where an exact number of samples are needed per strata, this can be implemented using a window function. More details and examples using window functions can be found in the [window function page](https://best-practice-and-impact.github.io/ons-spark/spark-functions/window-functions.html). We found that the simplest way of returning an exact number of samples per strata is to create a column for each strata with the number of samples required. To simplify this example we have reduced the number of animals in the `animal_type` column to 5: Bird; Cat; Dog; Fox; and Other.

In [18]:
# Preprocessing to simplify the number of animals in the `animal_type` column
simplified_animal_types = rescue.withColumn('animal_type',F.when(~F.col('animal_type').isin(['Cat','Bird','Dog','Fox']),'Other')
                                              .otherwise(F.col('animal_type')))

# Counting the number of animals in each group
simplified_animal_types.groupBy('animal_type').agg(F.count('animal_type').alias('count')).sort('count',ascending = False).show()


+-----------+-----+
|animal_type|count|
+-----------+-----+
|        Cat| 2909|
|       Bird| 1100|
|        Dog| 1008|
|      Other|  643|
|        Fox|  238|
+-----------+-----+



```r
# Preprocessing to simplify the number of animals in the `animal_type` column
simplified_animal_types <-rescue %>% sparklyr::mutate(animal_type = case_when(!animal_group %in% c('Cat','Bird','Dog','Fox') ~ 'Other',
                                                       .default = animal_group))

# Counting the number of animals in each group
simplified_animal_types %>%
        dplyr::group_by(animal_type) %>%
        dplyr::count(animal_type,name = 'row_count') %>%
        sdf_sort('animal_type')
```

Next we create our column which has the required number of samples per strata. This can be done simply using `F.when` statements, however this may not be viable when you are sampling across a lot of variables as this would require a `when` statement for each distinct value. A work around to this can be done using UDFs or a mapping function however, the efficiency of this has not been assessed. We will also give each row a randomly generated number from a uniform distribution which is used to order the rows per strata.  

In [19]:
# Create a sample_size column for each strata 
simplified_animal_types = (simplified_animal_types.withColumn('sample_size',F.when(F.col('animal_type') == 'Bird',15)
                                                .when(F.col('animal_type') == 'Cat',20)
                                                .when(F.col('animal_type') == 'Dog',10)
                                                .when(F.col('animal_type') == 'Fox',2)
                                                .when(F.col('animal_type') == 'Other',5)
                                                .otherwise(0))
                                                .withColumn('random_number',F.rand())
)

```r
simplified_animal_types <- simplified_animal_types %>% 
                            sparklyr::mutate(sample_size = case_when(
                                            animal_type == 'Bird' ~ 15,
                                            animal_type == 'Cat' ~ 20,
                                            animal_type == 'Dog' ~ 10,
                                            animal_type == 'Fox' ~ 2,
                                            animal_type == 'Other' ~ 5,
                                            .default = 0)) %>%
                            sparklyr::mutate(random_number = rand())
```

Using a window function, we create a column `strata_rank` which will rank each of the random numbers from `random_number` column in order. From this, we create a final column `sampled` which will contain a `1` if the ranked value of a random number is less than or equal to the required sample size for each strata. Otherwise it will contain a `0`. To obtain our final sample, we simply filter on the `sampled` column and extract all rows where this equals `1`. We can check the number of entries per sample by using a aggregating function (`agg`) and a `groupBy`.

In [22]:
from pyspark.sql.window import Window
# use Window function to rank based on size of random number
simplified_animal_types = simplified_animal_types.withColumn("strata_rank",F.rank().over(Window.partitionBy('animal_type').orderBy('random_number')))
# Set sampled column to 1 if rank is lower than sample_size, 0 otherwise
simplified_animal_types = (simplified_animal_types.withColumn('sampled',F.when((F.col('strata_rank') <= F.col('sample_size')), 1)
                                    .otherwise(0)))
 
# Take our sample by filtering where `sampled` is 1
sample = simplified_animal_types.filter(F.col('sampled') == 1)
# Count rows in sample
sample.groupBy('animal_type','sample_size').agg(F.count('animal_type').alias('count')).sort('count',ascending = False).show()


+-----------+-----------+-----+
|animal_type|sample_size|count|
+-----------+-----------+-----+
|        Cat|         20|   20|
|       Bird|         15|   15|
|        Dog|         10|   10|
|      Other|          5|    5|
|        Fox|          2|    2|
+-----------+-----------+-----+



```r

simplified_animal_types <- simplified_animal_types %>% 
              dplyr::group_by(animal_type) %>%
              sparklyr::mutate(strata_rank = rank(desc(random_number))) %>%
              dplyr::ungroup()

simplified_animal_types <- simplified_animal_types %>% sparklyr::mutate(sampled = case_when(
              strata_rank <= sample_size ~ 1,
              .default = 0))

sampled <- simplified_animal_types %>% sparklyr::filter(sampled == 1)

sampled %>%
        dplyr::group_by(animal_type,sample_size) %>%
        dplyr::count(animal_type,name = 'row_count') %>%
        sdf_sort('animal_type')

```

It should be pointed out that using a Window function and partitioning by a strata can lead to issues in datasets which have large skews. If one strata is significantly larger than others, this could also cause memory overflow issues on the executors resulting in spark sessions crashing. 

## Simplifying the required sample column
For pyspark, we can simplify the creation of the required sample column by using dictionaries. This does require the use of the `itertools` package and `chain` function to map the values within the column. A similar work around may also be possible in R.

In [11]:
from itertools import chain
# Create dictionary
strata_dictionary = {'Bird':15, 'Cat':20, 'Dog':10, 'Fox': 2, 'Other':5}


mapping_example = rescue.withColumn('animal_type',F.when(~F.col('animal_type').isin(['Cat','Bird','Dog','Fox']),'Other')
                                              .otherwise(F.col('animal_type')))

mapping_expr = F.create_map([F.lit(x) for x in chain(*strata_dictionary.items())]) 
mapping_example = mapping_example.withColumn("sample_size", 
              mapping_expr[F.col("animal_type")])

mapping_example.groupBy('animal_type','sample_size').agg(F.count('animal_type').alias('count')).sort('count',ascending = False).show()

+-----------+-----------+-----+
|animal_type|sample_size|count|
+-----------+-----------+-----+
|        Cat|         20| 2909|
|       Bird|         15| 1100|
|        Dog|         10| 1008|
|      Other|          5|  643|
|        Fox|          2|  238|
+-----------+-----------+-----+

