In [1]:
import json
from collections import Counter
import toolz
from IPython.display import Image, HTML

In [2]:
images = json.load(open('/Users/kcarnold/src/ImageCaptioning.pytorch/data/dataset_coco.json'))['images']

In [3]:
Counter(img['split'] for img in images)

Counter({'restval': 30504, 'test': 5000, 'train': 82783, 'val': 5000})

In [4]:
images_by_split = toolz.groupby('split', images)

In [5]:
valid_images = images_by_split['val']#[img for img in images if img['split'] == 'val']
len(valid_images)

5000

In [6]:
perm = np.random.RandomState(0).permutation(len(valid_images))
examples = [valid_images[idx] for idx in perm]

In [7]:
def coco_url(cocoid):
    return f'http://images.cocodataset.org/train2017/{cocoid:012d}.jpg'

In [8]:
coco_url(examples[0]['cocoid'])

'http://images.cocodataset.org/train2017/000000133707.jpg'

In [9]:
import pandas as pd

In [86]:
df = pd.DataFrame([dict(
    coco_id=img['cocoid'],
    url=coco_url(img['cocoid']),
    cap0=img['sentences'][0]['raw'],
    cap1=img['sentences'][1]['raw'],
    cap2=img['sentences'][2]['raw'],
    cap3=img['sentences'][3]['raw'],
    cap4=img['sentences'][4]['raw'],
) for img in examples[:50]])
df.to_clipboard(index=False)

In [10]:
from textrec import lang_model

In [11]:
from importlib import reload
reload(lang_model)

<module 'textrec.lang_model' from '/Users/kcarnold/code/textrec/src/textrec/lang_model.py'>

In [27]:
lang_model.dump_kenlm(
    'coco_train',
    [' '.join(sentence['tokens']) for img in images_by_split['train'] for sentence in img['sentences']])

Running /Users/kcarnold/code/kenlm/build/bin/lmplz -o 5 --prune 2 --verbose_header < /Users/kcarnold/code/textrec/models/coco_train.txt > /Users/kcarnold/code/textrec/models/coco_train.arpa
Running /Users/kcarnold/code/kenlm/build/bin/build_binary /Users/kcarnold/code/textrec/models/coco_train.arpa /Users/kcarnold/code/textrec/models/coco_train.kenlm
Done


In [12]:
model = lang_model.Model.get_or_load_model('coco_train')

Loading model coco_train ... reading raw ARPA data ...  Encoding bigrams to indices... Loaded.


In [30]:
model.score_seq(model.bos_state, 'a person standing inside of a phone booth')

(-398.472120447575, State([], []))

In [31]:
model.score_seq(model.bos_state, 'a group of people')

(-174.28278306181613, State([], []))

In [71]:
examples = [
    dict(
        name="car",
        bonuses='blue retro antique car surfboard'.split(),
        avoids='two parking lot hand train'.split(),
        good_recs=[
            "a: blue",
            "a: retro",
            "a blue: antique"
        ],
        bad_recs = [
            "a: car",
        ]
    ),
    dict(
        name="kitchen",
        bonuses="white narrow galley kitchen sink stove refrigerator pantry doorway".split(),
        avoids="street bathroom".split(),
        good_recs=[
            'a: white',
            'a: narrow',
            'a narrow: white'
        ],
        bad_recs=[
            'a: street',
            'a: bathroom'
        ]
    )
]

def get_bonuses(context, to_bonus, to_avoid, amt):
    bonus_words = {}
    for word in to_bonus:
        bonus_words[word] = amt if word not in context else -amt
    for word in to_avoid:
        bonus_words[word] = -amt
    return bonus_words

def eval_bonus_amt(amt):
    num_good_recs = 0
    num_bad_recs = 0
    for example in examples:
        for ref_rec in example['good_recs']:
            context, rest = ref_rec.split(':')
            context = context.split()
            rest = rest.strip()
            bonus_words = get_bonuses(context, example['bonuses'], example['avoids'], amt)
            ents = lang_model.beam_search_phrases(model, context, length_after_first=10, beam_width=3, bonus_words=bonus_words)
            generated_recs = [' '.join(ent.words) for ent in ents]
            print(' '.join(context), '::', ','.join(generated_recs))
#eval_bonus_amt(5.)

for example in examples:
    print(f'\n{example["name"]}')
    context = ['<s>']
    while len(context) < 15:
        bonus_words = get_bonuses(context, example['bonuses'], example['avoids'], 10.)
        ents = lang_model.beam_search_phrases(model, context, length_after_first=10, beam_width=5, bonus_words=bonus_words)
        generated_recs = [' '.join(ent.words) for ent in ents]
        print('{}: {}'.format(' '.join(context[1:]), ', '.join(generated_recs)))
        context.append(ents[0].words[0])


car
: a blue car parked, a blue car with, a blue surfboard, a blue car driving, a blue car and
a: blue car parked, blue surfboard, blue car driving, blue car with a, blue car and a
a blue: car parked in, car driving down, train car sitting, car driving on, car with a cat
a blue car: parked in an antique, parked in front of, driving down a street, driving down a road, parked in an airport
a blue car parked: in an antique, in front of a, in an airport, in front of the, in front of an
a blue car parked in: an antique style, an antique store, an antique truck, an antique motorcycle, front of a building
a blue car parked in an: antique motorcycle, antique style pizza, antique store with, antique style kitchen, antique style bathroom
a blue car parked in an antique: style kitchen with, style pizza with, style bathroom with, style on a table, store with a large
a blue car parked in an antique style: building with a clock, pizza with cheese, kitchen with a large, bathroom with a toilet, kitch

In [46]:
def get_bonuses(context):
    amt = 5.
    return {word: (amt if word not in context else -amt) for word in bonus_words}
context = 'a blue and'.split()
ents = lang_model.beam_search_phrases(model, context, length_after_first=1, beam_width=50, bonus_words=get_bonuses(context))
[' '.join(ent.words) for ent in ents]

['white',
 'yellow',
 'red',
 'silver',
 'black',
 'green',
 'gray',
 'orange',
 'grey',
 'gold',
 'pink',
 'a',
 'purple',
 'brown',
 'car',
 'the',
 'is',
 'surfboard',
 'antique',
 'some',
 'one',
 'two',
 'other',
 'an',
 'holding',
 'people',
 'another',
 'on',
 'in',
 'looking',
 'several',
 'has',
 'three',
 'various',
 'trees',
 'sitting',
 'small',
 'many',
 'and',
 'with',
 'eating',
 'standing',
 'large',
 'water',
 'wearing',
 'his',
 'smiling',
 'vegetables',
 'lots',
 'broccoli']

In [17]:
example_captions = [
    'plate with pancakes topped with banana slices , bacon , and blackberries , in front of a mug and maple syrup .',
    'a plate with blueberries , bacon , and pancakes topped with bananas .',
]
' '.join(set(tok for cap in example_captions for tok in cap.split()))

'front slices syrup . bacon and blackberries of a maple bananas plate with banana in mug blueberries , topped pancakes'

In [18]:
import nltk

In [94]:
def tokenize(caption):
    # FIXME: Karpathy seems to have killed commas and periods.
    return ["<s>"] + nltk.word_tokenize(caption.replace(',', ' ').replace('.', ' '))

In [95]:
from collections import defaultdict

In [96]:
dataset = defaultdict(dict)
for cap in example_captions:
    toks = tokenize(cap)
    for idx in range(1, len(toks)):
        context = ' '.join(toks[:idx])
        tok = toks[idx]
        dataset[context][tok] = 1.
dataset

defaultdict(dict,
            {'<s>': {'a': 1.0, 'plate': 1.0},
             '<s> a': {'plate': 1.0},
             '<s> a plate': {'with': 1.0},
             '<s> a plate with': {'blueberries': 1.0},
             '<s> a plate with blueberries': {'bacon': 1.0},
             '<s> a plate with blueberries bacon': {'and': 1.0},
             '<s> a plate with blueberries bacon and': {'pancakes': 1.0},
             '<s> a plate with blueberries bacon and pancakes': {'topped': 1.0},
             '<s> a plate with blueberries bacon and pancakes topped': {'with': 1.0},
             '<s> a plate with blueberries bacon and pancakes topped with': {'bananas': 1.0},
             '<s> plate': {'with': 1.0},
             '<s> plate with': {'pancakes': 1.0},
             '<s> plate with pancakes': {'topped': 1.0},
             '<s> plate with pancakes topped': {'with': 1.0},
             '<s> plate with pancakes topped with': {'banana': 1.0},
             '<s> plate with pancakes topped with banana': {

1. Pop a context
2. Generate 10 possible suggestions
3. Have annotator pick all the ones that are good suggestions.
4. Record the raw results (for later playing with ranking learning)
5. Record all good suggestions as 1, bad suggestions as 0.


In [97]:
# fake the first step
context = '<s>'

In [98]:
# Generate 10 possible suggestions.

In [99]:
assert model.id2str[:3] == ['<unk>', '<s>', '</s>']

In [100]:
from scipy.special import logsumexp

In [101]:
def next_word_distribution_ngram(model, toks):
    state = model.get_state(toks, bos=True)[0]
    logprobs = model.eval_logprobs_for_words(state, range(len(model.id2str)))
    logprobs[:3] = -1e99
    logprobs -= logsumexp(logprobs)
    return logprobs

In [102]:
dist = next_word_distribution_ngram(model, context.split())

In [103]:
logsumexp(dist)

0.0

In [104]:
scores = dist # for now

In [105]:
recs = [model.id2str[idx] for idx in np.argsort(scores)[-10:]]
recs

['this',
 'some',
 'several',
 'people',
 'three',
 'there',
 'an',
 'the',
 'two',
 'a']

"Which of these is a good suggestion?"

In [106]:
is_good = [0, 1, 0, 0, 0, 1, 0, 0, 0, 1]
assert len(is_good) == len(recs)
{word: label for word, label in zip(recs, is_good)}

{'a': 1,
 'an': 0,
 'people': 0,
 'several': 0,
 'some': 1,
 'the': 0,
 'there': 1,
 'this': 0,
 'three': 0,
 'two': 0}

In [107]:
for word, label in zip(recs, is_good):
    dataset[context][word] = label

Ok now learn us a classifier.

In [258]:
words_for_unigram_feats = sorted(word for words in dataset.values() for word in words.keys())
word2unigram_feat_idx = {word: idx for idx, word in enumerate(words_for_unigram_feats)}

In [259]:
one_hot_words = np.diag(np.ones(len(words_for_unigram_feats)))
one_hot_words

array([[1., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 1.]])

In [260]:
no_hot = np.zeros(len(words_for_unigram_feats))

In [261]:
def featurize(ngram_dist, word):
    word_idx = model.model.vocab_index(word)
    assert word_idx != 0, word
    if word in word2unigram_feat_idx:
        unigram_feat = one_hot_words[word2unigram_feat_idx[word]]
    else:
        unigram_feat = no_hot
    return np.r_[
        ngram_dist[word_idx],
        unigram_feat
    ]
    

In [262]:
X = []
y = []
examples = []
for context, words in dataset.items():
    ngram_dist = next_word_distribution_ngram(model, context.split())
    for word, label in words.items():
        examples.append((context, word, label))
        X.append(featurize(ngram_dist, word))
        y.append(int(label))

In [263]:
X = np.array(X)
y = np.array(y)

In [264]:
X.shape

(57, 58)

In [265]:
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC

In [266]:
clf = LogisticRegression().fit(X, y)

In [267]:
np.argsort(clf.predict_log_proba(X)[:,1])

array([43, 42, 41, 40, 39, 45, 44, 19,  2,  4,  5, 46,  6, 29, 31, 32, 33,
       24,  8, 47, 35,  9, 36, 37, 26, 10, 48, 38, 17, 50, 20, 51, 15, 53,
        0,  3, 30,  7, 12, 34, 21, 22, 23, 13, 56, 28, 54,  1, 18, 52, 25,
       27, 16, 11, 49, 14, 55])

In [276]:
examples

[('<s>', 'plate', 1.0),
 ('<s>', 'a', 1),
 ('<s>', 'this', 0),
 ('<s>', 'some', 1),
 ('<s>', 'several', 0),
 ('<s>', 'people', 0),
 ('<s>', 'three', 0),
 ('<s>', 'there', 1),
 ('<s>', 'an', 0),
 ('<s>', 'the', 0),
 ('<s>', 'two', 0),
 ('<s> plate', 'with', 1.0),
 ('<s> plate with', 'pancakes', 1.0),
 ('<s> plate with pancakes', 'topped', 1.0),
 ('<s> plate with pancakes topped', 'with', 1.0),
 ('<s> plate with pancakes topped with', 'banana', 1.0),
 ('<s> plate with pancakes topped with banana', 'slices', 1.0),
 ('<s> plate with pancakes topped with banana slices', 'bacon', 1.0),
 ('<s> plate with pancakes topped with banana slices bacon', 'and', 1.0),
 ('<s> plate with pancakes topped with banana slices bacon and',
  'blackberries',
  1.0),
 ('<s> plate with pancakes topped with banana slices bacon and blackberries',
  'in',
  1.0),
 ('<s> plate with pancakes topped with banana slices bacon and blackberries in',
  'front',
  1.0),
 ('<s> plate with pancakes topped with banana slices b

In [244]:
clf.predict_proba(X[38:39])

array([[0.51957722, 0.48042278]])

In [272]:
word2unigram_feat_idx['with']

54

In [273]:
clf.intercept_

array([0.77776883])

In [275]:
clf.coef_[0,54]

0.0

In [269]:
context = '<s> a'

In [270]:
ngram_dist = next_word_distribution_ngram(model, context.split())

In [271]:
candidates = [model.id2str[id] for id in model.filtered_bigrams[model.model.vocab_index('a')]]

In [250]:
recs = [candidates[i] for i in np.argsort(clf.predict_log_proba([featurize(ngram_dist, word) for word in candidates])[:,1])[-10:]]
' '.join(recs)

'sitting fence on wave is in of a and with'

In [220]:
ngram_dist[model.model.vocab_index('')]

-8.889380755798495

In [249]:
clf.predict_proba([featurize(ngram_dist, 'man')])

array([[0.51957722, 0.48042278]])

In [255]:
is_good = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
assert len(is_good) == len(recs)

In [256]:
for word, label in zip(recs, is_good):
    dataset[context][word] = label

In [257]:
dataset

defaultdict(dict,
            {'<s>': {'a': 1,
              'an': 0,
              'people': 0,
              'plate': 1.0,
              'several': 0,
              'some': 1,
              'the': 0,
              'there': 1,
              'this': 0,
              'three': 0,
              'two': 0},
             '<s> a': {'a': 0,
              'and': 0,
              'cat': 0.0,
              'couple': 0.0,
              'fence': 0,
              'group': 0.0,
              'in': 0,
              'is': 0,
              'large': 1.0,
              'man': 0.0,
              'of': 0,
              'on': 0,
              'person': 0.0,
              'plate': 1.0,
              'sitting': 0,
              'small': 0.0,
              'wave': 0,
              'white': 1.0,
              'with': 0,
              'woman': 0.0,
              'young': 0.0},
             '<s> a plate': {'with': 1.0},
             '<s> a plate with': {'blueberries': 1.0},
             '<s> a plate with blueberri