In [41]:
from collections import defaultdict
import json

ERROR_CLASSES = set([
    'OH', 'OT', 'OA', 'OW', 'ON', 'OS', 'OG', 'OC', 'OR', 'OD', 'OM', # orthographic errors
    'MI', 'MT', # morphological errors
    'XC', 'XF', 'XG', 'XN', 'XT', 'XM', # syntactic errors
    'SW', 'SF', # semantic errors
    'PC', 'PT', 'PM', # punctuation errors
    'MG', 'SP', # incorrectly merged/split words
    'UC' # No error
])

with open('../../data/real/clean/qalb-14+qalb-15+ZAEBUC/annotations/qalb-14+qalb-15+ZAEBUC_train.areta.txt') as f:
  tags = f.readlines()
  
error_freqs = defaultdict(lambda: 0)
for t in tags:
  t = t.strip().split('\t')[-1]
  t = t.split('+')
  for i in t:
    error_freqs[i] += 1
    
del error_freqs['UNK']
del error_freqs['X']
del error_freqs['O']
del error_freqs['M']
del error_freqs['S']
del error_freqs['']
    
for k in ERROR_CLASSES:
  if k not in error_freqs:
    error_freqs[k] = 1
    
error_freqs = dict(error_freqs)
total = sum(error_freqs.values())
for k, v in error_freqs.items():
  error_freqs[k] = (v / total) * 100
  
# json.dump(error_freqs, open('qalb-14+qalb-15+ZAEBUC_error-distribution.json', 'w'))

In [42]:
print(json.dumps(error_freqs, indent=2))
print(sum(error_freqs.values()))

{
  "UC": 71.41570966898146,
  "OH": 8.319846767865128,
  "MG": 0.8653902947355987,
  "XG": 0.1425160109017576,
  "OT": 1.4568022073527798,
  "MI": 0.5618792875043245,
  "OD": 0.9527224861407609,
  "XM": 0.7512256039421836,
  "XC": 0.6641465493237029,
  "OM": 0.5779956629230549,
  "OR": 1.1644292187354974,
  "PT": 8.60850708361108,
  "XN": 0.1850429910642714,
  "SW": 0.5003670483406883,
  "PC": 1.0523740011644291,
  "PM": 0.5340342409693452,
  "XT": 0.21229738509699358,
  "SF": 0.136778243736974,
  "OC": 0.046324031962738266,
  "SP": 0.7119894020065309,
  "OA": 0.7504661935821386,
  "OG": 0.17837705568165518,
  "XF": 0.15829487060491254,
  "MT": 0.022360416156877307,
  "OW": 0.029785761899538448,
  "ON": 0.00025313678668162983,
  "OS": 8.437892889387661e-05
}
100.0


In [35]:
keep = error_freqs['UC'] / 100
err = 1 - keep
del error_freqs['UC']


In [43]:
keys = list(error_freqs.keys())
values = list(error_freqs.values())

import torch
values = torch.tensor(values)

In [44]:
print(values, torch.sum(values))

tensor([7.1416e+01, 8.3198e+00, 8.6539e-01, 1.4252e-01, 1.4568e+00, 5.6188e-01,
        9.5272e-01, 7.5123e-01, 6.6415e-01, 5.7800e-01, 1.1644e+00, 8.6085e+00,
        1.8504e-01, 5.0037e-01, 1.0524e+00, 5.3403e-01, 2.1230e-01, 1.3678e-01,
        4.6324e-02, 7.1199e-01, 7.5047e-01, 1.7838e-01, 1.5829e-01, 2.2360e-02,
        2.9786e-02, 2.5314e-04, 8.4379e-05]) tensor(100.0000)


In [45]:
T = 25
new_values = torch.softmax(values / T, dim = 0) * 100 # * err
print(new_values, torch.sum(new_values))

tensor([38.9474,  3.1217,  2.3168,  2.2508,  2.3723,  2.2889,  2.3249,  2.3063,
         2.2983,  2.2903,  2.3447,  3.1579,  2.2546,  2.2832,  2.3342,  2.2863,
         2.2571,  2.2503,  2.2422,  2.3027,  2.3062,  2.2540,  2.2522,  2.2400,
         2.2407,  2.2380,  2.2380]) tensor(100.0000)


In [46]:
new_err_freqs= {k: v.item() for k, v in zip(keys, new_values)}
# new_err_freqs['UC'] = keep * 100
print(json.dumps(new_err_freqs, indent=2))

{
  "UC": 38.947357177734375,
  "OH": 3.1216959953308105,
  "MG": 2.3168258666992188,
  "XG": 2.2507941722869873,
  "OT": 2.3722875118255615,
  "MI": 2.2888689041137695,
  "OD": 2.3249332904815674,
  "XM": 2.306270122528076,
  "XC": 2.298250675201416,
  "OM": 2.2903449535369873,
  "OR": 2.34470534324646,
  "PT": 3.1579489707946777,
  "XN": 2.2546262741088867,
  "SW": 2.2832441329956055,
  "PC": 2.334219217300415,
  "PM": 2.286320686340332,
  "XT": 2.2570858001708984,
  "SF": 2.2502779960632324,
  "OC": 2.2421507835388184,
  "SP": 2.3026533126831055,
  "OA": 2.306199789047241,
  "OG": 2.254025459289551,
  "XF": 2.2522151470184326,
  "MT": 2.240002155303955,
  "OW": 2.2406675815582275,
  "ON": 2.2380223274230957,
  "OS": 2.2380073070526123
}


In [47]:
with open('../../data/real/clean/qalb-14+qalb-15+ZAEBUC/annotations/qalb-14+qalb-15+ZAEBUC_error-distribution_temp_25.json', 'w') as f:
    json.dump(new_err_freqs, f, indent=2)