In [None]:
from pyspark.sql.functions import explode, col, lit, xxhash64
from math import ceil

# Import glow.py and register Glow package
import glow
glow.register(spark)

In [None]:
# Provide your storage account, container and SAS token
outputStorageAccount = 
outputContainer = 
outputSAS = 
outputDir =

In [None]:
# Configure session credentials
# Set up a SAS for a container with public data - no changes needed here (public SAS)
spark.conf.set(
  "fs.azure.sas.dataset.dataset1000genomes.blob.core.windows.net",
  "sv=2019-10-10&si=prod&sr=c&sig=9nzcxaQn0NprMPlSh4RhFQHcXedLQIcFgbERiooHEqM%3D")

# Set up a SAS for a container to store .parquet files
spark.conf.set(
  "fs.azure.sas."+outputContainer+"."+outputStorageAccount+".blob.core.windows.net", outputSAS)


In [None]:
# DropDuplicates() partitions into 200 pieces (default value)
# To change default number of partitions change config -  sqlContext.setConf("spark.sql.shuffle.partitions", <YourNumberOfPartitions>)
partitionMax = 1500
sqlContext.setConf("spark.sql.shuffle.partitions", partitionMax)

In [None]:
# Flatten struct columns
def flattenStructFields(df):
  flat_cols = [c[0] for c in df.dtypes if c[1][:6] != 'struct']
  nested_cols = [c[0] for c in df.dtypes if c[1][:6] =='struct']
  flat_df = df.select(flat_cols + 
                     [col(nc+'.'+c).alias(nc+'_'+c)
                     for nc in nested_cols
                     for c in df.select(nc+'.*').columns])
  return flat_df

# Add empty columns to match schema
def completeSchema(df, diffSet):
  full_df = df
  for column in diffSet:
    full_df = full_df.withColumn(column.name, lit(None).cast(column.dataType.simpleString()))
  return full_df

# Transform dataframe with original vcf schema
def transformVcf(df, toFlatten, toHash, fullSchemaFields):
  # Drop duplicates
  dataDedup = df.dropDuplicates()
     
  # Add hashId column to identify variants
  if toHash:
    hashCols = list(set(data.columns) - {'genotypes'})
    dataHashed = dataDedup.withColumn('hashId', xxhash64(*hashCols))
  else:
    dataHashed = dataDedup
  
  # Flatten data - explode on genotypes, create separate column for each genotypes field, add empty columns to match schema to full dataset
  if not toFlatten:
    dataFinal = dataHashed
  else:
  # Explode and flatten data
    dataExploded = dataHashed.withColumn('genotypes', explode('genotypes'))
    dataExplodedFlatten = flattenStructFields(dataExploded)
  # Find schema for contig dataset and add columns to match full schema
    contigSet = set(dataExplodedFlatten.schema.fields)
    diffSet =(fullSchemaFields - contigSet)
    dataFinal = completeSchema(dataExplodedFlatten, diffSet)
   
  return dataFinal

In [None]:
# Create widgets for toFlatten and contigs
flatOptions = [False, True]
dbutils.widgets.dropdown("flatten", "False", [str(x) for x in flatOptions])

contigOptions =  list(map(str, range(1, 23)))
contigLiterals = ['X','Y','MT', 'All']
contigOptions.extend(contigLiterals)
dbutils.widgets.multiselect("contigsToProcess", "22", contigOptions)

In [None]:
# Define parameters
toFlatten = eval(getArgument("flatten"))
toHash = True
repartitionCoef = 45 / 1000000 # gives ~20MB .parquet files

# Define contig list
contigs = getArgument("contigsToProcess").split(",")
if "All" in contigs:
  contigs = contigOptions
  contigs.remove('All')

# Find schema for full dataset
sourceAll = "wasbs://dataset@dataset1000genomes.blob.core.windows.net/release/20130502/ALL.chr*.vcf.gz"
dataAll = spark.read\
  .format("vcf")\
  .option("includeSampleIds", True)\
  .option("flattenInfoFields", True)\
  .load(sourceAll)

dataAllExploded = dataAll.withColumn('genotypes', explode('genotypes'))
dataAllExplodedFlatten = flattenStructFields(dataAllExploded)
fullSet = set(dataAllExplodedFlatten.schema.fields)
                 
for contig in contigs:
  source = "wasbs://dataset@dataset1000genomes.blob.core.windows.net/release/20130502/ALL.chr"+contig+".*.vcf.gz"

# Load data
  data = spark.read\
    .format("vcf")\
    .option("includeSampleIds", True)\
    .option("flattenInfoFields", True)\
    .load(source)
  
  # Define number of partitions, will be used for coalesce later
  rowCount = data.count()
  partCount = ceil (repartitionCoef * rowCount)  
  if partCount > partitionMax:
    partCount = partitionMax

  dataFinal = transformVcf(data, toFlatten, toHash, fullSet)
  if not toFlatten:
    sink = "wasbs://"+outputContainer + "@" + outputStorageAccount + ".blob.core.windows.net"+ outputDir + "/original/chr"+contig
  else:
    sink = "wasbs://"+outputContainer + "@" + outputStorageAccount + ".blob.core.windows.net"+ outputDir + "/flattened/chr"+contig
                 
  dataFinal.coalesce(partCount). \
    write. \
    mode("overwrite"). \
    format("parquet"). \
    save(sink)
