In [1]:
from helpers import load_data

In [42]:
data = load_data('wsmp_train.json')
eval_data = load_data('wsmp_dev.json')

In [3]:
data[0]

('ghaṭa stha yogam yoga īśa tattva jñānasya kāraṇam',
 [('ghaṭa', 'iic.'),
  ('stha', 'iic.'),
  ('yoga', 'm. sg. acc.'),
  ('yoga', 'iic.'),
  ('īśa', 'm. sg. voc.'),
  ('tattva', 'iic.'),
  ('jñāna', 'n. sg. g.'),
  ('kāraṇa', 'n. sg. acc.')])

In [4]:
import numpy as np
from collections import defaultdict

rules = defaultdict(int)

"""
Find rules = operations how to transform token -> stem
"""

for sentence, labels in data:
    tokens = sentence.split()
    stems, _ = zip(*labels)
    
    # Skip malformed data
    if len(tokens) != len(stems):
        continue
    
    # Find rule for each token, stem pair
    for token, stem in zip(tokens, stems):
        # Find possible starting indices for overlapping
        # sequences for chars
        indices = []
        for i, char in enumerate(token):
            if char == stem[0]:
                indices.append(i)
        
        # If no overlap, no rule
        if len(indices) == 0:
            continue
        
        # Find length of overlapping char segments
        match_lens = []
        for idx in indices:
            current_length = 0
            for k in range(0, min(len(stem), len(token) - idx)):
                if stem[k] == token[idx + k]:
                    current_length += 1
                else:
                    break
            match_lens.append(current_length)
        
        # Take longest overlapping char segment
        # as 'root' (may be different from linguistic root)
        best_idx = np.argmin(match_lens)
        best_length = match_lens[best_idx]
        best_idx = indices[best_idx]
        
        # If no overlap, no rule
        if best_length == 0:
            continue
        
        # Prefix = part before 'root'
        prefix = token[:best_idx]
        # Suffix = part after 'root'
        suffix = token[best_idx + best_length:]
        # Replaced suffix
        stem_suffix = stem[best_length:]
        
        # Save rule
        rules[(prefix, suffix, stem_suffix)] += 1

In [6]:
token_stem_pairs = set()

for sentence, labels in data:
    tokens = sentence.split()
    stems, _ = zip(*labels)
    
    if len(tokens) != len(stems):
        continue
        
    for token, stem in zip(tokens, stems):
        token_stem_pairs.add((token, stem))

token_stem_pairs = list(token_stem_pairs)

In [43]:
eval_token_stem_pairs = set()

for sentence, labels in eval_data:
    tokens = sentence.split()
    stems, _ = zip(*labels)
    
    if len(tokens) != len(stems):
        continue
        
    for token, stem in zip(tokens, stems):
        eval_token_stem_pairs.add((token, stem))

eval_token_stem_pairs = list(eval_token_stem_pairs)

In [44]:
len(eval_token_stem_pairs)

20011

In [85]:
from tqdm.notebook import tqdm

success = 0

valid_rules = [rule for rule, count in rules.items() if count > 5]
candidate_lengths = []

for token, stem in tqdm(eval_token_stem_pairs):
    possible_stems = []
    
    # Reconstruct candidates
    for prefix, suffix, stem_suffix in valid_rules:
        possible_stem = token[:]
        if possible_stem.startswith(prefix):
            possible_stem = possible_stem[len(prefix):]
        else:
            continue
        
        if possible_stem.endswith(suffix):
            possible_stem = possible_stem[:len(possible_stem) - len(suffix)]
            possible_stem += stem_suffix
        else:
            continue
        
        possible_stems.append(possible_stem)
    
    #possible_stems = list(sorted(set(possible_stems)))
    possible_stems = list(possible_stems)
    
    if stem in possible_stems:
        success += 1
    
    candidate_lengths.append(len(possible_stems))

  0%|          | 0/20011 [00:00<?, ?it/s]

In [86]:
print(f"Perc. of reconstructed stems: {success / len(eval_token_stem_pairs)}")
print(f"Avg. number of candidate stems: {np.mean(candidate_lengths)}")

Perc. of reconstructed stems: 0.9715656388986058
Avg. number of candidate stems: 8.560091949427814


In [87]:
possible_stems

['aśāntau',
 'aśāntaun',
 'aśānta',
 'aśāntad',
 'aśānti',
 'aśāntu',
 'aśānt',
 'aśāt',
 'aśāntā']

In [88]:
len(valid_rules)

2730

In [89]:
valid_rules[20:40]

[('maṇḍala', '', 'aṇḍala'),
 ('abhyās', '', 'bhyāsa'),
 ('', 'ena', 'a'),
 ('', 'yate', ''),
 ('', 'āyate', 'an'),
 ('', 't', 'd'),
 ('', '', 'n'),
 ('ādh', 'ra', 'dhāra'),
 ('', 'iḥ', ''),
 ('', 'cchet', 'm'),
 ('', 'uvoḥ', 'ū'),
 ('', 'yati', ''),
 ('', 'avati', 'ū'),
 ('āpnuy', 't', 'p'),
 ('', 'ā', 'an'),
 ('', 'āya', 'a'),
 ('mudrā', '', 'udrā'),
 ('', 'ayaḥ', 'i'),
 ('', 'anti', ''),
 ('amb', 'ra', 'mbara')]