In [1]:
import os
import numpy as np
import pandas as pd

In [None]:
valid_df = pd.read_csv('/data/users/kyle.shaffer/dialog_data/cornell_movie_dialog_no_context_valid_retok.txt', sep='\t',
                      names=['left', 'right', 'conv_id'])
print(valid_df.shape)
valid_df.head()

In [None]:
def get_contexts(input_df):
    all_convos = []
    for grp_ix, grp in input_df.groupby('conv_id'):
        left, right = grp.left.tolist(), grp.right.tolist()
        unrolled_convo = []
        for lzip, rzip in zip(left, right):
            if not(lzip in unrolled_convo):
                unrolled_convo.append(lzip)
            if not(rzip in unrolled_convo):
                unrolled_convo.append(rzip)
                
        all_convos.append(unrolled_convo)
            
    return all_convos

In [None]:
all_convos = get_contexts(valid_df)
print(len(all_convos))
all_convos[0]

In [None]:
all_convos[7]

In [None]:
from tqdm import tqdm

def batch_convos(convos, n=3):
    batched_convos = []
    pad = '<PAD>'
    for c in tqdm(convos):
        if len(c) < n:
            diff = n - len(c)
            max_len = max([len(u.split()) for u in c])
            padded_sent = [pad] * max_len
            for _ in range(diff):
                c.insert(0, padded_sent)
            batched_convos.append(c)
        elif len(c) == 3:
            batched_convos.append(c)
        else:
            for u_ix in range(len(c) - (n - 1)):
                batched_convos.append(c[u_ix: (u_ix + n)])
                
    return batched_convos

In [None]:
batched_convos = batch_convos(all_convos)
len(batched_convos)

In [None]:
batched_convos[0]

In [None]:
batched_convos[1]

In [None]:
batched_convos[7]

In [None]:
batched_convos[8]

In [None]:
batched_convos[9]

In [None]:
batched_convos[10]

In [None]:
batched_convos[28]

In [None]:
all_convos[7]

In [None]:
with open('/data/users/kyle.shaffer/dialog_data/cornell_movie_context_train.txt', mode='w') as outfile:
    for bc in batched_convos:
        if isinstance(bc[0], str):
            outfile.write(bc[0])
        else:
            outfile.write(' '.join(bc[0]))
        outfile.write('\t')
        outfile.write(bc[1])
        outfile.write('\t')
        outfile.write(bc[2])
        outfile.write('\n')

In [None]:
# Double-checking and writing out new vocab
from collections import Counter

# vocab_cnt = Counter()

for l, r in zip(valid_df.left.tolist(), valid_df.right.tolist()):
    vocab_cnt.update(l.split())
    vocab_cnt.update(r.split())
    
print('Vocab Size:', len(vocab_cnt))
vocab_cnt.most_common(10)

In [None]:
with open('/data/users/kyle.shaffer/dialog_data/cornell_movie_vocab.txt', mode='w') as outfile:
    for w, c in vocab_cnt.most_common():
        outfile.write(w)
        outfile.write('\t')
        outfile.write(str(c))
        outfile.write('\n')

## Testing Generator

In [5]:
default_movie_file = '/data/users/kyle.shaffer/dialog_data/cornell_movie/dialogs_text.txt'
data_path = "/data/users/kyle.shaffer/dialog_data/cornell_movie/cornell_movie_dialog_no_context_train.txt"

def load_movie_text(input_file:str):
    movie_lines = []
    with open(input_file, 'r') as infile:
        for line in infile:
            movie_lines.append(line.strip())

    return movie_lines

In [6]:
movie_lines = load_movie_text(default_movie_file)
print(len(movie_lines))

304444


In [7]:
import tensorflow_datasets as tfds

tgt_vocab_size = 15000
tokenizer = tfds.features.text.SubwordTextEncoder.build_from_corpus(
                            movie_lines, target_vocab_size=tgt_vocab_size)

In [8]:
print('Discovered vocab size:', tokenizer.vocab_size)

Discovered vocab size: 14912


In [3]:
class DataProcessor(object):
    def __init__(self, max_len:int, tokenizer, train_file:str, valid_file:str, batch_size:int):
        self.max_len = max_len
        self.tokenizer = tokenizer
        self.train_file = train_file
        self.valid_file = valid_file
        self.bos = self.tokenizer.vocab_size
        self.eos = self.tokenizer.vocab_size + 1
        self.vocab_size = self.tokenizer.vocab_size + 2
        self.batch_size = batch_size
        
    def pad_batch(self, encoder_batch, decoder_batch):
        max_enc_length = self.max_len # max([len(s) for s in encoder_batch])
        max_dec_length = self.max_len # max([len(s) for s in decoder_batch])

        if max_enc_length > self.max_len:
            max_enc_length = self.max_len
        if max_dec_length > self.max_len:
            meax_dec_length = self.max_len

        enc_container, dec_in_container, dec_out_container = [], [], []
        for enc_seq, dec_seq in zip(encoder_batch, decoder_batch):
            if len(enc_seq) >= max_enc_length:
                enc_seq = enc_seq[:max_enc_length]
            enc_seq.insert(0, self.bos)
            enc_seq.append(self.eos)
            enc_container.append(enc_seq)

            if len(dec_seq) >= max_dec_length:
                dec_seq = dec_seq[:max_dec_length]
            dec_out_seq = dec_seq[:]
            dec_out_seq.append(self.eos)
            dec_seq.insert(0, self.bos)

            dec_in_container.append(dec_seq)
            dec_out_container.append(dec_out_seq)

        enc_padded = tf.keras.preprocessing.sequence.pad_sequences(enc_container, padding='post', maxlen=self.max_len)
        dec_in_padded = tf.keras.preprocessing.sequence.pad_sequences(dec_in_container, padding='post', maxlen=self.max_len)
        dec_out_padded = tf.keras.preprocessing.sequence.pad_sequences(dec_out_container, padding='post', maxlen=self.max_len)

        return enc_padded, dec_in_padded, dec_out_padded

    def get_line(self, data_file):
        with open(data_file, mode='r') as infile:
            for line in infile:
                context, response, _ = line.strip().split('\t')
                # context_bpe, response_bpe = self.tokenizer.encode(context), self.tokenizer.encode(response)
                # yield context_bpe, response_bpe
                if len(context.strip().split()) > self.max_len:
                    context = context.strip().split()[:self.max_len]
                if len(response.strip().split()) > self.max_len:
                    response = response.strip().split()[:self.max_len]
                yield context, response

    def batch_generator(self, mode:str='train'):
        assert mode in {'train', 'valid'}, "Please select as valid mode from: {train, valid}!"
        data_file = self.train_file if mode == 'train' else self.valid_file
        
        while True:
            encoder_batch, decoder_batch = [], []
            for context, resp in self.get_line(data_file):
                encoder_batch.append(context)
                decoder_batch.append(resp)
                if len(encoder_batch) == self.batch_size:
                    enc_padded, dec_in_padded, dec_out_padded = self.pad_batch(encoder_batch, decoder_batch)
                    yield [enc_padded, dec_in_padded], dec_out_padded

                    encoder_batch, decoder_batch = [], []

            # Check for non-empty batches
            if len(encoder_batch) > 0:
                enc_padded, dec_in_padded, dec_out_padded = self.pad_batch(encoder_batch, decoder_batch)
                yield [enc_padded, dec_in_padded], dec_out_padded
                
    def batch_gen_test(self, mode:str='train'):
        data_file = self.train_file if mode == 'train' else self.valid_file
        
        while True:
            encoder_batch, decoder_batch = [], []
            for context, resp in self.get_line(data_file):
                encoder_batch.append(context)
                decoder_batch.append(resp)
                if len(encoder_batch) == self.batch_size:
                    yield encoder_batch, decoder_batch
                    encoder_batch, decoder_batch = [], []

            # Check for non-empty batches
            if len(encoder_batch) > 0:
                yield [encoder_batch, decoder_batch], dec_out_padded
                

In [10]:
data_processor = DataProcessor(max_len=100, tokenizer=tokenizer, train_file=data_path,
                                valid_file=data_path, batch_size=10)
train_datagen = data_processor.batch_generator(mode='train')

In [11]:
for _ in range(5):
    c, r = next(data_processor.get_line(data_path))
    print(c)
    print(r)
    print()

Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.
Well, I thought we'd start with pronunciation, if that's okay with you.

Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.
Well, I thought we'd start with pronunciation, if that's okay with you.

Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.
Well, I thought we'd start with pronunciation, if that's okay with you.

Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.
Well, I thought we'd start with pronunciation, if that's okay with you.

Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.
Well, I thought we'd start with pronun

In [12]:
x, y = next(train_datagen)

AttributeError: 'str' object has no attribute 'insert'