In [50]:
import ast
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()
from sklearn.cluster import KMeans

In [14]:
def converter(in_str):
    return np.fromstring(in_str[1:-1], sep=" ")

In [89]:
def show_similar_posts(post_df, target_post_idx, print_columns, neighbor_count=5):
    reps = np.stack(post_df["embedding"].to_numpy())
    reps = reps / (np.sqrt((reps ** 2).sum(axis=1, keepdims=True)))
    target_rep = reps[target_post_idx]
    sims = ((reps - target_rep) ** 2).sum(axis=1)
    nearest = np.argsort(sims)
    print("target post {}".format(target_post_idx))
    for col in print_columns:
        print(col, post_df.iloc[target_post_idx][col])
    print()
    print("nearest neighbors:")
    idx = 1
    for j in nearest[1:neighbor_count+1]:
        print("Neighbor #", idx, "Post #: ", j, "%.3f" % sims[j])
        for col in print_columns:
            print(col, post_df.iloc[j][col])
            print()
        print()
        idx += 1

In [95]:
def get_clusters(post_df, print_columns, cluster_count=10, ex_in_cluster=3):
    reps = np.stack(post_df["embedding"].to_numpy())
    k_means = KMeans(n_clusters=10).fit(reps)
    trf_reps = k_means.transform(reps)
    reps_pred = k_means.predict(reps)
    reps_min_dist = np.array([trf_reps[i, reps_pred[i]] for i in range(len(reps_pred))])
    idxs = np.arange(len(reps))
    # for each cluster, get posts that are closest to the center
    for cluster in range(cluster_count):
        reps_in_cluster = reps_pred == cluster
        dist_in_cluster = reps_min_dist[reps_in_cluster]
        idx_in_cluster = idxs[reps_in_cluster]
        zipped = list(zip(dist_in_cluster, idx_in_cluster))
        zipped = sorted(zipped)
        print("cluster num {}".format(cluster))
        print("examples:")
        for i in range(ex_in_cluster):
            ex_idx = zipped[i][1]
            print("dist to cluster center", zipped[i][0])
            for col in print_columns:
                print(col, post_df.iloc[ex_idx][col])
            print()
    return k_means

In [97]:
def get_cluster_stats(post_df, stat_columns, k_means):
    reps = np.stack(post_df["embedding"].to_numpy())
    reps_pred = k_means.predict(reps)
    idxs = np.arange(len(reps))
    n_clusters = len(k_means.cluster_centers_)
    for cluster in range(n_clusters):
        reps_in_cluster = reps_pred == cluster
        idx_in_cluster = idxs[reps_in_cluster]
        print("cluster num {}".format(cluster))
        print("stats:")
        for col in stat_columns:
            vals = post_df.iloc[idx_in_cluster][col].to_numpy()
            print("mean {}: {}".format(col, np.mean(vals)))
            print("std {}: {}".format(col, np.std(vals)))
            print("median {}: {}".format(col, np.median(vals)))

In [15]:
full_df = pd.read_csv("data/all_posts_with_comment_topics.csv", index_col=0, converters={'mean_topic_dist': converter})

In [16]:
full_df.head(5)

Unnamed: 0,id,created_utc,author,author_fullname,author_flair_text,url,title,selftext,upvote_ratio,score,num_comments,data_split,mean_topic_dist,max_mean_topic,mode_max_topic,max_topic_sample
0,2t6afv,1421852000.0,AmarilloByMorning,t2_eka0c,"TTC#1 | 3 MCs, 1CP",https://www.reddit.com/r/ttcafterloss/comments...,A stranger made my night last night.,Last night I stopped at the grocery store to d...,0.88,11,2,val,"[0.02735043, 0.05982906, 0.02735043, 0.0324786...",28.0,28.0,28.0
3,2t5c8b,1421824000.0,PineappleSuppository,t2_ggn01,,https://www.reddit.com/r/ttcafterloss/comments...,Hello Everyone!,I have been mostly lurking and occasionally po...,0.91,8,10,train,"[0.02874407, 0.04055525, 0.02914665, 0.0311595...",16.0,16.0,29.0
7,2t2ksx,1421776000.0,angthrice,t2_bf1iq,"TTC#1 since May 2014, MC Dec 2014",https://www.reddit.com/r/ttcafterloss/comments...,AF here ! Never been more relieved!,I just got Af! Cramps are horrible and went fr...,1.0,8,13,train,"[0.03086178, 0.03109718, 0.03706493, 0.031568,...",29.0,29.0,29.0
10,2t0gqy,1421725000.0,gamingmamaftw,t2_jaw46,,https://www.reddit.com/r/ttcafterloss/comments...,Need some advice/totally lost :(,"Hi ladies,\n\nI am in need of some advice (and...",1.0,4,5,train,"[0.02873563, 0.05555556, 0.05363985, 0.0344827...",1.0,1.0,1.0
12,2sy3ig,1421685000.0,Hippopotamuscles,t2_g2obg,"James 11/14, blighted ovum 06/16 - Infertile.",https://www.reddit.com/r/ttcafterloss/comments...,Depression after miscarriage - NY Times article,I thought that [this](http://parenting.blogs.n...,1.0,6,11,test,"[0.02627167, 0.04744719, 0.02305106, 0.0434214...",9.0,9.0,11.0


# Word2Vec + LSTM Post Embeddings

In [43]:
df = pd.read_csv("data/post_embeddings.csv", index_col=0)

In [44]:
df["embedding"] = df["embedding"].apply(lambda x: ast.literal_eval(x))

In [45]:
df = df.merge(full_df, how="inner", on="id")

## Nearest Neighbors

In [48]:
# look at most similar posts for random selection posts
post_idxs = np.random.choice(len(df), 10)

In [92]:
show_similar_posts(df, post_idxs[0], print_columns=["title", "selftext", "num_comments"])

target post 5796
title Anyone ever had a Matris US prior to FET? What's the deal?!
selftext Hey everyone! I have my 1st FET coming up later this week  I have a Matris US scheduled for Tuesday. I've been pretty calm so far through this FET process, but don't know much about this US, and can't find much about it online. 
I know it looks at the cells in there uterus and rates it 1-10 and anything 7.5 and over means there are cells present that would be receptive to an embryo. If it's less they will cancel the cycle! I think this is why I'm getting a little nervous! I'm freaking that i won't score high enough and they'll cancel my transfer... and I've already faced so many delays, not sure how I'd handle another one. 

Just wondering if anyone else has been required to do this prior to FET and what your results and thoughts on them are?! I asked the nurse at the clinic how often they don't come back good. And she said it was 50-50. Ugh didn't make me feel better!

I like the idea of it, as

## K-Means Clustering

In [98]:
k_means = get_clusters(df, ["title", "selftext", "num_comments"])

cluster num 0
examples:
dist to cluster center 0.6345985652361779
title Intro & my doctor's appointment yesterday
selftext Been lurking here for a couple weeks, but wanted to make a proper introduction. Husband and I found out we were pregnant last July. We weren't TTC, but we weren't *not* TTC either. Needless to say, it was very unexpected. Ended up MC in August. Have been TTC since then with no luck. Every day, every cycle is a struggle.

Yesterday I had my first appt. with an OBGYN since my D&C over the summer. I recently moved to a completely different state, so this was a doctor I hadn't seen before. She was very nice. I told her about all my various concerns since the MC... ovulating late in my cycle, irregular cycles, super heavy bleeding and clots, severe cramps, etc.

I'm getting my progesterone levels tested next month. Very happy about that. Interestingly, she said that having bad cramps during your period is actually a good sign of having sufficient progesterone.

I feel l

In [99]:
get_cluster_stats(df, ["num_comments"], k_means)

cluster num 0
stats:
mean num_comments: 14.280296784830998
std num_comments: 19.088864497651315
median num_comments: 10.0
cluster num 1
stats:
mean num_comments: 12.816520467836257
std num_comments: 13.33023392258619
median num_comments: 9.0
cluster num 2
stats:
mean num_comments: 12.589216944801027
std num_comments: 12.718667580965057
median num_comments: 10.0
cluster num 3
stats:
mean num_comments: 14.527876631079478
std num_comments: 24.4524399755568
median num_comments: 10.0
cluster num 4
stats:
mean num_comments: 14.592557251908397
std num_comments: 17.369882187306874
median num_comments: 10.0
cluster num 5
stats:
mean num_comments: 12.525735294117647
std num_comments: 11.790377142264006
median num_comments: 9.0
cluster num 6
stats:
mean num_comments: 13.160342717258262
std num_comments: 13.629742618563021
median num_comments: 10.0
cluster num 7
stats:
mean num_comments: 13.361875637104994
std num_comments: 16.65762856310614
median num_comments: 9.0
cluster num 8
stats:
mean num_c