-
Notifications
You must be signed in to change notification settings - Fork 101
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c9ec5b8
commit 9311763
Showing
4 changed files
with
181 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from utils import Utils | ||
from collections import Counter | ||
from sequences.rowReader import * | ||
|
||
class Metal(): | ||
"""docstring for Metal""" | ||
def __init__(self, taskManager, parameters, modelOpt): | ||
# One Layers object per task; model(0) contains the Layers shared between all tasks (if any) | ||
if modelOpt: | ||
self.model = modelOpt | ||
else: | ||
self.model = self.initialize() | ||
self.taskManager = taskManager | ||
|
||
def initialize(self): | ||
|
||
taskWords, taskLabels = mkVocabularies() | ||
|
||
layersPerTask = [None for _ in range(taskManager.taskCount + 1)] | ||
|
||
layersPerTask[0] = Layers(taskManager, "mtl.layers", parameters, taskWords(0), None, isDual = false, providedInputSize = None) | ||
|
||
inputSize = layersPerTask[0].outDim | ||
|
||
for i in taskManager.indices: | ||
layersPerTask[i+1] = Layers(taskManager, s"mtl.task${i + 1}.layers", parameters, taskWords(i + 1), Some(taskLabels(i + 1)), isDual = taskManager.tasks(i).isDual, inputSize) | ||
|
||
for i in range(len(layersPerTask)): | ||
print (s"Summary of layersPerTask({i}):") | ||
print (layersPerTask[i]) | ||
|
||
return layersPerTask | ||
|
||
def mkVocabularies(self): | ||
# index 0 reserved for the shared Layers; tid + 1 corresponds to each task | ||
labels = [Counter() for _ in range(taskManager.taskCount + 1)] | ||
for i in range(1, len(labels)): # labels(0) not used, since only task-specific layers have a final layer | ||
labels[i][Utils.START_TAG] += 1 | ||
labels[i][Utils.STOP_TAG] += 1 | ||
|
||
words = [Counter() for _ in range(taskManager.taskCount + 1)] | ||
|
||
reader = MetalRowReader() | ||
|
||
for tid in taskManager.indices: | ||
for sentence in taskManager.tasks[tid].trainSentences | ||
annotatedSentences = reader.toAnnotatedSentences(sentence) | ||
|
||
for asent in annotatedSentences: | ||
annotatedSentence = asent[0] | ||
sentenceLabels = asent[1] | ||
for i in annotatedSentence.indices: | ||
words[tid + 1][annotatedSentence.words[i]] += 1 | ||
words[0][annotatedSentence.words[i]] += 1 | ||
This comment has been minimized.
Sorry, something went wrong. |
||
labels[tid + 1][sentenceLabels[i]] += 1 | ||
|
||
return words, labels | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
|
||
class Utils: | ||
def __init__(self): | ||
self.concatenateCount = 0 | ||
|
||
self.UNK_WORD = "<UNK>" | ||
self.EOS_WORD = "<EOS>" | ||
|
||
self.UNK_EMBEDDING = 0 | ||
|
||
self.START_TAG = "<START>" | ||
self.STOP_TAG = "<STOP>" | ||
|
||
self.RANDOM_SEED = 2522620396L # used for both DyNet, and the JVM seed for shuffling data | ||
self.WEIGHT_DECAY = 1e-5 | ||
|
||
self.LOG_MIN_VALUE = -10000.0 | ||
|
||
self.DEFAULT_DROPOUT_PROBABILITY = 0.0 # no dropout by default | ||
|
||
self.IS_DYNET_INITIALIZED = False | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
|
||
class AnnotatedSentence: | ||
|
||
def __init__(self, words, posTags = None, neTags = None, headPositions = None): | ||
self.words = words | ||
self.posTags = posTags | ||
self.neTags = neTags | ||
self.headPositions = headPositions | ||
self.size = len(words) | ||
self.indicies = range(self.size) | ||
|
||
class RowReader(object): | ||
|
||
def __init__(self): | ||
raise NotImplementedError | ||
|
||
def toAnnotatedSentences(self, rows): | ||
raise NotImplementedError | ||
|
||
class MetalRowReader(RowReader): | ||
|
||
def __init__(self): | ||
self.WORD_POSITION = 0 | ||
self.POS_TAG_POSITION = 1 | ||
self.NE_LABEL_POSITION = 2 | ||
self.LABEL_START_OFFSET = 3 | ||
|
||
def toAnnotatedSentences(self, rows): | ||
if (len(rows.head) == 2): | ||
self.parseSimple(rows) | ||
elif (len(rows.head) == 4): | ||
self.parseSimpleExtended(rows) | ||
elif (len(rows.head) >= 5): | ||
self.parseFull(rows) | ||
else: | ||
raise RuntimeError("ERROR: the Metal format expects 2, 4, or 5+ columns!") | ||
|
||
# Parser for the simple format: word, label | ||
def parseSimple(rows): | ||
assert(len(rows.head) == 2) | ||
words = list() | ||
labels = list() | ||
|
||
for row in rows: | ||
words += [row.get(self.WORD_POSITION)] | ||
labels += [row.get(self.WORD_POSITION + 1)] | ||
|
||
return AnnotatedSentence(words), labels | ||
|
||
# Parser for the simple extended format: word, POS tag, NE label, label | ||
def parseSimpleExtended(rows): | ||
assert(len(rows.head) == 4) | ||
words = list() | ||
posTags = list() | ||
neLabels = list() | ||
labels = list() | ||
|
||
for row in rows: | ||
words += [row.get(self.WORD_POSITION)] | ||
posTags += [row.get(self.POS_TAG_POSITION)] | ||
neLabels += [row.get(self.NE_LABEL_POSITION)] | ||
labels += [row.get(self.LABEL_START_OFFSET)] | ||
This comment has been minimized.
Sorry, something went wrong.
bethard
|
||
|
||
return AnnotatedSentence(words), posTags, neLabels, labels | ||
|
||
# Parser for the full format: word, POS tag, NE label, (label head)+ | ||
def parseFull(rows: IndexedSeq[Row]): | ||
assert(len(rows.head) >= 5) | ||
numSent = (len(rows.head) - 3) / 2 | ||
assert(numSent >= 1) | ||
|
||
words = list() | ||
posTags = list() | ||
neLabels = list() | ||
headPositions = [list() for i in range(numSent)] | ||
labels = [list() for i in range(numSent)] | ||
|
||
for row in rows: | ||
words += [row.get(self.WORD_POSITION)] | ||
posTags += [row.get(self.POS_TAG_POSITION)] | ||
neLabels += [row.get(self.NE_LABEL_POSITION)] | ||
|
||
for j in range(numSent): | ||
labels[j]+= [row.get(self.LABEL_START_OFFSET + (j * 2))] | ||
try: | ||
headPositions[j] += [int(row.get(self.LABEL_START_OFFSET + (j * 2) + 1))] | ||
except: | ||
raise RuntimeError # not sure about this part | ||
|
||
sentences = list() | ||
for i in range(numSent): | ||
annotatedSent = AnnotatedSentence(words, posTags, neLabels, headPositions[i]) | ||
sentLabels = labels[i] | ||
sentences += [(annotatedSent, sentLabels)] | ||
|
||
return sentences |
A more Pythonic way of writing this would be: