In [1]:
import json
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, concat, lit, sum as spark_sum, avg, count
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType
from pyspark.sql import functions as F
from io import StringIO
from pyspark.sql import DataFrame
import sys
import os
os.environ['HF_HUB_DISABLE_PROGRESS_BARS'] = '1'

In [2]:
# Initialize Spark
spark = SparkSession.builder \
    .appName("DAG_Lineage_Extraction") \
    .config("spark.sql.adaptive.enabled", "false") \
    .getOrCreate()

# Create sample data
sample_data = [
    ("John", "Doe", 25, 50000.0, "Engineering"),
    ("Jane", "Smith", 30, 75000.0, "Marketing"),
    ("Bob", "Johnson", 35, 60000.0, "Engineering"),
    ("Alice", "Brown", 28, 80000.0, "Sales"),
    ("Charlie", "Wilson", 32, 70000.0, "Marketing")
]

schema = StructType([
    StructField("first_name", StringType(), True),
    StructField("last_name", StringType(), True),
    StructField("age", IntegerType(), True),
    StructField("salary", DoubleType(), True),
    StructField("department", StringType(), True)
])

# Create initial DataFrame
df_original = spark.createDataFrame(sample_data, schema)

In [3]:
print("Original DataFrame:")
df_original.show()

# Transformation 1: Create derived columns
df_transformed = df_original \
    .withColumn("full_name", concat(col("first_name"), lit(" "), col("last_name"))) \
    .withColumn("salary_category",
                F.when(col("salary") > 70000, "High")
                .when(col("salary") > 50000, "Medium")
                .otherwise("Low")) \
    .withColumn("age_group",
                F.when(col("age") < 30, "Young")
                .otherwise("Senior"))
print("\nTransformed DataFrame:")
df_transformed.show()

Original DataFrame:
+----------+---------+---+-------+-----------+
|first_name|last_name|age| salary| department|
+----------+---------+---+-------+-----------+
|      John|      Doe| 25|50000.0|Engineering|
|      Jane|    Smith| 30|75000.0|  Marketing|
|       Bob|  Johnson| 35|60000.0|Engineering|
|     Alice|    Brown| 28|80000.0|      Sales|
|   Charlie|   Wilson| 32|70000.0|  Marketing|
+----------+---------+---+-------+-----------+


Transformed DataFrame:
+----------+---------+---+-------+-----------+--------------+---------------+---------+
|first_name|last_name|age| salary| department|     full_name|salary_category|age_group|
+----------+---------+---+-------+-----------+--------------+---------------+---------+
|      John|      Doe| 25|50000.0|Engineering|      John Doe|            Low|    Young|
|      Jane|    Smith| 30|75000.0|  Marketing|    Jane Smith|           High|   Senior|
|       Bob|  Johnson| 35|60000.0|Engineering|   Bob Johnson|         Medium|   Senior|
|   

In [4]:
df_aggregated = df_transformed \
    .groupBy("department", "salary_category") \
    .agg(
        spark_sum("salary").alias("total_salary"),
        avg("age").alias("avg_age"),
        count("*").alias("employee_count")
    )

print("\nAggregated DataFrame:")
df_aggregated.show()


Aggregated DataFrame:
+-----------+---------------+------------+-------+--------------+
| department|salary_category|total_salary|avg_age|employee_count|
+-----------+---------------+------------+-------+--------------+
|Engineering|            Low|     50000.0|   25.0|             1|
|      Sales|           High|     80000.0|   28.0|             1|
|  Marketing|         Medium|     70000.0|   32.0|             1|
|Engineering|         Medium|     60000.0|   35.0|             1|
|  Marketing|           High|     75000.0|   30.0|             1|
+-----------+---------------+------------+-------+--------------+



In [5]:
# Function to extract DAG information
def extract_dag_info(dataframe, stage_name):
    """Extract both logical and physical plan information"""

    # Method 1: Using explain() to capture plans
    from io import StringIO
    import sys

    # Capture explain output
    old_stdout = sys.stdout
    sys.stdout = buffer = StringIO()

    # Get extended plan information
    dataframe.explain(extended=True)
    plan_output = buffer.getvalue()

    sys.stdout = old_stdout

    # Split the explain output into logical and physical parts
    lines = plan_output.split('\n')
    logical_start = -1
    physical_start = -1

    for i, line in enumerate(lines):
        if "Parsed Logical Plan" in line or "Analyzed Logical Plan" in line:
            logical_start = i
        elif "Physical Plan" in line:
            physical_start = i
            break

    logical_plan_str = '\n'.join(lines[logical_start:physical_start]) if logical_start != -1 else "Not found"
    physical_plan_str = '\n'.join(lines[physical_start:]) if physical_start != -1 else "Not found"

    dag_info = {
        "stage_name": stage_name,
        "logical_plan": logical_plan_str,
        "physical_plan": physical_plan_str,
        "full_explain": plan_output,
        "schema": [{"name": field.name, "type": str(field.dataType)} for field in dataframe.schema.fields]
    }

    return dag_info


In [6]:
# Function to extract transformation details for LLM
def extract_transformation_context(df_before, df_after, transformation_type):
    """Extract context that LLM can use for lineage analysis"""

    before_columns = [field.name for field in df_before.schema.fields]
    after_columns = [field.name for field in df_after.schema.fields]

    transformation_context = {
        "transformation_type": transformation_type,
        "input_columns": before_columns,
        "output_columns": after_columns,
        "new_columns": list(set(after_columns) - set(before_columns)),
        "dropped_columns": list(set(before_columns) - set(after_columns))
    }

    return transformation_context

In [7]:
print("\n" + "="*50)
print("EXTRACTING DAG INFORMATION")
print("="*50)

original_dag = extract_dag_info(df_original, "original")
transformed_dag = extract_dag_info(df_transformed, "transformed")
aggregated_dag = extract_dag_info(df_aggregated, "aggregated")



EXTRACTING DAG INFORMATION


In [8]:
transform_context_1 = extract_transformation_context(df_original, df_transformed, "column_derivation")
transform_context_2 = extract_transformation_context(df_transformed, df_aggregated, "aggregation")

**LLM processing**

In [9]:

llm_input_data = {
    "job_id": "sample_lineage_job",
    "transformations": [
        {
            "stage": "original_to_transformed",
            "context": transform_context_1,
            "dag_info": transformed_dag
        },
        {
            "stage": "transformed_to_aggregated",
            "context": transform_context_2,
            "dag_info": aggregated_dag
        }
    ]
}

In [10]:

print("\nSample DAG Information for LLM:")
print(json.dumps(llm_input_data, indent=2)[:1000] + "...")


print("\nLogical Plan for 'transformed' stage:")
print(transformed_dag["logical_plan"][:500] + "...")

print("\nPhysical Plan for 'aggregated' stage:")
print(aggregated_dag["physical_plan"][:500] + "...")


print("\n" + "="*50)
print("DAG EXTRACTION COMPLETE")
print("="*50)


Sample DAG Information for LLM:
{
  "job_id": "sample_lineage_job",
  "transformations": [
    {
      "stage": "original_to_transformed",
      "context": {
        "transformation_type": "column_derivation",
        "input_columns": [
          "first_name",
          "last_name",
          "age",
          "salary",
          "department"
        ],
        "output_columns": [
          "first_name",
          "last_name",
          "age",
          "salary",
          "department",
          "full_name",
          "salary_category",
          "age_group"
        ],
        "new_columns": [
          "age_group",
          "salary_category",
          "full_name"
        ],
        "dropped_columns": []
      },
      "dag_info": {
        "stage_name": "transformed",
        "logical_plan": "== Analyzed Logical Plan ==\nfirst_name: string, last_name: string, age: int, salary: double, department: string, full_name: string, salary_category: string, age_group: string\nProject [first_

In [11]:
def prepare_llm_prompt(transformation_data):
    """Prepare structured prompt for LLM to extract column lineage"""

    prompt = f"""
    Analyze this Spark transformation and extract column lineage:

    Transformation Type: {transformation_data['context']['transformation_type']}
    Input Columns: {transformation_data['context']['input_columns']}
    Output Columns: {transformation_data['context']['output_columns']}
    New Columns: {transformation_data['context']['new_columns']}

    Logical Plan:
    {transformation_data['dag_info']['logical_plan'][:800]}

    Extract column lineage in this format:
    {{
        "lineage": [
            {{"source_columns": ["col1", "col2"], "target_column": "derived_col", "operation": "concat"}},
            {{"source_columns": ["col3"], "target_column": "derived_col2", "operation": "case_when"}}
        ]
    }}
    """

    return prompt

In [12]:
!pip install transformers==4.36.0 torch accelerate bitsandbytes

Collecting transformers==4.36.0
  Downloading transformers-4.36.0-py3-none-any.whl.metadata (126 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/126.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m126.8/126.8 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
Collecting bitsandbytes
  Downloading bitsandbytes-0.46.1-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting tokenizers<0.19,>=0.14 (from transformers==4.36.0)
  Downloading tokenizers-0.15.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloadi

In [16]:
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

In [17]:
from google.colab import userdata
from huggingface_hub import login

token = userdata.get('HF_TOKEN')
login(token=token, add_to_git_credential=False)

In [18]:
print("Loading CodeLlama model...")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/CodeLlama-7b-Instruct-hf")
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/CodeLlama-7b-Instruct-hf",
    torch_dtype=torch.float16,
    device_map="auto",
    load_in_8bit=True  # For Colab memory optimization
)

Loading CodeLlama model...


  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [19]:
def process_prompt_with_codellama(prompt):
    """Process prompt through CodeLlama"""
    formatted_prompt = f"<s>[INST] {prompt.strip()} [/INST]"

    inputs = tokenizer(formatted_prompt, return_tensors="pt", truncation=True, max_length=2048)

    # Move inputs to GPU
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    with torch.no_grad():
        outputs = model.generate(
            inputs['input_ids'],
            max_new_tokens=512,
            temperature=0.1,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id
        )

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    response_start = response.find("[/INST]") + 7
    return response[response_start:].strip()

In [21]:
sample_prompt = prepare_llm_prompt(llm_input_data["transformations"][0])

In [22]:
test_response = process_prompt_with_codellama(sample_prompt)

In [23]:
test_response

'The column lineage for this Spark transformation can be extracted as follows:\n\n{\n    "lineage": [\n        {\n            "source_columns": ["first_name", "last_name"],\n            "target_column": "full_name",\n            "operation": "concat"\n        },\n        {\n            "source_columns": ["salary"],\n            "target_column": "salary_category",\n            "operation": "case_when"\n        },\n        {\n            "source_columns": ["age"],\n            "target_column": "age_group",\n            "operation": "case_when"\n        }\n    ]\n}\n\nHere, we have three column derivations:\n\n1. The "full_name" column is derived from the "first_name" and "last_name" columns using the "concat" operation.\n2. The "salary_category" column is derived from the "salary" column using the "case_when" operation.\n3. The "age_group" column is derived from the "age" column using the "case_when" operation.\n\nNote that the "case_when" operation is used for both the "salary_category"