In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql import SparkSession

In [3]:
from shared.paths import DatasetPath

DS = DatasetPath('icews05-15')

In [4]:
spark = (SparkSession.builder
         .appName(str(DS))
         .config('spark.sql.legacy.timeParserPolicy', 'LEGACY')
         .config("spark.executor.memory", "8g")
         .config("spark.driver.memory", "8g")
         .config("spark.memory.offHeap.enabled", True)
         .config("spark.memory.offHeap.size", "16g")
         .getOrCreate())

Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/04/06 11:22:45 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [5]:
schema = T.StructType([
    T.StructField('sub', T.StringType(), False),
    T.StructField('pred', T.StringType(), False),
    T.StructField('obj', T.StringType(), False),
    T.StructField('time', T.StringType(), True),
])

df = (
    spark.read.csv([
        DS.raw_str('icews_2005-2015_train.txt'),
        DS.raw_str('icews_2005-2015_test.txt'),
        DS.raw_str('icews_2005-2015_valid.txt'),
    ], sep='\t', header=False, schema=schema)
)
df.head(5)

                                                                                

[Row(sub='Media Personnel (Pakistan)', pred='Make statement', obj='Chaudhry Nisar Ali Khan', time='2013-11-06'),
 Row(sub='William Ruto', pred='Make a visit', obj='The Hague', time='2013-02-13'),
 Row(sub='Catherine Ashton', pred='Express intent to meet or negotiate', obj='Grigol Vashadze', time='2010-07-14'),
 Row(sub='Ronnie Shikapwasha', pred='Make statement', obj='Michael Sata', time='2009-03-16'),
 Row(sub='Nuri al-Maliki', pred='Criticize or denounce', obj='Iraq', time='2011-11-16')]

In [6]:
print('Total Amount of predicates:' + str(df.groupby('pred').count().count()))
preds = df.groupby('pred').count().filter(F.col('count') > 100)
print('Selected Amount of predicates:' + str(preds.count()))
preds.sort('count', ascending=False).show(5)

                                                                                

Total Amount of predicates:251
Selected Amount of predicates:123
+--------------------+-----+
|                pred|count|
+--------------------+-----+
|      Make statement|76184|
|             Consult|49306|
|Express intent to...|30576|
|Make an appeal or...|26661|
|        Make a visit|24422|
+--------------------+-----+
only showing top 5 rows



In [7]:
df_raw_triples = (
    df.select(
        F.col('sub'),
        F.col('pred'),
        F.col('obj'),
        F.col('time').cast(T.TimestampType()),
        F.input_file_name().contains('train').alias('train'),
        F.input_file_name().contains('test').alias('test'),
        F.input_file_name().contains('valid').alias('valid'),
    )
        .join(preds.select(F.col('pred').alias('other__pred')), F.col('pred') == F.col('other__pred'), 'inner')
        .drop('other__pred')
        .sort('train', 'valid', 'test', ascending=False)
)
print('Sanity Check Pred count: ' + str(df_raw_triples.groupby('pred').count().count()))
df_raw_triples.head(5)

Sanity Check Pred count: 123


                                                                                

[Row(sub='Member of Parliament (India)', pred='Make statement', obj='Planning Commission (India)', time=datetime.datetime(2014, 8, 17, 0, 0), train=True, test=False, valid=False),
 Row(sub='Cabinet / Council of Ministers / Advisors (United States)', pred='Praise or endorse', obj='China', time=datetime.datetime(2009, 9, 17, 0, 0), train=True, test=False, valid=False),
 Row(sub='Citizen (Nigeria)', pred='Appeal for diplomatic cooperation (such as policy support)', obj='Media (Nigeria)', time=datetime.datetime(2015, 8, 17, 0, 0), train=True, test=False, valid=False),
 Row(sub='Citizen (Australia)', pred='Demonstrate or rally', obj='Police (Australia)', time=datetime.datetime(2014, 9, 18, 0, 0), train=True, test=False, valid=False),
 Row(sub='China', pred='Make statement', obj='Vietnam', time=datetime.datetime(2014, 6, 19, 0, 0), train=True, test=False, valid=False)]

In [8]:
df_node_entities = (
    df_raw_triples.select(F.col('sub').alias('name'))
        .union(df_raw_triples.select(F.col('obj').alias('name')))
        .distinct()
        .coalesce(1)
        .withColumn('id', F.monotonically_increasing_id())
)
print('Entity Count: ' + str(df_node_entities.count()))
df_node_entities.show(5)

Entity Count: 10463
+--------------------+---+
|                name| id|
+--------------------+---+
|Media Personnel (...|  0|
|Emmanuel Eweta Ud...|  1|
|      Moeletsi Mbeki|  2|
|Insurgent (Afghan...|  3|
|Christian (Indone...|  4|
+--------------------+---+
only showing top 5 rows



In [9]:
import re

@F.udf(T.StringType(), returnType=T.StringType())
def string_to_identifier(s):
    # Remove invalid characters
    s = re.sub('[^0-9a-zA-Z_\s]', '', s)

    # Remove leading characters until we find a letter or underscore
    s = re.sub('^[^a-zA-Z_]+', '', s)

    return s.title().replace(' ', '')

In [10]:
df_all_edges = (
    df_raw_triples
        .select(
            F.col('sub'),
            F.col('obj'),
            string_to_identifier(F.col('pred')).alias('type'),
            F.unix_timestamp('time').alias('timestamp_from'),
            F.unix_timestamp('time').alias('timestamp_to'),
            F.col('train').alias('train'),
            F.col('test').alias('test'),
            F.col('valid').alias('valid'),
        )
        .join(df_node_entities.withColumnRenamed('id', 'src'), F.col('sub') == F.col('name'))
        .drop('name')
        .join(df_node_entities.withColumnRenamed('id', 'dst'), F.col('obj') == F.col('name'))
        .drop('name')
    .drop('sub', 'obj')

)
print('Edge Count: ' + str(df_all_edges.count()))
df_all_edges.show(5)

                                                                                

Edge Count: 457514




+---------------+--------------+------------+-----+-----+-----+---+----+
|           type|timestamp_from|timestamp_to|train| test|valid|src| dst|
+---------------+--------------+------------+-----+-----+-----+---+----+
|  MakeStatement|    1383433200|  1383433200| true|false|false|  0|7871|
|PraiseOrEndorse|    1377986400|  1377986400| true|false|false|  1|4342|
|  MakeStatement|    1348783200|  1348783200| true|false|false|  1| 785|
|        Consult|    1214604000|  1214604000| true|false|false|  2| 393|
|          Yield|    1267398000|  1267398000| true|false|false|  1|5324|
+---------------+--------------+------------+-----+-----+-----+---+----+
only showing top 5 rows



                                                                                

In [11]:
df_node_entities.write.parquet(DS.processed_str('node__Entity'), mode='overwrite')

df_all_edges.write.parquet(DS.processed_str('edge__Entity_Rel_Entity'), mode='overwrite')

                                                                                