# Lambda Functions Primer
## SIADS 516 Week 2 Office Hours
### Created by Josh Haskins
<br>

## Contents:
* What is a lambda function?
* Basic examples
* Intermediate examples
* Overly complex example (don't do this please)
* Lambda functions in PySpark

### What is a lambda function?
<br>
Lambda functions are small, anonymous (unnamed) one line functions that can take the place of a named function. They are generally used for simple operations but can handle more complex operations if you desire. <br>
<br>
It is worth noting that despite it being an unnamed function, you can actually store them in a variable for later use.
<br>
<br>
Syntax of a lambda function:
<br>
<br>
   <b><h4 style="text-align:center"> <span style="color:green">lambda </span> <span style="color:blue"> arguments </span>: <span style="color:red"> expression </span> </h4></b>
<br>

***<span style="color:green">lambda </span>***: Keyword that tells the Python interpreter we are creating a lambda function
<br>
***<span style="color:blue"> arguments </span>***: Any number of arguments can be supplied to a lambda function 
<br>
***<span style="color:red"> expression </span>*** : Only a single expression can be in a lambda function, and this expression will be automatically returned
<br>


### Basic examples
<a id='basic_examples'></a>

In [1]:
# The classic square function in lambda form
square_lambda = lambda x: x * x    # argument - input value, expression - input value multiplied by 
print(square_lambda(5))

25


In [2]:
# The regular form of the same function
def square(x):
    return x * x
    
print(square(5))

25


In [3]:
# We can call regular functions inside lambda functions
squared_square = lambda x: square(x) * square(x)
print(squared_square(4))

256


In [4]:
# We can call lambda functions inside lambda functions
lambda_squared_square = lambda x: square_lambda(x) * square_lambda(x)
print(lambda_squared_square(4))

256


In [5]:
# We can mix and match
mixed_squared_square = lambda x: square_lambda(x) * square(x)
print(mixed_squared_square(4))

256


In [6]:
# We can have more than one input
add_lambda = lambda x, y: x + y
print(add_lambda(3, 5))

8


### Intermediate Examples

In [1]:
# Given a,b,and c as inputs calculate both the positive and negative roots. Note: if you see j in your result that is the python symbol for an imaginary number ^_^
quadratic_solver = lambda a,b,c: ((-b + (b**2 - 4*a*c)**(1/2)) / (2*a), (-b - (b**2 - 4*a*c)**(1/2)) / (2*a))
print(quadratic_solver(3,2,1))

((-0.3333333333333333+0.47140452079103173j), (-0.3333333333333333-0.47140452079103173j))


In [8]:
# Normal version of the quadratic solver. Note: we can get the square root of something by raising it to the 1/2 power
def quadratic_normal(a,b,c):
    positive = (-b + (b**2 - 4*a*c)**(1/2)) / (2*a)
    negative = (-b - (b**2 - 4*a*c)**(1/2)) / (2*a)
    return positive, negative
print(quadratic_normal(1,1,-6))

(2.0, -3.0)


In [9]:
# Flatten a nested list one level
flatten_lambda = lambda full_list: [item for sub_list in full_list for item in sub_list] 
nested_list = [[1, 2], [3, 4], [5]]
flatten_lambda(nested_list)

[1, 2, 3, 4, 5]

In [10]:
# Normal version of flatten
def flatten_normal(full_list):
    new_list = []
    for sub_list in full_list:
        for item in sub_list:
            new_list.append(item)
    return new_list
flatten_normal(nested_list)

[1, 2, 3, 4, 5]

Let's write a lambda function to solve the following piecewise mathematical function: <br>

$$f(x) =
\begin{cases}
x^2 & \text{if } x < 0, \\
\sqrt{x} & \text{if } 0 \leq x < 10, \\
x + 10 & \text{if } x \geq 10.
\end{cases} $$



In [11]:
# We can include logic inside lambda functions to change what the expression does just like in regular functions, but it is limited to binary operations (no elif)
piecewise_lambda = lambda x: x**2 if x < 0 else (x**0.5 if x < 10 else x + 10)
print(piecewise_lambda(-4))
print(piecewise_lambda(4)) 
print(piecewise_lambda(15))

16
2.0
25


Note: This is probably right on the edge of being too complex for what you should be doing with a lambda function in the first place. The general rule of thumb I have heard is that if you have to do more than 2 things in a lambda function, you should probably write a normal function just for readability purposes. Here is the same thing in a more readable format:

In [12]:
def piecewise_normal(x):
    if x < 0:
        return x**2
    elif x < 10:
        return x ** 0.5
    else:
        return x + 10
print(piecewise_normal(-4))
print(piecewise_normal(4)) 
print(piecewise_normal(15))

16
2.0
25


### Overly Complex Example
The following example shows that lambda functions can do some extremely complex things, but also hopefully demonstrate why you shouldn't do this because they are effectively unreadable.
<br>
<br>
Example - Compute the following equation:

$$f(x) =
\frac{\text{factorial}(x) + \sin^2(x) - \cos(x)}{
\begin{cases} 
x^3 & \text{if } x \neq 0, \\
1 & \text{if } x = 0
\end{cases}
}
+ \sum_{i=1}^{x} i^2$$

In [13]:
# Insane lambda version (note: this is a recursive lambda expression!)
complex_math_lambda = lambda x: (
    (1 if x == 0 else x * complex_math_lambda(x - 1))  # factorial replacement
    + (x - (x**3)/6 + (x**5)/120)**2  # sin(x)**2 approximation
    - (1 - (x**2)/2 + (x**4)/24)  # cos(x) approximation
) / (x**3 if x != 0 else 1) + sum(i**2 for i in range(1, x+1))

complex_math_lambda(5)

56.953782161458335

In [14]:
# Normal function for the same equation. Much longer but also understandable to anyone else who looks at it.
def complex_math_normal(x):
    # Helper function to compute factorial recursively
    def factorial(n):
        if n == 0:
            return 1
        return n * factorial(n - 1)

    # Helper function to approximate sin(x) using Taylor series expansion
    def sine(x):
        # sin(x) ≈ x - (x^3)/6 + (x^5)/120
        return x - (x**3) / 6 + (x**5) / 120

    # Helper function to approximate cos(x) using Taylor series expansion
    def cosine(x):
        # cos(x) ≈ 1 - (x^2)/2 + (x^4)/24
        return 1 - (x**2) / 2 + (x**4) / 24

    # Compute the factorial of x
    numerator = factorial(x)

    # Add the square of the sine of x to the numerator
    numerator += sine(x)**2

    # Subtract the cosine of x from the numerator
    numerator -= cosine(x)

    # Compute the denominator: x^3, but avoid division by zero
    denominator = x**3 if x != 0 else 1

    # Compute the sum of squares from 1 to x
    sum_of_squares = sum(i**2 for i in range(1, x + 1))

    # Final result: (numerator / denominator) + sum_of_squares
    return numerator / denominator + sum_of_squares

complex_math_normal(5) # slight differences due to the way python handles floats

56.677347222222224

### Lambda Functions in PySpark

Lambda functions in PySpark work largely the same as in regular python, but since we are generally passing them into Spark functions we need to be careful to understand what inputs Spark is going to pass into the lambda function are. Let's use the `reduceByKey` and `map` methods as examples to illustrate this.
<br>
<br>
`reduceByKey` operates on pair RDDs which by definition contain tuples of (key,value). This means we always know it will be a 2-tuple (think MRJob inputs and outputs). This means that `reduceByKey` can use a unique lambda format that you see in the lecture slides: `.reduceByKey(lambda a, v: a + v)` where `a` is the accumulator and `v` is the value being added to that accumulator.
<br>
<br>
`map` on the other hand can be used anywhere, so how can Spark know in advance what will be contained in the underlying RDD? Short version: it can't. The result is that if you try to do something like `.map(lambda x, y: y)` instead of `.map(lambda x: x[1])` you will get an error even though it sure feels like it should work. However, can you do this: `.map(lambda (x,y): y)` ?

Here is a quick table of some of the common PySpark methods for your reference:

| **Method**           | **Input Function** | **What It Operates On**                      | **Example**                     |
|-----------------------|--------------------|----------------------------------------------|----------------------------------|
| `.map`               | `lambda x`        | Each element in the RDD                      | `lambda x: x[1]` (extract index)|
| `.flatMap`           | `lambda x`        | Each element in the RDD, but returns multiple elements | `lambda x: x.split()` |
| `.filter`            | `lambda x`        | Each element in the RDD (returns `True/False`) | `lambda x: x > 5`              |
| `.reduceByKey`       | `lambda a, b`     | Values grouped by key                        | `lambda a, b: a + b`            |
| `.groupByKey`        | `lambda x`        | Groups all values for a key (not a reduction) | `lambda x: x[1]`                |
| `.sortBy`            | `lambda x`        | A key for sorting                            | `lambda x: x[1]` (sort by index)|
| `.keyBy`             | `lambda x`        | Turns elements into key-value pairs          | `lambda x: x[0]`                |
