# Big Data for Transportation: Mapping Operations Across Python

This notebook provides a tutorial on using "map" operations for efficient data processing. We'll explore various Python libraries, starting from the built-in `map` function and progressing to NumPy, JAX.
Later on we will see magic commands for interacting with larger datasets on Google Cloud.

This tutorial is structured to showcase how the concept of "map" evolves and scales as we move from basic Python to specialized libraries for numerical computation and big data processing.

Let's get started!

## 1. Python's Built-in `map` Function

The `map()` function in Python is a fundamental tool for applying a given function to each item of an iterable (like a list, tuple, etc.) and returning an iterator that yields the results. It embodies the core concept of "mapping" a transformation over a dataset.

Let's look at some basic examples to understand how `map` works.

### Example 1: Squaring Numbers


Suppose we have a list of numbers and we want to square each number. We can use `map` along with a lambda function or a defined function to achieve this.


In [None]:
numbers = [1, 2, 3, 4, 5]

# Using a lambda function for squaring
squared_numbers_map = map(lambda x: x**2, numbers)

# `map` returns an iterator, so we need to convert it to a list to see the results
squared_numbers_list = list(squared_numbers_map)
tuple(squared_numbers_map)
print(f"Original numbers: {numbers}")
print(f"Squared numbers using map: {squared_numbers_list}")

Original numbers: [1, 2, 3, 4, 5]
Squared numbers using map: [1, 4, 9, 16, 25]


**Explanation:**

*   `lambda x: x**2` is an anonymous function that takes an input `x` and returns its square.
*   `map(lambda x: x**2, numbers)` applies this lambda function to each element in the `numbers` list.
*   `list(squared_numbers_map)` converts the iterator returned by `map` into a list for display.

### Example 2: Converting Strings to Uppercase


Let's say we have a list of city names and we want to convert them to uppercase.


In [None]:
cities = ["new york", "london", "paris", "tokyo"]

uppercase_cities_map = map(str.upper, cities) # Using the built-in str.upper method

uppercase_cities_list = list(uppercase_cities_map)
print(f"Original cities: {cities}")
print(f"Uppercase cities using map: {uppercase_cities_list}")

Original cities: ['new york', 'london', 'paris', 'tokyo']
Uppercase cities using map: ['NEW YORK', 'LONDON', 'PARIS', 'TOKYO']


In [None]:
tmp = []
for c in cities:
  tmp.append(str.upper(c))
print(tmp)

['NEW YORK', 'LONDON', 'PARIS', 'TOKYO']


**Explanation:**

*   `str.upper` is a built-in string method that converts a string to uppercase. We pass this method directly as the function to `map`.
*   `map(str.upper, cities)` applies the `str.upper` method to each string in the `cities` list.

### Example 3: Applying a Custom Function

Let's define a function that calculates the distance traveled given speed and time, and use `map` to apply it to lists of speeds and times.

In [None]:
def calculate_distance(speed, time):
    return speed * time

speeds = [60, 70, 80] # km/h
times = [1.5, 2, 0.75] # hours

distances_map = map(calculate_distance, speeds, times) # Note: map can take multiple iterables

distances_list = list(distances_map)
print(f"Speeds: {speeds}")
print(f"Times: {times}")
print(f"Distances using map: {distances_list} km")

Speeds: [60, 70, 80]
Times: [1.5, 2, 0.75]
Distances using map: [90.0, 140, 60.0] km


**Explanation:**

*   `calculate_distance(speed, time)` is a function we defined.
*   `map(calculate_distance, speeds, times)` applies `calculate_distance` to corresponding elements from the `speeds` and `times` lists.  `map` can work with multiple input iterables, providing elements to the function in parallel.

### Exercise

**Exercise 1:  Celsius to Fahrenheit Conversion**

Write a function `celsius_to_fahrenheit(c)` that converts Celsius to Fahrenheit using the formula:  F = (9/5) * C + 32.
Then, use `map` to apply this function to a list of Celsius temperatures: `celsius_temps = [0, 10, 20, 30, 100]`.

In [None]:
# Your code for Exercise 1 here:
def celsius_to_fahrenheit(c):
    return (9/5) * c + 32

celsius_temps = [0, 10, 20, 30, 100]
fahrenheit_temps_map = map(celsius_to_fahrenheit, celsius_temps)
fahrenheit_temps_list = list(fahrenheit_temps_map)
print(f"Celsius temperatures: {celsius_temps}")
print(f"Fahrenheit temperatures: {fahrenheit_temps_list}")

Celsius temperatures: [0, 10, 20, 30, 100]
Fahrenheit temperatures: [32.0, 50.0, 68.0, 86.0, 212.0]


**Exercise 2:  Calculate Speed from Distance and Time (Multiple Arrays)**

You have two lists: `distances_km = [100, 250, 50, 120]` and `times_hr = [2, 5, 1, 1.5]`.
Use `map` and a lambda function to calculate the speed (speed = distance / time) for each corresponding pair of distance and time values.

In [None]:
# Your code for Exercise 2 here:
distances_km = [100, 250, 50, 120]
times_hr = [2, 5, 1, 1.5]

speeds_map = map(lambda d, t: d / t, distances_km, times_hr)
speeds_list = list(speeds_map)
print(f"Distances (km): {distances_km}")
print(f"Times (hr): {times_hr}")
print(f"Speeds (km/hr): {speeds_list}")

Distances (km): [100, 250, 50, 120]
Times (hr): [2, 5, 1, 1.5]
Speeds (km/hr): [50.0, 50.0, 50.0, 80.0]


**Exercise 3:  Categorize Trip Distance (Indexing within lambda)**

You have a list of trip distances: `trip_distances = [2.5, 15.8, 0.9, 7.3, 22.1]`.
And categories: `categories = ["short", "medium", "long"]`.
Use `map` and a lambda function that categorizes each trip distance as follows:
- If distance < 5 km: "short"
- If 5 km <= distance < 15 km: "medium"
- If distance >= 15 km: "long"

*Hint: You might need to use indexing within your lambda function to access the `categories` list based on conditions.*


In [None]:
# Your code for Exercise 3 here:
trip_distances = [2.5, 15.8, 0.9, 7.3, 22.1]
categories = ["short", "medium", "long"]

trip_categories_map = map(lambda dist: categories[0] if dist < 5 else (categories[1] if dist < 15 else categories[2]), trip_distances)
trip_categories_list = list(trip_categories_map)
print(f"Trip distances: {trip_distances}")
print(f"Trip categories: {trip_categories_list}")

Trip distances: [2.5, 15.8, 0.9, 7.3, 22.1]
Trip categories: ['short', 'long', 'short', 'medium', 'long']


### Limitations of Python's `map` for Big Data

While `map` is elegant and useful for many tasks, it has limitations when dealing with very large datasets and computationally intensive operations, especially in the context of transportation data, which can be massive.

*   **Performance:** Python's built-in `map` is implemented in C, but the function you apply is still executed in Python's interpreter. For complex operations or large datasets, this can become a performance bottleneck.
*   **No Vectorization or Parallelism:**  Python's `map` doesn't inherently offer vectorization (performing operations on entire arrays at once) or automatic parallelism (utilizing multiple CPU cores or GPUs).

To overcome these limitations, we turn to libraries like NumPy and JAX, which are designed for numerical computation and offer significant performance improvements, especially when dealing with arrays and large datasets.

## 2. NumPy: Numerical Powerhouse for Arrays

NumPy (Numerical Python) is the cornerstone of numerical computing in Python. It introduces the concept of **arrays**, which are highly efficient data structures for storing and manipulating numerical data. NumPy operations are often implemented in compiled languages (like C and Fortran), making them significantly faster than standard Python loops, especially for large datasets.

### 2.1. Concepts of NumPy Arrays

*   **Homogeneous Data:** NumPy arrays store elements of the same data type (e.g., all integers, all floats). This homogeneity allows for efficient memory storage and faster operations.
*   **Multidimensional:** Arrays can be one-dimensional (vectors), two-dimensional (matrices or tables), or have even higher dimensions.
*   **Vectorization:** NumPy operations are often vectorized, meaning they operate on entire arrays element-wise without explicit loops. This is a key factor in NumPy's performance.
*   **Broadcasting:** NumPy can perform operations on arrays with different shapes in certain cases, using a mechanism called broadcasting, which automatically expands the smaller array to match the shape of the larger array.

### 2.2. Creating NumPy Arrays

Let's see how to create different types of NumPy arrays.

In [None]:
import numpy as np

# 2.2.1 A 1D array is like a list or a vector.
one_d_array = np.array([10, 20, 30, 40, 50])
print("1D Array:")
print(one_d_array)
print(f"Shape: {one_d_array.shape}") # Shape is (5,), indicating 5 elements in 1 dimension
print(f"Data type: {one_d_array.dtype}") # Data type is inferred automatically (int64 here)

1D Array:
[10 20 30 40 50]
Shape: (5,)
Data type: int64


#### 2.2.2. 2D Array (Matrix/Table) - Representing Excel-like Data

A 2D array can represent tabular data, much like an Excel spreadsheet or a table in a database. Let's imagine we have data about taxi trips: trip ID, distance, and fare.

In [None]:
taxi_data_2d = np.array([
    [1, 5.2, 15.5],  # Trip 1: ID, Distance, Fare
    [2, 8.1, 22.0],  # Trip 2
    [3, 3.5, 10.8],  # Trip 3
    [4, 12.0, 35.2] # Trip 4
])

print("\n2D Array (Taxi Trip Data - Table):")
print(taxi_data_2d)
print(f"Shape: {taxi_data_2d.shape}") # Shape is (4, 3), 4 rows (trips), 3 columns (features)



2D Array (Taxi Trip Data - Table):
[[ 1.   5.2 15.5]
 [ 2.   8.1 22. ]
 [ 3.   3.5 10.8]
 [ 4.  12.  35.2]]
Shape: (4, 3)


### 2.2.3. 3D Array (Sequential Data) - Representing Time Series

A 3D array can be used to represent sequential data, such as sensor readings over time for multiple vehicles. Imagine we have sensor readings (speed, acceleration) taken every second for 3 vehicles over 5 seconds.

In [None]:
sequential_data_3d = np.array([
    [ # Vehicle 1
        [10, 0.1], # Second 1: Speed, Acceleration
        [15, 0.2], # Second 2
        [20, 0.0], # Second 3
        [22, -0.1],# Second 4
        [20, -0.2] # Second 5
    ],
    [ # Vehicle 2
        [5, 0.05],
        [8, 0.1],
        [12, 0.0],
        [10, -0.05],
        [8, -0.1]
    ],
    [ # Vehicle 3
        [25, 0.2],
        [30, 0.1],
        [32, 0.05],
        [35, 0.0],
        [33, -0.1]
    ]
])

print("\n3D Array (Sequential Sensor Data):")
print(sequential_data_3d)
print(f"Shape: {sequential_data_3d.shape}") # Shape is (3, 5, 2): 3 vehicles, 5 time steps, 2 features


3D Array (Sequential Sensor Data):
[[[10.    0.1 ]
  [15.    0.2 ]
  [20.    0.  ]
  [22.   -0.1 ]
  [20.   -0.2 ]]

 [[ 5.    0.05]
  [ 8.    0.1 ]
  [12.    0.  ]
  [10.   -0.05]
  [ 8.   -0.1 ]]

 [[25.    0.2 ]
  [30.    0.1 ]
  [32.    0.05]
  [35.    0.  ]
  [33.   -0.1 ]]]
Shape: (3, 5, 2)



**Interpretation:**

*   The first dimension (axis 0) represents vehicles.
*   The second dimension (axis 1) represents time steps (seconds).
*   The third dimension (axis 2) represents sensor features (speed, acceleration).

### 2.2.4. Quick Array Creation: `np.ones`, `np.zeros`, `np.random.rand`

NumPy provides convenient functions to quickly create large arrays filled with specific values or random numbers. These are useful for initializing arrays, creating placeholders, or generating test data.

In [None]:
# Create an array of ones
ones_array = np.ones((10, 10)) # 10x10 array filled with 1.0s
print("Array of ones (10x10):")
print(ones_array)
print(f"Shape: {ones_array.shape}")

# Create an array of zeros
zeros_array = np.zeros((5, 20), dtype=int) # 5x20 array filled with 0s, explicitly set integer dtype
print("\nArray of zeros (5x20, integer type):")
print(zeros_array)
print(f"Shape: {zeros_array.shape}, Data type: {zeros_array.dtype}")

# Create an array with random values between 0 and 1
random_array = np.random.rand(1000, 1000) # 1000x1000 array with random floats [0, 1)
print("\nArray with random values (1000x1000, first few elements):")
print(random_array[:3, :3]) # Print only the top-left 3x3 portion for brevity
print(f"Shape: {random_array.shape}")
del random_array

Array of ones (10x10):
[[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]]
Shape: (10, 10)

Array of zeros (5x20, integer type):
[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]]
Shape: (5, 20), Data type: int64

Array with random values (1000x1000, first few elements):
[[0.03237564 0.35566141 0.9619268 ]
 [0.37896167 0.92276219 0.75943766]
 [0.04149752 0.38544571 0.54552464]]
Shape: (1000, 1000)


### 2.2.5. Visualizing Memory Usage of NumPy Arrays

Large NumPy arrays can consume significant memory. It's helpful to understand how to check the memory usage of your arrays. We can use `sys.getsizeof` (gives shallow size) and `.nbytes` attribute (gives actual data size).


In [None]:
import sys
import numpy as np

large_random_array = np.random.rand(10**9) # 1 billion random floats

# Shallow size (size of the array object itself, not the data)
shallow_size_bytes = sys.getsizeof(large_random_array)
shallow_size_mb = shallow_size_bytes / (1024**2) # Convert to megabytes

# Deep size (size of the actual data in the array)
deep_size_bytes = large_random_array.nbytes
deep_size_mb = deep_size_bytes / (1024**2) # Convert to megabytes

print(f"\nMemory usage of large_random_array (shape: {large_random_array.shape}, dtype: {large_random_array.dtype}):")
print(f"Shallow Size (sys.getsizeof): {shallow_size_mb:.2f} MB")
print(f"Deep Size (array.nbytes): {deep_size_mb:.2f} MB") # This is the more relevant size for data
del large_random_array


Memory usage of large_random_array (shape: (1000000000,), dtype: float64):
Shallow Size (sys.getsizeof): 7629.39 MB
Deep Size (array.nbytes): 7629.39 MB


**Explanation:**

*   `sys.getsizeof()` gives the size of the Python object itself, which is usually small for NumPy arrays, as the actual data is stored in a contiguous memory block outside the main Python object.
*   `.nbytes` attribute of a NumPy array gives the total number of bytes consumed by the array's data elements. This is the more accurate measure of the memory used to store the array's numerical values.

When working with very large datasets, being mindful of memory usage is crucial. NumPy's efficient storage and operations help in managing memory effectively compared to standard Python lists.

### 2.3. `np.vectorize` and Performance

`np.vectorize` is a function in NumPy that can be used to convert a Python function into a vectorized function.  However, it's important to understand that **`np.vectorize` does not actually make your Python function run faster in the way that true NumPy vectorization does.** It's essentially a convenience function that applies your Python function element-wise to NumPy arrays, often using a Python loop under the hood.

**It's generally NOT recommended for performance-critical code when working with NumPy arrays. True NumPy vectorization (using built-in NumPy functions and operations) is much more efficient.**

Let's illustrate this with an example and compare its performance to standard NumPy operations and Python's `map`.

In [None]:
def complex_function(x):
    # A somewhat computationally intensive Python function (just for demonstration)
    return (x**2 + np.sin(x)) / (1 + np.abs(x))

# Create a large NumPy array
large_array = np.random.rand(int(1e6))

# 1. Using np.vectorize
vectorized_func = np.vectorize(complex_function)

# Time using np.vectorize
%timeit vectorized_result = vectorized_func(large_array)

# 2. Using Python's map (for comparison, not ideal for NumPy arrays directly)
list_array = large_array.tolist() # Convert to list for map
%timeit _ = list(map(complex_function, list_array))

# 3. Using NumPy's vectorized operations (the efficient way) - if possible to rewrite the function
def vectorized_complex_function(x):
    return (x**2 + np.sin(x)) / (1 + np.abs(x)) # This function is already vectorized for NumPy

# Time using vectorized NumPy operation
%timeit _ = vectorized_complex_function(large_array)

del large_array
del list_array

2.41 s ± 492 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.71 s ± 562 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
21.7 ms ± 1.47 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


**Observations:**

*   You'll likely see that `np.vectorize` is often slower than true NumPy vectorization (if a vectorized implementation is possible) and might even be comparable to or slightly slower than Python's `map` (especially if the function is not trivial).
*   **True NumPy vectorization** (using NumPy's built-in functions and operations that are already vectorized) is the most efficient approach for numerical operations on NumPy arrays.

**When to (rarely) consider `np.vectorize`:**

*   When you have a pre-existing Python function that is difficult or time-consuming to rewrite using NumPy's vectorized operations, and you want to apply it element-wise to a NumPy array.
*   For simpler tasks where performance is not the absolute top priority and convenience is preferred.

**In most cases, you should aim to leverage NumPy's built-in vectorized operations for optimal performance.**

### 2.4. Other Basic NumPy Operations (Briefly)

NumPy offers a vast array of functions for array manipulation, linear algebra, Fourier transforms, random number generation, and much more. Here are a few very basic examples related to "mapping" or element-wise operations.


In [None]:
# Element-wise addition, subtraction, multiplication, division
array1 = np.array([1, 2, 3])
array2 = np.array([4, 5, 6])

print("\nElement-wise operations:")
print(f"array1 + array2: {array1 + array2}")
print(f"array1 * array2: {array1 * array2}")

# Broadcasting example: scalar multiplication
scalar = 2
print(f"\nBroadcasting (scalar * array1): {scalar * array1}")

# Applying mathematical functions element-wise
print(f"\nElement-wise sine (np.sin(array1)): {np.sin(array1)}")


Element-wise operations:
array1 + array2: [5 7 9]
array1 * array2: [ 4 10 18]

Broadcasting (scalar * array1): [2 4 6]

Element-wise sine (np.sin(array1)): [0.84147098 0.90929743 0.14112001]


In [None]:
# Exercise: create 2 random array of (10, 50), (50, 20), and do matrix multiplication
arr1 = np.random.rand(10, 50)
arr2 = np.random.rand(50, 20)

res = arr1 @ arr2
print(res.shape)
print(res.nbytes)

(10, 20)
1600


## 3. JAX: NumPy on Steroids with Auto-differentiation and Acceleration

JAX is a powerful library developed by Google that builds upon NumPy and adds automatic differentiation, just-in-time (JIT) compilation, and excellent support for GPUs and TPUs. It's designed for high-performance numerical computing and machine learning.

### 3.1. Concepts of JAX

*   **NumPy API:** JAX largely follows the NumPy API, making it relatively easy to transition from NumPy to JAX. Most NumPy code can be run with JAX with minimal changes.
*   **Automatic Differentiation (Autograd):** JAX can automatically compute gradients of Python functions, which is crucial for machine learning and optimization tasks.
*   **Just-In-Time (JIT) Compilation:** JAX can compile Python functions into optimized machine code using `jax.jit`. This can dramatically speed up execution, especially for repeated computations.
*   **Vectorization and Parallelization (`vmap`, `pmap`):** JAX provides `vmap` for automatic vectorization and `pmap` for parallelization across multiple devices (GPUs/TPUs or CPU cores).
*   **GPU/TPU Acceleration:** JAX is designed to run efficiently on GPUs and TPUs, enabling significant speedups for computationally intensive tasks.

In [None]:
def f(x):
  return x**2

x = jnp.array([i for i in range(5)], dtype=jnp.float32)
grad_fn = jax.vmap(jax.grad(f))
x_square = f(x)
gradients = grad_fn(x)
print(x)
print(x_square)
print(gradients)

[0. 1. 2. 3. 4.]
[ 0.  1.  4.  9. 16.]
[0. 2. 4. 6. 8.]


### 3.2. JIT Compilation with `jax.jit`


In [None]:
import jax
import jax.numpy as jnp # JAX's NumPy replacement

# Define a function (similar to our complex_function from NumPy)
def jax_complex_function(x):
    return (x**2 + jnp.sin(x)) / (1 + jnp.abs(x))

# Create a JAX array (similar to NumPy array)
large_array = np.random.rand(int(1e7))
jax_large_array = jnp.array(large_array) # Convert our NumPy array to JAX array

# JIT compile the function
jit_jax_complex_function = jax.jit(jax_complex_function)

# **Run the JIT-compiled function once to trigger compilation**
dummy_input = jnp.zeros((1,))
_ = jit_jax_complex_function(dummy_input) # Discard the result, we just want to compile it

# Now time the JIT-compiled function (it's already compiled)
print("\nTiming JIT-compiled function (after initial compilation):")
%timeit _ = jit_jax_complex_function(jax_large_array)

# For comparison, time the function without JIT (optional, but good to see the difference)
print("\nTiming function without JIT:")
%timeit _ = jax_complex_function(jax_large_array)

del large_array
del jax_large_array
del _


Timing JIT-compiled function (after initial compilation):
104 ms ± 19.4 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)

Timing function without JIT:
401 ms ± 62 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


**Observations:**

*   The first time you run a JIT-compiled function, there will be a slight overhead for compilation. Subsequent runs will be much faster as the compiled version is reused.
*   By running `_ = jit_jax_complex_function(jax_large_array)` *before* the `%timeit`, we ensure that the timing measurement reflects the execution speed of the *compiled* code, not the compilation time.
*   You should observe a significant speedup with the JIT-compiled version compared to the non-JIT version.

### 3.3. Vectorization with `jax.vmap`

`jax.vmap` (vectorized map) is a powerful tool for automatically vectorizing functions. It takes a function that operates on single data points and transforms it into a function that operates on batches of data points (i.e., vectors or higher-dimensional arrays).

Let's take our `jax_complex_function` and use `vmap` to process a batch of inputs efficiently.

In [None]:
import jax
import jax.numpy as jnp

# Function that operates on a single input
@jax.jit
def single_process_function(x):
    return x**2 + 1

# Generate a large batch of random inputs (1000x1000 matrix now for demonstration)
large_batch_inputs = jnp.array(np.random.rand(1000, 1000))

# 1. Vectorize with vmap
vectorized_process_function_vmap = jax.vmap(single_process_function, in_axes=(0,))
jit_vectorized_process_function_vmap = jax.jit(vectorized_process_function_vmap) # JIT-compile vmap


@jax.jit
def loop_process_batch_fori(batch):
    batch_size = batch.shape[0]

    def body_fun(i, carry):
        result = single_process_function(batch[i])
        return carry.at[i].set(result)

    # Initialize the output array
    init_carry = jnp.zeros_like(batch)
    # Use fori_loop to accumulate results
    results = jax.lax.fori_loop(0, batch_size, body_fun, init_carry)
    return results

@jax.jit
def loop_process_batch_map(batch):
    return jax.lax.map(single_process_function, batch)

# Python loop version for baseline comparison
def loop_process_batch_python(batch):
    results = []
    batch_size = batch.shape[0]
    for i in range(batch_size):
        x = batch[i]
        results.append(single_process_function(x))
    return jnp.array(results)


# **Run JIT-compiled functions once for compilation**
dummy_input = jnp.zeros((1, 1))
_ = single_process_function(dummy_input)
_ = jit_vectorized_process_function_vmap(dummy_input)
_ = loop_process_batch_fori(dummy_input)
_ = loop_process_batch_map(dummy_input)
_ = loop_process_batch_python(dummy_input)


print("\nTiming: Vectorized function with vmap (and JIT):")
%timeit _ = jit_vectorized_process_function_vmap(large_batch_inputs)

print("\nTiming: Loop-based batch processing with jax.lax.fori_loop (and JIT):")
%timeit _ = loop_process_batch_fori(large_batch_inputs)

print("\nTiming: Loop-based batch processing with Python loop (and JIT):")
%timeit _ = loop_process_batch_python(large_batch_inputs)


print("\nTiming: Map-based batch processing with jax.lax.map (and JIT):")
%timeit _ = loop_process_batch_map(large_batch_inputs)


Timing: Vectorized function with vmap (and JIT):
853 µs ± 142 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Timing: Loop-based batch processing with jax.lax.fori_loop (and JIT):
The slowest run took 19.66 times longer than the fastest. This could mean that an intermediate result is being cached.
60.8 µs ± 100 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

Timing: Loop-based batch processing with Python loop (and JIT):
485 ms ± 55.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Timing: Loop-based batch processing with jax.lax.map (and JIT):
1.77 ms ± 240 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


**Explanation:**

*   **`jax.vmap(single_process_function, in_axes=(0,))`**: We use `jax.vmap` to vectorize `single_process_function`.
    *   `in_axes=(0,)` specifies that the vectorization should happen along the *first axis* (axis 0) of the input argument. Since `large_batch_inputs` is a 2D array (1000x1000), and we want to apply `single_process_function` to each row (or the first dimension), `in_axes=(0,)` is correct. If `large_batch_inputs` were 1D, `in_axes=(0)` would also work (or even just omitting `in_axes` as the default is `0`).

### 3.4. Parallelization with `jax.pmap`

`jax.pmap` is for parallelizing computations across multiple devices (CPUs/GPUs/TPUs). Let's illustrate with a simpler example of processing sequential data in parallel.

**Instruction:  Ensure you are using a multi-device runtime in Colab (Runtime -> Change runtime type -> Select a runtime with more CPUs).**

In [None]:
import jax
jax.devices()
# [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
#  TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
#  TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
#  TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
#  TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
#  TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
#  TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
#  TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [None]:
import numpy as np
import jax.numpy as jnp

# Simplified function: Calculate sum of squares for a time series
@jax.jit
def sum_of_squares(time_series): # time_series is 2D: [time_steps, features]
    return jnp.sum(time_series**2)

def sum_of_squares_np(time_series): # time_series is 2D: [time_steps, features]
    return np.sum(time_series**2)

# Generate sequential data for multiple vehicles (smaller example for clarity)
num_vehicles = 8
time_steps = int(1e8)  # change to smaller ones for your RAM
num_features = 2
num_devices = jax.device_count()

# Generate data
data = np.random.rand(num_vehicles, time_steps, num_features).astype(np.float32)
data = data.reshape((num_devices, num_vehicles // num_devices, time_steps, num_features))
sharded_data = jax.device_put_sharded(list(data), jax.devices())

# Parallelize processing of each vehicle's time series using pmap
parallel_sum_of_squares = jax.pmap(sum_of_squares, axis_name='devices')
dum_p = jnp.zeros((1,1,1,1), dtype=jnp.float32)
_ = parallel_sum_of_squares(dum_p)

print("\n Parallelization with pmap on multi-device runtime:")
print(f"Shape of Input Data (Sequential Data): {sharded_data.shape}")

# Time parallel processing
%timeit parallel_results_pmap = parallel_sum_of_squares(sharded_data)

del sharded_data

data = data.reshape(-1, time_steps, num_features)
# Sequential processing for comparison
def sequential_sum_of_squares_all_vehicles(all_vehicle_data):
    results = []
    for vehicle_data in all_vehicle_data:
        results.append(sum_of_squares_np(vehicle_data))
    return jnp.array(results)

print("\nSequential processing (for comparison):")
%timeit sequential_results_loop = sequential_sum_of_squares_all_vehicles(data)



 Parallelization with pmap on multi-device runtime:
Shape of Input Data (Sequential Data): (8, 1, 100000000, 2)
2.44 ms ± 432 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Sequential processing (for comparison):
2.41 s ± 12.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


**Explanation:**

*   `sum_of_squares` is a simple function that calculates the sum of squares of all elements in a time series (2D array).
*   **`jax.pmap(sum_of_squares, axis_name='devices')`**: We parallelize `sum_of_squares` across the vehicle dimension (axis 0). Each vehicle's time series is intended to be processed on a separate device (CPU core in a multi-CPU runtime).
*   **Performance Comparison:** We compare the timing of `pmap` with a sequential loop to highlight the potential speedup from parallelization.

**Key takeaway:** `pmap` allows you to distribute computations across devices, potentially speeding up processing when you have independent units of work (like processing each vehicle's data separately). The actual speedup you observe in Colab will depend on the runtime environment and the nature of the computation. For true large-scale parallelism, consider dedicated GPU/TPU environments.

## 4. Magic Commands for Data Interaction (NYC Taxi Data Example)

Let's try accessing the NYC Taxi dataset. The free tier includes:
1 TB of query processing per month
10 GB of storage

In [None]:
project_id = "bda-6893"  # Replace with your actual project ID

In [None]:
%%bigquery --project $project_id
SELECT
    vendor_id,
    pickup_datetime,
    dropoff_datetime,
    passenger_count,
    trip_distance,
    fare_amount,
    tip_amount,
    total_amount,
    pickup_location_id,
    dropoff_location_id
FROM
    `bigquery-public-data.new_york_taxi_trips.tlc_yellow_trips_2020` # Example table (January 2020)
WHERE
    EXTRACT(MONTH FROM pickup_datetime) = 1  # To get January data
    AND passenger_count > 0  -- Filter out invalid records
    AND fare_amount > 0      -- Filter out invalid records
LIMIT 10

Query is running:   0%|          |

Downloading:   0%|          |

Unnamed: 0,vendor_id,pickup_datetime,dropoff_datetime,passenger_count,trip_distance,fare_amount,tip_amount,total_amount,pickup_location_id,dropoff_location_id
0,2,2020-01-14 17:17:58+00:00,2020-01-14 17:18:10+00:00,4,0.06,1.3,0.0,1.6,28,28
1,2,2020-01-06 10:41:48+00:00,2020-01-06 10:41:53+00:00,4,0.01,0.3,0.0,0.6,138,138
2,1,2020-01-15 18:11:20+00:00,2020-01-15 18:12:45+00:00,1,0.0,1.0,0.0,1.3,146,146
3,2,2020-01-05 17:12:20+00:00,2020-01-05 17:12:36+00:00,1,0.0,1.0,0.0,1.3,226,226
4,1,2020-01-01 03:08:02+00:00,2020-01-01 03:08:51+00:00,2,0.6,0.46,0.0,0.76,234,234
5,1,2020-01-24 08:54:10+00:00,2020-01-24 09:01:01+00:00,1,1.3,0.41,0.0,0.41,239,141
6,1,2020-01-16 17:50:11+00:00,2020-01-16 17:55:29+00:00,1,1.0,0.41,0.0,0.41,43,43
7,1,2020-01-28 17:09:33+00:00,2020-01-28 17:10:26+00:00,3,0.0,0.69,0.0,0.99,132,132
8,1,2020-01-17 16:17:58+00:00,2020-01-17 17:06:33+00:00,1,7.6,0.01,0.0,0.31,138,161
9,1,2020-01-28 23:17:10+00:00,2020-01-28 23:58:41+00:00,1,17.9,0.01,0.0,0.31,209,20
