# Setup

In [None]:
!pip install --upgrade pylatexenc anytree openai

In [14]:
import anytree
from datetime import datetime
import json
import openai
from pylatexenc import latexwalker
from pylatexenc.latexwalker import LatexEnvironmentNode
from pylatexenc.latexwalker import LatexGroupNode
from pylatexenc.latexwalker import LatexMacroNode
from pylatexenc.latexwalker import LatexWalker
from pylatexenc import latex2text
from pylatexenc.latex2text import EnvironmentTextSpec
from pylatexenc.latex2text import fmt_matrix_environment_node
from pylatexenc.latex2text import LatexNodes2Text
from pylatexenc.latex2text import MacroTextSpec
from pylatexenc.latex2text import SpecialsTextSpec
from pylatexenc.macrospec import EnvironmentSpec
from pylatexenc.macrospec import MacroSpec
from pylatexenc.macrospec import MacroStandardArgsParser
from pylatexenc.macrospec import SpecialsSpec
import re
import requests
import subprocess

In [None]:
!git clone https://github.com/maxbbraun/wittgenstein.git

# Data Preparation

Source: https://www.gutenberg.org/ebooks/5740

In [15]:
# Load the (corrected) LaTeX version of the Tractatus into memory.
with open('wittgenstein/tractatus.tex') as f:
  tractatus_latex = f.read()

In [16]:
# Define parsing rules for turning LaTeX into Unicode text with minimal HTML.
def init_pylatexenc(translate_to_english, include_footnotes=False):
  walker_context = latexwalker.get_default_latex_context_db()
  walker_context.add_context_category(
      'tractatus',
      prepend=True,
      macros=[
          MacroSpec('binom', '{{'),
          MacroSpec('BookTitle', '{'),
          MacroSpec('discretionary', '{{{'),
          MacroSpec('DPtypo', '{{'),
          MacroSpec('emph', '{'),
          MacroSpec('Emph', '{'),
          MacroSpec('EmphPart', '{'),
          MacroSpec('enlargethispage', '{'),
          MacroSpec('glq', '{'),
          MacroSpec('glqq', '{'),
          MacroSpec('grq', '{'),
          MacroSpec('grqq', '{'),
          MacroSpec('hspace', '{'),
          MacroSpec('Illustration', '[{'),
          MacroSpec('mbox', '{'),
          MacroSpec('Not', '{'),
          MacroSpec('overline', '{'),
          MacroSpec('phantom', '{'),
          MacroSpec('PropERef', '{'),
          MacroSpec('PropGRef', '{'),
          MacroSpec('PropositionE', '{{'),
          MacroSpec('PropositionG', '{{'),
          MacroSpec('raisebox', '{[[{'),
          MacroSpec('smash', '[[{'),
          MacroSpec('text', '{'),
          MacroSpec('textit', '{'),
          MacroSpec('vspace', '{')],
      environments=[
          EnvironmentSpec('array', '[{', is_math_mode=True),
          EnvironmentSpec('split', '', is_math_mode=True),
          EnvironmentSpec('tabular', '[{')],
      specials=[
          SpecialsSpec('`'),
          SpecialsSpec("'"),
          SpecialsSpec('``'),
          SpecialsSpec("''"),
          SpecialsSpec('&'),
          SpecialsSpec('^', args_parser=MacroStandardArgsParser('{')),
          SpecialsSpec('_', args_parser=MacroStandardArgsParser('{'))])
  latex_walker = LatexWalker(tractatus_latex, latex_context=walker_context)

  if translate_to_english:
    footnote_pattern = r' (Footnote: %(2)s)' if include_footnotes else ''
    binomial_pattern = r'(%(1)s choose %(2)s)'
    single_open_quote = '‘'
    single_close_quote = '’'
    double_open_quote = '“'
    double_close_quote = '”'
    illustration_placeholder = '[Figure]'
  else:
    footnote_pattern = r' (Fußnote: %(2)s)' if include_footnotes else ''
    binomial_pattern = r'(%(1)s über %(2)s)'
    single_open_quote = '‚'
    single_close_quote = '‘'
    double_open_quote = '„'
    double_close_quote = '“'
    illustration_placeholder = '[Abbildung]'

  nodes2text_context = latex2text.get_default_latex_context_db()
  nodes2text_context.add_context_category(
      'tractatus',
      prepend=True,
      macros=[
          MacroTextSpec('-', discard=True),
          MacroTextSpec('AllowBreak', discard=True),
          MacroTextSpec('BarOp', simplify_repl=' | '),
          MacroTextSpec('binom', simplify_repl=binomial_pattern),
          MacroTextSpec('BookTitle', simplify_repl=r'%s'),
          MacroTextSpec('dasHeiszt', simplify_repl='d. h.'),
          MacroTextSpec('discretionary', simplify_repl=r'%(3)s'),
          MacroTextSpec('DittoInWords', simplify_repl='„'),
          MacroTextSpec('DittoInWorten', simplify_repl='„'),
          MacroTextSpec('DotOp', simplify_repl=' . '),
          MacroTextSpec('DPtypo', simplify_repl=r'%(2)s'),
          MacroTextSpec('emph', simplify_repl=r'<em>%s</em>'),
          MacroTextSpec('Emph', simplify_repl=r'<em>%s</em>'),
          MacroTextSpec('EmphPart', simplify_repl=r'<em>%s</em>'),
          MacroTextSpec('False', simplify_repl='F'),
          MacroTextSpec('enlargethispage', discard=True),
          MacroTextSpec('exempliGratia', simplify_repl='e.g.'),
          MacroTextSpec('ExempliGratia', simplify_repl='E.g.'),
          MacroTextSpec('fivedots', simplify_repl='.....'),
          MacroTextSpec('footnote', simplify_repl=footnote_pattern),
          MacroTextSpec('fourdots', simplify_repl='....'),
          MacroTextSpec('glq', simplify_repl=single_open_quote),
          MacroTextSpec('glqq', simplify_repl=double_open_quote),
          MacroTextSpec('grq', simplify_repl=single_close_quote),
          MacroTextSpec('grqq', simplify_repl=double_close_quote),
          MacroTextSpec('hline', simplify_repl='; '),
          MacroTextSpec('hspace', discard=True),
          MacroTextSpec('idEst', simplify_repl='i.e.'),
          MacroTextSpec('IdEst', simplify_repl='I.e.'),
          MacroTextSpec('Illustration',
                        simplify_repl=illustration_placeholder),
          MacroTextSpec('Implies', simplify_repl='⊃'),
          MacroTextSpec('lor', simplify_repl='v'),
          MacroTextSpec('mbox', simplify_repl='%s'),
          MacroTextSpec('Not', simplify_repl='~%s'),
          MacroTextSpec('overline', simplify_repl='%s\u0305'),
          MacroTextSpec('phantom', discard=True),
          MacroTextSpec('PropERef', simplify_repl=r'%s'),
          MacroTextSpec('PropGRef', simplify_repl=r'%s'),
          MacroTextSpec('raisebox', simplify_repl=r'%(4)s'),
          MacroTextSpec('smash', simplify_repl=r'%(3)s'),
          MacroTextSpec('stretchyspace', discard=True),
          MacroTextSpec('text', simplify_repl=r'%s'),
          MacroTextSpec('textit', simplify_repl=r'<i>%s<i>'),
          MacroTextSpec('undAndere', simplify_repl='u. a.'),
          MacroTextSpec('undSoFort', simplify_repl='u. s. f.'),
          MacroTextSpec('UndSoWeiter', simplify_repl='U. s. w.'),
          MacroTextSpec('verystretchyspace', discard=True),
          MacroTextSpec('vspace', discard=True),
          MacroTextSpec('Wahr', simplify_repl='W'),
          MacroTextSpec('zumBeispiel', simplify_repl='z. B.'),
          MacroTextSpec('ZumBeispiel', simplify_repl='Z. B.')],
      environments=[
          EnvironmentTextSpec('array',
                              simplify_repl=fmt_matrix_environment_node),
          EnvironmentTextSpec('split',
                              simplify_repl=fmt_matrix_environment_node),
          EnvironmentTextSpec('tabular',
                              simplify_repl=fmt_matrix_environment_node)],
      specials=[
          SpecialsTextSpec('`', single_open_quote),
          SpecialsTextSpec("'", single_close_quote),
          SpecialsTextSpec('``', double_open_quote),
          SpecialsTextSpec("''", double_close_quote),
          SpecialsTextSpec('&', '|'),
          SpecialsTextSpec('^', simplify_repl='<sup>%s</sup>'),
          SpecialsTextSpec('_', simplify_repl='<sub>%s</sub>')])
  nodes2text = LatexNodes2Text(latex_context=nodes2text_context,
                               math_mode='text',
                               keep_braced_groups=False)

  return latex_walker, nodes2text

In [17]:
# Find the propositions in the structured text.
def gather_propositions(translate_to_english, nodes=None, latex_walker=None,
                        nodes2text=None):
  if not nodes:
    latex_walker, nodes2text = init_pylatexenc(
        translate_to_english=translate_to_english)
    nodes, _, _ = latex_walker.get_latex_nodes()
  assert latex_walker
  assert nodes2text

  # Pick the proposition macro based on the selected language.
  if translate_to_english:
    proposition_macro = 'PropositionE'
  else:
    proposition_macro = 'PropositionG'

  # Walk the nodes tree recursively and extract all propositions.
  propositions = []
  for node in nodes:
    if node.isNodeType(LatexGroupNode) or node.isNodeType(LatexEnvironmentNode):
      # Recurse into groups of nodes.
      child_nodes = node.nodelist
      if child_nodes:
        child_propositions = gather_propositions(
            translate_to_english=translate_to_english, nodes=child_nodes,
            latex_walker=latex_walker, nodes2text=nodes2text)
        propositions.extend(child_propositions)
    elif (node.isNodeType(LatexMacroNode) and
          node.macroname == proposition_macro):
      # Expect the proposition's number and content as the node arguments.
      proposition_nodes = node.nodeargd.argnlist
      assert len(proposition_nodes) == 2
      number_node = proposition_nodes[0]
      content_node = proposition_nodes[1]

      # Convert the number and content to text.
      number = nodes2text.node_to_text(number_node)
      content = nodes2text.node_to_text(content_node)

      # Remove any repeated and trailing whitespace.
      content = re.sub(r'\s+', ' ', content).strip()

      proposition = (number, content)
      propositions.append(proposition)

  return propositions

In [215]:
# Extract the propositions for both languages.
german_propositions = gather_propositions(translate_to_english=False)
english_propositions = gather_propositions(translate_to_english=True)
assert len(german_propositions) == len(english_propositions)
propositions = list(zip(german_propositions, english_propositions))

In [None]:
# Put the propositions into a tree structure.
propositions_tree = anytree.AnyNode(number='')
for german_proposition, english_proposition in propositions:
  number, german_content = german_proposition
  english_number, english_content = english_proposition
  assert number == english_number

  # Determine the number of the parent.
  parent_number = number[:-1]
  while parent_number.endswith('0'):
    parent_number = parent_number[:-1]
  if parent_number.endswith('.'):
    parent_number = parent_number[:-1]
  def parent_number_filter(node):
    return node.number == parent_number
  parent_node = anytree.search.find(propositions_tree,
                                    filter_=parent_number_filter)

  # Create the new node and attach it to the parent.
  anytree.AnyNode(number=number, parent=parent_node,
                  german_content=german_content,
                  english_content=english_content)

print(anytree.RenderTree(propositions_tree))

# Data Inspection

In [20]:
def proposition_importance(proposition_number):
  # The importance of the proposition, with 0 being the most important.
  number_match = re.search(r'(\d+)(\.(\d+))?', proposition_number)
  importance = len(number_match.group(3) or '')

  return importance

In [21]:
def proposition_html(number, german_content, english_content):
  html = ''

  # Hide all but 0 importance propositions by default.
  if proposition_importance(number) > 0:
    html += '<tr class="expandable" style="display: none;">\n'
  else:
    html += '<tr>\n'

  # See frontend/templates/index.html for CSS.
  html += '  <td class="english number mobile">%s</td>\n' % number
  html += '  <td class="english proposition">%s</td>\n' % english_content
  html += '  <td class="german number nomobile">%s</td>\n' % number
  html += '  <td class="german proposition nomobile">%s</td>\n' % german_content
  html += '</tr>'

  return html

In [None]:
importance_limit = 5  # @param {type:"slider", min:0, max:5, step:1}
proposition_prefix = ''  # @param{type:"string"}
as_html = True  # @param {type:"boolean"}
reverse = True  # @param {type:"boolean"}

if reverse:
  ordered = reversed(propositions)
else:
  ordered = propositions

for german_proposition, english_proposition in ordered:
  number, german_content = german_proposition
  english_number, english_content = english_proposition
  assert number == english_number

  if not number.startswith(proposition_prefix):
    continue

  importance = proposition_importance(number)
  if importance > importance_limit:
    continue

  if as_html:
    print(proposition_html(number, german_content, english_content))
  else:
    print('%s %s' % (number, german_content))
    print('%s %s' % (number, english_content))

# Training

Documentation: https://beta.openai.com/docs/guides/fine-tuning

In [None]:
openai_api_key = ''  # @param {type:"string"}
%env OPENAI_API_KEY=$openai_api_key

In [24]:
prompt_separator = ' ->'  # @param{type:"string"}
language_separator = '==='  # @param{type:"string"}
proposition_separator = '\n'  # @param{type:"string"}
stop_sequence = '###'  # @param{type:"string"}

In [25]:
training_filename = 'tractatus-%d.jsonl' % datetime.now().timestamp()

with open(training_filename, 'w', encoding='utf-8') as training_file:
  for node in anytree.PreOrderIter(propositions_tree,
                                   filter_=lambda node: node.number):
    prompt = node.number
    completion = ' %s%s%s' % (node.german_content,
                              language_separator,
                              node.english_content)

    prompt += prompt_separator
    completion += stop_sequence

    data = {
        'prompt': prompt,
        'completion': completion
    }
    json_line = '%s\n' % json.dumps(data, ensure_ascii=False)
    training_file.write(json_line)

In [None]:
!openai tools fine_tunes.prepare_data --file $training_filename

In [None]:
model = 'davinci'  # @param ["ada", "babbage", "curie", "davinci"]
batch_size = 1  # @param{type:"integer"}
n_epochs = 4  # @param{type:"integer"}
learning_rate_multiplier = 0.02  # @param{type:"number"}
prompt_loss_weight = 0.1  # @param{type:"number"}

!openai api fine_tunes.create --training_file $training_filename --model $model --batch_size $batch_size --n_epochs $n_epochs --learning_rate_multiplier $learning_rate_multiplier --prompt_loss_weight $prompt_loss_weight

# Inference

Documentation: https://beta.openai.com/docs/api-reference

In [202]:
def complete(model, prompt, prompt_separator, num_completions, max_tokens,
             top_p, stop_sequence, presence_penalty, frequency_penalty):
  return openai.Completion.create(
      model=model,
      prompt='%s%s' % (prompt, prompt_separator),
      n=num_completions,
      max_tokens=max_tokens,
      top_p=top_p,
      stop=stop_sequence,
      presence_penalty=presence_penalty,
      frequency_penalty=frequency_penalty)

In [314]:
openai.organization = ""  # @param {type:"string"}
openai.api_key = openai_api_key
fine_tuned_model = ''  # @param {type:"string"}
prompt = '8'  # @param {type:"string"}
max_tokens = 1024  # @param {type:"integer"}
top_p = 0.7  # @param {type:"slider", min:0, max:1, step:0.1}
presence_penalty = 1  # @param {type:"slider", min:-2.0, max:2.0, step:0.1}
frequency_penalty = 1  # @param {type:"slider", min:-2.0, max:2.0, step:0.1}

In [None]:
# Sample a single proposition.
completions = complete(model=fine_tuned_model,
                       prompt=prompt,
                       prompt_separator=prompt_separator,
                       num_completions=1,
                       max_tokens=max_tokens,
                       top_p=top_p,
                       stop_sequence=stop_sequence,
                       presence_penalty=presence_penalty,
                       frequency_penalty=frequency_penalty)
proposition = completions.choices[0].text[1:]
number = prompt
german_content, english_content = proposition.split(language_separator)

print('%s %s' % (number, german_content))
print('%s %s' % (number, english_content))

In [316]:
def plagiarism(new_german, new_english, old_propositions):
  for (_, old_german), (_, old_english) in old_propositions:
    # Any partial match is considered plagiarism.
    if new_german in old_german or new_english in old_english:
      return True
  return False

In [None]:
# Collect many propositions.
num_completions = 12  # @param {type:"integer"}

completions = complete(model=fine_tuned_model,
                       prompt=prompt,
                       prompt_separator=prompt_separator,
                       num_completions=num_completions,
                       max_tokens=max_tokens,
                       top_p=top_p,
                       stop_sequence=stop_sequence,
                       presence_penalty=presence_penalty,
                       frequency_penalty=frequency_penalty)

new_propositions = []
for choice in completions.choices:
  if choice.finish_reason != 'stop':
    # Incomplete proposition.
    continue

  german_content, english_content = choice.text[1:].split(language_separator)

  if plagiarism(new_german=german_content,
                new_english=english_content,
                old_propositions=propositions):
    # Proposition already exists in the original.
    continue

  if plagiarism(new_german=german_content,
                new_english=english_content,
                old_propositions=new_propositions):
    # Proposition already exists in the current set.
    continue

  number = prompt

  new_propositions.append((
      (number, german_content),
      (number, english_content)))

print('%d new propositions' % len(new_propositions))

In [318]:
insert_query = """INSERT
  tractatus.propositions (id,
    number,
    german,
    english,
    model,
    prompt,
    top_p,
    presence_penalty,
    frequency_penalty,
    completion_id,
    timestamp)
"""

insert_values = []
for german_proposition, english_proposition in new_propositions:
  german_number, german_content = german_proposition
  english_number, english_content = english_proposition
  assert german_number == english_number

  insert_value = (
      "REGEXP_EXTRACT(GENERATE_UUID(), r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-([0-9a-f]{12})$')",
      "'%s'" % german_number,
      "'%s'" % german_content,
      "'%s'" % english_content,
      "'%s'" % completions.model,
      "'%s'" % prompt,
      "%s" % top_p,
      "%s" % presence_penalty,
      "%s" % frequency_penalty,
      "'%s'" % completions.id,
      'CURRENT_TIMESTAMP()')
  insert_values.append('  (%s)' % ', '.join(insert_value))

insert_query += 'VALUES\n%s' % ',\n'.join(insert_values)

print(insert_query)