## Spark Schema Evolution using the Glue Catalog:

There are 4 cases:

1. Columns are added in the middle - which is an incompatible change.
2. Columns are dropped i.e. missing - which is a compatible change as nulls are expected for newer records.
3. Column types are changed - is an incompatible change. You could try automatic casting as a resolution here.
4. Columns are added at the end - which is a compatible change.


In [1]:
import boto3

glue = boto3.client('glue',region_name='us-west-2')
account='123456789012'

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
6,application_1598906256794_0007,pyspark,idle,Link,Link,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [2]:
from pyspark.sql.functions import when, lit, col

typeMapping={
    "LongType":"bigint",
    "StringType":"string",
    "IntegerType":"integer",
    "DecimalType":"decimal",
    "BooleanType":"boolean"
    "FloatType":"decimal",
    "BinaryType":"binary",
    "BooleanType":"boolean"
}

def evolveSchema(df, table,strict=False,forcecast=True):
    original_df=spark.sql("SELECT * FROM "+table+" LIMIT 0")
    odf = original_df
    if (df.schema != odf.schema):
        if (strict):
            print ("Strict Schema Validation Failed : Incoming Schema is not compatible with existing Table.")
            if len(odf.schema) > len(df.schema):
                print ("Original Data Diff : "+str(set(odf.schema)-set(df.schema)))
            else:
                print ("Incoming Data Diff : "+str(set(df.schema)-set(odf.schema)))
            return (False,df)
        else:
            new_cols=[s for s in list(set([s.name for s in df.schema.fields])-set([s.name for s in odf.schema.fields]))]
            
            missing_cols=[s for s in list(set([s.name for s in odf.schema.fields])-set([s.name for s in df.schema.fields]))]
            for k in missing_cols:
                df=df.withColumn(k,lit(None))
                
            if forcecast:
                ## force cast columns 
                existing_cols=[f'cast({s.name} as {s.dataType.typeName()}) {s.name}' for s in odf.schema.fields]
                #print  (existing_cols)
                #print  (new_cols)
                new_df = df.selectExpr(existing_cols+new_cols)
            else:    
                ## re-arrange the columns to ensure schema compatibility
                existing_cols=[s.name for s in odf.schema.fields]
                new_df = df.select(existing_cols+new_cols)
                  
            return (True, new_df)
    return (True,None)

def update_glue_schema(dataframe, database, table):
    
    response_get_table=glue.get_table(
       CatalogId=account,
       DatabaseName=database,
       Name=table
    )
    print (response_get_table)
    
    tableInput=response_get_table['Table']
    
    del tableInput['DatabaseName']
    del tableInput['CreateTime']
    del tableInput['UpdateTime']
    del tableInput['CreatedBy']
    del tableInput['IsRegisteredWithLakeFormation']
    
    columns=[{'Name':k.name,'Type':typeMapping[str(k.dataType)]} for k in dataframe.schema.fields]
    tableInput['StorageDescriptor']['Columns']=columns
    
    response_update_table = glue.update_table(
       CatalogId=account,
       DatabaseName=database,
       TableInput=tableInput)
    
    return response_update_table    

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [3]:
spark.sql("SELECT * FROM schema_test.s_schema_test LIMIT 10").show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---+-------------------+---+---+
| id|      modified_time| sk|txt|
+---+-------------------+---+---+
|  0|2020-08-31 22:27:03|  0|  A|
|  1|2020-08-31 22:27:03|  1|  B|
|  2|2020-08-31 22:27:03|  2|  C|
|  3|2020-08-31 22:27:03|  3|  D|
|  4|2020-08-31 22:27:03|  4|  E|
|  5|2020-08-31 22:27:03|  5|  F|
|  6|2020-08-31 22:27:03|  6|  G|
|  7|2020-08-31 22:27:03|  7|  H|
|  8|2020-08-31 22:27:03|  8|  I|
|  9|2020-08-31 22:27:03|  9|  J|
+---+-------------------+---+---+

### Case 1 : Columns are added in the middle

In [6]:
from datetime import datetime

## Generates Data
def get_json_data(start, count, increment=0):
    now = str(datetime.today().replace(microsecond=0))
    data = [{"id": i, "id1": i,"sk": i+increment, "txt": chr(65 + (i % 26)), "modified_time" : now } for i in range(start, start + count)]
    return data

# Creates the Dataframe
def create_json_df(spark, data):
    sc = spark.sparkContext
    return spark.read.json(sc.parallelize(data, 2))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [7]:
df1 = create_json_df(spark, get_json_data(0, 4000))
print(df1.count())
df1.printSchema()
df1.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

4000
root
 |-- id: long (nullable = true)
 |-- id1: long (nullable = true)
 |-- modified_time: string (nullable = true)
 |-- sk: long (nullable = true)
 |-- txt: string (nullable = true)

+---+---+-------------------+---+---+
| id|id1|      modified_time| sk|txt|
+---+---+-------------------+---+---+
|  0|  0|2020-08-31 22:30:52|  0|  A|
|  1|  1|2020-08-31 22:30:52|  1|  B|
|  2|  2|2020-08-31 22:30:52|  2|  C|
+---+---+-------------------+---+---+
only showing top 3 rows

In [8]:
evolved_df=evolveSchema(df1,"schema_test.s_schema_test",False)
evolved_df[1].printSchema()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- id: long (nullable = true)
 |-- modified_time: string (nullable = true)
 |-- sk: long (nullable = true)
 |-- txt: string (nullable = true)
 |-- id1: long (nullable = true)

In [9]:
evolved_df[1].write.mode("APPEND").parquet("s3://s3bucket/parquet/schema_test/")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [10]:
update_glue_schema(evolved_df[1],'schema_test','s_schema_test')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

{'Table': {'Name': 's_schema_test', 'DatabaseName': 'schema_test', 'Owner': 'owner', 'CreateTime': datetime.datetime(2020, 8, 31, 22, 28, 33, tzinfo=tzlocal()), 'UpdateTime': datetime.datetime(2020, 8, 31, 22, 28, 33, tzinfo=tzlocal()), 'LastAccessTime': datetime.datetime(2020, 8, 31, 22, 28, 33, tzinfo=tzlocal()), 'Retention': 0, 'StorageDescriptor': {'Columns': [{'Name': 'id', 'Type': 'bigint'}, {'Name': 'modified_time', 'Type': 'string'}, {'Name': 'sk', 'Type': 'bigint'}, {'Name': 'txt', 'Type': 'string'}], 'Location': 's3://s3bucket/parquet/schema_test/', 'InputFormat': 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat', 'OutputFormat': 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat', 'Compressed': False, 'NumberOfBuckets': -1, 'SerdeInfo': {'SerializationLibrary': 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe', 'Parameters': {'serialization.format': '1'}}, 'BucketColumns': [], 'SortColumns': [], 'Parameters': {'CrawlerSchemaDeserializer

In [11]:
spark.sql("SELECT * FROM schema_test.s_schema_test LIMIT 10").show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---+-------------------+---+---+---+
| id|      modified_time| sk|txt|id1|
+---+-------------------+---+---+---+
|  0|2020-08-31 22:30:52|  0|  A|  0|
|  1|2020-08-31 22:30:52|  1|  B|  1|
|  2|2020-08-31 22:30:52|  2|  C|  2|
|  3|2020-08-31 22:30:52|  3|  D|  3|
|  4|2020-08-31 22:30:52|  4|  E|  4|
|  5|2020-08-31 22:30:52|  5|  F|  5|
|  6|2020-08-31 22:30:52|  6|  G|  6|
|  7|2020-08-31 22:30:52|  7|  H|  7|
|  8|2020-08-31 22:30:52|  8|  I|  8|
|  9|2020-08-31 22:30:52|  9|  J|  9|
+---+-------------------+---+---+---+

In [12]:
spark.sql("SELECT * FROM schema_test.s_schema_test WHERE id1 is null LIMIT 10").show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---+-------------------+---+---+----+
| id|      modified_time| sk|txt| id1|
+---+-------------------+---+---+----+
|  0|2020-08-31 22:27:03|  0|  A|null|
|  1|2020-08-31 22:27:03|  1|  B|null|
|  2|2020-08-31 22:27:03|  2|  C|null|
|  3|2020-08-31 22:27:03|  3|  D|null|
|  4|2020-08-31 22:27:03|  4|  E|null|
|  5|2020-08-31 22:27:03|  5|  F|null|
|  6|2020-08-31 22:27:03|  6|  G|null|
|  7|2020-08-31 22:27:03|  7|  H|null|
|  8|2020-08-31 22:27:03|  8|  I|null|
|  9|2020-08-31 22:27:03|  9|  J|null|
+---+-------------------+---+---+----+

### Case 2 : Columns are dropped i.e. missing.

Note : New Column is moved to the end, making it a compatible change.

In [13]:
## Generates Data
def get_json_data(start, count, increment=0):
    now = str(datetime.today().replace(microsecond=0))
    data = [{"id": i, "id1": i, "txt": chr(65 + (i % 26)), "modified_time" : now } for i in range(start, start + count)]
    return data

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [14]:
df2 = create_json_df(spark, get_json_data(0, 4000))
print(df2.count())
df2.printSchema()
df2.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

4000
root
 |-- id: long (nullable = true)
 |-- id1: long (nullable = true)
 |-- modified_time: string (nullable = true)
 |-- txt: string (nullable = true)

+---+---+-------------------+---+
| id|id1|      modified_time|txt|
+---+---+-------------------+---+
|  0|  0|2020-08-31 22:32:26|  A|
|  1|  1|2020-08-31 22:32:26|  B|
|  2|  2|2020-08-31 22:32:26|  C|
+---+---+-------------------+---+
only showing top 3 rows

In [27]:
evolved_df=evolveSchema(df2,"schema_test.s_schema_test",False)
evolved_df[1].printSchema()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

['sk']
root
 |-- id: long (nullable = true)
 |-- modified_time: string (nullable = true)
 |-- sk: long (nullable = true)
 |-- txt: string (nullable = true)
 |-- id1: long (nullable = true)

In [35]:
evolved_df[1].show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---+-------------------+----+---+---+
| id|      modified_time|  sk|txt|id1|
+---+-------------------+----+---+---+
|  0|2020-08-31 22:32:26|null|  A|  0|
|  1|2020-08-31 22:32:26|null|  B|  1|
|  2|2020-08-31 22:32:26|null|  C|  2|
|  3|2020-08-31 22:32:26|null|  D|  3|
|  4|2020-08-31 22:32:26|null|  E|  4|
|  5|2020-08-31 22:32:26|null|  F|  5|
|  6|2020-08-31 22:32:26|null|  G|  6|
|  7|2020-08-31 22:32:26|null|  H|  7|
|  8|2020-08-31 22:32:26|null|  I|  8|
|  9|2020-08-31 22:32:26|null|  J|  9|
| 10|2020-08-31 22:32:26|null|  K| 10|
| 11|2020-08-31 22:32:26|null|  L| 11|
| 12|2020-08-31 22:32:26|null|  M| 12|
| 13|2020-08-31 22:32:26|null|  N| 13|
| 14|2020-08-31 22:32:26|null|  O| 14|
| 15|2020-08-31 22:32:26|null|  P| 15|
| 16|2020-08-31 22:32:26|null|  Q| 16|
| 17|2020-08-31 22:32:26|null|  R| 17|
| 18|2020-08-31 22:32:26|null|  S| 18|
| 19|2020-08-31 22:32:26|null|  T| 19|
+---+-------------------+----+---+---+
only showing top 20 rows

In [36]:
evolved_df[1].write.mode("APPEND").parquet("s3://s3bucket/parquet/schema_test/")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [3]:
spark.sql("SELECT * FROM schema_test.s_schema_test WHERE sk is null LIMIT 10").show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---+-------------------+----+---+---+
| id|      modified_time|  sk|txt|id1|
+---+-------------------+----+---+---+
|  0|2020-08-31 22:32:26|null|  A|  0|
|  1|2020-08-31 22:32:26|null|  B|  1|
|  2|2020-08-31 22:32:26|null|  C|  2|
|  3|2020-08-31 22:32:26|null|  D|  3|
|  4|2020-08-31 22:32:26|null|  E|  4|
|  5|2020-08-31 22:32:26|null|  F|  5|
|  6|2020-08-31 22:32:26|null|  G|  6|
|  7|2020-08-31 22:32:26|null|  H|  7|
|  8|2020-08-31 22:32:26|null|  I|  8|
|  9|2020-08-31 22:32:26|null|  J|  9|
+---+-------------------+----+---+---+

In [4]:
spark.sql("SELECT * FROM schema_test.s_schema_test WHERE sk is not null LIMIT 10").show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---+-------------------+---+---+---+
| id|      modified_time| sk|txt|id1|
+---+-------------------+---+---+---+
|  0|2020-08-31 22:30:52|  0|  A|  0|
|  1|2020-08-31 22:30:52|  1|  B|  1|
|  2|2020-08-31 22:30:52|  2|  C|  2|
|  3|2020-08-31 22:30:52|  3|  D|  3|
|  4|2020-08-31 22:30:52|  4|  E|  4|
|  5|2020-08-31 22:30:52|  5|  F|  5|
|  6|2020-08-31 22:30:52|  6|  G|  6|
|  7|2020-08-31 22:30:52|  7|  H|  7|
|  8|2020-08-31 22:30:52|  8|  I|  8|
|  9|2020-08-31 22:30:52|  9|  J|  9|
+---+-------------------+---+---+---+

### Case 3: Column types are changed.

In [5]:
from datetime import datetime

## Generates Data
def get_json_data(start, count, increment=0):
    now = str(datetime.today().replace(microsecond=0))
    data = [{"id": i, "id1": str(i),"sk": i+increment, "txt": chr(65 + (i % 26)), "modified_time" : now } for i in range(start, start + count)]
    return data

# Creates the Dataframe
def create_json_df(spark, data):
    sc = spark.sparkContext
    return spark.read.json(sc.parallelize(data, 2))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [6]:
df3 = create_json_df(spark, get_json_data(0, 4000))
print(df3.count())
df3.printSchema()
df3.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

4000
root
 |-- id: long (nullable = true)
 |-- id1: string (nullable = true)
 |-- modified_time: string (nullable = true)
 |-- sk: long (nullable = true)
 |-- txt: string (nullable = true)

+---+---+-------------------+---+---+
| id|id1|      modified_time| sk|txt|
+---+---+-------------------+---+---+
|  0|  0|2020-08-31 22:56:52|  0|  A|
|  1|  1|2020-08-31 22:56:52|  1|  B|
|  2|  2|2020-08-31 22:56:52|  2|  C|
+---+---+-------------------+---+---+
only showing top 3 rows

In [9]:
evolved_df=evolveSchema(df3,"schema_test.s_schema_test",False, True)
evolved_df[1].printSchema()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- id: long (nullable = true)
 |-- modified_time: string (nullable = true)
 |-- sk: long (nullable = true)
 |-- txt: string (nullable = true)
 |-- id1: long (nullable = true)

In [10]:
evolved_df[1].write.mode("APPEND").parquet("s3://s3bucket/parquet/schema_test/")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### Case 4: Columns are added at the end.

In [11]:
from datetime import datetime

## Generates Data
def get_json_data(start, count, increment=0):
    now = str(datetime.today().replace(microsecond=0))
    data = [{"id": i, "id1": i,"sk": i+increment, "txt": chr(65 + (i % 26)), "modified_time" : now,  "txt1": chr(65 + (i % 26)) } for i in range(start, start + count)]
    return data

# Creates the Dataframe
def create_json_df(spark, data):
    sc = spark.sparkContext
    return spark.read.json(sc.parallelize(data, 2))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [12]:
df4 = create_json_df(spark, get_json_data(0, 4000))
print(df4.count())
df4.printSchema()
df4.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

4000
root
 |-- id: long (nullable = true)
 |-- id1: long (nullable = true)
 |-- modified_time: string (nullable = true)
 |-- sk: long (nullable = true)
 |-- txt: string (nullable = true)
 |-- txt1: string (nullable = true)

+---+---+-------------------+---+---+----+
| id|id1|      modified_time| sk|txt|txt1|
+---+---+-------------------+---+---+----+
|  0|  0|2020-08-31 23:12:46|  0|  A|   A|
|  1|  1|2020-08-31 23:12:46|  1|  B|   B|
|  2|  2|2020-08-31 23:12:46|  2|  C|   C|
+---+---+-------------------+---+---+----+
only showing top 3 rows

In [13]:
evolved_df=evolveSchema(df4,"schema_test.s_schema_test",False, True)
evolved_df[1].printSchema()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- id: long (nullable = true)
 |-- modified_time: string (nullable = true)
 |-- sk: long (nullable = true)
 |-- txt: string (nullable = true)
 |-- id1: long (nullable = true)
 |-- txt1: string (nullable = true)

In [14]:
evolved_df[1].write.mode("APPEND").parquet("s3://s3bucket/parquet/schema_test/")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [15]:
update_glue_schema(evolved_df[1],'schema_test','s_schema_test')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

{'Table': {'Name': 's_schema_test', 'DatabaseName': 'schema_test', 'Owner': 'owner', 'CreateTime': datetime.datetime(2020, 8, 31, 22, 28, 33, tzinfo=tzlocal()), 'UpdateTime': datetime.datetime(2020, 8, 31, 22, 31, 12, tzinfo=tzlocal()), 'LastAccessTime': datetime.datetime(2020, 8, 31, 22, 28, 33, tzinfo=tzlocal()), 'Retention': 0, 'StorageDescriptor': {'Columns': [{'Name': 'id', 'Type': 'bigint'}, {'Name': 'modified_time', 'Type': 'string'}, {'Name': 'sk', 'Type': 'bigint'}, {'Name': 'txt', 'Type': 'string'}, {'Name': 'id1', 'Type': 'bigint'}], 'Location': 's3://s3bucket/parquet/schema_test/', 'InputFormat': 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat', 'OutputFormat': 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat', 'Compressed': False, 'NumberOfBuckets': -1, 'SerdeInfo': {'SerializationLibrary': 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe', 'Parameters': {'serialization.format': '1'}}, 'BucketColumns': [], 'SortColumns': [], 'Param