In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
import numpy as np
import xml.etree.ElementTree as ET

In [0]:
f = open('drive/My Drive/AFC/embeddings/glove.6B.100d.txt', 'r', encoding='utf8')

In [0]:
lines = f.readlines()

In [0]:
word2index = {}
embeddings = np.zeros(shape=(len(lines) + 1, len(lines[0].split()[1:])), dtype=np.float32)
for idx, line in enumerate(lines):
    line = line.split()
    word2index[line[0]] = len(word2index) + 1
    embeddings[idx + 1] = np.array(line[1:], dtype=np.float32)

In [0]:
root = ET.parse('drive/My Drive/AFC/dataset/semcor.data.xml').getroot()

In [0]:
#loading training set
pos2index = {}
lemma2index = {'unk': 1}
f = open('drive/My Drive/AFC/dataset/semcor.gold.key.bnids.txt', 'r', encoding='utf8')
sentences = []
lemmas = []
pos = []
for sentence in root.findall('text/sentence'):
    s = []
    l = []
    p = []
    for word in sentence:
        w = word.text.lower()
        lemma = word.attrib['lemma'].lower()
        tag = word.attrib['pos'].lower()
        s.append(word2index[w]) if w in word2index else s.append(word2index['unk'])
        if lemma not in lemma2index:
          lemma2index[lemma] = len(lemma2index) + 1
        l.append(lemma2index[lemma])
        if tag not in pos2index:
            pos2index[tag] = len(pos2index) + 1
        p.append(pos2index[tag])

    sentences.append(s)
    lemmas.append(l)
    pos.append(p)
    

In [0]:
root = ET.parse('drive/My Drive/AFC/dataset/ALL.data.xml').getroot()

In [0]:
#loading test set
f = open('drive/My Drive/AFC/dataset/ALL.gold.key.bnids.txt', 'r', encoding='utf8')
test = {}
for sentence in root.findall('text/sentence'):
    s = []
    l = []
    p = []  
    dataset = sentence.attrib['id'].split('.')[0]
    if dataset not in test:
        test[dataset] = []
    for word in sentence:
        w = word.text.lower()
        lemma = word.attrib['lemma'].lower()
        tag = word.attrib['pos'].lower()
        s.append(word2index[w]) if w in word2index else s.append(word2index['unk'])
        l.append(lemma2index[lemma]) if lemma in lemma2index else l.append(lemma2index['unk'])
        p.append(pos2index[tag])
    test[dataset].append((s, l, p))

In [0]:
import tensorflow as tf

In [11]:
tf.__version__

'2.1.0'

In [0]:
import tqdm

In [0]:
def accuracy(y_true, y_pred):
    assert(len(y_true) == len(y_pred))
    c = 0
    for true, pred in zip(y_true, y_pred):
        if true == pred:
            c += 1
        
    return c / len(y_true)

In [0]:
batch_size = 16
hidden_size = 100

word_ids = tf.keras.Input([None], dtype=tf.int32)

pretrained_emb = tf.keras.layers.Embedding(embeddings.shape[0], embeddings.shape[1], weights=[embeddings], mask_zero=True, trainable=False)(word_ids)

bid = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(hidden_size, return_sequences=True))(pretrained_emb)

lemmas_scores = tf.keras.layers.Dense(len(lemma2index) + 1)(bid)
pos_scores = tf.keras.layers.Dense(len(pos2index) + 1)(bid)

In [0]:
model = tf.keras.Model(inputs=word_ids, outputs=[lemmas_scores, pos_scores])

In [0]:
def masked_loss(x):
    
    def loss(labels, logits):
        #print('logits ',logits.get_shape())
        #print('labels ', labels.get_shape())
        #output_shape = logits.get_shape()
        #sequence = tf.count_nonzero(x, axis=-1, dtype=tf.int32)
        not_zeros = tf.cast(tf.not_equal(x, 0), tf.int32)
        sequence = tf.reduce_sum(not_zeros, axis=-1)
        losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)
        mask = tf.sequence_mask(sequence)
        #masked_losses = tf.where(mask, losses, tf.zeros_like(losses))
        masked_losses = tf.multiply(losses, tf.cast(not_zeros, tf.float32))
        
        return masked_losses
    
    return loss
    
model.compile(optimizer=tf.keras.optimizers.Adam(0.001), loss=masked_loss(x=word_ids), 
              target_tensors=[tf.keras.Input([None], dtype=tf.int32), 
                              tf.keras.Input([None], dtype=tf.int32)])

In [17]:
steps = len(sentences) // batch_size
epochs = 5
for epoch in range(epochs):
  avg_loss = 0
  for step in tqdm.tqdm(range(steps), desc='Epoch ' + str(epoch + 1) + '/' + str(epochs)):
      
      l = model.train_on_batch(x=tf.keras.preprocessing.sequence.pad_sequences(sentences[step * batch_size: (step + 1) * batch_size], padding='post'), 
                               y=[tf.keras.preprocessing.sequence.pad_sequences(lemmas[step * batch_size: (step + 1) * batch_size], padding='post'), 
                                  tf.keras.preprocessing.sequence.pad_sequences(pos[step * batch_size: (step + 1) * batch_size], padding='post')])
          
      avg_loss += l[0]
      
      if (step > 0) and step % 500 == 0 or (step == steps - 1):
          print('Loss:', avg_loss / step)
              
  for dataset in test:
          y_true_lemma = []
          y_true_pos = []
          y_pred_lemma = []
          y_pred_pos = []
          for x, y1, y2 in test[dataset]:
              pred_lemma, pred_pos = model.predict(x)
              pred_lemma = np.argmax(pred_lemma, axis=-1).reshape(-1)
              pred_pos = np.argmax(pred_pos, axis=-1).reshape(-1)
              
              y_true_lemma += y1
              y_pred_lemma += pred_lemma.tolist()
              
              y_true_pos += y2
              y_pred_pos += pred_pos.tolist()
          
          acc_lemma, acc_pos = accuracy(y_true_lemma, y_pred_lemma), accuracy(y_true_pos, y_pred_pos)
              
          print(dataset + ' results ')
          print('lemma accuracy: ' + str(acc_lemma) + '\t' + 'pos accuracy: ' + str(acc_pos))

Epoch 1/5:  22%|██▏       | 501/2323 [07:42<29:07,  1.04it/s]

Loss: 2.7472714449167253


Epoch 1/5:  43%|████▎     | 1001/2323 [15:01<18:14,  1.21it/s]

Loss: 2.142255606532097


Epoch 1/5:  65%|██████▍   | 1501/2323 [22:52<11:56,  1.15it/s]

Loss: 1.815761889676253


Epoch 1/5:  86%|████████▌ | 2001/2323 [30:36<04:48,  1.12it/s]

Loss: 1.5698080129027367


Epoch 1/5: 100%|██████████| 2323/2323 [35:49<00:00,  1.08it/s]

Loss: 1.4469674629424791





senseval2 results 
lemma accuracy: 0.8156434269857787	pos accuracy: 0.7481789802289281
senseval3 results 
lemma accuracy: 0.8265656018769175	pos accuracy: 0.734885399747338
semeval2007 results 
lemma accuracy: 0.7981880662293034	pos accuracy: 0.7244611059044048
semeval2013 results 
lemma accuracy: 0.8020498152782743	pos accuracy: 0.6931235847932309


Epoch 2/5:   0%|          | 0/2323 [00:00<?, ?it/s]

semeval2015 results 
lemma accuracy: 0.8095238095238095	pos accuracy: 0.7142857142857143


Epoch 2/5:  22%|██▏       | 501/2323 [07:36<28:43,  1.06it/s]

Loss: 0.5592514248490333


Epoch 2/5:  43%|████▎     | 1001/2323 [14:50<18:45,  1.17it/s]

Loss: 0.5370316023826599


Epoch 2/5:  65%|██████▍   | 1501/2323 [22:43<12:03,  1.14it/s]

Loss: 0.5057698522259791


Epoch 2/5:  86%|████████▌ | 2001/2323 [30:32<04:51,  1.10it/s]

Loss: 0.4668887651599944


Epoch 2/5: 100%|██████████| 2323/2323 [35:44<00:00,  1.08it/s]

Loss: 0.4469694127998508





senseval2 results 
lemma accuracy: 0.8836281651057926	pos accuracy: 0.807145334720777
senseval3 results 
lemma accuracy: 0.8922577152138603	pos accuracy: 0.7922757624977441
semeval2007 results 
lemma accuracy: 0.8637925648234926	pos accuracy: 0.788191190253046
semeval2013 results 
lemma accuracy: 0.8647360266952687	pos accuracy: 0.7565248480514837


Epoch 3/5:   0%|          | 0/2323 [00:00<?, ?it/s]

semeval2015 results 
lemma accuracy: 0.8609831029185868	pos accuracy: 0.7922427035330261


Epoch 3/5:  22%|██▏       | 501/2323 [07:33<28:39,  1.06it/s]

Loss: 0.32033428648114204


Epoch 3/5:  43%|████▎     | 1001/2323 [14:45<17:52,  1.23it/s]

Loss: 0.315428119301796


Epoch 3/5:  65%|██████▍   | 1501/2323 [22:30<11:50,  1.16it/s]

Loss: 0.2992503169327974


Epoch 3/5:  86%|████████▌ | 2001/2323 [30:18<04:51,  1.10it/s]

Loss: 0.2745787087250501


Epoch 3/5: 100%|██████████| 2323/2323 [35:34<00:00,  1.09it/s]

Loss: 0.2625655102014901





senseval2 results 
lemma accuracy: 0.9046132500867152	pos accuracy: 0.8222337842525147
senseval3 results 
lemma accuracy: 0.907236960837394	pos accuracy: 0.8184443241292185
semeval2007 results 
lemma accuracy: 0.8856607310215557	pos accuracy: 0.8131833801936895
semeval2013 results 
lemma accuracy: 0.8870218090811584	pos accuracy: 0.787033726611846


Epoch 4/5:   0%|          | 0/2323 [00:00<?, ?it/s]

semeval2015 results 
lemma accuracy: 0.8740399385560675	pos accuracy: 0.8133640552995391


Epoch 4/5:  22%|██▏       | 501/2323 [07:41<28:51,  1.05it/s]

Loss: 0.22250076867640017


Epoch 4/5:  43%|████▎     | 1001/2323 [14:53<17:55,  1.23it/s]

Loss: 0.21868399899825453


Epoch 4/5:  65%|██████▍   | 1501/2323 [22:35<11:41,  1.17it/s]

Loss: 0.2051831921065847


Epoch 4/5:  86%|████████▌ | 2001/2323 [30:15<04:44,  1.13it/s]

Loss: 0.18535663773119448


Epoch 4/5: 100%|██████████| 2323/2323 [35:27<00:00,  1.09it/s]

Loss: 0.1758997002929372





senseval2 results 
lemma accuracy: 0.9172736732570239	pos accuracy: 0.8329864724245577
senseval3 results 
lemma accuracy: 0.9126511460025266	pos accuracy: 0.8296336401371593
semeval2007 results 
lemma accuracy: 0.8978444236176195	pos accuracy: 0.823805060918463
semeval2013 results 
lemma accuracy: 0.8959599570968895	pos accuracy: 0.800500536288881


Epoch 5/5:   0%|          | 0/2323 [00:00<?, ?it/s]

semeval2015 results 
lemma accuracy: 0.881336405529954	pos accuracy: 0.8221966205837173


Epoch 5/5:  22%|██▏       | 501/2323 [07:33<28:45,  1.06it/s]

Loss: 0.16792218124866484


Epoch 5/5:  43%|████▎     | 1001/2323 [14:45<17:51,  1.23it/s]

Loss: 0.16494994891062378


Epoch 5/5:  65%|██████▍   | 1501/2323 [22:29<11:45,  1.16it/s]

Loss: 0.1532962877874573


Epoch 5/5:  86%|████████▌ | 2001/2323 [30:12<04:44,  1.13it/s]

Loss: 0.13644281321018933


Epoch 5/5: 100%|██████████| 2323/2323 [35:23<00:00,  1.09it/s]

Loss: 0.12845378027046475





senseval2 results 
lemma accuracy: 0.9198751300728408	pos accuracy: 0.8459937565036421
senseval3 results 
lemma accuracy: 0.9166215484569572	pos accuracy: 0.8442519400830175
semeval2007 results 
lemma accuracy: 0.9019056544829741	pos accuracy: 0.8288034989065917
semeval2013 results 
lemma accuracy: 0.9021570730544631	pos accuracy: 0.8296984864736027
semeval2015 results 
lemma accuracy: 0.8859447004608295	pos accuracy: 0.8452380952380952


In [0]:
model.save('drive/My Drive/AFC/models/multitask.h5')