Once you have a scraped tumblr corpus you want to fine-tune GPT-2 on, this preps the data for training.

To get a corpus, use a tumblr scraper such as https://github.com/bbolli/tumblr-utils

After prepping the data, use `gpt-2/train.py` to finetune.

In [1]:
import numpy as np, pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
%config Completer.use_jedi = False

In [24]:
import os

import bs4
from bs4 import BeautifulSoup

In [25]:
def lprint(s, prefix=""):
    print(f"{prefix}{s}", end=f"\n\n{prefix}---------\n\n")

In [26]:
POSTS_DIR = ""  # wherever the posts are, as .html files

In [27]:
Q_CHAR = "会"
A_CHAR = "域"
T_CHAR = "职"

UNAME_CHAR = "友"
ORIG_POST_CHAR = "翰"

EOT_FULL = "<|endoftext|>"

In [28]:
from reblogs_v5 import *

In [29]:
# a lot of hardcoded magic numbers and stuff here -- i tuned a lot of stuff here "to my taste"
# eg excluding reblog chains that have too much writing by others and not much by me

def screen_for_inclusion(processed, post_metadata, 
                         skip_autoresponder_mention=False,
                         is_v5_extension=False,
                         has_v5_extension_permissiveness=False,
                         suppress_logs=False):
    body, _, tags = processed.partition(T_CHAR)
    
    ar_uname = UNAME_CHAR + map_uname("nostalgebraist-autoresponder", uname_config="frank_v5_train") + Q_CHAR
    nost_uname = UNAME_CHAR + map_uname("nostalgebraist", uname_config="frank_v5_train") + Q_CHAR
    nost_diagnostic = nost_uname[:-1]  # TODO: verify -- i had to fill this back in after thinking i'd removed it
    
    if skip_autoresponder_mention:
        if "Frank" in processed:
            f_ix=processed.index('Frank')
            print(f"rejecting for Frank mention: {processed[f_ix-10:f_ix+10]}")
            return False, 0, 0
        if ar_uname in processed:
            print("rejecting for AR post")
            return False, 0, 0
    
    if post_metadata["is_quotes"] == True:
        return False, 0, 0
    
    if ORIG_POST_CHAR in body:
        other_body = ""
        me_body = body.split(ORIG_POST_CHAR)[1]
    elif A_CHAR in body:
        me_body, other_body = "", ""
        in_other = True
        for ix, char in enumerate(body):
            if char == Q_CHAR:
                if body[ix-len(nost_diagnostic):ix+1] == nost_uname:
                    in_other=False
                else:
                    in_other = True
            if char == A_CHAR:
                in_other = False
            if in_other:
                other_body += char
            else:
                me_body += char
        #other_body, me_body = body.split(A_CHAR)[1]
    else:
        return False, 0, 0
    
    base_ratio_cutoff = 2
    base_word_cutoff = 250
    if is_v5_extension or has_v5_extension_permissiveness:
        ratio_cutoff = 3
        word_cutoff = 50
    else:
        ratio_cutoff = base_ratio_cutoff
        word_cutoff = base_word_cutoff
    
    if len(me_body) < 10:
        return False, len(me_body.split()), len(other_body.split())
    if (len(other_body) / len(me_body)) > ratio_cutoff and len(me_body.split()) < word_cutoff:
        if not suppress_logs:
            print(f"rejecting other_body {len(other_body.split())} words, me_body {len(me_body.split())} words")
        return False, len(me_body.split()), len(other_body.split())
        
    if is_v5_extension and screen_for_inclusion(
        processed, post_metadata, skip_autoresponder_mention=True, is_v5_extension=False, suppress_logs=True
    )[0]:
        # print(f"skipping other_body {len(other_body.split())} words, me_body {len(me_body.split())} words")
        return False, len(me_body.split()), len(other_body.split())
    elif is_v5_extension:
        print(f"accepting other_body {len(other_body.split())} words, me_body {len(me_body.split())} words")
    
    return True, len(me_body.split()), len(other_body.split())

In [None]:
import re

def fix_p_in_h2_bug(raw_html):
    return re.sub(r"(<h2>.*)<p>(.*)</p>(.*</h2>)", lambda m: "".join(m.groups()),  raw_html)

In [30]:
from collections import Counter, defaultdict

def get_all_posts(posts_dir=POSTS_DIR, 
                  existing_posts_dir=None, 
                  limit=None, 
                  is_v5_extension=False,
                  has_v5_extension_permissiveness=False,
                  is_reward=False,
                  suppress_screener_logs=False,
                  do_image_analysis=False,
                  get_image_urls=False,
                  use_cached_images=False,
                  USE_A_CHAR_ALWAYS=True):
    posts = []
    post_fns = []
    meta_counts = Counter()
    all_meta_counts = Counter()
    all_image_urls = set()
    image_urls = set()
    reply_urls_to_fns = defaultdict(set)
            
    # the next line refers to a (tiny) feature i haven't copied to the public github yet
    # to include OCR'd images in training corpus using a cache to avoid repeat calls to rekognition
    # so it's commented out
    # 
    # user_defined_image_analysis = cached_image_analysis_fn if use_cached_images else IMAGE_ANALYSIS_FN
    
    user_defined_image_analysis = IMAGE_ANALYSIS_FN
        
    all_fns = os.listdir(posts_dir)
    if existing_posts_dir is not None:
        all_existing_fns = {fn for fn in os.listdir(existing_posts_dir) if fn.endswith(".html")}
        all_fns = [fn for fn in all_fns if fn not in all_existing_fns]
    
    for ix, fn in enumerate(sorted(all_fns)):
        if not fn.endswith(".html"):
            continue
    
        with open(os.path.join(posts_dir, fn), "r") as f:
            raw_html = f.read()
            fixed_html = fix_p_in_h2_bug(raw_html)
            soup = BeautifulSoup(fixed_html)

        try:
            # print(os.path.join(posts_dir, fn))
            uname_config = "frank_v5_operate" if is_reward else "frank_v5_train"
            processed, post_metadata = process_post(soup,
                                                    uname_config=uname_config,
                                                    do_image_analysis=do_image_analysis,
                                                    get_image_urls=get_image_urls,
                                                    user_defined_image_analysis=user_defined_image_analysis,
                                                    debug=False)
        except Exception as e:
            print(f"hit {e} on {fn}")
            continue
            
        for key in sorted(post_metadata.keys()):
            if key not in {"image_urls", "reply_post_url"}:
                all_meta_counts[key] += post_metadata[key]
            if key == "image_urls":
                all_image_urls.update(post_metadata[key])
                
        if USE_A_CHAR_ALWAYS:
            processed = processed.replace(nost_uname, A_CHAR)
           
        passed_screen, words_me, words_other = screen_for_inclusion(
            processed, post_metadata,
            is_v5_extension=is_v5_extension,
            has_v5_extension_permissiveness=has_v5_extension_permissiveness,
            suppress_logs=suppress_screener_logs or is_reward)
        
        all_meta_counts["words_me"] += words_me
        all_meta_counts["words_other"] += words_other
        
        if passed_screen or is_reward:
            for key in sorted(post_metadata.keys()):
                if key not in {"image_urls", "reply_post_url"}:
                    meta_counts[key] += post_metadata[key]
                if key == "image_urls":
                    image_urls.update(post_metadata[key])
                
            meta_counts["words_me"] += words_me
            meta_counts["words_other"] += words_other
                
            posts.append(processed)
            post_fns.append(fn)
            
            if post_metadata["reply_post_url"] is not None:
                reply_urls_to_fns[post_metadata["reply_post_url"]].add(fn)
         
        if ix % 500 == 0:
            print(f"{ix}/{len(all_fns)}\n")
            for k in meta_counts.keys():
                print(f"incl_meta_counts[{k}]:\t{meta_counts[k]}\nall__meta_counts[{k}]:\t{all_meta_counts[k]}\n")
            if get_image_urls:
                print(f"n_images: {len(image_urls)}")
            print()
            
        if limit is not None:
            if ix >= limit:
                break
     
    if get_image_urls:
        return posts, meta_counts, post_fns, image_urls, reply_urls_to_fns
    else:
        return posts, meta_counts, post_fns, reply_urls_to_fns

part 1: making a train corpus for the generator

In [1]:
posts, meta_counts, post_fns = get_all_posts()

In [2]:
total_me = meta_counts["words_me"]
total_other = meta_counts["words_other"]

print(f"{total_me//1000}K words from me")
print(f"{total_other//1000}K words from others")
print()
print(f"{total_me / (total_me + total_other):.1%} me")

In [3]:
# review examples
from textwrap import fill

subset_review = [p for p in posts if not p.startswith("翰")]

for p in np.random.choice(subset_review, 10):
    print(p)
    print("\n\n" + 20*"~~~~~" + "\n\n")

In [45]:
posts_string = "".join(posts)

In [46]:
TRAIN_DATA_PATH = ""  # fill in
with open(TRAIN_DATA_PATH, "w", encoding="utf-8") as f:
    f.write(posts_string)

part 2: making a train corpus for the selector

In [1]:
AR_POSTS_DIR = ""  # wherever the scraped _bot_ posts are, as .html files

In [None]:
ar_posts, meta_counts, post_fns, image_urls, reply_urls_to_fns = get_all_posts(
    posts_dir=AR_POSTS_DIR, is_reward=True,
)

In [None]:
# this cell assumes you have a "reward" file and are just adding to it
# make it from scratch as an empty dict if you aren't

import pickle

with open("reward/reward.pkl.gz", "rb") as f:
    data = pickle.load(f)
ids_to_reward_data = data["ids_to_reward_data"]

In [None]:
from tqdm.notebook import tqdm
from reward_data import get_prompt_and_continuation_from_processed

def post_from_id(id_: int, post_list: list):
    ix = [i for i, _ in enumerate(post_ids) if _ == id_][0]
    return post_list[ix]

n_prompt_same = 0
n_cont_same = 0

new_ids_to_reward_data = {}

for id_ in tqdm(set(ids_to_reward_data.keys()).intersection(post_ids)):
    processed = post_from_id(id_, ar_posts)
    prompt, continuation = get_prompt_and_continuation_from_processed(processed)
    
    new_row = {"note_count": ids_to_reward_data[id_]["note_count"]}
    new_row["prompt"] = prompt
    new_row["continuation"] = continuation
    new_ids_to_reward_data[id_] = new_row
    
for id_ in tqdm(set(post_ids).difference(ids_to_reward_data.keys())):
    processed = post_from_id(id_, ar_posts)
    try:
        prompt, continuation = get_prompt_and_continuation_from_processed(processed)
    except Exception as e:
        print(f"skipping {id_}: {e}")
    
    new_row = {"note_count": None}
    new_row["prompt"] = prompt
    new_row["continuation"] = continuation
    new_ids_to_reward_data[id_] = new_row

save

In [None]:
new_data = {"ids_to_reward_data": new_ids_to_reward_data_bad_parses_removed, "offset": data["offset"]}

with open("reward/reward.pkl.gz", "wb") as f:
    pickle.dump(new_data, f)