Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasjinreal committed May 10, 2017
0 parents commit bae381a
Show file tree
Hide file tree
Showing 8 changed files with 142,018 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
@@ -0,0 +1,3 @@
.idea/
__pycache__/
checkpoints/
60 changes: 60 additions & 0 deletions README.md
@@ -0,0 +1,60 @@
# PyTorch Seq2Seq Model

![PicName](http://ofwzcunzi.bkt.clouddn.com/EItTIwqpcrAsrexq.png)


> Aim to build a Marvelous ChatBot

# Synopsis

This is the first and the only opensource of **ChatBot**, I will continues update this repo personally, aim to build a intelligent ChatBot, as the next version of Jarvis.

This repo will maintain to build a **Marvelous ChatBot** based on PyTorch,
welcome star and submit PR.

![PicName](http://ofwzcunzi.bkt.clouddn.com/FmLv0IpsmiMkLGiQ.png)
# Already Done

Currently this repo did those work:

* based on official tutorial, this repo will move on develop a seq2seq chatbot, QA system;
* re-constructed whole project, separate mess code into `data`, `model`, `train logic`;
* model can be save into local, and reload from previous saved dir, which is lack in official tutorial;
* just replace the dataset you can train your own data!

Last but not least, this project will maintain or move on other repo in the future but we will
continue implement a practical seq2seq based project to build anything you want: **Translate Machine**,
**ChatBot**, **QA System**... anything you want.


# Requirements

```
PyTorch
python3
Ubuntu Any Version
Both CPU and GPU can works
```

# Usage

Before dive into this repo, you want to glance the whole structure, we have these setups:

* `config`: contains the config params, which is global in this project, you can change a global param here;
* `datasets`: contains data and data_loader, using your own dataset, you should implement your own data_loader but just a liitle change on this one;
* `models`: contains seq2seq model definition;
* `utils`: this folder is very helpful, it contains some code may help you get out of anoying things, such as save model, or catch KeyboardInterrupt exception or load previous model, all can be done in here.

to train model is also straightforward, just type:
```
python3 train.py
```

# Contribute

wecome submit PR!!!! Let's build ChatBot together!

# Contact

if you have anyquestion, you can find me via wechat `jintianiloveu`
29 changes: 29 additions & 0 deletions config/global_config.py
@@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
# file: global_config.py
# author: JinTian
# time: 10/05/2017 4:51 PM
# Copyright 2017 JinTian. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
import torch
import os


MODEL_PREFIX = 'seq2seq_translate'
CHECKPOINT_DIR = './checkpoints'
MAX_LENGTH = 10


use_cuda = torch.cuda.is_available()
teacher_forcing_ratio = 0.5
147 changes: 147 additions & 0 deletions datasets/data_loader.py
@@ -0,0 +1,147 @@
# -*- coding: utf-8 -*-
# file: data_loader.py
# author: JinTian
# time: 10/05/2017 6:27 PM
# Copyright 2017 JinTian. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""
this file load pair data into seq2seq model
raw file contains:
sequenceA sequenceB
....
"""
import torch
from torch.autograd import Variable
import math
import random
import re
import time
import unicodedata
from io import open
from config.global_config import *


class PairDataLoader(object):
"""
this class load raw file and generate pair data.
"""

def __init__(self):

self.SOS_token = 0
self.EOS_token = 1
self.eng_prefixes = (
"i am ", "i m ",
"he is", "he s ",
"she is", "she s",
"you are", "you re ",
"we are", "we re ",
"they are", "they re "
)

self._prepare_data('eng', 'fra')

class Lang(object):

def __init__(self, name):
self.name = name
self.word2index = {}
self.word2count = {}
self.index2word = {0: "SOS", 1: "EOS"}
self.n_words = 2 # Count SOS and EOS

def add_sentence(self, sentence):
for word in sentence.split(' '):
self.add_word(word)

def add_word(self, word):
if word not in self.word2index:
self.word2index[word] = self.n_words
self.word2count[word] = 1
self.index2word[self.n_words] = word
self.n_words += 1
else:
self.word2count[word] += 1

def filter_pair(self, p):
return len(p[0].split(' ')) < MAX_LENGTH and \
len(p[1].split(' ')) < MAX_LENGTH and \
p[0].startswith(self.eng_prefixes)

def filter_pairs(self, pairs):
return [pair for pair in pairs if self.filter_pair(pair)]

@staticmethod
def unicode_to_ascii(s):
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn'
)

def normalize_string(self, s):
s = self.unicode_to_ascii(s).lower().strip()
s = re.sub(r"([.!?])", r" \1", s)
s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
return s

def read_lang(self, lang1, lang2, reverse=False):
print("Reading lines...")
lines = open('./datasets/%s-%s.txt' % (lang1, lang2), encoding='utf-8'). \
read().strip().split('\n')
pairs = [[self.normalize_string(s) for s in l.split('\t')] for l in lines]
if reverse:
pairs = [list(reversed(p)) for p in pairs]
input_lang = self.Lang(lang2)
output_lang = self.Lang(lang1)
else:
input_lang = self.Lang(lang1)
output_lang = self.Lang(lang2)

return input_lang, output_lang, pairs

@staticmethod
def indexes_from_sentence(lang, sentence):
return [lang.word2index[word] for word in sentence.split(' ')]

def variable_from_sentence(self, lang, sentence):
indexes = self.indexes_from_sentence(lang, sentence)
indexes.append(self.EOS_token)
result = Variable(torch.LongTensor(indexes).view(-1, 1))
if use_cuda:
return result.cuda()
else:
return result

def _prepare_data(self, lang1, lang2, reverse=False):
input_lang, output_lang, pairs = self.read_lang(lang1, lang2, reverse)
print("Read %s sentence pairs" % len(pairs))
self.pairs = self.filter_pairs(pairs)
print("Trimmed to %s sentence pairs" % len(self.pairs))
print("Counting words...")
for pair in self.pairs:
input_lang.add_sentence(pair[0])
output_lang.add_sentence(pair[1])
self.input_lang = input_lang
self.output_lang = output_lang
print("Counted words:")
print(input_lang.name, input_lang.n_words)
print(output_lang.name, output_lang.n_words)

def get_pair_variable(self):
input_variable = self.variable_from_sentence(self.input_lang, random.choice(self.pairs)[0])
target_variable = self.variable_from_sentence(self.output_lang, random.choice(self.pairs)[1])
return input_variable, target_variable

0 comments on commit bae381a

Please sign in to comment.