In [None]:
from pyspark import SparkContext, SparkConf
from pyspark.sql import SQLContext, Row, SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, FloatType, ArrayType
from pyspark.sql.functions import collect_set
from argparse import ArgumentParser
import os
import json
import pandas as pd
from pyspark.ml.feature import NGram, CountVectorizer, Tokenizer
from pyspark.ml.linalg import SparseVector
import numpy as np
import re
from pyspark import StorageLevel

In [1]:
bucket='gs://uga-dsp'
minInitialCount=2
bytesDir=f"{bucket}/project1/data/bytes/"
asmDir=f"{bucket}/project1/data/asm/"
filesDir=f"{bucket}/project1/files/"
dest='gs://micky-practicum/'

In [None]:
def hexGen(i):
    return ('0'+str(hex(i)).upper()[2:])[-2:]

In [None]:
def _assembleSchema():
    out=[hexGen(i) for i in range(256)]+['??']\
       +['rel_'+hexGen(i) for i in range(256)]+['rel_??']
    return out

In [None]:
_schema=sc.broadcast(_assembleSchema())

In [None]:
def buildSingleSchema():
    schema = StructType([StructField('hash',StringType())]\
                        +[StructField(_schema.value[i],LongType()) for i in range(257)])
    return schema

In [None]:
rdd=sc.wholeTextFiles(bytesDir,minPartitions=45)
schema=StructType([StructField('file_name',StringType()),StructField('contents',StringType())])
df=spark.createDataFrame(rdd,schema)

In [None]:
tokenizer = Tokenizer(inputCol="contents", outputCol="words")
tokenized = tokenizer.transform(df)

In [None]:
rdd2=tokenized.drop('contents').rdd
pattern=sc.broadcast('\w+(?=\.bytes)')
rdd3=rdd2.map(lambda x: (re.findall(pattern.value, x[0])[0],[y for y in x[1] if len(y)==2]))

schema2=StructType([StructField('hash',StringType()),StructField('words',ArrayType(StringType()))])
tokenized2=spark.createDataFrame(rdd3,schema2)
tokenized2.persist(StorageLevel.MEMORY_AND_DISK)

In [None]:
cv = CountVectorizer(inputCol="words", outputCol="word_count")

model = cv.fit(tokenized2)

result = model.transform(tokenized2)

In [None]:
#result = spark.read.load("examples/src/main/resources/users.parquet")
spark.catalog.clearCache()
result.cache()

In [None]:
result.select("hash", "word_count").write.save(dest+"X_train_pre.parquet",mode='overwrite')
print('finished')
#result.show()

In [None]:
intermediate = sqlContext.read.load(dest+'X_train_pre.parquet')

In [None]:
def code_to_index(code):
    index=None
    if code=='??':
        index=256
    else:
        index=int(code[-2:],16)
    print(f'{code}:{index}')
    
    return index

In [None]:
decimal_vocab=[code_to_index(code) for code in model.vocabulary]
vocab_index=sc.broadcast([decimal_vocab.index(i) for i in range(len(decimal_vocab))])

In [None]:
set_name='X_small_test'
indicator_file=sc.broadcast(sc.textFile(f'{filesDir}{set_name}.txt').collect())

In [None]:
formatted=intermediate.rdd.map(lambda x:list([x[0]]+[int(x[1][vocab_index.value[i]]) for i in range(257)]))\
            .filter(lambda x:x[0] in indicator_file.value)
df_formatted=spark.createDataFrame(formatted,schema=buildSingleSchema())

In [None]:
df_formatted.write.save(f"{dest}counts/{set_name}.parquet",mode='overwrite')