In [1]:
from pyspark.sql import SparkSession
from pyspark import SparkConf
from pyspark.sql.functions import split
from pyspark.sql.functions import udf, col
from pyspark.sql.types import MapType, StringType, ArrayType, IntegerType, StructType, StructField
from transformers import AutoTokenizer
import torch
import boto3
from io import BytesIO


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
s3_bucket = 'reddit-tifu'
AWS_ACCESS_KEY_ID = '**'
AWS_SECRET_ACCESS_KEY = '**'

In [3]:
def create_spark(appname, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY):
    conf = SparkConf()
    
    conf.setAll([
        ("spark.jars.packages", "org.apache.hadoop:hadoop-aws:3.3.4"),
        ("spark.hadoop.fs.s3a.access.key", AWS_ACCESS_KEY_ID),
        ("spark.hadoop.fs.s3a.secret.key", AWS_SECRET_ACCESS_KEY)
    ])
    
    spark = SparkSession.builder \
        .master("local[*]") \
        .config(conf=conf) \
        .appName(appname) \
        .getOrCreate()
    return spark

In [4]:
spark = create_spark("data prep", AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY)

:: loading settings :: url = jar:file:/home/ec2-user/.local/lib/python3.9/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /home/ec2-user/.ivy2/cache
The jars for the packages stored in: /home/ec2-user/.ivy2/jars
org.apache.hadoop#hadoop-aws added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-f1b32e4b-19e0-4713-8f80-bd5b9cc05b1e;1.0
	confs: [default]
	found org.apache.hadoop#hadoop-aws;3.3.4 in central
	found com.amazonaws#aws-java-sdk-bundle;1.12.262 in central
	found org.wildfly.openssl#wildfly-openssl;1.0.7.Final in central
:: resolution report :: resolve 286ms :: artifacts dl 8ms
	:: modules in use:
	com.amazonaws#aws-java-sdk-bundle;1.12.262 from central in [default]
	org.apache.hadoop#hadoop-aws;3.3.4 from central in [default]
	org.wildfly.openssl#wildfly-openssl;1.0.7.Final from central in [default]
	---------------------------------------------------------------------
	|                  |            modules            ||   artifacts   |
	|       conf       | number| search|dwnlded|evicted|| number|dwnlded|
	-----------------------------

In [5]:
df = spark.read.parquet(f"s3a://{s3_bucket}/reddit_tifu_llm.parquet")

25/04/15 22:28:51 WARN MetricsConfig: Cannot locate configuration: tried hadoop-metrics2-s3a-file-system.properties,hadoop-metrics2.properties
                                                                                

In [6]:
df.show(1, truncate = False)

                                                                                

+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|Prompt                                                                                                                                                                                                                                                                                                                                                                 |
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [7]:
def extract_sections(prompt):
    split_prompt = prompt.split(" \nTitle: ")
    text = split_prompt[0]
    text = text.replace('Text: ', '')
    title = split_prompt[1].split(" ####")[0]
    return {
    "text": text,
    "title": title
    }

In [8]:
extract_sections_udf = udf(extract_sections, MapType(StringType(), StringType()))
annotated_df = df.withColumn("AnnotatedSections", extract_sections_udf(df["Prompt"]))
annotated_df.select("AnnotatedSections").show(1, truncate=False)

[Stage 2:>                                                          (0 + 1) / 1]

+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|AnnotatedSections                                                                                                                                                                                                                                                                                                                                                      |
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

                                                                                

In [9]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

In [10]:
def tokens_to_ids_attention_mask(prompt):
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})

    token_ids_dict = {}
    attention_mask_dict = {}
    
    for section in prompt.keys():
        sentence = prompt[section]
        inputs = tokenizer(sentence, truncation=True, padding='max_length', max_length=512, return_tensors="pt")
        token_ids_dict[section] = inputs['input_ids'].tolist()[0]
        attention_mask_dict[section] = inputs['attention_mask'].tolist()[0]

    tokens = {}
    for section in prompt.keys():
        tokens[section] = [tokenizer.decode(token) for token in token_ids_dict[section]]

    for section in prompt.keys():
        assert len(token_ids_dict[section]) == len(attention_mask_dict[section]) == len(tokens[section]) 
        
    return (token_ids_dict['title'], token_ids_dict['text'], 
            attention_mask_dict['title'], attention_mask_dict['text'], 
            tokens['title'], tokens['text'])

In [11]:
schema = StructType([
 StructField("title_token_ids", ArrayType(IntegerType()), False),
 StructField("text_token_ids", ArrayType(IntegerType()), False),
 StructField("title_attention_mask", ArrayType(IntegerType()), False),
 StructField("text_attention_mask", ArrayType(IntegerType()), False),
 StructField("title_tokens", StringType(), False),
 StructField("text_tokens", StringType(), False),   
])

In [12]:
tokens_to_ids_attention_mask_udf = udf(tokens_to_ids_attention_mask, schema)

In [13]:
df_with_ids_masks = annotated_df.withColumn("ids_masks_tokens", tokens_to_ids_attention_mask_udf(annotated_df["AnnotatedSections"]))

In [14]:
df_final = df_with_ids_masks.select("ids_masks_tokens.title_token_ids", "ids_masks_tokens.text_token_ids", 
                                    "ids_masks_tokens.title_attention_mask", "ids_masks_tokens.text_attention_mask",
                                   "ids_masks_tokens.title_tokens", "ids_masks_tokens.text_tokens")

In [15]:
df_final.show(1, truncate=True)

[Stage 3:>                                                          (0 + 1) / 1]

+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|     title_token_ids|      text_token_ids|title_attention_mask| text_attention_mask|        title_tokens|         text_tokens|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|[1, 9566, 1259, 3...|[1, 306, 471, 373...|[1, 1, 1, 1, 1, 1...|[1, 1, 1, 1, 1, 1...|[<s>, forget, tin...|[<s>, I, was, on,...|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
only showing top 1 row



                                                                                

In [15]:
df_final.count()

79949