In [1]:
from pyspark.sql import SparkSession
from pyspark import SparkConf
from pyspark.sql.functions import split
from pyspark.sql.functions import udf, col
from pyspark.sql.types import MapType, StringType, ArrayType, IntegerType, StructType, StructField
from transformers import AutoTokenizer
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
s3_bucket = 'reddit-tifu'
AWS_ACCESS_KEY_ID = '***'
AWS_SECRET_ACCESS_KEY = '***'

In [3]:
def create_spark(appname, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY):
    conf = SparkConf()
    
    conf.setAll([
        ("spark.jars.packages", "org.apache.hadoop:hadoop-aws:3.3.4"),
        ("spark.hadoop.fs.s3a.access.key", AWS_ACCESS_KEY_ID),
        ("spark.hadoop.fs.s3a.secret.key", AWS_SECRET_ACCESS_KEY)
    ])
    
    spark = SparkSession.builder \
        .master("local[*]") \
        .config(conf=conf) \
        .appName(appname) \
        .getOrCreate()
    return spark

In [4]:
spark = create_spark("data prep", AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY)

:: loading settings :: url = jar:file:/home/ec2-user/.local/lib/python3.9/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /home/ec2-user/.ivy2/cache
The jars for the packages stored in: /home/ec2-user/.ivy2/jars
org.apache.hadoop#hadoop-aws added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-dc79ae75-ee42-4f74-8b10-88cb38bb3808;1.0
	confs: [default]
	found org.apache.hadoop#hadoop-aws;3.3.4 in central
	found com.amazonaws#aws-java-sdk-bundle;1.12.262 in central
	found org.wildfly.openssl#wildfly-openssl;1.0.7.Final in central
:: resolution report :: resolve 214ms :: artifacts dl 17ms
	:: modules in use:
	com.amazonaws#aws-java-sdk-bundle;1.12.262 from central in [default]
	org.apache.hadoop#hadoop-aws;3.3.4 from central in [default]
	org.wildfly.openssl#wildfly-openssl;1.0.7.Final from central in [default]
	---------------------------------------------------------------------
	|                  |            modules            ||   artifacts   |
	|       conf       | number| search|dwnlded|evicted|| number|dwnlded|
	----------------------------

In [5]:
df = spark.read.parquet(f"s3a://{s3_bucket}/reddit_tifu_llm.parquet")

25/04/14 21:37:34 WARN MetricsConfig: Cannot locate configuration: tried hadoop-metrics2-s3a-file-system.properties,hadoop-metrics2.properties
                                                                                

In [8]:
df.show(1, truncate = False)

                                                                                

+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|Prompt                                                                                                                                                                                                                                                                                                                                                                 |
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [6]:
def extract_sections(prompt):
    split_prompt = prompt.split(" \nTitle: ")
    text = split_prompt[0]
    text = text.replace('Text: ', '')
    title = split_prompt[1].split(" ####")[0]
    return {
    "text": text,
    "title": title
    }

In [7]:
extract_sections_udf = udf(extract_sections, MapType(StringType(), StringType()))
annotated_df = df.withColumn("AnnotatedSections", extract_sections_udf(df["Prompt"]))
annotated_df.select("AnnotatedSections").show(1, truncate=False)

                                                                                

+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|AnnotatedSections                                                                                                                                                                                                                                                                                                                                                      |
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [8]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", use_fast=True)

In [9]:
@udf(returnType=ArrayType(StringType()))
def tokenize_text(text):
    tokens = ['title', ':']
    tokens += tokenizer.tokenize(text['title'])
    tokens += ['text', ':'] + tokenizer.tokenize(text['text'])
    return tokens

In [10]:
tokenized_df = annotated_df.withColumn("tokens", tokenize_text(col("AnnotatedSections")))

In [11]:
tokenized_df.show(1, truncate = False)

[Stage 2:>                                                          (0 + 1) / 1]

+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

                                                                                

In [42]:
tokenizer?

[0;31mSignature:[0m     
[0mtokenizer[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mtext[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mList[0m[0;34m[[0m[0mstr[0m[0;34m][0m[0;34m,[0m [0mList[0m[0;34m[[0m[0mList[0m[0;34m[[0m[0mstr[0m[0;34m][0m[0;34m][0m[0;34m,[0m [0mNoneType[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtext_pair[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mList[0m[0;34m[[0m[0mstr[0m[0;34m][0m[0;34m,[0m [0mList[0m[0;34m[[0m[0mList[0m[0;34m[[0m[0mstr[0m[0;34m][0m[0;34m][0m[0;34m,[0m [0mNoneType[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtext_target[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mstr[0m[0;34m,[0m [0mList[0m[0;34m[[0m[0mstr[0m[0;34m][0m[0;34m,[0m [0mList[0m[0;34m[[0m[0mList[0m[0;34m[[0m[0mstr[0m[0;34m][0m[0;34m][0m[0;34m,[0m [0mNoneType[0m[0;34m

In [63]:
def tokens_to_ids_attention_mask(tokens):
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    
    inputs = tokenizer(tokens, truncation=True, padding='max_length', max_length=512, return_tensors="pt")
    token_ids = inputs['input_ids'].tolist()[0] 
    attention_mask = inputs['attention_mask'].tolist()[0]
    
    return (token_ids, attention_mask)

In [64]:
schema = StructType([
 StructField("token_ids", ArrayType(IntegerType()), False),
 StructField("attention_mask", ArrayType(IntegerType()), False)
])

In [65]:
tokens_to_ids_attention_mask_udf = udf(tokens_to_ids_attention_mask, schema)

In [66]:
df_with_ids_masks = tokenized_df.withColumn("ids_masks", tokens_to_ids_attention_mask_udf(tokenized_df["tokens"]))

In [68]:
df_final = df_with_ids_masks.select("tokens", "ids_masks.token_ids", "ids_masks.attention_mask")
df_final.show(1, truncate=False)

[Stage 16:>                                                         (0 + 1) / 1]

+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

                                                                                