# Dataset Customization

*Modify the filters which questions are chosen for the dataset.*


In [1]:
import sqlite3
import pandas as pd
import logging
import pickle

In [2]:
logging.basicConfig()
logging.getLogger().setLevel(logging.INFO)
log = logging.getLogger("training-set-builder")

Connect to database, set number of questions to retreive

In [3]:
# Create your connection.
db = sqlite3.connect(r'C:\Users\liamb\Documents\graph4stackoverflow\stackoverflow.db')
cursor = db.cursor()

In [4]:
# SMALL: 10000
# LARGE: 70000
QUESTIONS_RETREIVED = 70000

# Following Xu et al. (2018);
1. Retreive all questions with the python tag
2. Questions must have an accepted answer
3. Questions must have at least 4 answers

In [5]:
# OLD VERSION:
# valid_questions = pd.read_sql_query(f"""
#         SELECT Q.PostId, Q.Body, Q.Title, Q.OwnerUserId FROM Post Q
#         INNER JOIN Post A ON Q.PostId = A.ParentId
#         WHERE (Q.Tags LIKE '%<python>%')
#         GROUP BY A.ParentId
#         HAVING SUM(A.Score) > 15
#         LIMIT {QUESTIONS_RETREIVED}
# """, db)
# valid_questions.columns = ['post_id', 'question_body', 'question_title', 'question_user_id']
# valid_questions


In [6]:
# NEW VERSION:
"""
# Following Xu et al. (2018);
1. Retreive all questions with the python tag
2. Questions must have an accepted answer
3. Questions must have at least 4 answers
"""

valid_questions = pd.read_sql_query(f"""
        SELECT Q.PostId, Q.Body, Q.Title, Q.OwnerUserId FROM Post Q
        INNER JOIN Post A ON Q.PostId = A.ParentId
        WHERE (Q.Tags LIKE '%<python>%') and (Q.AcceptedAnswerId IS NOT NULL) AND Q.AnswerCount >= 4
        LIMIT {QUESTIONS_RETREIVED}
""", db)
valid_questions.columns = ['post_id', 'question_body', 'question_title', 'question_user_id']
valid_questions


Unnamed: 0,post_id,question_body,question_title,question_user_id
0,337,<p>I am about to build a piece of a project th...,XML Processing in Python,111.0
1,337,<p>I am about to build a piece of a project th...,XML Processing in Python,111.0
2,337,<p>I am about to build a piece of a project th...,XML Processing in Python,111.0
3,337,<p>I am about to build a piece of a project th...,XML Processing in Python,111.0
4,337,<p>I am about to build a piece of a project th...,XML Processing in Python,111.0
...,...,...,...,...
69995,3761124,<p>I have a dictionary with either a integer o...,Finding maximum value in a dictionary containi...,441337.0
69996,3761124,<p>I have a dictionary with either a integer o...,Finding maximum value in a dictionary containi...,441337.0
69997,3761124,<p>I have a dictionary with either a integer o...,Finding maximum value in a dictionary containi...,441337.0
69998,3761124,<p>I have a dictionary with either a integer o...,Finding maximum value in a dictionary containi...,441337.0


In [7]:
"Number of unique users in training set: {}".format(valid_questions['question_user_id'].nunique())

'Number of unique users in training set: 4553'

## Optional: Reduce database to just the training set

- Only include users who are needed
- Only include posts which are needed
- Only include comments which are needed
- Only include badges which are needed

This is particulary useful when training on an external server where the full database is too large to transfer.

In [8]:
# Get target question ids
target_question_ids = valid_questions['post_id'].tolist()
target_question_ids

[337,
 337,
 337,
 337,
 337,
 337,
 337,
 337,
 337,
 337,
 337,
 337,
 469,
 469,
 469,
 469,
 535,
 535,
 535,
 535,
 535,
 535,
 535,
 683,
 683,
 683,
 683,
 683,
 683,
 683,
 683,
 742,
 742,
 742,
 742,
 742,
 742,
 742,
 742,
 742,
 766,
 766,
 766,
 766,
 766,
 766,
 773,
 773,
 773,
 773,
 773,
 773,
 773,
 773,
 773,
 773,
 773,
 773,
 773,
 773,
 972,
 972,
 972,
 972,
 972,
 972,
 972,
 972,
 972,
 972,
 972,
 972,
 972,
 972,
 972,
 972,
 972,
 972,
 972,
 1171,
 1171,
 1171,
 1171,
 1171,
 1171,
 1171,
 1476,
 1476,
 1476,
 1476,
 1476,
 1476,
 1476,
 1476,
 1734,
 1734,
 1734,
 1734,
 1829,
 1829,
 1829,
 1829,
 1854,
 1854,
 1854,
 1854,
 1854,
 1854,
 1854,
 1854,
 1854,
 1854,
 1854,
 1854,
 1854,
 1854,
 1854,
 1854,
 1854,
 1854,
 1854,
 1854,
 1854,
 1854,
 1854,
 1854,
 1854,
 1854,
 1983,
 1983,
 1983,
 1983,
 1983,
 1983,
 1983,
 1983,
 1983,
 1983,
 1983,
 1983,
 1983,
 1983,
 1983,
 1983,
 1983,
 1983,
 2933,
 2933,
 2933,
 2933,
 2933,
 2933,
 2933,
 2933,
 

In [9]:
# Get comment ids from target questions
target_question_comment_ids = pd.read_sql_query(f"""
        SELECT C.CommentId FROM Comment C
        INNER JOIN Post P ON C.PostId = P.PostId
        WHERE P.PostId IN ({','.join([str(x) for x in target_question_ids])})
""", db)['CommentId'].tolist()
target_question_comment_ids

[91959108,
 101289352,
 95886610,
 100148644,
 100153092,
 100274880,
 11143825,
 56074548,
 122445392,
 123977389,
 10878993,
 14212856,
 66426829,
 66784385,
 84768515,
 121073638,
 20573549,
 20573550,
 1710850,
 1713035,
 71404286,
 52400945,
 88822720,
 96175951,
 223296,
 3406081,
 45935119,
 56986982,
 126902510,
 59624877,
 66759015,
 104336703,
 59157900,
 74388769,
 85205636,
 103439080,
 121194397,
 128666478,
 15969994,
 68844278,
 121829805,
 1003615,
 14942726,
 40406609,
 51624123,
 51624504,
 51624712,
 51635228,
 51635763,
 51637194,
 53598589,
 53601045,
 88144354,
 128961455,
 22804126,
 51356379,
 4115788,
 19511761,
 59595836,
 102718302,
 119565993,
 798754,
 79219046,
 75414627,
 30275959,
 38538847,
 4093715,
 19787276,
 19787277,
 57781299,
 124929555,
 19723,
 85096162,
 35165095,
 1236449,
 85121837,
 97343204,
 115318596,
 127051533,
 128406992,
 128407014,
 7459421,
 33788191,
 249184,
 5483402,
 15761265,
 122152689,
 125164810,
 3999455,
 11005215,
 71005

In [10]:
# Get post ids from answers to target questions
target_question_answer_post_ids = pd.read_sql_query(f"""
        SELECT A.PostId FROM Post A
        INNER JOIN Post Q ON A.ParentId = Q.PostId
        WHERE Q.PostId IN ({','.join([str(x) for x in target_question_ids])})
""", db)['PostId'].tolist()
target_question_answer_post_ids

[342,
 471,
 525,
 635,
 69410,
 69772,
 123307,
 199213,
 202259,
 7954780,
 13832269,
 23143835,
 497,
 518,
 3040,
 195170,
 538,
 541,
 660,
 61746,
 74452,
 117712,
 9120453,
 701,
 735,
 745,
 750,
 31126,
 31188,
 57833,
 4905822,
 764,
 4572,
 8320,
 27780,
 27792,
 33957,
 123090,
 2921293,
 26250049,
 777,
 802,
 1619,
 8332,
 4813530,
 33058407,
 783,
 7286,
 37252,
 1573195,
 14443477,
 16427674,
 20013133,
 31660194,
 44617583,
 45431237,
 45873519,
 61048516,
 68091577,
 69783850,
 982,
 2982,
 4600,
 22525,
 959064,
 8961717,
 9041763,
 9636303,
 16240409,
 24748849,
 24865663,
 28060251,
 32076685,
 34404761,
 43703054,
 45341362,
 64950870,
 70662971,
 73486831,
 1174,
 1191,
 3107,
 28705,
 29836,
 4292022,
 34850877,
 1478,
 1479,
 1484,
 13107,
 37226387,
 37955839,
 63297421,
 65057066,
 1780,
 6161,
 123093,
 1418610,
 1840,
 1852,
 1870,
 1885,
 1857,
 1871,
 1879,
 28426,
 3021004,
 7587420,
 7707465,
 14231316,
 14477954,
 14885455,
 15674751,
 25863224,
 26643

In [11]:
# Get comment ids from answers to target questions
target_question_answer_comment_ids = pd.read_sql_query(f"""
        SELECT C.CommentId FROM Comment C
        INNER JOIN Post P ON C.PostId = P.PostId
        WHERE P.PostId IN ({','.join([str(x) for x in target_question_answer_post_ids])})
""", db)['CommentId'].tolist()
target_question_answer_comment_ids

[19687,
 28358,
 1004799,
 2569142,
 5310505,
 5452460,
 9149586,
 112239409,
 62004485,
 73639411,
 86092343,
 110824458,
 37009654,
 127309525,
 1809583,
 78482232,
 68179135,
 75906884,
 83436441,
 95472956,
 95689581,
 101273641,
 107580230,
 112885164,
 112925832,
 117750612,
 47627995,
 75871843,
 84055960,
 85877248,
 105165851,
 112328782,
 6516989,
 77942,
 380033,
 4761427,
 11740045,
 92154367,
 109487682,
 111684189,
 127602520,
 890358,
 909442,
 1460422,
 1462370,
 3312421,
 26524901,
 66379065,
 74807413,
 74807910,
 79414054,
 116673347,
 125980162,
 37017015,
 37627178,
 39510577,
 49930520,
 52358701,
 52455643,
 54869282,
 66400024,
 69699187,
 80708617,
 81488346,
 81488519,
 104825053,
 116473648,
 116835087,
 121612422,
 125651476,
 16083866,
 9138090,
 26573661,
 60359673,
 13509643,
 26441697,
 27731542,
 80139805,
 80145532,
 121930875,
 19826518,
 38337772,
 38346495,
 52239891,
 62979012,
 63956147,
 38085905,
 77809809,
 99953037,
 6867868,
 8563618,
 171110

In [12]:
# Get user ids from answers to target questions
target_question_answer_user_ids = pd.read_sql_query(f"""
        SELECT A.OwnerUserId FROM Post A
        INNER JOIN Post Q ON A.ParentId = Q.PostId
        WHERE Q.PostId IN ({','.join([str(x) for x in target_question_ids])})
""", db)['OwnerUserId'].dropna().tolist()
target_question_answer_user_ids

[59.0,
 147.0,
 154.0,
 188.0,
 11072.0,
 9510.0,
 27642.0,
 21106.0,
 232485.0,
 1527852.0,
 346478.0,
 50.0,
 153.0,
 457.0,
 745.0,
 156.0,
 157.0,
 197.0,
 6372.0,
 8450.0,
 218681.0,
 111.0,
 145.0,
 154.0,
 199.0,
 3119.0,
 2147.0,
 4702.0,
 572606.0,
 612.0,
 1057.0,
 2990.0,
 2384.0,
 3207.0,
 15687.0,
 351981.0,
 3408904.0,
 150.0,
 1384652.0,
 92.0,
 1057.0,
 499257.0,
 546822.0,
 189.0,
 207.0,
 3926.0,
 83284.0,
 650654.0,
 1141493.0,
 2237635.0,
 541136.0,
 7933904.0,
 8137464.0,
 4531270.0,
 9726459.0,
 14558.0,
 7864006.0,
 200.0,
 99.0,
 618.0,
 2482.0,
 110274.0,
 176186.0,
 841337.0,
 650551.0,
 1099876.0,
 1640404.0,
 3748584.0,
 541136.0,
 5010481.0,
 4548106.0,
 3781929.0,
 4475534.0,
 1190453.0,
 7452220.0,
 15096247.0,
 267.0,
 188.0,
 101.0,
 620.0,
 2963.0,
 180962.0,
 2639344.0,
 305.0,
 269.0,
 2089740.0,
 1531.0,
 541136.0,
 6496481.0,
 14033284.0,
 12415637.0,
 77.0,
 758.0,
 20832.0,
 116.0,
 116.0,
 30.0,
 50.0,
 2089740.0,
 216.0,
 116.0,
 3051.0,
 36430

In [13]:
# USER INFO: Get badge ids from users who answered target questions
target_question_answer_badge_ids = pd.read_sql_query(f"""
        SELECT B.UserId FROM Badge B
        WHERE B.UserId IN ({','.join([str(x) for x in target_question_answer_user_ids])})
""", db)['UserId'].to_list()
target_question_answer_badge_ids

[1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,


In [14]:
# USER GRAPH: Get comment ids from users who answered target questions
target_question_answer_user_comment_ids = pd.read_sql_query(f"""
        SELECT C.CommentId FROM Comment C
        WHERE C.UserId IN ({','.join([str(x) for x in target_question_answer_user_ids])})
""", db)['CommentId'].tolist()
target_question_answer_user_comment_ids

[720,
 721,
 723,
 4039,
 4124,
 4376,
 4731,
 4860,
 6440,
 9568,
 10081,
 10083,
 13865,
 16901,
 20623,
 58991,
 61925,
 69443,
 69654,
 69695,
 69759,
 72521,
 73815,
 73924,
 79415,
 84361,
 85447,
 86392,
 98277,
 98298,
 98321,
 135847,
 136397,
 142778,
 158959,
 160657,
 160663,
 160779,
 160782,
 167457,
 168918,
 170664,
 171613,
 173834,
 174558,
 189699,
 204831,
 208139,
 229492,
 233762,
 234013,
 235669,
 237324,
 237352,
 237707,
 248106,
 248135,
 248145,
 248148,
 248189,
 250045,
 250207,
 250415,
 250570,
 250927,
 276561,
 278996,
 283249,
 298632,
 298634,
 306041,
 329402,
 329406,
 338687,
 339538,
 339604,
 362282,
 362464,
 363421,
 367991,
 383072,
 389800,
 406382,
 409969,
 417717,
 422664,
 441303,
 445639,
 454050,
 454052,
 459091,
 459233,
 460695,
 478760,
 507302,
 507629,
 509434,
 521668,
 521723,
 529015,
 530431,
 577205,
 632104,
 682122,
 684290,
 685893,
 688286,
 752042,
 767822,
 767825,
 767846,
 767963,
 770198,
 770207,
 773299,
 773304,


In [15]:
# USER GRAPH: Get post ids from users who answered target questions
target_question_answer_user_post_ids = pd.read_sql_query(f"""
        SELECT P.PostId FROM Post P
        WHERE P.OwnerUserId IN ({','.join([str(x) for x in target_question_answer_user_ids])}) OR P.LastEditorUserId IN ({','.join([str(x) for x in target_question_answer_user_ids])})
""", db)['PostId'].tolist()
target_question_answer_user_post_ids

[9,
 11,
 12,
 19,
 22,
 30,
 31,
 33,
 44,
 48,
 56,
 58,
 65,
 71,
 73,
 76,
 77,
 80,
 85,
 87,
 88,
 92,
 98,
 99,
 103,
 107,
 108,
 109,
 124,
 126,
 127,
 128,
 133,
 134,
 135,
 139,
 141,
 142,
 146,
 148,
 153,
 159,
 167,
 170,
 173,
 174,
 175,
 180,
 194,
 197,
 199,
 206,
 207,
 212,
 229,
 233,
 243,
 246,
 263,
 268,
 269,
 274,
 297,
 298,
 304,
 308,
 328,
 329,
 332,
 335,
 336,
 337,
 339,
 342,
 347,
 354,
 356,
 361,
 364,
 367,
 382,
 384,
 411,
 412,
 427,
 469,
 470,
 471,
 482,
 483,
 497,
 502,
 516,
 518,
 521,
 522,
 525,
 531,
 535,
 537,
 538,
 539,
 540,
 541,
 551,
 561,
 566,
 589,
 590,
 591,
 594,
 595,
 602,
 605,
 607,
 608,
 622,
 623,
 629,
 635,
 651,
 657,
 660,
 665,
 667,
 676,
 679,
 681,
 683,
 691,
 699,
 701,
 704,
 709,
 712,
 713,
 718,
 723,
 735,
 742,
 745,
 750,
 751,
 754,
 761,
 762,
 766,
 768,
 771,
 773,
 774,
 777,
 783,
 791,
 794,
 795,
 797,
 798,
 802,
 805,
 817,
 826,
 829,
 833,
 834,
 840,
 845,
 871,
 873,
 876,
 884,

**Insert into new DB**

In [16]:
NEW_DB_NAME = "..\database\g4so.db"

new_db = sqlite3.connect(NEW_DB_NAME)
new_db_cursor = new_db.cursor()

In [17]:

# Delete old DB if it exists
new_db_cursor.execute(f"""DROP TABLE IF EXISTS post;""")
new_db_cursor.execute(f"""DROP TABLE IF EXISTS comment;""")
new_db_cursor.execute(f"""DROP TABLE IF EXISTS badge;""")
new_db_cursor.execute(f"""DROP TABLE IF EXISTS user;""")
new_db_cursor.execute(f"""DROP TABLE IF EXISTS tag;""")
new_db.commit()

In [18]:
with open("../database/create.sql", "r") as sql_file:
    new_db_cursor.executescript(sql_file.read())
new_db.commit()

In [19]:
# Get list of column names in post table
cursor.execute(f"""SELECT * FROM post LIMIT 1;""")
post_columns = [description[0] for description in cursor.description]

In [20]:
# Attach original DB to new DB
cursor.execute(f"""ATTACH '{NEW_DB_NAME}' AS NEW_DB;""")
db.commit()

In [21]:
# Insert chosen posts
insert_posts_query = f"""INSERT INTO NEW_DB.post({', '.join(post_columns)}) SELECT {', '.join(post_columns)} FROM main.post WHERE post.PostId IN ({', '.join([str(x) for x in target_question_ids + target_question_answer_post_ids + target_question_answer_user_post_ids])})"""
print(insert_posts_query)
cursor.execute(insert_posts_query)
db.commit()

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [22]:
# Get list of column names in comment table
cursor.execute(f"""SELECT * FROM comment LIMIT 1;""")
comment_columns = [description[0] for description in cursor.description]

# Insert chosen comments
insert_comments_query = f"""INSERT INTO NEW_DB.comment({', '.join(comment_columns)}) SELECT {', '.join(comment_columns)} FROM main.comment WHERE comment.CommentId IN ({', '.join([str(x) for x in target_question_comment_ids + target_question_answer_comment_ids + target_question_answer_user_comment_ids])})"""
cursor.execute(insert_comments_query)
db.commit()

In [23]:
# Get list of column names in badge table
cursor.execute(f"""SELECT * FROM badge LIMIT 1;""")
badge_columns = [description[0] for description in cursor.description]

# Insert chosen badges
insert_badges_query = f"""INSERT INTO NEW_DB.badge({', '.join(badge_columns)}) SELECT {', '.join(badge_columns)} FROM main.badge WHERE badge.UserId IN ({', '.join([str(x) for x in target_question_answer_badge_ids])})"""
cursor.execute(insert_badges_query)
db.commit()

In [24]:
# Get list of column names in user table
cursor.execute(f"""SELECT * FROM user LIMIT 1;""")
user_columns = [description[0] for description in cursor.description]

# Insert chosen users
insert_users_query = f"""INSERT INTO NEW_DB.user({', '.join(user_columns)}) SELECT {', '.join(user_columns)} FROM main.user WHERE user.UserId IN ({', '.join([str(x) for x in target_question_answer_user_ids])})"""
cursor.execute(insert_users_query)
db.commit()

In [25]:
# Insert tag table
insert_tags_query = f"""INSERT INTO NEW_DB.tag(TagId, TagName, Count) SELECT TagId, TagName, Count FROM main.tag"""
cursor.execute(insert_tags_query)
db.commit()