-
Notifications
You must be signed in to change notification settings - Fork 5
/
ec_data_topic_utils.py
executable file
·81 lines (67 loc) · 3.28 KB
/
ec_data_topic_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# -*- coding: utf-8 -*-
# @Time : 2017/1/12 16:04
import math
import pickle
import six
class Data(object):
def __init__(self,data_path,vocab_path,pretrained,batch_size):
self.batch_size = batch_size
data, vocab ,pretrained= self.load_vocab_data(data_path,vocab_path,pretrained)
self.train=data['train']
self.valid=data['valid']
self.test=data['test']
self.train2=data['train2']
self.valid2=data['valid2']
self.test2=data['test2']
self.word_size = len(vocab['word2id'])+1
self.max_sent_len = vocab['max_sent_len']
self.max_topic_len = vocab['max_topic_len']
self.word2id = vocab['word2id']
word2id = vocab['word2id']
#self.id2word = dict((v, k) for k, v in word2id.iteritems())
self.id2word = {}
for k, v in six.iteritems(word2id):
self.id2word[v]=k
self.pretrained=pretrained
def gen_batch(self,data,i):
begin=i*self.batch_size
data_size=data['input'].shape[0]
end=(i+1)*self.batch_size
if end>data_size:
end=data_size
input=data['input'][begin:end]
target = data['target'][begin:end]
#target = data['label'][begin:end] #sigmoid: drop at least 3 points
label = data['label'][begin:end,:-1]
length=data['length'][begin:end]
weight=data['weight'][begin:end]
topic_input=data['topic_input'][begin:end]
topic_length=data['topic_length'][begin:end]
topic_weight=data['topic_weight'][begin:end]
return input,target,label,weight,length,topic_input,topic_weight,topic_length
def gen_sent_batch(self,data,i):
begin=i*self.batch_size
data_size=data['input'].shape[0]
end=(i+1)*self.batch_size
if end>data_size:
end=data_size
input=data['input'][begin:end]
target = data['target'][begin:end]
#target = data['label'][begin:end] #sigmoid: drop at least 3 points
length=data['length'][begin:end]
weight=data['weight'][begin:end]
topic_input=data['topic_input'][begin:end]
topic_length=data['topic_length'][begin:end]
topic_weight=data['topic_weight'][begin:end]
return input,target,weight,length,topic_input,topic_weight,topic_length
def load_vocab_data(self,data_path,vocab_path,pretrained):
with open(data_path, 'rb') as fdata, open(vocab_path, 'rb') as fword2id:
data = pickle.load(fdata)
vocab = pickle.load(fword2id)
with open(pretrained, 'rb') as fin:
pretrained = pickle.load(fin)
return data, vocab, pretrained
def gen_batch_num(self,data): #
data_size = data['input'].shape[0]
batch_num=math.ceil(data_size/float(self.batch_size))
return int(batch_num)