In [1]:
from collections import defaultdict
import os
import sys
sys.path.append('..')

import pandas as pd
from tqdm import tqdm
import torch
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.iterators import BasicIterator
from allennlp.nn.util import move_to_device

from adat.utils import load_weights, calculate_wer
from adat.masker import SimpleMasker, MASK_TOKEN
from adat.models import get_basic_classification_model, get_basic_seq2seq_model
from adat.dataset import InsuranceReader, OneLangSeq2SeqReader

In [2]:
! nvidia-smi

Fri Jan 10 21:51:31 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.87.00    Driver Version: 418.87.00    CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  GeForce GTX 108...  Off  | 00000000:02:00.0 Off |                  N/A |
|  0%   26C    P8     9W / 280W |     10MiB / 11178MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   1  GeForce GTX 108...  Off  | 00000000:03:00.0 Off |                  N/A |
|  0%   27C    P8     8W / 280W |     10MiB / 11178MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
|   2  GeForce GTX 108...  Off  | 00000000:83:00.0 Off |                  N/A |
|  0%   

In [3]:
# os.environ['CUDA_VISIBLE_DEVICES'] = '3'
cuda_device = 1

In [4]:
max_length = 20
min_length = 2

data = pd.read_csv('../data/full.csv')

data = data[['treatments', 'target']]
treatment_len = data.treatments.apply(lambda x: len(x.split()))
data = data[(treatment_len <= max_length) & (treatment_len >= min_length)]

In [5]:
data['seq_len'] = data.treatments.apply(lambda x: len(x.split()))

negative = data[data.target == 0]
positive = data[data.target == 1]

In [6]:
negative_examples = defaultdict(list)
positive_examples = defaultdict(list)

for row in negative.itertuples():
    negative_examples[row.seq_len].append(row.treatments.strip())
    
for row in positive.itertuples():
    positive_examples[row.seq_len].append(row.treatments.strip())

# Models

In [7]:
seq2seq_reader = OneLangSeq2SeqReader(masker=None)
seq2seq_vocab = Vocabulary.from_files('vocab_seq2seq_masked')
seq2seq_model = get_basic_seq2seq_model(seq2seq_vocab)
load_weights(seq2seq_model, 'model_seq2seq_masked.th')

In [8]:
class_reader = InsuranceReader()
class_vocab = Vocabulary.from_files('vocab_classification')
class_model = get_basic_classification_model(class_vocab)
load_weights(class_model, 'model_classification.th')

# MCMC

In [9]:
from adat.mcmc import MCMCSampler, NormalProposal
from pprint import pprint

In [16]:
example = positive_examples[6][3]
print(example)

a_1737 a_1690 a_1690 a_1737 a_2001 a_1667


In [17]:
proposal = NormalProposal()
sampler = MCMCSampler(
    proposal_distribution=proposal, 
    classification_model=class_model, 
    classification_reader=class_reader, 
    generation_model=seq2seq_model, 
    generation_reader=seq2seq_reader,
    initial_sequence=example,
    l2_norm=False
)

In [18]:
history = sampler.sample(num_steps=1000)



In [23]:
for ex in history:
    if ex['prob_diff'] > 0.1:
        pprint(ex)
        print()

{'bleu': 0.6900655593423543,
 'bleu_diff': -0.30993444065764575,
 'generated_sequence': 'a_1690 a_1690 a_1690 a_1737 a_2001 a_1667 a_1667',
 'l2_norm': 2.3845975,
 'prob': 0.22387059,
 'prob_diff': 0.113484874}

{'bleu': 0.816496580927726,
 'bleu_diff': -0.18350341907227397,
 'generated_sequence': 'a_1690 a_1690 a_1690 a_1737 a_2001 a_1667',
 'l2_norm': 1.9094037,
 'prob': 0.25439045,
 'prob_diff': 0.14400473}

{'bleu': 0.816496580927726,
 'bleu_diff': -0.18350341907227397,
 'generated_sequence': 'a_1690 a_1690 a_1690 a_1737 a_2001 a_1667',
 'l2_norm': 1.6861911,
 'prob': 0.25439045,
 'prob_diff': 0.14400473}

{'bleu': 0.816496580927726,
 'bleu_diff': 0.0,
 'generated_sequence': 'a_1731 a_1690 a_1690 a_1737 a_2001 a_1667',
 'l2_norm': 2.0001087,
 'prob': 0.120621905,
 'prob_diff': 0.101751}

{'bleu': 0.4364357804719847,
 'bleu_diff': -0.40871847425653185,
 'generated_sequence': 'a_375 a_1690 a_1690 a_1690 a_1737 a_1667 a_1667',
 'l2_norm': 1.906865,
 'prob': 0.25257355,
 'prob_diff': 0