In [1]:
import re
import random

import dask
import dask.array as da
import numpy as np
from nltk.corpus import stopwords
from nltk.corpus import wordnet as wn

import corpus
import utterances

### Basic Setup

In [2]:
big_dfs = corpus.get_corpus_dfs("refcocoplus")
applied = corpus.get_classifications(big_dfs, "refcocoplus", which_split="train")
region_data, all_refexp = corpus.get_region_info(big_dfs, "refcocoplus")
all_applied, word2ind, X_idx, wordlist = applied["refcocoplus"]
refexp_list = [exp.split() for k, v in all_refexp.items() for exp in v]

In [3]:
#basic spellcheck stuff
refexp_list = [re.sub(r"\st shirt\s", "tshirt", re.sub(r"\so clock\s", "o'clock", " ".join(x))) for x in refexp_list]

In [4]:
noun_syns_dict, noun_syns_key = utterances.vocab_synsets(word2ind)
verb_syns_dict, verb_syns_key = utterances.vocab_synsets(word2ind, noun=False)

In [5]:
noun_alt_dict = utterances.produce_noun_alt_dict(noun_syns_key, noun_syns_dict)

In [6]:
verb_alt_dict = utterances.produce_verb_alt_dict(verb_syns_key, verb_syns_dict, word2ind)

### Generate Utterances for Specific Refexp

In [25]:
ic, ii = random.choice(list(all_refexp.keys()))
ic, ii

(1, 226246)

In [29]:
ex_exp = random.choice(list(all_refexp[(ic, ii)].keys()))
ri = all_refexp[(ic,ii)][ex_exp]
ri, ex_exp

(169060, 'bus with crane sticking up near it')

In [30]:
split_str, patterns = utterances.generate_patterns(ex_exp, word2ind, noun_alt_dict, verb_alt_dict)
split_str

('bus', 'sticking', 'near')

In [65]:
utt_gen = utterances.generate_utterances(patterns)
utt_l = [[x for x in i if x] for i in utt_gen]
utt_l = [x for x in utt_l if x] + [list(split_str)]

### All the dask stuff that follows

In [43]:
def read_one_exp(row_seq, ind_seq):
    return all_applied[row_seq,:][:, ind_seq].prod(axis=1)

In [44]:
ids, rows = tuple(corpus.imageid2rows(region_data, X_idx, ic, ii))
len(ids), len(rows)

(8, 8)

In [71]:
read_one_exp(rows,corpus.exp2indseq(word2ind,list(split_str)))

Unnamed: 0,Array,Chunk
Bytes,64 B,16 B
Shape,"(8,)","(2,)"
Count,7906 Tasks,7 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 64 B 16 B Shape (8,) (2,) Count 7906 Tasks 7 Chunks Type float64 numpy.ndarray",8  1,

Unnamed: 0,Array,Chunk
Bytes,64 B,16 B
Shape,"(8,)","(2,)"
Count,7906 Tasks,7 Chunks
Type,float64,numpy.ndarray


In [87]:
read_them = dask.delayed(read_one_exp)
prod_list = [read_them(rows,corpus.exp2indseq(word2ind, x)) for x in utt_l]
sample = prod_list[0].compute()
arrays = [da.from_delayed(x, dtype=sample.dtype, shape=sample.shape) for x in prod_list]
stack = da.stack(arrays, axis=1)
stack = stack.compute()

In [88]:
col_sums = stack.sum(axis=0, keepdims=True)
new_stack = stack / col_sums
new_stack

Unnamed: 0,Array,Chunk
Bytes,55.12 kiB,16 B
Shape,"(8, 882)","(2, 1)"
Count,53749 Tasks,6174 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 55.12 kiB 16 B Shape (8, 882) (2, 1) Count 53749 Tasks 6174 Chunks Type float64 numpy.ndarray",882  8,

Unnamed: 0,Array,Chunk
Bytes,55.12 kiB,16 B
Shape,"(8, 882)","(2, 1)"
Count,53749 Tasks,6174 Chunks
Type,float64,numpy.ndarray


In [89]:
row_sums = new_stack.sum(axis=1, keepdims=True)
newer_stack = new_stack / row_sums
newer_stack

Unnamed: 0,Array,Chunk
Bytes,55.12 kiB,16 B
Shape,"(8, 882)","(2, 1)"
Count,68169 Tasks,6174 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 55.12 kiB 16 B Shape (8, 882) (2, 1) Count 68169 Tasks 6174 Chunks Type float64 numpy.ndarray",882  8,

Unnamed: 0,Array,Chunk
Bytes,55.12 kiB,16 B
Shape,"(8, 882)","(2, 1)"
Count,68169 Tasks,6174 Chunks
Type,float64,numpy.ndarray


In [90]:
final_col_sums = newer_stack.sum(axis=0, keepdims=True)
final_stack = newer_stack / final_col_sums
final_stack

Unnamed: 0,Array,Chunk
Bytes,55.12 kiB,16 B
Shape,"(8, 882)","(2, 1)"
Count,83163 Tasks,6174 Chunks
Type,float64,numpy.ndarray
"Array Chunk Bytes 55.12 kiB 16 B Shape (8, 882) (2, 1) Count 83163 Tasks 6174 Chunks Type float64 numpy.ndarray",882  8,

Unnamed: 0,Array,Chunk
Bytes,55.12 kiB,16 B
Shape,"(8, 882)","(2, 1)"
Count,83163 Tasks,6174 Chunks
Type,float64,numpy.ndarray


In [91]:
exp_idx = len(utt_l) - 1

In [92]:
answer = final_stack[:,exp_idx].argmax(axis=0)

In [94]:
answer_idx = answer.compute()

In [95]:
ids[answer_idx]

248133

In [96]:
ids

[163474, 167919, 169060, 248133, 257740, 258446, 1205507, 1365918]