This notebook transform conversation logs into a set of Q&A pairs. A pair is a list of consecutive messages from two different users. The Q (Question) is made of N consecutive message from an user X, the A (Answer) part is made of M consecutive message from an user Y.
Note that Q is not necessarily a Question, but just a prompt for a reply which makes sense with the A part outside the context, so a chatbot can be trained solely on the Q&A pairs.

In [1]:
import psycopg2
conn = psycopg2.connect(user="postgres", password="mysecretpassword", host="localhost")
# distinct connection to perform insertions while iterating over the SELECT results
insertion_conn = psycopg2.connect(user="postgres", password="mysecretpassword", host="localhost")
# should be UTF-8
print('connection encoding, used as default:', conn.encoding)

connection encoding, used as default: UTF8


In [2]:
cur = conn.cursor()
cur.execute("CREATE TABLE IF NOT EXISTS qa_pair (question TEXT, answer TEXT, context TEXT);")
# cur.execute("TRUNCATE TABLE qa_pair;")
cur.close()
conn.commit()

Use a named cursor `window_read_cursor` to read the conversation by channel sort by date, so that consecutive messages in the same channel come together.

The named cursor is necessary to iterate over an important amount of data, which would not easily fit into memory

In [3]:
from collections import deque
from itertools import takewhile, islice
import time
import random

named_cur = conn.cursor('window_read_cursor')
insertion_cur = insertion_conn.cursor()


# table schema:
# chat_messages (message TEXT, message_time TIMESTAMP, username TEXT, channel TEXT)
print("about to run query")
named_cur.execute("SELECT message, message_time, username, channel FROM chat_messages ORDER BY channel ASC, message_time ASC")


start_time = time.time()

window = deque()
window_size = 6
# keep track of the latest channel to detect when the conversation has changed
latest_channel = ''
print("iterating over query results")
total = 0
for record in named_cur:
    
    if len(window) > window_size:
        window.popleft()
    if record[3] != latest_channel:
        window.clear()
        latest_channel = record[3]
    window.append(record)
    # now the window is ready, can we extract a Q&A pair ?
    # do we have more than one utterance? At least one as (Q)uestion/prompt, another as (A)nswer and a third to know A wasn't multi-message
    if len(window) < 3:
        continue
    # first, check that we are not in the middle of an utterance made of consecutive messages from the same user
    if window[-1][2] == window[-2][2]:
        continue
    
    A_user = window[-2][2]
    window_list = list(reversed(window))
    
    # now we know there's an A, check there's a Q
    
    if not any(map(lambda r: r[2] != A_user, window_list[2:])):
        continue
    # extract the consecutive messages to form Q and A
    answer = list(takewhile(lambda r: r[2] == A_user, window_list[1:]))
    # take the last user writing a message in the window before A
    Q_user = window_list[len(answer) + 1][2]
    prompt = list(takewhile(lambda r: r[2] == Q_user, window_list[len(answer) + 1:]))
    if False:
        print('\n\nprompt:')
        for r in prompt:
            print(r[2], ':', r[0])
        print('answer:')
        for r in answer:
            print(r[2], ':', r[0])
        if random.randint(1,1000) == 4:
            break
    total += 1
    q_text = ""
    a_text = ""
    for r in reversed(prompt):
        q_text += '\n' + r[0]
    if len(q_text) < 10 or len(q_text) > 1000:
        continue
    for r in reversed(answer):
        a_text += '\n' + r[0]
    if len(a_text) > 1000 or len(a_text) < 2:
        continue
    context = r[3] + " " + str(r[1])
    insertion_cur.execute("INSERT INTO qa_pair (question, answer, context) VALUES (%s, %s, %s)", (q_text, a_text, context))
    if total % 2000 == 0:
        print(f'extracted {total} QA pairs, committing...')
        insertion_conn.commit()
    
conn.commit()
insertion_cur.close()


elapsed_time = round(time.time() - start_time)
print(f'{total} pairs inserted in {elapsed_time} seconds')

about to run query
iterating over query results
extracted 2000 QA pairs, committing...
extracted 4000 QA pairs, committing...
extracted 6000 QA pairs, committing...
extracted 8000 QA pairs, committing...
extracted 10000 QA pairs, committing...
extracted 12000 QA pairs, committing...
extracted 14000 QA pairs, committing...
extracted 16000 QA pairs, committing...
extracted 18000 QA pairs, committing...
extracted 20000 QA pairs, committing...
extracted 22000 QA pairs, committing...
extracted 24000 QA pairs, committing...
extracted 26000 QA pairs, committing...
extracted 28000 QA pairs, committing...
extracted 30000 QA pairs, committing...
extracted 32000 QA pairs, committing...
extracted 34000 QA pairs, committing...
extracted 36000 QA pairs, committing...
extracted 38000 QA pairs, committing...
extracted 40000 QA pairs, committing...
extracted 42000 QA pairs, committing...
extracted 44000 QA pairs, committing...
extracted 46000 QA pairs, committing...
extracted 50000 QA pairs, committing