In [None]:
import requests

import pandas as pd

def load_data(url):
  response = requests.get(url)
  lines = response.text.strip().split('\n')

  data = []
  for line in lines:
    parts = line.split('\t')
    lemma, inflect, tags = parts
    #if segments != "-":
    #segments = segments.split("|")
    tags = tags.split(";")
    data.append([lemma, inflect, tags])

  return pd.DataFrame(data, columns=["lemma", "inflected", "features"])

df = load_data("https://raw.githubusercontent.com/unimorph/eng/master/eng")

print(df)

                  lemma        inflected          features
0             microtome       microtomes           [N, PL]
1             microtome       microtomes   [V, PRS, 3, SG]
2             microtome      microtoming  [V, V.PTCP, PRS]
3             microtome       microtomed          [V, PST]
4             microtome       microtomed  [V, V.PTCP, PST]
...                 ...              ...               ...
652472       myriadaire       myriadaire           [N, SG]
652473     dibridgehead     dibridgehead           [N, SG]
652474       Chicagoese       Chicagoese           [N, SG]
652475           Druzer           Druzer           [N, SG]
652476  electrosensible  electrosensible           [N, SG]

[652477 rows x 3 columns]


In [55]:
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader

class SurfMorphEmbeeding(Dataset):
  def __init__(self, form2tags, char2id, tag2id, max_len=32):
    self.items = list(form2tags.items())
    self.char2id = char2id
    self.tag2id = tag2id
    self.max_len = max_len
    self.num_tags = len(tag2id)

  def _encode_form(self, form):
    ids = [self.char2id.get(c, 0) for c in form[:self.max_len]]
    if len(ids) < self.max_len:
        ids += [0] * (self.max_len - len(ids))
    return torch.tensor(ids, dtype=torch.long)

  def _encode_tags(self, tags):
    y = torch.zeros(self.num_tags, dtype=torch.float)
    for t in tags:
        y[self.tag2id[t]] = 1.0
    return y

  def __len__(self):
    return len(self.items)

  def __getitem__(self, idx):
    form, tags = self.items[idx]
    return {
        "form": form,
        "x": self._encode_form(form),    # character ids
        "y": self._encode_tags(tags),    # multi-hot atomic tags
    }
    

In [56]:
import torch.nn.functional as F

class CharCNNTagger(nn.Module):
    def __init__(self, vocab_size, num_tags, emb_dim=64, num_kernels=128, kernel_sizes=(3,4,5), max_len=32):
        super().__init__()
        self.char_emb = nn.Embedding(vocab_size + 1, emb_dim, padding_idx=0)
        self.convs = nn.ModuleList([
            nn.Conv1d(in_channels=emb_dim, out_channels=num_kernels, kernel_size=k) for k in kernel_sizes
        ])
        self.dropout = nn.Dropout(0.2)
        self.fc = nn.Linear(num_kernels * len(kernel_sizes), num_tags)

    def forward(self, x):                 # x: [B, L]
        e = self.char_emb(x)              # [B, L, E]
        e = e.transpose(1, 2)             # [B, E, L]
        conv_outs = []
        for conv in self.convs:
            h = conv(e)                   # [B, K, Lâ€™]
            h = F.relu(h)
            h = F.max_pool1d(h, kernel_size=h.size(-1)).squeeze(-1)  # [B, K]
            conv_outs.append(h)
        z = torch.cat(conv_outs, dim=-1)  # [B, K * len(sizes)]
        z = self.dropout(z)
        logits = self.fc(z)               # [B, num_tags]
        return logits

In [57]:
import pandas as pd
from collections import Counter
from itertools import chain


# map features to id
all_tags = sorted(set(chain.from_iterable(df["features"])))

tag2id = {t: i for i, t in enumerate(all_tags)}
id2tag = {i: t for t, i in tag2id.items()}

num_tags = len(tag2id)

# groups features together
form2tags = (
    df.groupby("inflected")["features"]
      .apply(lambda rows: sorted(set(chain.from_iterable(rows))))
      .to_dict()
)

# creates vocabulary of character
all_forms = list(form2tags.keys())
chars = sorted(set("".join(all_forms)))
char2id = {c: i+1 for i, c in enumerate(chars)}

In [58]:
dataset = SurfMorphEmbeeding(form2tags, char2id, tag2id, max_len=32)

# Simple random split
n = len(dataset)
n_train = int(0.9*n)
train_set, val_set = torch.utils.data.random_split(dataset, [n_train, n-n_train])

train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
val_loader   = DataLoader(val_set,   batch_size=64)

In [59]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CharCNNTagger(
    vocab_size=len(char2id),
    num_tags=len(tag2id),
    emb_dim=64,
    num_kernels=128,
    kernel_sizes=(3,4,5),
).to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

def run_epoch(loader, train=True):
    model.train(train)
    total_loss, total_items = 0.0, 0
    for batch in loader:
        x = batch["x"].to(device)
        y = batch["y"].to(device)
        logits = model(x)
        loss = criterion(logits, y)
        if train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        total_loss += loss.item() * x.size(0)
        total_items += x.size(0)
    return total_loss / max(total_items, 1)

for epoch in range(8):
    tr_loss = run_epoch(train_loader, train=True)
    va_loss = run_epoch(val_loader, train=False)
    print(f"epoch {epoch+1}  train_loss={tr_loss:.4f}  val_loss={va_loss:.4f}")


epoch 1  train_loss=0.1447  val_loss=0.1179
epoch 2  train_loss=0.1189  val_loss=0.1120
epoch 3  train_loss=0.1144  val_loss=0.1092
epoch 4  train_loss=0.1119  val_loss=0.1073
epoch 5  train_loss=0.1102  val_loss=0.1065
epoch 6  train_loss=0.1089  val_loss=0.1061
epoch 7  train_loss=0.1079  val_loss=0.1053
epoch 8  train_loss=0.1069  val_loss=0.1045


In [61]:
import torch

def encode_form(form, char2id, max_len=32):
    ids = [char2id.get(c, 0) for c in form[:max_len]]
    if len(ids) < max_len:
        ids += [0]*(max_len - len(ids))
    return torch.tensor([ids], dtype=torch.long)

def predict_tags(word, threshold=0.5):
    model.eval()
    with torch.no_grad():
        x = encode_form(word, char2id).to(device)
        logits = model(x)
        probs = torch.sigmoid(logits).squeeze(0)

        idxs = (probs >= threshold).nonzero(as_tuple=True)[0].tolist()
        return sorted([(id2tag[i], probs[i].item()) for i in idxs], key=lambda x: -x[1])


# Examples
print(predict_tags("eats", threshold=0.3))
print(predict_tags("microtomed", threshold=0.3))
print(predict_tags("rectangularizing"))
print(predict_tags("wugs"))
print(predict_tags("wugging"))
print(predict_tags("wuggle"))
print(predict_tags("pug"))


[('N', 0.9069347977638245), ('PL', 0.8268887996673584), ('V', 0.7776874303817749), ('3', 0.7733673453330994), ('PRS', 0.7583993077278137), ('SG', 0.3986973464488983)]
[('V.PTCP', 0.4763296842575073), ('PST', 0.4608100354671478)]
[('V.PTCP', 0.9902050495147705), ('PRS', 0.9883169531822205), ('V', 0.9850516319274902)]
[('N', 0.9231547117233276), ('PL', 0.87114018201828), ('PRS', 0.676342248916626), ('3', 0.6716897487640381), ('V', 0.657807469367981), ('SG', 0.5164280533790588)]
[('V.PTCP', 0.9508506655693054), ('V', 0.8895914554595947), ('PRS', 0.864666223526001), ('N', 0.5620590448379517)]
[('SG', 0.8041651844978333), ('IMP+SBJV', 0.8020147681236267), ('NFIN', 0.8017383217811584), ('V', 0.7709524631500244), ('N', 0.715103805065155)]
[('SG', 0.8948113918304443), ('N', 0.8401877284049988), ('V', 0.6223856806755066)]
