# Spark Basics

#### Count lines and words in a text file

In [1]:
# Prerequisites
from pyspark.sql import SparkSession

In [2]:
# Get SparkSession
spark = SparkSession.builder.master("local").getOrCreate()
print("Spark Version: ", spark.version)

Spark Version:  3.5.0


In [3]:
print("Type of Spark object: ", type(spark))

Type of Spark object:  <class 'pyspark.sql.session.SparkSession'>


In [4]:
# Read text file and show beginning
df_txt = spark.read.text("data/lorem_ipsum.txt")
df_txt.show(10, truncate=False)

+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [5]:
# Number of lines
print("Number of lines: ", df_txt.count())

Number of lines:  39


In [6]:
# Number of words "Lorem"
filtered = df_txt.filter(df_txt.value.contains("lorem"))
print("Number of times word 'lorem' occurs: ", filtered.count())

Number of times word 'lorem' occurs:  8


#### Read Text data into Partitions

In [7]:
# Read data into partitions
df_txt_2 = spark.read.text("data/lorem_ipsum.txt").repartition(4)
print("Numer of partitions: ", df_txt_2.rdd.getNumPartitions())
df_txt_2.show(n=10, truncate=False)

Numer of partitions:  4
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|value                                                                                                                                                                                                                                                                                                                        

### Partition table of integers

In [8]:
# place 1000 intergers into 8 partititions
df_ints = spark.range(0, 10000, 1, 8)
print("Numer of partitions: ", df_ints.rdd.getNumPartitions())
df_ints.show()

Numer of partitions:  8
+---+
| id|
+---+
|  0|
|  1|
|  2|
|  3|
|  4|
|  5|
|  6|
|  7|
|  8|
|  9|
| 10|
| 11|
| 12|
| 13|
| 14|
| 15|
| 16|
| 17|
| 18|
| 19|
+---+
only showing top 20 rows



#### collect() Action
In Apache Spark, the collect() function is used to retrieve all elements of a distributed dataset (RDD, DataFrame, or Dataset) to the driver program as a list or array. This is often used when you need to perform actions on the entire dataset locally.

In [9]:
# Create a DataFrame
data = [("Alice", 25), ("Bob", 30), ("Charlie", 35)]
columns = ["Name", "Age"]
df_names = spark.createDataFrame(data, columns)

df_names.show()

+-------+---+
|   Name|Age|
+-------+---+
|  Alice| 25|
|    Bob| 30|
|Charlie| 35|
+-------+---+



In [11]:
# Use collect to bring all rows to the driver
collected_data = df_names.collect()

print("Type of collected_data: ", type(collected_data))
print("Type of collected_data[0]: ", type(collected_data[0]))
# Print the collected rows
for row in collected_data:
    print(row)

Type of collected_data:  <class 'list'>
Type of collected_data[0]:  <class 'pyspark.sql.types.Row'>
Row(Name='Alice', Age=25)
Row(Name='Bob', Age=30)
Row(Name='Charlie', Age=35)


### take() Action
The take() function in Apache Spark is an action used to retrieve a specified number of elements from a distributed dataset (RDD, DataFrame, or Dataset).

In [12]:
# Create an RDD
rdd = spark.sparkContext.parallelize([10, 20, 30, 40, 50, 60, 70, 80, 90, 100])

# Use take to fetch the first 3 elements
result = rdd.take(3)

print("First 3 Elements:", result)
print("Type of result: ", type(result))

First 3 Elements: [10, 20, 30]
Type of result:  <class 'list'>


#### Create a DataFrame Manually from Rows

In [13]:
from pyspark.sql import Row

rows = [Row("John Smith", "CA"), Row("Jane Doe", "WA")]
df_persons = spark.createDataFrame(rows, ["Name", "State"])

df_persons.show()

+----------+-----+
|      Name|State|
+----------+-----+
|John Smith|   CA|
|  Jane Doe|   WA|
+----------+-----+

