Skip to content

Commit

Permalink
lstm update
Browse files Browse the repository at this point in the history
  • Loading branch information
coffee-cup committed Apr 8, 2018
1 parent e043b38 commit d1b2158
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 25 deletions.
98 changes: 77 additions & 21 deletions lstm.py
Expand Up @@ -2,11 +2,13 @@
import random

import numpy as np
import pandas as pd
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import precision_recall_fscore_support
from sklearn.model_selection import train_test_split
from torch.autograd import Variable
from torch.utils.data import DataLoader, Dataset
Expand Down Expand Up @@ -37,7 +39,7 @@ def __init__(self, config, embedding_dim, hidden_dim, label_size):
self.label_size = label_size
self.hidden_dim = hidden_dim
# self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, dropout=0.1)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, dropout=0.2)
self.hidden2label = nn.Linear(hidden_dim, label_size)
self.hidden = self.init_hidden()

Expand All @@ -60,15 +62,6 @@ def forward(self, embeds):
return log_probs


def get_accuracy(truth, pred):
assert len(truth) == len(pred)
right = 0
for i in range(len(truth)):
if truth[i] == pred[i]:
right += 1.0
return right / len(truth)


def train_epoch(model, dataloader, loss_fn, optimizer, epoch):
'''Train a single epoch.'''
model.train()
Expand Down Expand Up @@ -102,8 +95,8 @@ def train_epoch(model, dataloader, loss_fn, optimizer, epoch):
count += 1

if count % 100 == 0:
print('\tBatch: {} Iteration: {} Loss: {}'.format(
epoch, count, loss.data[0]))
print('\tIteration: {} Loss: {}'.format(epoch, count,
loss.data[0]))

loss.backward()
optimizer.step()
Expand Down Expand Up @@ -137,14 +130,50 @@ def evaluate(model, dataloader):
correct += (predict.data.numpy() == labels.data.numpy()).sum()
total_samples += labels.size()[0]

truth_res += labels.data.numpy().tolist()
pred_res += predict.data.numpy().tolist()

acc = correct / total_samples
return acc
metrics = precision_recall_fscore_support(
truth_res, pred_res, average='micro')
return acc, metrics


def evenly_distribute(X, y):
counts = [0, 0]

for l in y:
counts[l[0]] += 1

new_X = []
new_y = []
min_count = min(counts[0], counts[1])
print('Min sample count: {}'.format(min_count))

new_counts = [0, 0]
for i in range(0, len(X)):
l = y[i][0]
if new_counts[l] <= min_count:
new_X.append(X[i])
new_y.append(y[i])
new_counts[l] += 1

if new_counts[0] >= min_count and new_counts[1] >= min_count:
break

return new_X, new_y


def lstm(config, embedding_data, code):
X = [row[0] for row in embedding_data]
y = [row[1] for row in embedding_data]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# Evenly distribute across classes
X, y = evenly_distribute(X, y)

print('Total samples: {}'.format(len(X)))

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2)

train_dataset = MbtiDataset(X_train, y_train)
train_dataloader = DataLoader(
Expand All @@ -153,7 +182,7 @@ def lstm(config, embedding_data, code):
shuffle=True,
num_workers=4)

test_dataset = MbtiDataset(X_test, y_test)
test_dataset = MbtiDataset(X_val, y_val)
test_dataloader = DataLoader(
test_dataset,
batch_size=config.batch_size,
Expand All @@ -173,13 +202,14 @@ def lstm(config, embedding_data, code):

parameters = filter(lambda p: p.requires_grad, model.parameters())
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(parameters, lr=1e-4)
optimizer = optim.Adam(parameters, lr=1e-3)

losses = []
train_accs = []
test_accs = []

best_model = None
best_metrics = None
for i in range(config.epochs):
avg_loss = 0.0

Expand All @@ -188,18 +218,20 @@ def lstm(config, embedding_data, code):
losses.append(train_loss)
train_accs.append(train_acc)

acc = evaluate(model, test_dataloader)
acc, metrics = evaluate(model, test_dataloader)
test_accs.append(acc)

print('Epoch #{} Test Acc: {:.2f}%'.format(i, acc * 100))
print('Epoch #{} Val Acc: {:.2f}%'.format(i, acc * 100))
print('')

if acc > best_acc:
best_acc = acc
best_model = model.state_dict()
best_metrics = metrics

save_data = {
'best_acc': best_acc,
'best_metrics': best_metrics,
'losses': losses,
'train_accs': train_accs,
'test_accs': test_accs,
Expand All @@ -218,9 +250,33 @@ def lstm(config, embedding_data, code):
pickle.dump(save_data, f, protocol=pickle.HIGHEST_PROTOCOL)


def load_model(config, code):
model_file = 'saves/{}_model'.format(code)
model = LSTMClassifier(
config,
embedding_dim=config.feature_size,
hidden_dim=128,
label_size=2)
model.load_state_dict(torch.load(model_file))
return model


if __name__ == '__main__':
config = get_config()
pre_data = pd.read_csv(config.pre_save_file).values
split = int(len(pre_data) * 0.9)

trainval = pre_data[:split]
test = pre_data[split:]

# Save trainval and test datasets
with open('trainval_set', 'wb') as f:
pickle.dump(trainval, f, protocol=pickle.HIGHEST_PROTOCOL)

with open('test_set', 'wb') as f:
pickle.dump(test, f, protocol=pickle.HIGHEST_PROTOCOL)

code = THIRD
embedding_data = word2vec(config, code=code, batch=False)
lstm(config, embedding_data, code)
for code in [FIRST, SECOND, THIRD, FOURTH]:
embedding_data = word2vec(
config, code=code, batch=False, pre_data=trainval)
lstm(config, embedding_data, code)
76 changes: 76 additions & 0 deletions lstm_predict.py
@@ -0,0 +1,76 @@
import os
import pickle
import sys

import numpy as np
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

from config import parse_config
from lstm import LSTMClassifier, MbtiDataset
from preprocess import preprocess_text
from utils import FIRST, FOURTH, SECOND, THIRD, codes, get_char_for_binary
from word2vec import load_word2vec, word2vec


def np_sentence_to_list(L_sent):
newsent = []
for sentance in L_sent:
temp = []
for word in sentance:
temp.append(word.tolist())
newsent.append(temp)
return newsent


def load_model(config, code):
model_file = 'saves/{}_model'.format(code)
model = LSTMClassifier(
config,
embedding_dim=config.feature_size,
hidden_dim=128,
label_size=2)
model.load_state_dict(torch.load(model_file))
return model


def predict(config, text, code, model=None):
if model is None:
model = load_model(config, code)

preprocessed = preprocess_text(text)

word_model = load_word2vec(config.embeddings_model)
embedding = []
embedding = []
for word in preprocessed.split(' '):
if word in word_model.wv.index2word:
vec = word_model.wv[word]
embedding.append(vec)

input = Variable(torch.Tensor(np_sentence_to_list(embedding)))

pred = model(input)
pred_label = pred.data.max(1)[1].numpy()[0]
pred_char = get_char_for_binary(code, pred_label)
return pred_char


if __name__ == '__main__':
config = get_config()

if sys.stdin.isatty():
text = raw_input('Enter some text: ')
else:
text = sys.stdin.read()

personality = ''
codes = [FIRST, SECOND, THIRD, FOURTH]
for code in codes:
personality += predict(code, text)

print('Prediction is {}'.format(personality))
2 changes: 1 addition & 1 deletion preprocess.py
Expand Up @@ -24,7 +24,7 @@

def filter_text(post):
"""Decide whether or not we want to use the post."""
return len(post) > 5
return len(post.split(' ')) >= 7


def preprocess_text(post):
Expand Down
7 changes: 4 additions & 3 deletions word2vec.py
Expand Up @@ -6,7 +6,7 @@

import numpy as np
import pandas as pd
from gensim.models import Word2Vec, word2vec
from gensim.models import Word2Vec

from tqdm import trange
from utils import (ALL, FIRST, FOURTH, SECOND, THIRD, get_binary_for_code,
Expand Down Expand Up @@ -141,15 +141,16 @@ def get_one_hot_data(embedding_data):
return newdata


def word2vec(config, code=ALL, batch=True):
def word2vec(config, code=ALL, batch=True, pre_data=None):
"""Create word2vec embeddings
:config user configuration
"""
print('\n--- Creating word embeddings')

if pre_data is None:
pre_data = pd.read_csv(config.pre_save_file).values
embedding_data = None
pre_data = pd.read_csv(config.pre_save_file).values
if os.path.isfile(config.embeddings_model) and not config.force_word2vec:
# Load model from file
model = load_word2vec(config.embeddings_model)
Expand Down

0 comments on commit d1b2158

Please sign in to comment.