<a href="https://colab.research.google.com/github/daveshap/QuestionDetector/blob/main/QuestionDetector.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Compile Training Data
Note: Generate the raw data with [this notebook](https://github.com/daveshap/QuestionDetector/blob/main/DownloadGutenbergTop100.ipynb)

In [1]:
import re
import random

datafile = '/content/drive/My Drive/Gutenberg/sentence_data.txt'
corpusfile = '/content/drive/My Drive/Gutenberg/corpus_data.txt'
testfile = '/content/drive/My Drive/Gutenberg/test_data.txt'
sample_cnt = 3000

questions = list()
exclamations = list()
other = list()

with open(datafile, 'r', encoding='utf-8') as infile:
  body = infile.read()
sentences = re.split('\n\n', body)

for i in sentences:
  if 'í' in i or 'á' in i:
    continue 
  if '?' in i:
    questions.append(i)
  elif '!' in i:
    exclamations.append(i)
  else:
    other.append(i)

def flatten_sentence(text):
  text = text.lower()
  fa = re.findall('[\w\s]',text)
  return ''.join(fa)

corpus = ''
random.seed()
data = random.sample(questions, sample_cnt)
for i in data:
  corpus += '<|SENTENCE|> %s <|LABEL|> %s <|END|>\n\n' % (flatten_sentence(i), 'question')
data = random.sample(exclamations, sample_cnt)
for i in data:
  corpus += '<|SENTENCE|> %s <|LABEL|> %s <|END|>\n\n' % (flatten_sentence(i), 'exclamation')
data = random.sample(other, sample_cnt)
for i in data:
  corpus += '<|SENTENCE|> %s <|LABEL|> %s <|END|>\n\n' % (flatten_sentence(i), 'other')
with open(corpusfile, 'w', encoding='utf-8') as outfile:
  outfile.write(corpus)
print('Done!', corpusfile)

corpus = ''
random.seed()
data = random.sample(questions, 50)
for i in data:
  corpus += '<|SENTENCE|> %s <|LABEL|> %s <|END|>\n\n' % (flatten_sentence(i), 'question')
data = random.sample(exclamations, 50)
for i in data:
  corpus += '<|SENTENCE|> %s <|LABEL|> %s <|END|>\n\n' % (flatten_sentence(i), 'exclamation')
data = random.sample(other, 50)
for i in data:
  corpus += '<|SENTENCE|> %s <|LABEL|> %s <|END|>\n\n' % (flatten_sentence(i), 'other')
with open(testfile, 'w', encoding='utf-8') as outfile:
  outfile.write(corpus)
print('Done!', testfile)

Done! /content/drive/My Drive/Gutenberg/corpus_data.txt
Done! /content/drive/My Drive/Gutenberg/test_data.txt


# Finetune Model
Finetune GPT-2

In [None]:
!pip install tensorflow-gpu==1.15.0 --quiet
!pip install gpt-2-simple --quiet

import gpt_2_simple as gpt2

# note: manually mount your google drive in the file explorer to the left

model_dir = '/content/drive/My Drive/GPT2/models'
checkpoint_dir = '/content/drive/My Drive/GPT2/checkpoint'
model_name = '124M'

gpt2.download_gpt2(model_name=model_name, model_dir=model_dir)
print('\n\nModel is ready!')

run_name = 'QuestionDetector'
step_cnt = 2000

sess = gpt2.start_tf_sess()

gpt2.finetune(sess,
              dataset=corpusfile,
              model_name=model_name,
              model_dir=model_dir,
              checkpoint_dir=checkpoint_dir,
              steps=step_cnt,
              restore_from='fresh',  # start from scratch
              #restore_from='latest',  # continue from last work
              run_name=run_name,
              print_every=50,
              sample_every=1000,
              save_every=1000
              )

Fetching checkpoint: 1.05Mit [00:00, 502Mit/s]                                                      
Fetching encoder.json: 1.05Mit [00:00, 96.3Mit/s]                                                   
Fetching hparams.json: 1.05Mit [00:00, 322Mit/s]                                                    
Fetching model.ckpt.data-00000-of-00001: 498Mit [00:03, 163Mit/s]                                   
Fetching model.ckpt.index: 1.05Mit [00:00, 237Mit/s]                                                
Fetching model.ckpt.meta: 1.05Mit [00:00, 172Mit/s]                                                 
Fetching vocab.bpe: 1.05Mit [00:00, 161Mit/s]                                                       




Model is ready!
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Loading checkpoint /content/drive/My Drive/GPT2/models/124M/model.ckpt
INFO:tensorflow:Restoring parameters from /content/drive/My Drive/GPT2/models/124M/model.ckpt


  0%|          | 0/1 [00:00<?, ?it/s]

Loading dataset...


100%|██████████| 1/1 [00:02<00:00,  2.51s/it]


dataset has 440110 tokens
Training...
[50 | 48.70] loss=2.31 avg=2.31
[100 | 89.53] loss=2.58 avg=2.44
[150 | 130.15] loss=2.51 avg=2.47
[200 | 170.86] loss=2.31 avg=2.43
[250 | 211.55] loss=1.92 avg=2.32
[300 | 252.27] loss=1.76 avg=2.23
[350 | 293.07] loss=1.80 avg=2.16
[400 | 333.74] loss=1.36 avg=2.06
[450 | 374.43] loss=1.47 avg=1.99
[500 | 414.98] loss=1.43 avg=1.93
[550 | 455.25] loss=1.06 avg=1.85
[600 | 495.61] loss=1.00 avg=1.78
[650 | 536.04] loss=1.06 avg=1.72
[700 | 576.33] loss=0.98 avg=1.66
[750 | 616.35] loss=0.83 avg=1.60
[800 | 656.75] loss=0.71 avg=1.54
[850 | 697.28] loss=0.45 avg=1.47
[900 | 737.44] loss=0.53 avg=1.41
[950 | 777.86] loss=0.53 avg=1.36


# Test Results

| Run | Model | Steps | Samples | Last Loss | Avg Loss | Accuracy |
|---|---|---|---|---|---|---|
| 01 | 124M | 2000 | 9000 |  |  |  |



In [None]:
right = 0
wrong = 0

print('Loading test set...')
with open(testfile, 'r', encoding='utf-8') as file:
  test_set = file.readlines()

for t in test_set:
  t = t.strip()
  if t == '':
    continue
  prompt = t.split('<|LABEL|>')[0] + '<|LABEL|>'
  expect = t.split('<|LABEL|>')[1].replace('<|END|>', '').strip()
  #print('\nPROMPT:', prompt)
  response = gpt2.generate(sess, 
                           return_as_list=True,
                           length=30,  # prevent it from going too crazy
                           prefix=prompt,
                           model_name=model_name,
                           model_dir=model_dir,
                           truncate='\n',  # stop inferring here
                           include_prefix=False,
                           checkpoint_dir=checkpoint_dir,)[0]
  response = response.strip()
  if expect in response:
    right += 1
  else:
    wrong += 1
  print('right:', right, '\twrong:', wrong, '\taccuracy:', right / (right+wrong))
  #print('RESPONSE:', response)

print('\n\nModel:', model_name)
print('Samples:', max_samples)
print('Steps:', step_cnt)