In [1]:
import os
import sys
import time
import pickle
from collections import defaultdict

import re
import numpy as np

import tensorflow as tf
sess_opt = tf.ConfigProto(gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.95 , allow_growth=True)
                         ,allow_soft_placement=True
#                          ,log_device_placement=True
                         ,device_count={'GPU': 1})

data_root = "data/cornell_movie-dialogs/"

In [2]:
class MovieLine:
    def __init__(self,p0,p1,dialog):
        self.p0 = p0
        self.p1 = p1
        self.dia = None
        self.__process(dialog)
            
    def __process(self,d):
        d = re.sub(r'[\"\'+-=$#\^\\@*&><]' , '' , d)
        for k in [ ',' , '.' , '?' , ';' , ':' ]:
            d = re.sub(r'[{0}] '.format(k) , ' {0} '.format(k) , d)
        self.dia = d.strip()

In [6]:
class Tokenizer:
    def __init__(self , data):
        self.Word2Idx = {}
        self.Idx2Word = ['<NULL>' , '<OOV>' , '<SOS>' , '<EOS>']
        self.WordFreq = defaultdict(lambda:0)
        self.count(data)
        
    def count(self , data):
        for ss in data:
            x = ss[0].split()
            for s in x:
                self.WordFreq[s] += 1
            x = ss[1].split()
            for s in x:
                self.WordFreq[s] += 1
        self.Idx2Word.extend(list(self.WordFreq.keys()))
        self.Word2Idx = dict([(k,i) for i,k in enumerate(self.Idx2Word)])
        print("Total Words :" , len(self.Idx2Word))
        counts = list(self.WordFreq.values())
        for p in range(10,100,10):
            print("PR {0:>3d} count : {1:>4.0f}".format(p , np.percentile(counts,float(p))))
    
    def transform(self , data , min_count=5):
        xIdxList = []
        yIdxList = []
        x_MaxLen = 0
        y_MaxLen = 0
        for ss in data:
            x = ss[0].split()
            tmp = []
            for s in x:
                if(self.WordFreq[s] < min_count):
                    tmp.append(self.Word2Idx['<OOV>'])
                else:
                    tmp.append(self.Word2Idx[s])
            tmp.append(self.Word2Idx['<EOS>'])
            l = len(tmp)
            if(l > x_MaxLen):
                x_MaxLen = l
            xIdxList.append(tmp)
            
            y = ss[1].split()
            tmp = []
            for s in y:
                if(self.WordFreq[s] < min_count):
                    tmp.append(self.Word2Idx['<OOV>'])
                else:
                    tmp.append(self.Word2Idx[s])
            tmp.append(self.Word2Idx['<EOS>'])
            l = len(tmp)
            if(l > y_MaxLen):
                y_MaxLen = l
            yIdxList.append(tmp)
        
        return xIdxList , yIdxList , x_MaxLen , y_MaxLen
    
    def inverse(self,idxArr):
        tmp = []
        for i in idxArr:
            tmp.append(self.Idx2Word[i])
        return tmp

In [3]:
def get_movie_line(root,MovieLineDict):
    with open(os.path.join(root, "movie_lines.txt") , "r" , encoding="utf-8" , errors='ignore') as f:
        x = f.readline()
        c = 0
        while(x!=""):
            l = x.split(" +++$+++ ")
            c+=1
            MovieLineDict[l[0]] = MovieLine(l[1],l[2],l[4])
            x = f.readline()
            if c%1000 == 0:
                print(c,end="\r",flush=True)

MovieLineDict = {}
get_movie_line(data_root,MovieLineDict)

304000

In [4]:
def get_dialog(root,myList,LineDict):
    with open(os.path.join(root, "movie_conversations.txt") , "r" , encoding='utf-8' )as f:
        x = f.readline()
        c = 0
        while( x!="" ):
            l = x.split(" +++$+++ ")
            p0 = l[0]
            p1 = l[1]
            conversation = l[3].replace('[','').replace(']','').replace('\'','')
            conversation = conversation.strip().split(", ")
            if(len(conversation) <= 1):
                continue
            try:
                c+=1
                convList = [LineDict[k].dia for k in conversation]
                all_conversation = zip(convList[0:-1],convList[1::])
                myList.extend(all_conversation)
                x = f.readline()
            except KeyError:
                x = f.readline()
                continue
            if( c%100 == 0 ):
                print(c , end="\r")
                
dialogList = []
get_dialog(data_root,dialogList,MovieLineDict)

83000

In [7]:
myTok = Tokenizer(dialogList)

Total Words : 98818
PR  10 count :    1
PR  20 count :    1
PR  30 count :    1
PR  40 count :    2
PR  50 count :    2
PR  60 count :    3
PR  70 count :    4
PR  80 count :    8
PR  90 count :   20


### Build Seq2Seq Model