Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

load/save functionality

  • Loading branch information...
commit 56f65244dc36e7f0063a901d40e7d2d76cda836f 1 parent 64ff5f2
@andersjo andersjo authored
View
1  rungsted/feat_map.pyx
@@ -88,6 +88,7 @@ cdef class DictFeatMap(FeatMap):
def __set__(self, value):
self.feat2index = value
+ self.next_i = len(value)
def __init__(self, int n_labels):
View
63 rungsted/runner.py
@@ -2,8 +2,10 @@
import argparse
import logging
import random
+import cPickle
import numpy as np
import sys
+from os.path import exists
from feat_map import HashingFeatMap, DictFeatMap
from input import read_vw_seq
@@ -25,7 +27,8 @@
default=10, type=int)
parser.add_argument('--shuffle', help="Shuffle examples after each iteration", action='store_true')
parser.add_argument('--average', help="Average over all updates", action='store_true')
-
+parser.add_argument('--initial-model', '-i', help="Initial model from this file")
+parser.add_argument('--final-model', '-f', help="Save model here after training")
args = parser.parse_args()
@@ -37,39 +40,56 @@
else:
feat_map = DictFeatMap(args.n_labels)
-train = read_vw_seq(args.train, args.n_labels, ignore=args.ignore, feat_map=feat_map)
-logging.info("Training data {} sentences".format(len(train)))
+if args.initial_model:
+ if not args.hash_bits and exists(args.initial_model + ".features"):
+ feat_map.feat2index_ = cPickle.load(open(args.initial_model + ".features"))
+
+train = None
+if args.train:
+ train = read_vw_seq(args.train, args.n_labels, ignore=args.ignore, feat_map=feat_map)
+ logging.info("Training data {} sentences".format(len(train)))
+
# Prevents the addition of new features when loading the test set
feat_map.freeze()
test = read_vw_seq(args.test, args.n_labels, ignore=args.ignore, feat_map=feat_map)
logging.info("Test data {} sentences".format(len(test)))
+logging.info("Weight vector size {}".format(feat_map.n_feats()))
+# Loading weights
w = Weights(n_labels, feat_map.n_feats())
-logging.info("Weight vector size {}".format(feat_map.n_feats()))
+if args.initial_model:
+ w.load(open(args.initial_model))
+
+if not args.hash_bits and args.final_model:
+ cPickle.dump(feat_map.feat2index_, open(args.final_model + ".features", 'w'), protocol=2)
n_updates = 0
+# Training loop
+if args.train:
+ for epoch in range(1, args.passes+1):
+ learning_rate = 0.1 if epoch < args.decay_delay else epoch**args.decay_exp * 0.1
+ if args.shuffle:
+ random.shuffle(train)
+ for sent in train:
+ flattened_labels = [e.flat_label() for e in sent]
-# Learning loop
-for epoch in range(1, args.passes+1):
- learning_rate = 0.1 if epoch < args.decay_delay else epoch**args.decay_exp * 0.1
- if args.shuffle:
- random.shuffle(train)
- for sent in train:
- flattened_labels = [e.flat_label() for e in sent]
+ gold_seq = np.array(flattened_labels, dtype=np.int32)
+ pred_seq = np.array(viterbi(sent, n_labels, w, feat_map), dtype=np.int32)
- gold_seq = np.array(flattened_labels, dtype=np.int32)
- pred_seq = np.array(viterbi(sent, n_labels, w, feat_map), dtype=np.int32)
+ assert len(gold_seq) == len(pred_seq)
- assert len(gold_seq) == len(pred_seq)
+ update_weights(pred_seq, gold_seq, sent, w, n_updates, learning_rate, n_labels, feat_map)
- update_weights(pred_seq, gold_seq, sent, w, n_updates, learning_rate, n_labels, feat_map)
+ n_updates += 1
- n_updates += 1
+ if n_updates % 1000 == 0:
+ print >>sys.stderr, '\r{} k sentences total'.format(n_updates / 1000),
- if n_updates % 1000 == 0:
- print >>sys.stderr, '\r{} k sentences total'.format(n_updates / 1000),
+ if args.average:
+ w.average_weights(n_updates)
+# Testing
y_gold = []
y_pred = []
@@ -77,9 +97,6 @@
if args.predictions:
out = open(args.predictions, 'w')
-if args.average:
- w.average_weights(n_updates)
-
for sent in test:
y_pred_sent = viterbi(sent, n_labels, w, feat_map)
y_gold += [e.flat_label() for e in sent]
@@ -96,3 +113,7 @@
print >>sys.stderr, ''
logging.info("Accuracy: {:.3f}".format(accuracy))
+
+# Save model
+if args.final_model:
+ w.save(open(args.final_model, 'w'))
View
12 rungsted/struct_perceptron.pyx
@@ -68,6 +68,16 @@ cdef class Weights:
self.t[label_i, label_j] += val
+ def load(self, file):
+ with np.load(file) as npz_file:
+ assert self.e.shape[0] == npz_file['e'].shape[0]
+ assert self.t.shape[0] == npz_file['t'].shape[0]
+ assert self.t.shape[1] == npz_file['t'].shape[1]
+ self.e = npz_file['e']
+ self.t = npz_file['t']
+
+ def save(self, file):
+ np.savez(file, e=self.e, t=self.t)
def update_weights(int[:] pred_seq, int[:] gold_seq, list sent, Weights w, int n_updates, double alpha, int n_labels,
FeatMap feat_map):
@@ -128,7 +138,7 @@ def viterbi(list sent, int n_labels, Weights w, FeatMap feat_map):
# Find best sequence from the trellis
best_seq = [np.asarray(trellis)[-1].argmax()]
for word_i in reversed(range(1, len(path))):
- best_seq.append(path[word_i, best_seq[-1]])
+ best_seq.append(path[word_i, <int> best_seq[-1]])
return [label + 1 for label in reversed(best_seq)]
Please sign in to comment.
Something went wrong with that request. Please try again.