In [1]:
from typing import Iterator, List, Dict
import torch
import torch.optim as optim
import numpy as np
from allennlp.data import Instance
from allennlp.data.fields import TextField, SequenceLabelField,LabelField,ArrayField, ListField
from allennlp.data.dataset_readers import DatasetReader
from allennlp.common.file_utils import cached_path
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer, TokenCharactersIndexer
from allennlp.data.tokenizers import Token
from allennlp.data.vocabulary import Vocabulary
from allennlp.models import Model
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.seq2vec_encoders import Seq2VecEncoder, PytorchSeq2VecWrapper
from allennlp.nn.util import get_text_field_mask, masked_log_softmax
from allennlp.training.metrics import Average
from allennlp.data.iterators import BucketIterator, BasicIterator
from allennlp.training.trainer import Trainer
from allennlp.predictors import SentenceTaggerPredictor
from torch.nn.modules import NLLLoss
from torch.nn import LogSoftmax
torch.manual_seed(1)

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


<torch._C.Generator at 0x7f9926d4db70>

In [2]:
from io import open
import glob
import os
from sklearn.model_selection import train_test_split
import unicodedata
import string

def findFiles(path): return glob.glob(path)


all_letters = string.ascii_letters + " .,;'"
n_letters = len(all_letters)

# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
        and c in all_letters
    )

name_category_list = []

# Read a file and split into lines
def readLines(filename):
    lines = open(filename, encoding='utf-8').read().strip().split('\n')
    return [unicodeToAscii(line) for line in lines]

for filename in findFiles('data/names/*.txt'):
    category = os.path.splitext(os.path.basename(filename))[0]
    lines = readLines(filename)
    for name in lines:
        name_category_list.append((name,category))

X = [x[0] for x in name_category_list]
y = [x[1] for x in name_category_list]
x_train, x_val, y_train, y_val = train_test_split(X,y,test_size=.2, stratify=y)
with open('train_data.txt','w+') as f:
    f.write('\n'.join('{} {}'.format(x[0],x[1]) for x in zip(x_train,y_train)))
with open('val_data.txt','w+') as f:
    f.write('\n'.join('{} {}'.format(x[0],x[1]) for x in zip(x_val,y_val)))

In [3]:
class NameLangDatasetReader(DatasetReader):

    def __init__(self, token_indexers: Dict[str, TokenIndexer] = None) -> None:
        super().__init__(lazy=False)
        self.token_indexers = token_indexers or {"characters": SingleIdTokenIndexer()}
        
    def text_to_instance(self, name: str, label: str=None) -> Instance:
        tokens = [Token(ch) for ch in name]
        char_field = TextField(tokens, self.token_indexers)
        fields = {"name": char_field}
        if label is None:
            return Instance(fields)
        
        label_field = LabelField(label=label)
        fields["label"] = label_field
        return Instance(fields)
    
    def _read(self, file_path: str) -> Iterator[Instance]:
        with open(file_path) as f:
            for line in f:
                line = line.strip().split()
                language = line[-1]
                name = ' '.join(line[:-1])
                yield self.text_to_instance(name, language)


In [4]:
reader = NameLangDatasetReader()
train_dataset = reader.read('train_data.txt')
validation_dataset = reader.read('val_data.txt')
vocab = Vocabulary.from_instances(train_dataset + validation_dataset)

16059it [00:00, 43296.89it/s]
4015it [00:00, 71876.37it/s]
100%|██████████| 20074/20074 [00:00<00:00, 172801.07it/s]


In [5]:
train_dataset[0].fields['name'].tokens, train_dataset[0].fields['label'].label

([M, u, h, l, f, e, l, d], 'German')

In [6]:
class NamesClassifier(Model):
    def __init__(self,
                 char_embeddings: TextFieldEmbedder,
                 encoder: Seq2VecEncoder,
                 vocab: Vocabulary) -> None:
        super().__init__(vocab)
        self.char_embeddings = char_embeddings
        self.encoder = encoder
        self.hidden2tag = torch.nn.Linear(in_features=encoder.get_output_dim(),
                                          out_features=vocab.get_vocab_size('labels'))
        self.accuracy = Average()
        self.m = LogSoftmax() # Softmax try and see
        self.loss = NLLLoss()
        
    def forward(self,
                name: Dict[str, torch.Tensor],
                label: torch.Tensor = None) -> Dict[str, torch.Tensor]:
        mask = get_text_field_mask(name)
        embeddings = self.char_embeddings(name)
        encoder_out = self.encoder(embeddings, mask)
        tag_logits = self.hidden2tag(encoder_out)
        output = {"tag_logits": tag_logits}
        if label is not None:
            output["loss"] = self.loss(self.m(tag_logits), label)
            prediction = tag_logits.max(1)[1]
            self.accuracy(prediction.eq(label).double().mean())
        return output
    
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self.accuracy.get_metric(reset)}

In [25]:
from overrides import overrides

from allennlp.common.util import JsonDict
from allennlp.data import DatasetReader, Instance
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
from allennlp.models import Model
from allennlp.predictors.predictor import Predictor

class NamePredictor(Predictor):

    def __init__(self, model: Model, dataset_reader: DatasetReader, language: str = 'en_core_web_sm') -> None:
        super().__init__(model, dataset_reader)

    def predict(self, name: str) -> JsonDict:
        tag_logits = self.predict_json({"name" : name})['tag_logits']
        print(tag_logits)
        max_id = np.argmax(tag_logits, axis=-1)
        return model.vocab.get_token_from_index(max_id, 'labels')
        
    @overrides
    def _json_to_instance(self, json_dict: JsonDict) -> Instance:
        name = json_dict["name"]
        return self._dataset_reader.text_to_instance(name)

In [9]:
EMBEDDING_DIM = 6
HIDDEN_DIM = 6
token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'), embedding_dim=EMBEDDING_DIM)
char_embeddings = BasicTextFieldEmbedder({"characters": token_embedding})
lstm = PytorchSeq2VecWrapper(torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))
model = NamesClassifier(char_embeddings, lstm, vocab)
if torch.cuda.is_available():
    cuda_device = 0
    model = model.cuda(cuda_device)
else:
    cuda_device = -1
optimizer = optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)
iterator = BasicIterator(batch_size=2)
iterator.index_with(vocab)
trainer = Trainer(model=model,
                  optimizer=optimizer,
                  iterator=iterator,
                  train_dataset=train_dataset,
                  validation_dataset=validation_dataset,
                  patience=10,
                  num_epochs=5,
                  cuda_device=cuda_device)
trainer.train()

accuracy: 0.6150, loss: 1.2857 ||: 100%|██████████| 8030/8030 [00:18<00:00, 431.42it/s]
accuracy: 0.6877, loss: 1.0660 ||: 100%|██████████| 2008/2008 [00:01<00:00, 1082.39it/s]
accuracy: 0.6937, loss: 1.0317 ||: 100%|██████████| 8030/8030 [00:18<00:00, 435.26it/s]
accuracy: 0.7099, loss: 0.9925 ||: 100%|██████████| 2008/2008 [00:01<00:00, 1119.62it/s]
accuracy: 0.7152, loss: 0.9632 ||: 100%|██████████| 8030/8030 [00:18<00:00, 442.81it/s]
accuracy: 0.7171, loss: 0.9763 ||: 100%|██████████| 2008/2008 [00:01<00:00, 1089.31it/s]
accuracy: 0.7238, loss: 0.9338 ||: 100%|██████████| 8030/8030 [00:18<00:00, 440.35it/s]
accuracy: 0.7179, loss: 0.9615 ||: 100%|██████████| 2008/2008 [00:01<00:00, 1119.47it/s]
accuracy: 0.7315, loss: 0.9171 ||: 100%|██████████| 8030/8030 [00:18<00:00, 444.52it/s]
accuracy: 0.7328, loss: 0.9254 ||: 100%|██████████| 2008/2008 [00:01<00:00, 1121.81it/s]


{'best_epoch': 4,
 'peak_cpu_memory_MB': 2354.14,
 'peak_gpu_0_memory_MB': 752,
 'training_duration': '00:01:40',
 'training_start_epoch': 0,
 'training_epochs': 4,
 'epoch': 4,
 'training_accuracy': tensor(0.7315, dtype=torch.float64),
 'training_loss': 0.9171358023307392,
 'training_cpu_memory_MB': 2354.14,
 'training_gpu_0_memory_MB': 752,
 'validation_accuracy': tensor(0.7328, dtype=torch.float64),
 'validation_loss': 0.9253557659240358,
 'best_validation_accuracy': tensor(0.7328, dtype=torch.float64),
 'best_validation_loss': 0.9253557659240358}

In [29]:
predictor = NamePredictor(model, dataset_reader=reader)
predictor.predict("Vikhorev")

[7.55558443069458, -1.1911295652389526, -6.244527816772461, -2.878906726837158, -1.9350042343139648, -2.196650743484497, 0.3870698809623718, -1.3938158750534058, -2.520087957382202, -2.6504149436950684, -6.677775859832764, -2.172342538833618, -2.7314114570617676, -0.6685163378715515, -3.9820199012756348, -5.869513511657715, -3.290464162826538, -4.328010559082031]


'Russian'

In [32]:
vocab.print_statistics()



----Vocabulary Statistics----


Top 10 most frequent tokens in namespace 'tokens':
	Token: a		Frequency: 14767
	Token: o		Frequency: 10802
	Token: e		Frequency: 10317
	Token: i		Frequency: 10202
	Token: n		Frequency: 9348
	Token: r		Frequency: 7511
	Token: h		Frequency: 6547
	Token: s		Frequency: 6511
	Token: l		Frequency: 5950
	Token: k		Frequency: 5894

Top 10 longest tokens in namespace 'tokens':
	Token: a		length: 1	Frequency: 14767
	Token: o		length: 1	Frequency: 10802
	Token: e		length: 1	Frequency: 10317
	Token: i		length: 1	Frequency: 10202
	Token: n		length: 1	Frequency: 9348
	Token: r		length: 1	Frequency: 7511
	Token: h		length: 1	Frequency: 6547
	Token: s		length: 1	Frequency: 6511
	Token: l		length: 1	Frequency: 5950
	Token: k		length: 1	Frequency: 5894

Top 10 shortest tokens in namespace 'tokens':
	Token: ,		length: 1	Frequency: 3
	Token: X		length: 1	Frequency: 14
	Token: q		length: 1	Frequency: 38
	Token: x		length: 1	Frequency: 59
	Token: Q		length: 1	Frequency: 60
