In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

In [2]:
import pandas as pd
import mmh3
import re
from tqdm import tqdm
from datasets import load_metric
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer, DataCollatorForTokenClassification

MAX_LEN_CHARS = 256*3

In [3]:
df_t = pd.read_csv("data/eng_sentences.tsv", sep="\t", names=["id", "lang", "text"])
df_t = df_t[df_t["text"].str.len() < MAX_LEN_CHARS]

# shuffle
df_t["id"] = df_t["text"].map(lambda x: mmh3.hash64(x.encode('utf8'))[0])
df_t["text"] = df_t["text"].map(lambda x: re.sub(r'[^\x00-\x7F]+', '', x))
df_t["text"] = df_t["text"].map(lambda x: re.sub(' +', ' ', x).strip())
df_t = df_t.sort_values("id")
len(df_t)

1582094

In [4]:
df_o = pd.read_csv("data/oss.tsv", sep="\t", names=["text"])
df_o = df_o[df_o["text"].str.len() < MAX_LEN_CHARS]

# shuffle
df_o["id"] = df_o["text"].map(lambda x: mmh3.hash64(x.encode('utf8'))[0])
df_o["text"] = df_o["text"].map(lambda x: re.sub(r'[^\x00-\x7F]+', '', x))
df_o["text"] = df_o["text"].map(lambda x: re.sub(' +', ' ', x).strip())
df_o = df_o.sort_values("id")
len(df_o)

66939

In [5]:
df_g = pd.read_csv("data/gutenberg.tsv", sep="\t", names=["text"])
df_g = df_g[df_g["text"].str.len() < MAX_LEN_CHARS]

# shuffle
df_g["id"] = df_g["text"].map(lambda x: mmh3.hash64(x.encode('utf8'))[0])
df_g["text"] = df_g["text"].map(lambda x: re.sub(r'[^\x00-\x7F]+', '', x))
df_g["text"] = df_g["text"].map(lambda x: re.sub(' +', ' ', x).strip())
df_g = df_g.sort_values("id")
len(df_g)

4102516

In [6]:
test_size = 10000
texts_t = list(df_t["text"])[-test_size:]
texts_t = [text for text in texts_t if text.isascii()]
len(texts_t)

10000

In [7]:
test_size = 10000
texts_o = list(df_o["text"])[-test_size:]
texts_o = [text for text in texts_o if text.isascii()]
len(texts_o)

10000

In [8]:
test_size = 10000
texts_g = list(df_g["text"])[-test_size:]
texts_g = [text for text in texts_g if text.isascii()]
len(texts_g)

10000

In [9]:
texts_t[1]

"Alice didn't see the dog."

In [10]:
texts_o[1]

"Since the JavaScript support doesn't understand ES7 constructs, features like IntelliSense might not be fully accurate."

In [11]:
texts_g[1]

'"I don\'t fear him, anyway he comes," replied Will Banion. "I don\'t like it, but all of this was forced on me."'

In [12]:
texts = texts_t + texts_o + texts_g

In [13]:
label_list = ['O', 'B-COMMA']
tokenizer = AutoTokenizer.from_pretrained('./comma-distilroberta-base-3domains/')
model = AutoModelForTokenClassification.from_pretrained('./comma-distilroberta-base-3domains/', num_labels=len(label_list))

In [14]:
def predict_labels(text):
    words = text.split(" ")
    words_without_comma = [word.replace(",", "") for word in words]
    tokens = tokenizer(words_without_comma, truncation=True, is_split_into_words=True)
    word_ids = tokens.word_ids()
    predictions = model.forward(input_ids=torch.tensor(tokens['input_ids']).unsqueeze(0), attention_mask=torch.tensor(tokens['attention_mask']).unsqueeze(0))
    predictions = torch.argmax(predictions.logits.squeeze(), axis=1)

    word_preds = [label_list[0] for _ in words]
    for pred, word_id in zip(predictions.numpy(), word_ids):
        if word_id is not None and pred != 0:
            word_preds[word_id] = label_list[pred]
    return word_preds

def predict_labels_batch(l_texts):
    words_without_comma = [[word.replace(",", "") for word in text.split(" ")] for text in l_texts]
    tokens = tokenizer(words_without_comma, truncation=True, is_split_into_words=True)
    max_len = max(len(x) for x in tokens["input_ids"])
    for input_ids in tokens["input_ids"]:
        input_ids.extend([0] * (max_len - len(input_ids)))
    for attention_mask in tokens["attention_mask"]:
        attention_mask.extend([0] * (max_len - len(attention_mask)))
    predictions = model.forward(input_ids=torch.tensor(tokens['input_ids']), attention_mask=torch.tensor(tokens['attention_mask']))
    predictions = torch.argmax(predictions.logits.squeeze(), axis=-1)

    word_preds = []
    for i in range(len(l_texts)):
        word_ids = tokens.word_ids(batch_index=i)
        word_preds.append([label_list[0] for _ in words_without_comma[i]])
        for pred, word_id in zip(predictions[i].numpy(), word_ids):
            if word_id is not None and pred != 0:
                word_preds[i][word_id] = label_list[pred]
    return word_preds

In [15]:
metric = load_metric("seqeval")

In [16]:
BATCH_SIZE = 64
predictions = []
for i in tqdm(range(0, len(texts), BATCH_SIZE)):
    predictions.extend(predict_labels_batch(texts[i:i+BATCH_SIZE]))

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [08:57<00:00,  1.15s/it]


In [17]:
labels = [
    ["B-COMMA" if "," in token else "O" for token in text.split(" ")]
    for text in texts
]

In [18]:
results = metric.compute(predictions=predictions, references=labels)
results

{'COMMA': {'precision': 0.8378378378378378,
  'recall': 0.8166786036807523,
  'f1': 0.8271229204446897,
  'number': 32113},
 'overall_precision': 0.8378378378378378,
 'overall_recall': 0.8166786036807523,
 'overall_f1': 0.8271229204446897,
 'overall_accuracy': 0.9798254347109993}

In [19]:
# tatoeba
results = metric.compute(predictions=predictions[0:10000], references=labels[0:10000])
results

{'COMMA': {'precision': 0.8464730290456431,
  'recall': 0.845771144278607,
  'f1': 0.8461219411032767,
  'number': 1206},
 'overall_precision': 0.8464730290456431,
 'overall_recall': 0.845771144278607,
 'overall_f1': 0.8461219411032767,
 'overall_accuracy': 0.9951700883964953}

In [20]:
# oss
results = metric.compute(predictions=predictions[10000:20000], references=labels[10000:20000])
results

{'COMMA': {'precision': 0.8275974025974026,
  'recall': 0.8210661942341761,
  'f1': 0.8243188616703047,
  'number': 6209},
 'overall_precision': 0.8275974025974026,
 'overall_recall': 0.8210661942341761,
 'overall_f1': 0.8243188616703047,
 'overall_accuracy': 0.9860451973772935}

In [21]:
# gutenberg
results = metric.compute(predictions=predictions[20000:30000], references=labels[20000:30000])
results

{'COMMA': {'precision': 0.8400384342231692,
  'recall': 0.8141549923070694,
  'f1': 0.8268942119872521,
  'number': 24698},
 'overall_precision': 0.8400384342231692,
 'overall_recall': 0.8141549923070694,
 'overall_f1': 0.8268942119872521,
 'overall_accuracy': 0.9729185497801381}

In [22]:
tokenizer = AutoTokenizer.from_pretrained('./comma-roberta-base-3domains-more-data/')
model = AutoModelForTokenClassification.from_pretrained('./comma-roberta-base-3domains-more-data/', num_labels=len(label_list))

In [23]:
BATCH_SIZE = 64
predictions = []
for i in tqdm(range(0, len(texts), BATCH_SIZE)):
    predictions.extend(predict_labels_batch(texts[i:i+BATCH_SIZE]))

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 469/469 [17:55<00:00,  2.29s/it]


In [24]:
results = metric.compute(predictions=predictions, references=labels)
results

{'COMMA': {'precision': 0.8662496463709805,
  'recall': 0.8581571326254165,
  'f1': 0.8621844007133249,
  'number': 32113},
 'overall_precision': 0.8662496463709805,
 'overall_recall': 0.8581571326254165,
 'overall_f1': 0.8621844007133249,
 'overall_accuracy': 0.9837874742136189}

In [25]:
# tatoeba
results = metric.compute(predictions=predictions[0:10000], references=labels[0:10000])
results

{'COMMA': {'precision': 0.8662262592898431,
  'recall': 0.8698175787728026,
  'f1': 0.868018204385602,
  'number': 1206},
 'overall_precision': 0.8662262592898431,
 'overall_recall': 0.8698175787728026,
 'overall_f1': 0.868018204385602,
 'overall_accuracy': 0.9958470571387655}

In [26]:
# oss
results = metric.compute(predictions=predictions[10000:20000], references=labels[10000:20000])
results

{'COMMA': {'precision': 0.8464598249801114,
  'recall': 0.8568207440811725,
  'f1': 0.8516087722106611,
  'number': 6209},
 'overall_precision': 0.8464598249801114,
 'overall_recall': 0.8568207440811725,
 'overall_f1': 0.8516087722106611,
 'overall_accuracy': 0.9880937855211698}

In [27]:
# gutenberg
results = metric.compute(predictions=predictions[20000:30000], references=labels[20000:30000])
results

{'COMMA': {'precision': 0.8713657112308262,
  'recall': 0.8579237185197182,
  'f1': 0.8645924716923391,
  'number': 24698},
 'overall_precision': 0.8713657112308262,
 'overall_recall': 0.8579237185197182,
 'overall_f1': 0.8645924716923391,
 'overall_accuracy': 0.9786507203813727}