## Import dependencies

In [1]:
!pip install pyspark py4j



In [2]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col,substring,trim,to_timestamp,lit
from pyspark.sql.types import StructType, StructField, StringType, FloatType, TimestampType
import json
import os
from abc import ABC, abstractmethod
import uuid


In [3]:
spark = SparkSession.builder.appName("test_pyspark").getOrCreate()

## Load utils

In [4]:
def load_json_file(path):
    with open(path, 'r') as file:
        config = json.load(file)
    return config


def array_to_colspecs(arr):
    start_pos = 0
    colspecs = []

    for width in arr:
        end_pos = start_pos + width
        colspecs.append((start_pos, end_pos))
        start_pos = end_pos

    return colspecs

def generate_file_schema(columns_name,columns_type):
  schema_fields = []
  for name, dtype in zip(config["metadata"]["columns_name"], config["metadata"]["columns_type"]):
      if dtype == "string":
          schema_fields.append(StructField(name, StringType(), True))
      elif dtype == "timestamp":
          schema_fields.append(StructField(name, TimestampType(), True))
      elif dtype == "float":
          schema_fields.append(StructField(name, FloatType(), True))

  schema = StructType(schema_fields)
  return schema

# Function to generate column expressions based on widths
def generate_column_exprs(widths, names):
    exprs = []
    startPos = 1
    for width, name in zip(widths, names):
        exprs.append(substring(col("value"), startPos, width).alias(name))
        startPos += width
    return exprs

def verify_dict(dict_to_check,required_fields):
    for field in required_fields:
        if field not in dict_to_check:
            raise ValueError(f"Field {field} is required in the config file")

## Loading config

In [15]:
## Fixed width
#config_path = "/content/config_fixed_width.json"
## Delimiter
config_path = "/content/config.json"
config = load_json_file(config_path)

## Strategies

### Abstract

In [6]:
class DataLoaderStrategy(ABC):
    @abstractmethod  #This method will validate the config file
    def load_data(self, config):
        pass

    @abstractmethod    #This method will load the data from a delimited file
    def validateConfig(self, config):
        pass

### Fixed_width strategy

In [7]:
class FixedWidthFileLoader(DataLoaderStrategy):
  def validateConfig(self, config):
        required_fields = ['origin_path', 'final_path', 'metadata']
        verify_dict(config,required_fields)

        metadata_required_fields = ['columns_widths', 'columns_name','columns_type']
        verify_dict(config['metadata'],metadata_required_fields)

  def load_data(self, config):
    self.validateConfig(config)
    df = spark.read.text(config["origin_path"])
    # Use the metadata to parse the DataFrame columns
    column_exprs = generate_column_exprs(config["metadata"]["columns_widths"], config["metadata"]["columns_name"])
    df = df.select(*column_exprs)
    # Casting the columns to their specified data types based on the JSON metadata
    for name, dtype in zip(config["metadata"]["columns_name"], config["metadata"]["columns_type"]):
        if dtype == "float":
            df = df.withColumn(name, col(name).cast(FloatType()))
        if(dtype == "string"):
            df = df.withColumn(name, trim(col(name)))
        if(dtype == 'timestamp'):
            df = df.withColumn(name, trim(col(name)))

    return df

### Delimited Parser


In [12]:
class DelimitedFileLoader(DataLoaderStrategy):
   def validateConfig(self, config):
        required_fields = ['origin_path', 'final_path', 'metadata']
        verify_dict(config,required_fields)

        metadata_required_fields = ['delimiter']
        verify_dict(config['metadata'],metadata_required_fields)


   def load_data(self, config):
      self.validateConfig(config)
      df = spark.read.csv(config['origin_path'], header=True, inferSchema=True, sep=config['metadata']['delimiter'])
      return df

### Getting Strategy


In [13]:
strategy_map = {
    'delimited': DelimitedFileLoader(),
    'fixed_width': FixedWidthFileLoader()
}


def get_strategy(strategy_type):
    strategy = strategy_map[strategy_type]
    if not strategy:
        raise ValueError(f"Invalid strategy: {strategy_type}")
    return strategy

## Read DF

In [16]:
strategy = get_strategy(config['metadata']['type'])
df = strategy.load_data(config)

## Adding Aditional Data

In [18]:
if 'additional_data' in config['metadata']:
    for col, value in config['metadata']['additional_data'].items():
        df = df.withColumn(col, lit(value))

In [19]:
df.head()

Row(stock='A', transaction_date='02-Jun-2012', open_price=34.93, close_price=34.93, max_price=34.93, min_price=34.93, variation=0, partition_date='2024-04-05')

## Saving Data

In [20]:
# Check if 'partition_by' is in your configuration and write the DataFrame to Parquet accordingly
full_path = f"{config['final_path']}/{uuid.uuid4()}"
if 'partition_by' in config:
    # Write DataFrame to Parquet with partitioning
    df.write.partitionBy(config['partition_by']).parquet(full_path)
else:
    # Write DataFrame to Parquet without partitioning
    df.write.parquet(full_path)