In [None]:
import os
import re
import json
import joblib
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from collections import defaultdict, Counter
import gzip
from glob import glob
from functools import partial
from datetime import datetime
from tqdm import tqdm
from filter_subreddits import filter_subs

In [None]:
# analyze subreddits
subreddit_cnts = defaultdict(int)
all_users = set()
for fname in tqdm(glob("../emnlp-2020-mental-health-generalization/data/processed/reddit/wolohan/*.submissions.tar.gz")):
    with gzip.open(fname, "r") as f:
        data0 = json.load(f)
    user = data0[0]["user_id_str"]
    all_users.add(user)
    for record in data0:
        subreddit_cnts[record['subreddit']] += 1

for fname in tqdm(glob("../emnlp-2020-mental-health-generalization/data/processed/reddit/wolohan/*.comments.tar.gz")):
    with gzip.open(fname, "r") as f:
        data0 = json.load(f)
    user = data0[0]["user_id_str"]
    all_users.add(user)
    for record in data0:
        subreddit_cnts[record['subreddit']] += 1
subreddit_cnts = pd.Series(subreddit_cnts).sort_values(ascending=False)
print(subreddit_cnts)

In [None]:
subreddit_cnts["depression"]

In [None]:
for dataset, subs in filter_subs.items():
    intersection = set(subreddit_cnts.index.tolist()) & subs
    print(dataset, len(subs), len(intersection))
    print(subreddit_cnts[list(intersection)].sum())

In [None]:
import random
random.seed(2021)

In [None]:
# the same set of users across filtering strategy
MIN_POSTS = 32
users_split = {"train": [], "val": [], "test": []}
for user in all_users:
    tmp = random.random()
    if tmp < 0.1:
        users_split["test"].append(user)
    elif 0.1 <= tmp < 0.2:
        users_split["val"].append(user)
    else:
        users_split["train"].append(user)

In [None]:
# for filter_type in ['Depression', 'RSDD', 'All Mental Health']:
for filter_type in ['All Mental Health']:
    print(filter_type)
    out_dir = f"split_filter_{filter_type.replace(' ', '-')}"
    os.makedirs(out_dir, exist_ok=True)
    all_user_submissions = defaultdict(list)
    all_user_submissions_utc = defaultdict(list)
    all_user_comments = defaultdict(list)
    all_user_comments_utc = defaultdict(list)
    all_labels = {}
    for fname in tqdm(glob("../emnlp-2020-mental-health-generalization/data/processed/reddit/wolohan/*.submissions.tar.gz")):
        with gzip.open(fname, "r") as f:
            data0 = json.load(f)
        user = data0[0]["user_id_str"]
        label0 = int(data0[0]["depression"] == "depression")
        for record in data0:
            if record['subreddit'] in filter_subs[filter_type]:
                continue
            title = " ".join(record["title_tokenized"])
            text = ' '.join(record["text_tokenized"])
            utc = int(record["created_utc"])
            all_labels[user] = label0
            all_user_submissions[user].append(title+"\n"+text)
            all_user_submissions_utc[user].append(utc)

    for fname in tqdm(glob("../emnlp-2020-mental-health-generalization/data/processed/reddit/wolohan/*.comments.tar.gz")):
        with gzip.open(fname, "r") as f:
            data0 = json.load(f)
        user = data0[0]["user_id_str"]
        for record in data0:
            if record['subreddit'] in filter_subs[filter_type]:
                continue
            text = ' '.join(record["text_tokenized"])
            utc = int(record["created_utc"])
            all_user_comments[user].append(text)
            all_user_comments_utc[user].append(utc)
    
    print({k: len(v) for k, v in users_split.items()})
    for split, users in users_split.items():
        with open(f"{out_dir}/{split}.pkl", "wb") as f:
            user_posts = []
            labels = []
            for user in tqdm(users):
                if user not in all_labels:
                    continue
                label0 = all_labels[user]
                posts0 = all_user_submissions[user] + all_user_comments[user]
                times0 = all_user_submissions_utc[user] + all_user_comments_utc[user]
                if len(posts0) < MIN_POSTS:
                    continue
                # sort by ascending time
                sorted_posts = [pair[0] for pair in sorted(zip(posts0, times0), key=lambda x: x[1])]
                user_posts.append(posts0)
                labels.append(label0)
            print(split, len(labels))
            pickle.dump([user_posts, labels], f)