---
layout: post
title:  Spark Dataset APIs
date:   2025-11-07
categories: [Spark, Scala]
mermaid: true
maths: true
typora-root-url: /Users/ojitha/GitHub/ojitha.github.io
typora-copy-images-to: ../../blog/assets/images/${filename}
---

<style>
/* Styles for the two-column layout */
.image-text-container {
    display: flex; /* Enables flexbox */
    flex-wrap: wrap; /* Allows columns to stack on small screens */
    gap: 20px; /* Space between the image and text */
    align-items: center; /* Vertically centers content in columns */
    margin-bottom: 20px; /* Space below this section */
}

.image-column {
    flex: 1; /* Allows this column to grow */
    min-width: 250px; /* Minimum width for the image column before stacking */
    max-width: 40%; /* Maximum width for the image column to not take up too much space initially */
    box-sizing: border-box; /* Include padding/border in element's total width/height */
}

.text-column {
    flex: 2; /* Allows this column to grow more (e.g., twice as much as image-column) */
    min-width: 300px; /* Minimum width for the text column before stacking */
    box-sizing: border-box;
}

</style>

<div class="image-text-container">
    <div class="image-column">
        <img src="https://raw.githubusercontent.com/ojitha/blog/master/assets/images/2025-10027-Scala-2-Collections/scala-collections-illustration.svg" alt="Scala Functors" width="150" height="150">
    </div>
    <div class="text-column">
<p>TBC</p>
    </div>
</div>

<!--more-->

------

* TOC
{:toc}
------

## Introduction

## Introduction

### What are Datasets?

Apache Spark Datasets are the foundational type in Spark's Structured APIs, providing a **type-safe**, distributed collection of strongly typed JVM objects. While DataFrames are Datasets of type `Row`, Datasets allow you to define custom domain-specific objects that each row will consist of, combining the benefits of RDDs (type safety, custom objects) with the optimizations of DataFrames (Catalyst optimizer, Tungsten execution).

**Key Characteristics:**

1. **Type Safety**: Compile-time type checking prevents runtime type errors
2. **Encoders**: Special serialization mechanism that maps domain-specific types to Spark's internal binary format
3. **Catalyst Optimization**: Benefits from Spark SQL's query optimizer
4. **JVM Language Feature**: Available only in Scala and Java (not Python or R)
5. **Functional API**: Supports functional transformations like `map`, `filter`, `flatMap`

**Dataset[T]**: A distributed collection of data elements of type `T`, where `T` is a domain-specific class (case class in Scala, JavaBean in Java) that Spark can encode and optimize.

$$
\text{Dataset}[T] = \{t_1, t_2, \ldots, t_n\} \text{ where } t_i \in T
$$

Translation: A Dataset of type T is a collection of n elements, where each element belongs to type T.

**Encoder[T]**: A mechanism that converts between JVM objects of type `T` and Spark SQL's internal binary format (InternalRow).

$$
\text{Encoder}[T]: T \leftrightarrow \text{InternalRow}
$$

Translation: An Encoder for type T provides bidirectional conversion between objects of type T and Spark's internal row representation.

### Mathematical Foundations

Datasets embody key functional programming concepts:

1. **Functor Laws** (for `map`):
    - Identity: `ds.map(x => x) = ds`
    - Composition: `ds.map(f).map(g) = ds.map(x => g(f(x)))`

2. **Monad Laws** (for `flatMap`):
    - Left identity: `Dataset(x).flatMap(f) = f(x)`
    - Right identity: `ds.flatMap(x => Dataset(x)) = ds`
    - Associativity: `ds.flatMap(f).flatMap(g) = ds.flatMap(x => f(x).flatMap(g))`

### Dataset Movie Lens

Let's examine the MovieLens dataset: [recommended for education and development](https://grouplens.org/datasets/movielens/){:target="_blank"} for simplicity.

```mermaid
erDiagram
    MOVIES ||--o{ RATINGS : "receives"
    MOVIES ||--o{ TAGS : "has"
    MOVIES ||--|| LINKS : "references"
    
    MOVIES {
        int movieId PK "Primary Key"
        string title "Movie title with year"
        string genres "Pipe-separated genres"
    }
    
    RATINGS {
        int userId FK "Foreign Key to User"
        int movieId FK "Foreign Key to Movie"
        float rating "Rating value (0.5-5.0)"
        long timestamp "Unix timestamp"
    }
    
    TAGS {
        int userId FK "Foreign Key to User"
        int movieId FK "Foreign Key to Movie"
        string tag "User-generated tag"
        long timestamp "Unix timestamp"
    }
    
    LINKS {
        int movieId PK "Primary Key"
        int movieId FK "Foreign Key to Movie"
        string imdbId "IMDB identifier"
        string tmdbId "TMDB identifier"
    }
```

#### **Entities and Attributes:**

1.  **MOVIES** (9,742 movies)
    -   `movieId` (Primary Key)
    -   `title` (includes release year)
    -   `genres` (pipe-separated list)
2.  **RATINGS** (100,836 ratings)
    -   `userId` (Foreign Key)
    -   `movieId` (Foreign Key)
    -   `rating` (0.5 to 5.0 stars)
    -   `timestamp` (Unix timestamp)
3.  **TAGS** (3,683 tags)
    -   `userId` (Foreign Key)
    -   `movieId` (Foreign Key)
    -   `tag` (user-generated metadata)
    -   `timestamp` (Unix timestamp)
4.  **LINKS** (9,742 links)
    -   `movieId` (Primary Key & Foreign Key)
    -   `imdbId` (IMDB identifier)
    -   `tmdbId` (The Movie Database identifier)

#### **Relationships:**

-   **MOVIES ↔ RATINGS**: One-to-Many (a movie can have multiple ratings)
-   **MOVIES ↔ TAGS**: One-to-Many (a movie can have multiple tags)
-   **MOVIES ↔ LINKS**: One-to-One (each movie has one set of external links)



In [1]:
// Configure Coursier to fetch doc JARs
interp.repositories() ++= Seq(
coursierapi.MavenRepository.of("https://repo1.maven.org/maven2")
)

// Enable compiler to use Java classpath (REMOVED the invalid doc.value line)
interp.configureCompiler(c => {
c.settings.usejavacp.value = true
})

// Import Spark
import $ivy.`org.apache.spark::spark-sql:3.3.1` // Or use any other 2.x version here


[32mimport [39m[36m$ivy.$[39m

In [2]:
import org.apache.logging.log4j.{LogManager, Level}
import org.apache.logging.log4j.core.config.Configurator

// Set log levels BEFORE creating SparkSession
Configurator.setRootLevel(Level.WARN)
Configurator.setLevel("org.apache.spark", Level.WARN)
Configurator.setLevel("org.apache.spark.executor.Executor", Level.WARN)


[32mimport [39m[36morg.apache.logging.log4j.{LogManager, Level}[39m
[32mimport [39m[36morg.apache.logging.log4j.core.config.Configurator[39m

In [3]:

import org.apache.spark.sql._

val spark = {
  NotebookSparkSession.builder()
    .master("local[*]")
    .getOrCreate()
}

//Set logger level to Warn
// spark.sparkContext.setLogLevel("ERROR")

import spark.implicits._

SLF4J: Class path contains multiple SLF4J bindings.
SLF4J: Found binding in [jar:file:/home/jovyan/.cache/coursier/v1/https/repo1.maven.org/maven2/org/apache/logging/log4j/log4j-slf4j-impl/2.17.2/log4j-slf4j-impl-2.17.2.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: Found binding in [jar:file:/home/jovyan/.cache/coursier/v1/https/repo1.maven.org/maven2/org/slf4j/slf4j-log4j12/1.7.30/slf4j-log4j12-1.7.30.jar!/org/slf4j/impl/StaticLoggerBinder.class]
SLF4J: See http://www.slf4j.org/codes.html#multiple_bindings for an explanation.
SLF4J: Actual binding is of type [org.apache.logging.slf4j.Log4jLoggerFactory]


11:48:02.941 [scala-interpreter-1] WARN  org.apache.hadoop.util.NativeCodeLoader - Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


[32mimport [39m[36morg.apache.spark.sql._[39m
[36mspark[39m: [32mSparkSession[39m = org.apache.spark.sql.SparkSession@5b4f9fd2
[32mimport [39m[36mspark.implicits._[39m

Let's define the Case class

In [4]:
case class Movie(
  movieId: Int,
  title: String,
  genres: String
)

defined [32mclass[39m [36mMovie[39m

Create a DataSet using the above Case class:

In [5]:
// Read CSV and convert to Dataset
val moviesDS = spark.read
  .option("header", "true")
  .option("inferSchema", "true")
  .csv("ml-latest-small/movies.csv")
  .as[Movie]

// Example queries
moviesDS.show(2)


+-------+----------------+--------------------+
|movieId|           title|              genres|
+-------+----------------+--------------------+
|      1|Toy Story (1995)|Adventure|Animati...|
|      2|  Jumanji (1995)|Adventure|Childre...|
+-------+----------------+--------------------+
only showing top 2 rows



[36mmoviesDS[39m: [32mDataset[39m[[32mMovie[39m] = [movieId: int, title: string ... 1 more field]

**Key Points:**

- Case classes must be serializable
- All fields should have Spark-compatible types
- The `.as[T]` method performs the conversion from DataFrame to Dataset

##### Understanding Encoders

Encoders are a critical component of the Dataset API. They provide:

1. <span>Efficient Serialisation</span>{:gtxt}: Convert JVM objects to Spark's internal Tungsten binary format
2. <span>Schema Generation</span>{:gtxt}: Automatically infer schema from case class structure
3. <span>Code Generation</span>{:gtxt}: Enable whole-stage code generation for better performance


In [6]:
import org.apache.spark.sql.Dataset
// for primitive types
val intDS : Dataset[Int] = Seq(1,2,3).toDS()

[32mimport [39m[36morg.apache.spark.sql.Dataset[39m
[36mintDS[39m: [32mDataset[39m[[32mInt[39m] = [value: int]

In [7]:
val tupleDS: Dataset[(String, Int)] = Seq(("a",1), ("b", 2)).toDS

[36mtupleDS[39m: [32mDataset[39m[([32mString[39m, [32mInt[39m)] = [_1: string, _2: int]

Using Case classes:

In [8]:
case class Dog(name: String, age: Int)

val dogsDS: Dataset[Dog] = Seq(Dog("Liela",3), Dog("Tommy", 5)).toDS

defined [32mclass[39m [36mDog[39m
[36mdogsDS[39m: [32mDataset[39m[[32mDog[39m] = [name: string, age: int]

In [9]:
dogsDS.show()

+-----+---+
| name|age|
+-----+---+
|Liela|  3|
|Tommy|  5|
+-----+---+



## Dataset Transformations

### map Transformation

The `map` transformation applies a function to each element in the Dataset, producing a new Dataset with transformed elements. It's a **narrow transformation** (no shuffle required) and maintains a **one-to-one relationship** between input and output elements.

```scala
def map[U](func: T => U)(implicit encoder: Encoder[U]): Dataset[U]
```
`f`: function

For example, to extract the movie title:


In [10]:
moviesDS.map(m => m.title).show(3, truncate=false)

+-----------------------+
|value                  |
+-----------------------+
|Toy Story (1995)       |
|Jumanji (1995)         |
|Grumpier Old Men (1995)|
+-----------------------+
only showing top 3 rows



In [11]:
def extractMovieInfoFun(movie: Movie): (String, String) = (movie.title, movie.genres)
moviesDS.map(extractMovieInfoFun)

defined [32mfunction[39m [36mextractMovieInfoFun[39m
[36mres11_1[39m: [32mDataset[39m[([32mString[39m, [32mString[39m)] = [_1: string, _2: string]

As shown above, you can create a function.

Or you can create a anonymous function as follows:

In [12]:
val extractMovieInfoAnonymousFun: Movie => (String, String) = movie => (movie.title, movie.genres)
moviesDS.map(extractMovieInfoAnonymousFun)

[36mextractMovieInfoAnonymousFun[39m: [32mMovie[39m => ([32mString[39m, [32mString[39m) = ammonite.$sess.cmd12$Helper$$Lambda$6500/771828955@2bdc4462
[36mres12_1[39m: [32mDataset[39m[([32mString[39m, [32mString[39m)] = [_1: string, _2: string]

Above can be directly written in the `map` function:

In [13]:
moviesDS.map(movie => (movie.title, movie.genres))

[36mres13[39m: [32mDataset[39m[([32mString[39m, [32mString[39m)] = [_1: string, _2: string]

### flatMap Transformation

The `flatMap` transformation applies a function to each element and **flattens** the results. Each input element can produce **zero, one, or multiple output elements**. This is essential for transformations like tokenization, exploding nested structures, or filtering with expansion.

```scala
def flatMap[U](func: T => TraversableOnce[U])(implicit encoder: Encoder[U]): Dataset[U]
```

Translation: Given a function that transforms each element of type `T` into a collection of type `U`, flatten all collections into a single Dataset of type `U`.

In [14]:
case class MovieGenres (id: Int, genres: String)
val genres = moviesDS.map { movie =>
    MovieGenres(movie.movieId, movie.genres)
}

defined [32mclass[39m [36mMovieGenres[39m
[36mgenres[39m: [32mDataset[39m[[32mMovieGenres[39m] = [id: int, genres: string]

In [15]:
genres.show(3, truncate=false)

+---+-------------------------------------------+
|id |genres                                     |
+---+-------------------------------------------+
|1  |Adventure|Animation|Children|Comedy|Fantasy|
|2  |Adventure|Children|Fantasy                 |
|3  |Comedy|Romance                             |
+---+-------------------------------------------+
only showing top 3 rows



In [16]:
val genresDS = genres.flatMap(m => m.genres.split("\\|"))
genresDS.show()

+---------+
|    value|
+---------+
|Adventure|
|Animation|
| Children|
|   Comedy|
|  Fantasy|
|Adventure|
| Children|
|  Fantasy|
|   Comedy|
|  Romance|
|   Comedy|
|    Drama|
|  Romance|
|   Comedy|
|   Action|
|    Crime|
| Thriller|
|   Comedy|
|  Romance|
|Adventure|
+---------+
only showing top 20 rows



[36mgenresDS[39m: [32mDataset[39m[[32mString[39m] = [value: string]

> The `split()` method takes a *regex pattern, and `|` is a special character in regex meaning "OR"*{:rtxt}. So `split("|")` doesn't work as expected. *Instead, use `split("\\|")` for split*{:gtxt}.
{:.yellow}

Complex Example: Nested Structure Explosion 

In [26]:
case class Sentence(id: Int, words: Seq[String], occurrences: Seq[Int])

object Sentence {
  // Create Sentence from string format "1: Hello, how are you?"
  def fromMovie(movie: Movie): Sentence = {
    val id = movie.movieId
    val text = movie.genres
    
    // Extract words
    val words = text.split("\\|").toSeq
    
    // Count occurrences of each word
    val wordCounts = words.groupBy(identity).mapValues(_.size).toMap
    val occurrences = words.map(word => wordCounts(word))
    
    Sentence(id, words, occurrences)
  }
  
}

defined [32mclass[39m [36mSentence[39m
defined [32mobject[39m [36mSentence[39m

In [27]:
val sentenceDS = moviesDS.map(Sentence.fromMovie)

[36msentenceDS[39m: [32mDataset[39m[[32mSentence[39m] = [id: int, words: array<string> ... 1 more field]

In [28]:
sentenceDS.show(3)

+---+--------------------+---------------+
| id|               words|    occurrences|
+---+--------------------+---------------+
|  1|[Adventure, Anima...|[1, 1, 1, 1, 1]|
|  2|[Adventure, Child...|      [1, 1, 1]|
|  3|   [Comedy, Romance]|         [1, 1]|
+---+--------------------+---------------+
only showing top 3 rows



In [32]:
sentenceDS.flatMap { sentence =>
    sentence.words.zip(sentence.occurrences).map { case (word, numOccured) =>
        (sentence.id, word, numOccured)
        
    }
}.show(5)

+---+---------+---+
| _1|       _2| _3|
+---+---------+---+
|  1|Adventure|  1|
|  1|Animation|  1|
|  1| Children|  1|
|  1|   Comedy|  1|
|  1|  Fantasy|  1|
+---+---------+---+
only showing top 5 rows



[^1]: Chambers, B., Zaharia, M., 2018. Spark: The Definitive Guide. Ch. 11: "Datasets"

[^2]: Holden Karau, Rachel Warren., 2017. High Performance Spark: Best Practices for Scaling and Optimizing Apache Spark. Ch. 3: "DataFrames, Datasets, and Spark SQL"

[^3]: Chambers, B., Zaharia, M., 2018. Spark: The Definitive Guide. Ch. 13: "Advanced RDDs"

[^4]: Holden Karau, Rachel Warren., 2017. High Performance Spark: Best Practices for Scaling and Optimizing Apache Spark. Ch. 4: "Joins (SQL and Core)"

[^5]: Holden Karau, Rachel Warren., 2017. High Performance Spark: Best Practices for Scaling and Optimizing Apache Spark. Ch. 6: "Working with Key/Value Data"

[^6]: Ryza, Sandy, Laserson, Uri, Owen, Sean, Wills, Josh., 2017. Advanced Analytics with Spark, 2nd Edition. Ch. 2: "Introduction to Data Analysis with Scala and Spark"

[^7]: [Apache Spark Dataset API Documentation](https://spark.apache.org/docs/2.4.8/api/scala/index.html#org.apache.spark.sql.Dataset) - Scala 2.x API

{:gtxt: .message color="green"}
{:ytxt: .message color="yellow"}
{:rtxt: .message color="red"}

In [1]:
scala.util.Properties.versionString


[36mres1[39m: [32mString[39m = [32m"version 2.12.20"[39m

In [16]:
spark.stop()