In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

In [2]:
import pandas as pd
import mmh3
import re
from tqdm import tqdm

MAX_LEN_CHARS = 256*3

In [3]:
from datasets import load_metric

In [4]:
from deepmultilingualpunctuation import PunctuationModel

model = PunctuationModel(model="oliverguhr/fullstop-punctuation-multilang-large")



In [5]:
df_t = pd.read_csv("data/eng_sentences.tsv", sep="\t", names=["id", "lang", "text"])
df_t = df_t[df_t["text"].str.len() < MAX_LEN_CHARS]

# shuffle
df_t["id"] = df_t["text"].map(lambda x: mmh3.hash64(x.encode('utf8'))[0])
df_t["text"] = df_t["text"].map(lambda x: re.sub(r'[^\x00-\x7F]+', '', x))
df_t["text"] = df_t["text"].map(lambda x: re.sub(' +', ' ', x).strip())
df_t = df_t.sort_values("id")
len(df_t)

1582094

In [6]:
df_o = pd.read_csv("data/oss.tsv", sep="\t", names=["text"])
df_o = df_o[df_o["text"].str.len() < MAX_LEN_CHARS]

# shuffle
df_o["id"] = df_o["text"].map(lambda x: mmh3.hash64(x.encode('utf8'))[0])
df_o["text"] = df_o["text"].map(lambda x: re.sub(r'[^\x00-\x7F]+', '', x))
df_o["text"] = df_o["text"].map(lambda x: re.sub(' +', ' ', x).strip())
df_o = df_o.sort_values("id")
len(df_o)

66939

In [7]:
df_g = pd.read_csv("data/gutenberg.tsv", sep="\t", names=["text"])
df_g = df_g[df_g["text"].str.len() < MAX_LEN_CHARS]

# shuffle
df_g["id"] = df_g["text"].map(lambda x: mmh3.hash64(x.encode('utf8'))[0])
df_g["text"] = df_g["text"].map(lambda x: re.sub(r'[^\x00-\x7F]+', '', x))
df_g["text"] = df_g["text"].map(lambda x: re.sub(' +', ' ', x).strip())
df_g = df_g.sort_values("id")
len(df_g)

4102516

In [8]:
test_size = 10000
texts_t = list(df_t["text"])[-test_size:]
texts_t = [text for text in texts_t if text.isascii()]
len(texts_t)

10000

In [9]:
test_size = 10000
texts_o = list(df_o["text"])[-test_size:]
texts_o = [text for text in texts_o if text.isascii()]
len(texts_o)

10000

In [10]:
test_size = 10000
texts_g = list(df_g["text"])[-test_size:]
texts_g = [text for text in texts_g if text.isascii()]
len(texts_g)

10000

In [11]:
texts_t[1]

"Alice didn't see the dog."

In [12]:
texts_o[1]

"Since the JavaScript support doesn't understand ES7 constructs, features like IntelliSense might not be fully accurate."

In [13]:
texts_g[1]

'"I don\'t fear him, anyway he comes," replied Will Banion. "I don\'t like it, but all of this was forced on me."'

In [14]:
texts = texts_t + texts_o + texts_g

In [15]:
clean_text = model.preprocess(texts_o[1])
labled_words = model.predict(clean_text)
print(labled_words)

[['Since', '0', 0.9999497], ['the', '0', 0.9999912], ['JavaScript', '0', 0.9999894], ['support', '0', 0.999967], ["doesn't", '0', 0.99999106], ['understand', '0', 0.9999101], ['ES7', '0', 0.9999515], ['constructs', ',', 0.99361897], ['features', '0', 0.9998851], ['like', '0', 0.999982], ['IntelliSense', '0', 0.99992466], ['might', '0', 0.9999914], ['not', '0', 0.99999213], ['be', '0', 0.99974865], ['fully', '0', 0.99999154], ['accurate', '.', 0.99954885]]


In [16]:
metric = load_metric("seqeval")

In [35]:
def clean_punctuation(s):
    return s.replace(",", "").replace(".", "").replace("!", "").replace("?", "").replace(":", "")

In [47]:
texts_without_punctuation = [
    model.preprocess(text) for text in texts
]
preds = []
for clean_text in tqdm(texts_without_punctuation):
    if clean_text:
        preds.append(model.predict(clean_text))
    else:
        preds.append([])

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14095/14095 [44:38<00:00,  5.26it/s]


In [48]:
predictions = [
    ["B-COMMA" if "," == token_pred[1] else "O" for token_pred in pred]
    for pred in preds
]

In [49]:
labels = [
    ["B-COMMA" if "," in token else "O" for token in text.split(" ") if token == "" or clean_punctuation(token)]
    for text in texts
]

In [54]:
# Quickfix for different tokens count
predictions[15905] = ['O']
predictions[23594] = ['O', 'O', 'O', 'O']

In [55]:
results = metric.compute(predictions=predictions, references=labels)
results

{'COMMA': {'precision': 0.720698444097843,
  'recall': 0.5900950007786949,
  'f1': 0.6488902589395807,
  'number': 32105},
 'overall_precision': 0.720698444097843,
 'overall_recall': 0.5900950007786949,
 'overall_f1': 0.6488902589395807,
 'overall_accuracy': 0.9622585936162497}

In [56]:
# tatoeba
results = metric.compute(predictions=predictions[0:10000], references=labels[0:10000])
results

{'COMMA': {'precision': 0.7761324041811847,
  'recall': 0.7388059701492538,
  'f1': 0.7570093457943925,
  'number': 1206},
 'overall_precision': 0.7761324041811847,
 'overall_recall': 0.7388059701492538,
 'overall_f1': 0.7570093457943925,
 'overall_accuracy': 0.9925531499394618}

In [57]:
# oss
results = metric.compute(predictions=predictions[10000:20000], references=labels[10000:20000])
results

{'COMMA': {'precision': 0.7812441621520643,
  'recall': 0.6744073536526367,
  'f1': 0.7239051410766835,
  'number': 6201},
 'overall_precision': 0.7812441621520643,
 'overall_recall': 0.6744073536526367,
 'overall_f1': 0.7239051410766835,
 'overall_accuracy': 0.9794956838093034}

In [58]:
# gutenberg
results = metric.compute(predictions=predictions[20000:30000], references=labels[20000:30000])
results

{'COMMA': {'precision': 0.701101789143839,
  'recall': 0.5616649121386347,
  'f1': 0.6236849204208254,
  'number': 24698},
 'overall_precision': 0.701101789143839,
 'overall_recall': 0.5616649121386347,
 'overall_f1': 0.6236849204208254,
 'overall_accuracy': 0.946145060884392}