In [None]:
import requests

import pandas as pd

def load_data(url):
  response = requests.get(url)
  lines = response.text.strip().split('\n')

  data = []
  for line in lines:
    parts = line.split('\t')
    lemma, inflect, tags = parts
    tags = tags.split(";")
    data.append([lemma, inflect, tags])

  return pd.DataFrame(data, columns=["lemma", "inflected", "features"])

df = load_data("https://raw.githubusercontent.com/unimorph/eng/master/eng")

print(df)

                  lemma        inflected          features
0             microtome       microtomes           [N, PL]
1             microtome       microtomes   [V, PRS, 3, SG]
2             microtome      microtoming  [V, V.PTCP, PRS]
3             microtome       microtomed          [V, PST]
4             microtome       microtomed  [V, V.PTCP, PST]
...                 ...              ...               ...
652472       myriadaire       myriadaire           [N, SG]
652473     dibridgehead     dibridgehead           [N, SG]
652474       Chicagoese       Chicagoese           [N, SG]
652475           Druzer           Druzer           [N, SG]
652476  electrosensible  electrosensible           [N, SG]

[652477 rows x 3 columns]


In [34]:
import requests

url = "https://raw.githubusercontent.com/unimorph/eng/master/eng"
response = requests.get(url)

with open("eng_dataset.tsv", "w", encoding="utf-8") as f:
  f.write(response.text)

In [None]:
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader

class SurfMorphEmbedding(Dataset):
  def __init__(self, form2tags, char2id, tag2id, max_len=32):
    self.items = list(form2tags.items())
    self.char2id = char2id
    self.tag2id = tag2id
    self.max_len = max_len
    self.num_tags = len(tag2id)

  def _encode_form(self, form):
    ids = [self.char2id.get(c, 0) for c in form[:self.max_len]]
    if len(ids) < self.max_len:
        ids += [0] * (self.max_len - len(ids))
    return torch.tensor(ids, dtype=torch.long)

  def _encode_tags(self, tags):
    y = torch.zeros(self.num_tags, dtype=torch.float)
    for t in tags:
        y[self.tag2id[t]] = 1.0
    return y

  def __len__(self):
    return len(self.items)

  def __getitem__(self, idx):
    form, tags = self.items[idx]
    return {
        "form": form,
        "x": self._encode_form(form),
        "y": self._encode_tags(tags),
    }
    

In [5]:
import torch.nn.functional as F

class CharCNNTagger(nn.Module):
    def __init__(self, vocab_size, num_tags, emb_dim=64, num_kernels=128, kernel_sizes=(3,4,5), max_len=32):
        super().__init__()
        self.char_emb = nn.Embedding(vocab_size + 1, emb_dim, padding_idx=0)
        self.convs = nn.ModuleList([
            nn.Conv1d(in_channels=emb_dim, out_channels=num_kernels, kernel_size=k) for k in kernel_sizes
        ])
        self.dropout = nn.Dropout(0.2)
        self.fc = nn.Linear(num_kernels * len(kernel_sizes), num_tags)

    def forward(self, x):
        e = self.char_emb(x)
        e = e.transpose(1, 2)
        conv_outs = []
        for conv in self.convs:
            h = conv(e)
            h = F.relu(h)
            h = F.max_pool1d(h, kernel_size=h.size(-1)).squeeze(-1)
            conv_outs.append(h)
        z = torch.cat(conv_outs, dim=-1)
        z = self.dropout(z)
        logits = self.fc(z)
        return logits

In [6]:
import pandas as pd
from collections import Counter
from itertools import chain

# map features to id
all_tags = sorted(set(chain.from_iterable(df["features"])))

tag2id = {t: i for i, t in enumerate(all_tags)}
id2tag = {i: t for t, i in tag2id.items()}

num_tags = len(tag2id)

# groups features together
form2tags = (
    df.groupby("inflected")["features"]
      .apply(lambda rows: sorted(set(chain.from_iterable(rows))))
      .to_dict()
)

# creates vocabulary of character
all_forms = list(form2tags.keys())
chars = sorted(set("".join(all_forms)))
char2id = {c: i+1 for i, c in enumerate(chars)}

In [7]:
dataset = SurfMorphEmbedding(form2tags, char2id, tag2id, max_len=32)

# Simple random split
n = len(dataset)
n_train = int(0.9*n)
train_set, val_set = torch.utils.data.random_split(dataset, [n_train, n-n_train])

train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
val_loader   = DataLoader(val_set,   batch_size=64)

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CharCNNTagger(
    vocab_size=len(char2id),
    num_tags=len(tag2id),
    emb_dim=64,
    num_kernels=128,
    kernel_sizes=(3,4,5),
).to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

def run_epoch(loader, train=True):
    model.train(train)
    total_loss, total_items = 0.0, 0
    for batch in loader:
        x = batch["x"].to(device)
        y = batch["y"].to(device)
        logits = model(x)
        loss = criterion(logits, y)
        if train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        total_loss += loss.item() * x.size(0)
        total_items += x.size(0)
    return total_loss / max(total_items, 1)

for epoch in range(8):
    tr_loss = run_epoch(train_loader, train=True)
    va_loss = run_epoch(val_loader, train=False)
    print(f"epoch {epoch+1}  train_loss={tr_loss:.4f}  val_loss={va_loss:.4f}")


epoch 1  train_loss=0.1436  val_loss=0.1168
epoch 2  train_loss=0.1189  val_loss=0.1122
epoch 2  train_loss=0.1189  val_loss=0.1122
epoch 3  train_loss=0.1146  val_loss=0.1094
epoch 3  train_loss=0.1146  val_loss=0.1094
epoch 4  train_loss=0.1118  val_loss=0.1082
epoch 4  train_loss=0.1118  val_loss=0.1082
epoch 5  train_loss=0.1100  val_loss=0.1066
epoch 5  train_loss=0.1100  val_loss=0.1066
epoch 6  train_loss=0.1088  val_loss=0.1053
epoch 6  train_loss=0.1088  val_loss=0.1053
epoch 7  train_loss=0.1078  val_loss=0.1050
epoch 7  train_loss=0.1078  val_loss=0.1050
epoch 8  train_loss=0.1069  val_loss=0.1044
epoch 8  train_loss=0.1069  val_loss=0.1044


In [9]:
import torch

def encode_form(form, char2id, max_len=32):
    ids = [char2id.get(c, 0) for c in form[:max_len]]
    if len(ids) < max_len:
        ids += [0]*(max_len - len(ids))
    return torch.tensor([ids], dtype=torch.long)

def predict_tags(word, threshold=0.5):
    model.eval()
    with torch.no_grad():
        x = encode_form(word, char2id).to(device)
        logits = model(x)
        probs = torch.sigmoid(logits).squeeze(0)

        idxs = (probs >= threshold).nonzero(as_tuple=True)[0].tolist()
        return sorted([(id2tag[i], probs[i].item()) for i in idxs], key=lambda x: -x[1])

# Examples
print(predict_tags("eats"))
print(predict_tags("microtomed", threshold=0.3))
print(predict_tags("rectangularizing"))
print(predict_tags("wugs"))
print(predict_tags("wugging"))
print(predict_tags("wuggle"))
print(predict_tags("pug"))


[('N', 0.8908878564834595), ('V', 0.725412905216217), ('3', 0.7114032506942749), ('PL', 0.6816283464431763), ('PRS', 0.6800766587257385)]
[('PST', 0.5687178373336792), ('V.PTCP', 0.4926902949810028), ('V', 0.3081898093223572)]
[('V', 0.9932036399841309), ('PRS', 0.9908037185668945), ('V.PTCP', 0.9884963035583496)]
[('N', 0.935979962348938), ('V', 0.8439024686813354), ('3', 0.7256053686141968), ('PL', 0.718870997428894), ('PRS', 0.7028244733810425), ('SG', 0.6043742299079895)]
[('V.PTCP', 0.9274927377700806), ('PRS', 0.9248952269554138), ('V', 0.895807147026062)]
[('IMP+SBJV', 0.8259298801422119), ('NFIN', 0.82508385181427), ('V', 0.784173309803009), ('SG', 0.7334151864051819), ('N', 0.6739606857299805)]
[('SG', 0.8506795167922974), ('N', 0.8333054184913635)]


In [47]:
import pandas as pd
from itertools import chain
from collections import defaultdict
import torch

# heirarchy
def build_hierarchy_from_tsv(path: str):
    df = pd.read_csv(path, sep="\t", header=None, names=["lemma", "form", "features"])
    tag_lists = df["features"].astype(str).str.split(";")
    atomic_tags = sorted(set(chain.from_iterable(tag_lists)))

    POS = {"N", "V", "ADJ"}

    tag_pos = defaultdict(set)
    for tags in tag_lists:
        tset = set(tags)
        pos_in_row = POS.intersection(tset)
        for t in tset:
            tag_pos[t].update(pos_in_row)

    POS_EXCLUSIVE = [p for p in ["N", "V", "ADJ"] if p in atomic_tags]

    POS_OF_TAG = {}
    for t in atomic_tags:
        if t in POS:
            continue
        if len(tag_pos[t]) == 1:
            POS_OF_TAG[t] = next(iter(tag_pos[t]))

    constraints = defaultdict(list)
    for t, p in POS_OF_TAG.items():
        constraints[t].append(p)

    return df, atomic_tags, POS_EXCLUSIVE, POS_OF_TAG, dict(constraints)


# constraints
def enforce_constraints(pred_ids, id2tag, tag2id, constraints):
    pred_ids = set(pred_ids)
    changed = True
    while changed:
        changed = False
        for t_id in list(pred_ids):
            t = id2tag[t_id]
            for parent in constraints.get(t, []):
                p_id = tag2id[parent]
                if p_id not in pred_ids:
                    pred_ids.add(p_id)
                    changed = True
    return sorted(pred_ids)

def enforce_pos_exclusivity(pred_ids, probs, id2tag, POS_IDS):
    pred_ids = set(pred_ids)
    present_pos = [i for i in pred_ids if i in POS_IDS]

    if len(present_pos) == 0:
        return sorted(pred_ids), None
    if len(present_pos) == 1:
        return sorted(pred_ids), present_pos[0]

    best_pos = max(present_pos, key=lambda i: probs[i].item())
    for i in present_pos:
        if i != best_pos:
            pred_ids.remove(i)
    return sorted(pred_ids), best_pos

def filter_by_pos_hierarchy(pred_ids, chosen_pos_id, id2tag, POS_OF_TAG):
    if chosen_pos_id is None:
        return sorted(pred_ids)

    chosen_pos = id2tag[chosen_pos_id]
    keep = set([chosen_pos_id])

    for i in pred_ids:
        t = id2tag[i]
        expected_pos = POS_OF_TAG.get(t)
        if expected_pos is None or expected_pos == chosen_pos:
            keep.add(i)

    return sorted(keep)


# bundle tags

def make_bundle_from_tags(tags):
    # Deterministic order: POS first if present, then rest alpha
    pos_priority = {"N": 0, "V": 0, "ADJ": 0}
    def key(t):
        return (pos_priority.get(t, 1), t)
    return ";".join(sorted(tags, key=key))

def encode_form(form: str, char2id: dict, max_len: int = 32):
    ids = [char2id.get(c, 0) for c in form[:max_len]]
    if len(ids) < max_len:
        ids += [0] * (max_len - len(ids))
    return torch.tensor([ids], dtype=torch.long)

def predict_atomic_tags(
    word: str,
    model,
    device,
    char2id: dict,
    tag2id: dict,
    id2tag: dict,
    constraints: dict,
    POS_EXCLUSIVE: list,
    POS_OF_TAG: dict,
    threshold: float = 0.5,
    max_len: int = 32,
):
    model.eval()
    with torch.no_grad():
        x = encode_form(word, char2id, max_len=max_len).to(device)
        logits = model(x)
        probs = torch.sigmoid(logits).squeeze(0)

        pred_ids = (probs >= threshold).nonzero(as_tuple=True)[0].tolist()

        pred_ids = enforce_constraints(pred_ids, id2tag, tag2id, constraints)

        POS_IDS = [tag2id[p] for p in POS_EXCLUSIVE if p in tag2id]
        pred_ids, chosen_pos_id = enforce_pos_exclusivity(pred_ids, probs, id2tag, POS_IDS)

        pred_ids = filter_by_pos_hierarchy(pred_ids, chosen_pos_id, id2tag, POS_OF_TAG)

        tags_with_probs = [(id2tag[i], probs[i].item()) for i in pred_ids]
        tags_with_probs.sort(key=lambda x: -x[1])
        return tags_with_probs  # list[(tag, prob)]

def predict_bundle(
    word: str,
    model,
    device,
    char2id: dict,
    tag2id: dict,
    id2tag: dict,
    constraints: dict,
    POS_EXCLUSIVE: list,
    POS_OF_TAG: dict,
    threshold: float = 0.5,
    max_len: int = 32,
):
    tags_with_probs = predict_atomic_tags(
        word, model, device, char2id, tag2id, id2tag,
        constraints, POS_EXCLUSIVE, POS_OF_TAG,
        threshold=threshold, max_len=max_len
    )
    tag_set = {t for t, _ in tags_with_probs}
    return make_bundle_from_tags(tag_set), tags_with_probs


In [51]:
df, atomic_tags, POS_EXCLUSIVE, POS_OF_TAG, constraints = build_hierarchy_from_tsv("../data/eng_dataset.tsv")
tag2id = {t: i for i, t in enumerate(atomic_tags)}
id2tag = {i: t for t, i in tag2id.items()}


In [52]:
bundle, tags = predict_bundle(
    "wuggle", model, device,
    char2id, tag2id, id2tag,
    constraints, POS_EXCLUSIVE, POS_OF_TAG,
    threshold=0.4
)
print(bundle)
print(tags)

V;IMP+SBJV;NFIN;SG
[('V', 0.8657882809638977), ('SG', 0.8472962975502014), ('IMP+SBJV', 0.845360279083252), ('NFIN', 0.8450121283531189)]
