In [1]:
import os
import pandas as pd
import zstandard as zstd
import json
from tqdm import tqdm
import datetime

import sqlite3
from sqlalchemy import create_engine

In [2]:
directory = '/mnt/processed/private/msds-pt2025a/lt6'
reddit_dir = '/mnt/data/public/reddit/'

subreddit_file = 'subreddits/subreddits_basic.csv'
subreddit_file_json = 'subreddits/subreddits.json.gz'

db_name = f'{directory}/database/lab1-test.db'

# Subreddit

In [4]:
# Define Subreddit DataFrame
column_names = [
    'submission_count',
    'subreddit_id',
    'created_utc',
    'display_name',
    'subscriber_count'
]

subreddit_df = pd.read_csv(os.path.join(reddit_dir, subreddit_file),
                           header=None,
                           names=column_names,
                           dtype=str)

subreddit_df = subreddit_df.drop('submission_count', axis=1)
subreddit_df['created_utc'] = pd.to_datetime(
    subreddit_df['created_utc'].astype('Int64'), unit='s')
subreddit_df['subscriber_count'] = subreddit_df['subscriber_count'].astype('Int64')

In [11]:
# Populate Subreddit Table
conn = sqlite3.connect(db_name)
engine = create_engine(f'sqlite:///{db_name}')

subreddit_df.to_sql('reddit_subreddits',
                    con=engine, if_exists='replace', index=False)

conn.close()

In [5]:
conn = sqlite3.connect(db_name)

# Define the SQL query to select the top 10 rows
select_query = 'SELECT count(*) FROM reddit_subreddits;'

# Execute the select query
cursor = conn.cursor()
cursor.execute(select_query)

# Fetch the results
results = cursor.fetchall()

# Close the connection
conn.close()

# Return the results
results

[(1067472,)]

In [13]:
conn = sqlite3.connect(db_name)
cursor = conn.cursor()

column_info_query = "PRAGMA table_info('reddit_subreddits');"
cursor.execute(column_info_query)
column_info = cursor.fetchall()

# Close the connection
conn.close()

# Format and return the results
column_info_list = [{"column_name": col[1], "data_type": col[2]}
                    for col in column_info]
column_info_list

[{'column_name': 'subreddit_id', 'data_type': 'TEXT'},
 {'column_name': 'created_utc', 'data_type': 'DATETIME'},
 {'column_name': 'display_name', 'data_type': 'TEXT'},
 {'column_name': 'subscriber_count', 'data_type': 'BIGINT'}]

# Submissions

## Define Table

In [14]:
columns = {
    'id': 'TEXT PRIMARY KEY',
    'created_utc': 'DATETIME',
    'subreddit': 'TEXT',
    'subreddit_id': 'TEXT',
    'title': 'TEXT',
    'selftext': 'TEXT',
    'score': 'BIGINT',
    'num_comments': 'BIGINT',
    'num_crossposts': 'BIGINT',
    'retrieved_on': 'DATETIME'
}

# Create or connect to the SQLite database
conn = sqlite3.connect(db_name)
cursor = conn.cursor()

cursor.execute(f'''
CREATE TABLE IF NOT EXISTS reddit_submissions (
    {', '.join(f'{col} {col_type}' for col, col_type in columns.items())}
)
''')

conn.commit()
conn.close()

## Extract Data

In [3]:
conn = sqlite3.connect(db_name)
cursor = conn.cursor()

cursor.execute("SELECT * FROM reddit_subreddits WHERE display_name LIKE '%music%'")
# cursor.execute('SELECT * FROM reddit_submissions LIMIT 10')

results = cursor.fetchall()
conn.close()

subreddits = [result[0] for result in results]

In [11]:
def unix_to_datetime(unix_time):
    if unix_time is not None:
        return datetime.datetime.utcfromtimestamp(unix_time)
    return None


# Function to insert a record into the database
def insert_records(records, conn, cursor):
    filtered_records = [record for record in records if record[3] in subreddits]
    cursor.executemany('''
        INSERT OR REPLACE INTO reddit_submissions (id, created_utc, subreddit, subreddit_id, title, selftext, score, num_comments, num_crossposts, retrieved_on)
        VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
        ''', filtered_records)


def process_file(file_path, conn, cursor):
    # Create a Zstandard decompressor
    decompressor = zstd.ZstdDecompressor(max_window_size=2147483648)

    # Open the compressed file
    with open(file_path, 'rb') as compressed:
        total_size = os.path.getsize(file_path) * 15
        # Create a progress bar using tqdm
        with tqdm(total=total_size, unit='B', unit_scale=True, desc='Decompressing') as pbar:
            # Create a stream reader for the compressed file
            with decompressor.stream_reader(compressed) as reader:
                buffer = ''
                batch_records = []
                batch_size = 10000  # Number of records to insert at once

                while True:
                    chunk = reader.read(8 * 1024)  # Read in chunks of 512KB
                    if not chunk:
                        break
                    pbar.update(len(chunk))

                    buffer += chunk.decode('utf-8', errors='ignore')
                    while '\n' in buffer:
                        line, buffer = buffer.split('\n', 1)
                        json_data = json.loads(line)

                        # Extract and convert the necessary fields
                        record = (
                            json_data.get('id'),
                            unix_to_datetime(json_data.get('created_utc')),
                            json_data.get('subreddit'),
                            json_data.get('subreddit_id'),
                            json_data.get('title'),
                            json_data.get('selftext'),
                            json_data.get('score'),
                            json_data.get('num_comments'),
                            json_data.get('num_crossposts'),
                            unix_to_datetime(json_data.get('retrieved_on'))
                        )
                        
                        batch_records.append(record)

                        if len(batch_records) == batch_size:
                            # insert_records(batch_records, conn, cursor)
                            batch_records.clear()

    insert_records(batch_records, conn, cursor)

In [9]:
# Loop through files by constructing file names for each month
for year in range(2022, 2023):
    for month in range(1, 13):
        if year == 2022 and month <= 3:
            continue
        else:
            # Skip the months that don't have data according to your range
            if year == 2022 and month > 8:
                break
            file_name = f"submissions/RS_{year}-{month:02d}.zst"
            file_path = os.path.join(reddit_dir, file_name)

            # Check if the file exists before trying to process it
            if os.path.exists(file_path):
                conn = sqlite3.connect(db_name)
                cursor = conn.cursor()

                try:
                    print(f"Processing {file_name}...")
                    process_file(file_path, conn, cursor)
                    conn.commit()
                    conn.close()
                except Exception as e:
                    print(f" An error occured: {e}")
                    conn.rollback()
                    conn.close()

                    raise e  # Stop file
            else:
                print(f"File {file_name} does not exist.")

Processing submissions/RS_2022-04.zst...


Decompressing:  91%|█████████ | 130G/144G [1:13:38<07:41, 29.5MB/s] 


Processing submissions/RS_2022-05.zst...


Decompressing:  92%|█████████▏| 143G/155G [1:18:26<06:53, 30.4MB/s] 


Processing submissions/RS_2022-06.zst...


Decompressing:  92%|█████████▏| 139G/151G [1:25:12<07:06, 27.2MB/s]   


Processing submissions/RS_2022-07.zst...


Decompressing:  92%|█████████▏| 146G/159G [1:26:34<07:53, 28.0MB/s]  


Processing submissions/RS_2022-08.zst...


Decompressing:  96%|█████████▌| 150G/157G [1:27:46<04:05, 28.4MB/s]  


In [5]:
conn = sqlite3.connect(db_name)
cursor = conn.cursor()

#cursor.execute('SELECT count(*) FROM reddit_submissions')
cursor.execute("SELECT count(*) FROM reddit_submissions where created_utc > '2021-09-01'")

results = cursor.fetchall()
conn.close()
results

[(2969458,)]