In [263]:
from onmt.utils.parse import ArgumentParser
import onmt.opts as opts
from onmt.translate.translator import build_translator
from onmt.utils.misc import split_corpus
from itertools import repeat

In [264]:
def _get_parser():
    parser = ArgumentParser(description='translate.py')

    opts.config_opts(parser)
    opts.translate_opts(parser)
    return parser

In [267]:
parser = _get_parser()
opt = parser.parse_args("--model bilstm_0.5M_step_110000.pt --replace_unk -tgt ../data/python_data/test_5000.labels -src ../data/python_data/test_5000.data --beam_size 50 --attn_debug")
ArgumentParser.validate_translate_opts(opt)

In [268]:
fields, model, model_opt =  onmt.model_builder.load_test_model(opt)

In [269]:
src_reader = onmt.inputters.str2reader[opt.data_type].from_opt(opt)
tgt_reader = onmt.inputters.str2reader["text"].from_opt(opt)

In [270]:
data = onmt.inputters.Dataset(fields,
            readers=([src_reader, tgt_reader]
                     if opt.tgt else [src_reader]),
            data=[("src", opt.src), ("tgt", opt.tgt)] if opt.tgt else [("src", opt.src)],
            dirs=[opt.src_dir, None] if opt.tgt else [opt.src_dir],
            sort_key=onmt.inputters.str2sortkey[opt.data_type],
            filter_pred=None
        )

In [271]:
data_iter = onmt.inputters.OrderedIterator(
        dataset=data, device='cpu',
        batch_size=1, train=False, sort=False,
        sort_within_batch=True, shuffle=False)

In [272]:
translator = onmt.translate.Translator(model, fields,
                                       beam_size=50,
                                       n_best=1,
                                       global_scorer=onmt.translate.GNMTGlobalScorer(0, 0, "none", "none"),
                                       gpu=-1, 
                                      src_reader=src_reader, 
                                      tgt_reader=tgt_reader, 
                                      replace_unk=True)

In [273]:
builder = onmt.translate.TranslationBuilder(
        data, translator.fields,
        1, True, True)

In [274]:
all_attns = []

for j, batch in enumerate(data_iter):
    batch_data = translator.translate_batch(batch, data, attn_debug=True)
    translations = builder.from_batch(batch_data)
    print("src:", " ".join(translations[0].src_raw))
    print("pred:", " ".join(translations[0].pred_sents[0]))
    print('tgt:',' '.join(translations[0].gold_sent))
    print("idx:",str(j))
    print("-----")
    
    attns = [t.numpy() for t in translations[0].attns]
    all_attns.append((translations[0].src_raw, translations[0].pred_sents[0], ' '.join(translations[0].gold_sent), attns[0][:-1]))
    
    if j==49:
        break

src: filename return [ line . strip ( ) for line in open ( filename , ' r ' ) if line . strip ( ) and not line . strip ( ) . startswith ( ' ' ) ]
pred: read comments
tgt: requirements from file
idx: 0
-----
src: fname , url , url image readme = open ( path . join ( path . dirname ( file ) , fname ) ) . read ( ) if hasattr ( readme , ' decode ' ) : # in python 3 , turn bytes into str . readme = readme . decode ( ' utf8 ' ) readme = re . sub ( r ' `<([^>]*)>` ' , r ' `\\1 < ' + url + r " /blob/master/\\1>` " , readme ) readme = re . sub ( r " \\ . \\ . image:: / " , " . . image:: " + url image + " / " , readme ) return readme
pred: read
tgt: read
idx: 1
-----
src: self , * args , ** kwargs response = self . session . post ( * args , ** kwargs ) browser . add soup ( response , self . soup config ) return response
pred: post
tgt: post
idx: 2
-----
src: self if self . session is not none : self . session . cookies . clear ( ) self . session . close ( ) self . session = none
pred: close sess

src: computer , name , values if values == ' none ' : return none else : type , key = values if type == ' attr() ' : return computer [ ' element ' ] . get ( key ) or none elif type == ' string ' : return key
pred: color
tgt: lang
idx: 30
-----
src: computer , name , value result = [ ] for function , args in value : if function == ' translate ' : args = length or percentage tuple ( computer , name , args ) result . append ( ( function , args ) ) return tuple ( result )
pred: background
tgt: transform
idx: 31
-----
src: tokens parts = [ ] for split part in split on comma ( tokens ) : if not split part : # happens when there ' s a comma at the beginning , at the end , or # when two commas are next to each other . return for part in split part : parts . append ( part ) return parts
pred: split parts
tgt: split on optional comma
idx: 32
-----
src: token if token . type == ' dimension ' : factor = angle to radians . get ( token . unit ) if factor is not none : return token . value * factor
p

In [275]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mplc
import matplotlib
%matplotlib inline
from matplotlib.backends.backend_pdf import PdfPages

In [276]:
def visualize_attn(src, pred, tgt, attn):
    '''
    src and target are lists of strings
    attn is a numpy matrix of shape (len(tgt), len(src))
    '''
    if len(pred)==0:
        return
    
    plt.figure()

    plt.matshow(attn.transpose(), cmap='gray', origin='upper', aspect=0.5/len(pred), norm=mplc.LogNorm(vmin=attn.min(), vmax=attn.max()))

    plt.tick_params(axis='both', which='major', labelsize=8)
    plt.tick_params(axis='both', which='minor', labelsize=8)
    
    plt.title('Ground truth:'+tgt)
    plt.xticks(np.arange(0, len(pred), 1), pred)
    plt.yticks(np.arange(0, len(src), 1), src)
    
    pp.savefig()
    plt.close();
    

In [277]:
from tqdm import tqdm_notebook as tqdm

In [278]:
pp = PdfPages('attn_wts.pdf')
plt.ioff()
for src, pred, tgt, attn in tqdm(all_attns):
#     print(src, pred, tgt, attn)
    visualize_attn(src, pred, tgt, attn)
#     break
pp.close()

HBox(children=(IntProgress(value=0, max=50), HTML(value='')))






<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

<Figure size 432x288 with 0 Axes>

In [279]:
pwd

'/home/ubuntu/adversarial-ml-on-code-stuff/models/onmt'

In [296]:
cd ..

/home/ubuntu/adversarial-ml-on-code-stuff/models/onmt


In [281]:
import torch

In [299]:
m = torch.load('bilstm_adv_0.5M_step_115000.pt', map_location='cpu')

In [300]:
sum(m['model'][p].numel() for p in m['model'])

21269000