In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('merge and split data').getOrCreate()

In [3]:
from pyspark.sql.functions import *
import pyspark.sql.functions as F
from pyspark.sql.types import StringType

## 1. Merge teacher, field, concept, school course file

### Visualize

In [4]:
course_teacher = spark.read.options(header=True, inferSchema=True).text("/content/drive/MyDrive/Big Data/Output/course_teacher.txt")
course_teacher.show(5)

+-----------+
|      value|
+-----------+
|1158 0 7266|
|1158 0 7157|
|1159 0 7375|
|1160 0 7906|
|1161 0 9365|
+-----------+
only showing top 5 rows



In [5]:
course_field = spark.read.options(header=True, inferSchema=True).text('/content/drive/MyDrive/Big Data/Output/course_field.txt')
course_field.show(5)

+------------+
|       value|
+------------+
|2491 1 13023|
|1925 1 13023|
|1924 1 13023|
|1768 1 13023|
|1419 1 13023|
+------------+
only showing top 5 rows



In [6]:
course_concept = spark.read.options(header=True, inferSchema=True).text('/content/drive/MyDrive/Big Data/Output/course_concept.txt')
course_concept.show(5)

+------------+
|       value|
+------------+
|2339 2 13101|
|1171 2 13101|
|2387 2 13101|
|1795 2 13102|
|2343 2 13102|
+------------+
only showing top 5 rows



In [7]:
course_school = spark.read.options(header=True, inferSchema=True).text('/content/drive/MyDrive/Big Data/Output/course_school.txt')
course_school.show(5)

+-------------+
|        value|
+-------------+
|1158 3 229197|
|1159 3 229197|
|1160 3 229197|
|1161 3 229528|
|1162 3 229197|
+-------------+
only showing top 5 rows



In [8]:
all_df = course_field.union(course_teacher).union(course_concept).union(course_school)
all_df.show(5)

+------------+
|       value|
+------------+
|2491 1 13023|
|1925 1 13023|
|1924 1 13023|
|1768 1 13023|
|1419 1 13023|
+------------+
only showing top 5 rows



In [9]:
print(f"Size of course_teacher: {course_teacher.count()}")
print(f"Size of course_field: {course_field.count()}")
print(f"Size of course_concept: {course_concept.count()}")
print(f"Size of course_school: {course_school.count()}")
print(f"Size of course_teacher + course_field + course_concept + course_school: {all_df.count()}")

Size of course_teacher: 33249
Size of course_field: 576
Size of course_concept: 369142
Size of course_school: 3089
Size of course_teacher + course_field + course_concept + course_school: 406056


In [10]:
all_df.select("value").coalesce(1).write.mode("overwrite").text("course_all")

### Filter

In [11]:
triplets = spark.read.options(header=True, inferSchema=True).text('/content/drive/MyDrive/Big Data/Output/kg_initial.txt')
triplets.show(5)

+-----------+
|      value|
+-----------+
|2492 1 3149|
|1926 1 3149|
|1925 1 3149|
|1769 1 3149|
|1420 1 3149|
+-----------+
only showing top 5 rows



In [12]:
triplets = triplets.withColumn('h', split('value', ' ')[0]) \
                    .withColumn('r', split('value', ' ')[1]) \
                    .withColumn('t', split('value', ' ')[2]) \
                    .drop(col('value'))
triplets.show(5)

+----+---+----+
|   h|  r|   t|
+----+---+----+
|2492|  1|3149|
|1926|  1|3149|
|1925|  1|3149|
|1769|  1|3149|
|1420|  1|3149|
+----+---+----+
only showing top 5 rows



In [13]:
course_type = {
    0: 'course.teacher',
    1: 'course.field',
    2: 'course.concept',
    3: 'course.school'
}

course_type_df = spark.createDataFrame(
    [(k, v) for k, v in course_type.items()],
    ["r", "relation_name"]
)

course_type_df.show()

+---+--------------+
|  r| relation_name|
+---+--------------+
|  0|course.teacher|
|  1|  course.field|
|  2|course.concept|
|  3| course.school|
+---+--------------+



In [14]:
def filter_invalid_relations_and_entities(triplets, min_entities=5, min_rel=25):
    old_size = -1
    current_size = triplets.count()
    print(f"Old size: {current_size}")
    print("================ Before Filtering ===============")
    triplets.groupBy("r").count() \
            .join(course_type_df, on="r", how="left") \
            .select("relation_name", "count") \
            .orderBy(col("count").desc()) \
            .show(truncate=False)

    print("===== # of attribute type in each relation =====")
    # Sử dụng window function
    triplets.groupBy("r").agg(
        countDistinct("t").alias("unique_t_count")
    ).join(course_type_df, on="r", how="left") \
     .select("relation_name", "unique_t_count") \
     .orderBy("unique_t_count", ascending=False) \
     .show(truncate=False)

    while old_size != current_size:
        old_size = current_size

        # Filter entities -
        entity_counts = triplets.groupBy("t").count()
        valid_entities = entity_counts.filter(col("count") >= min_entities).select("t")
        triplets = triplets.join(valid_entities, on="t", how="inner")

        # Filter relations
        rel_counts = triplets.groupBy("r").count()
        valid_rels = rel_counts.filter(col("count") >= min_rel).select("r")
        triplets = triplets.join(valid_rels, on="r", how="inner")

        current_size = triplets.count()
        print(f"New size: {current_size}")

    # In thông tin thống kê
    print("================ Valid relations ===============")
    triplets.groupBy("r").count() \
            .join(course_type_df, on="r", how="left") \
            .select("relation_name", "count") \
            .orderBy(col("count").desc()) \
            .show(truncate=False)

    print("===== # of attribute type in each relation =====")
    # Sử dụng window function
    triplets.groupBy("r").agg(
        countDistinct("t").alias("unique_t_count")
    ).join(course_type_df, on="r", how="left") \
     .select("relation_name", "unique_t_count") \
     .orderBy("unique_t_count", ascending=False) \
     .show(truncate=False)

    return triplets

In [15]:
fil_triplets = filter_invalid_relations_and_entities(triplets)

Old size: 406056
+--------------+------+
|relation_name |count |
+--------------+------+
|course.concept|369142|
|course.teacher|33249 |
|course.school |3089  |
|course.field  |576   |
+--------------+------+

===== # of attribute type in each relation =====
+--------------+--------------+
|relation_name |unique_t_count|
+--------------+--------------+
|course.concept|214837        |
|course.teacher|9427          |
|course.school |419           |
|course.field  |78            |
+--------------+--------------+

New size: 84314
New size: 84314
+--------------+-----+
|relation_name |count|
+--------------+-----+
|course.concept|72220|
|course.teacher|9019 |
|course.school |2577 |
|course.field  |498  |
+--------------+-----+

===== # of attribute type in each relation =====
+--------------+--------------+
|relation_name |unique_t_count|
+--------------+--------------+
|course.concept|8163          |
|course.teacher|1322          |
|course.school |155           |
|course.field  |43        

In [16]:
fil_triplets.show(5)

+---+----+----+
|  r|   t|   h|
+---+----+----+
|  1|3149|2492|
|  1|3149|1926|
|  1|3149|1925|
|  1|3149|1769|
|  1|3149|1420|
+---+----+----+
only showing top 5 rows



In [17]:
fil_triplets = fil_triplets.select(['h', 'r', 't'])
fil_triplets.coalesce(1).write \
    .option("header", True) \
    .mode("overwrite") \
    .csv("fil_triplets")

### 2. Split user_course data

In [19]:
user_course = spark.read.options(header=True, inferSchema=True).csv('/content/drive/MyDrive/Big Data/Output/user_course.csv')
user_course.show(truncate=False)

+-----+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|user |course_order                                                                                                                                                                                                                                                                                                                    |
+-----+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|0    |2345,1

In [20]:
user_course = user_course.select(
    col("user").cast("int"),
    split(col("course_order"), ",").cast("array<int>").alias("course_order")
)
user_course.show(truncate=False)

+-----+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|user |course_order                                                                                                                                                                                                                                                                                                                                                                                       |
+-----+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [21]:
user_course.printSchema()

root
 |-- user: integer (nullable = true)
 |-- course_order: array (nullable = true)
 |    |-- element: integer (containsNull = true)



In [22]:
# Split data
TRAIN_PATH = 'train_data'
VAL_PATH = 'val_data'
TEST_PATH = 'test_data'

In [23]:
def split_course_data(df, output_path="./"):
    """
    Chia DataFrame thành train, validation, test

    Args:
        df: PySpark DataFrame với cột 'user' và 'course_order'
        output_path: Đường dẫn lưu file output
    """

    # Tạo DataFrame cho test set (course cuối cùng)
    test_df = df.select(
        col("user"),
        element_at(col("course_order"), -1).alias("course_id")
    ).filter(col("course_id").isNotNull())

    # Tạo DataFrame cho validation set (course kề cuối)
    val_df = df.select(
        col("user"),
        when(size(col("course_order")) >= 2,
             element_at(col("course_order"), -2)).alias("course_id")
    ).filter(col("course_id").isNotNull())

    # Tạo DataFrame cho training set (các course còn lại)
    # Lấy tất cả course trừ 2 course cuối
    train_df = df.select(
        col("user"),
        when(size(col("course_order")) > 2,
             slice(col("course_order"), 1, size(col("course_order")) - 2)).alias("course_list")
    ).filter(col("course_list").isNotNull() & (size(col("course_list")) > 0))
    return train_df, val_df, test_df

In [24]:
train_df, val_df, test_df = split_course_data(user_course)

In [25]:
train_df.show(truncate=False)

+-----+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|user |course_list                                                                                                                                                                                                                                                                                                                                                                             |
+-----+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [26]:
val_df.show(truncate=False)

+-----+---------+
|user |course_id|
+-----+---------+
|0    |578      |
|10   |1241     |
|100  |2339     |
|1000 |2788     |
|10000|2976     |
|10003|1042     |
|10005|948      |
|10006|1077     |
|10007|1077     |
|10008|370      |
|10009|929      |
|1001 |2664     |
|10011|2361     |
|10013|538      |
|10014|573      |
|10016|1706     |
|10017|1860     |
|10018|129      |
|10019|1115     |
|10020|2931     |
+-----+---------+
only showing top 20 rows



In [27]:
test_df.show(truncate=False)

+-----+---------+
|user |course_id|
+-----+---------+
|0    |184      |
|10   |2338     |
|100  |1348     |
|1000 |1786     |
|10000|3116     |
|10003|817      |
|10005|2670     |
|10006|2783     |
|10007|1010     |
|10008|1086     |
|10009|1038     |
|1001 |2200     |
|10011|11       |
|10013|2667     |
|10014|1058     |
|10016|2481     |
|10017|2174     |
|10018|761      |
|10019|692      |
|10020|2776     |
+-----+---------+
only showing top 20 rows



In [28]:
# Save files
train_df.select(concat_ws(" ", col("user"), col("course_list"))).coalesce(1) \
        .write.mode("overwrite").text(TRAIN_PATH)

val_df.select(concat_ws(" ", col("user"), col("course_id"))).coalesce(1) \
      .write.mode("overwrite").text(VAL_PATH)

test_df.select(concat_ws(" ", col("user"), col("course_id"))).coalesce(1) \
       .write.mode("overwrite").text(TEST_PATH)