In [1]:
from collections import namedtuple

# other packages
import pandas as pd
# pytorch
import torch
import torch.optim as optim
import tqdm
from allennlp.data.dataset_readers.seq2seq import Seq2SeqDatasetReader
from allennlp.data.iterators import BucketIterator
from allennlp.data.token_indexers import SingleIdTokenIndexer
from allennlp.data.tokenizers.character_tokenizer import CharacterTokenizer
from allennlp.data.vocabulary import Vocabulary
# model
from allennlp.models.encoder_decoders.simple_seq2seq import SimpleSeq2Seq
# attention
from allennlp.modules.attention import AdditiveAttention
from allennlp.modules.seq2seq_encoders import StackedSelfAttentionEncoder
# encoder
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.predictors import SimpleSeq2SeqPredictor
from allennlp.training.trainer import Trainer
# preprocessing
from imblearn.under_sampling import RandomUnderSampler
from sklearn.model_selection import train_test_split

In [2]:
torch.cuda.is_available()

True

In [3]:


Config = namedtuple('Config', [
                    'lazy',
                    'max_vocab_size', 
                    'batch_size', 
                    'epochs', 
                    'max_seq_len', 
                    'IN_EMBEDDING_DIM', 
                    'HIDDEN_DIM', 
                    'OUT_EMBEDDING_DIM', 
                    'CUDA_DEVICE',
                    'n_samples'
])



# create config file

In [4]:
config = Config(False,
               10000,
               64,
               20,
               100,
               64,
               32,
               64,
               0 if torch.cuda.is_available() else -1, 
               60000)

# prepare data

In [5]:
raw_dataset = pd.read_csv('data/ru_train.csv')

raw_dataset['before'] = raw_dataset['before'].astype(str)
raw_dataset['after'] = raw_dataset['after'].astype(str)

d = raw_dataset['class'].value_counts().to_dict()

for i in d:
    if d[i] > config.n_samples:
        d[i] = config.n_samples

In [6]:
d

{'PLAIN': 60000,
 'PUNCT': 60000,
 'CARDINAL': 60000,
 'LETTERS': 60000,
 'DATE': 60000,
 'VERBATIM': 60000,
 'ORDINAL': 46738,
 'MEASURE': 40534,
 'TELEPHONE': 10088,
 'DECIMAL': 7297,
 'ELECTRONIC': 5832,
 'MONEY': 2690,
 'FRACTION': 2460,
 'DIGIT': 2012,
 'TIME': 1945}

In [7]:


rus = RandomUnderSampler(sampling_strategy=d, random_state=0)

raw_dataset_resampled, _ = rus.fit_resample(raw_dataset, raw_dataset['class'])



In [8]:
df = raw_dataset_resampled[['before', 'after']]

df_train, df_test = train_test_split(df, test_size=0.1)

df_train.to_csv('train_dataset.tsv', index=False, header=False, sep='\t')
df_test.to_csv('test_dataset.tsv', index=False, header=False, sep='\t')

# create reader

In [9]:
reader = Seq2SeqDatasetReader(
    source_tokenizer = CharacterTokenizer(),
    target_tokenizer = CharacterTokenizer(),
    source_token_indexers={'tokens': SingleIdTokenIndexer()},
    target_token_indexers={'tokens': SingleIdTokenIndexer(namespace='target_tokens')},
    lazy=config.lazy
)

train_dataset = reader.read('train_dataset.tsv',)
validation_dataset = reader.read('test_dataset.tsv')

431636it [00:21, 20165.96it/s]
47960it [00:01, 41660.54it/s]


# prepare vocabulary

In [10]:
vocab = Vocabulary.from_instances(train_dataset,
                                  min_count={'tokens': 3, 'target_tokens': 3}
                                 )

in_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                         embedding_dim=config.IN_EMBEDDING_DIM
                        )

source_embedder = BasicTextFieldEmbedder({"tokens": in_embedding})

iterator = BucketIterator(batch_size=config.batch_size, 
                          sorting_keys=[("source_tokens", "num_tokens")],
                         )

iterator.index_with(vocab)

100%|██████████| 431636/431636 [00:05<00:00, 84312.93it/s]


# prepare model

In [11]:
## get simple encoder
encoder = StackedSelfAttentionEncoder(input_dim=config.IN_EMBEDDING_DIM, 
                                      hidden_dim=config.HIDDEN_DIM, 
                                      projection_dim=64, 
                                      feedforward_hidden_dim=64, 
                                      num_layers=3, 
                                      num_attention_heads=8)

attention = AdditiveAttention(config.HIDDEN_DIM, config.HIDDEN_DIM)

In [12]:
max_decoding_steps = 100

model = SimpleSeq2Seq(vocab, 
                      source_embedder, 
                      encoder, 
                      max_decoding_steps,
                      target_embedding_dim=config.OUT_EMBEDDING_DIM,
                      target_namespace='target_tokens',
                      beam_size=8,
                      use_bleu=False,
                      attention=attention,
                      scheduled_sampling_ratio = 0.15)

In [13]:
if torch.cuda.is_available():
    model.cuda(config.CUDA_DEVICE)

In [14]:
optimizer = optim.Adam(model.parameters())

# training

In [15]:
torch.cuda.is_available()

True

In [16]:
trainer = Trainer(model=model,
                  optimizer=optimizer,
                  iterator=iterator,
                  train_dataset=train_dataset,
                  validation_dataset=validation_dataset,
                  num_epochs=1,
                  cuda_device=config.CUDA_DEVICE,
                  patience=2)

In [17]:
print(f'Will train for {config.epochs} epochs')
for i in range(config.epochs):
    print(f'Epoch: {i+1}')
    trainer.train()

  0%|          | 0/6745 [00:00<?, ?it/s]

Will train for 20 epochs
Epoch: 1


loss: 1.4819 ||: 100%|██████████| 6745/6745 [07:33<00:00, 14.89it/s]
loss: 0.7548 ||: 100%|██████████| 750/750 [01:12<00:00, 10.35it/s]
  0%|          | 0/6745 [00:00<?, ?it/s]

Epoch: 2


loss: 0.7479 ||: 100%|██████████| 6745/6745 [07:17<00:00, 15.42it/s]
loss: 0.4563 ||: 100%|██████████| 750/750 [01:06<00:00, 11.26it/s]
  0%|          | 0/6745 [00:00<?, ?it/s]

Epoch: 3


loss: 0.5524 ||: 100%|██████████| 6745/6745 [07:18<00:00, 15.37it/s]
loss: 0.3196 ||: 100%|██████████| 750/750 [01:10<00:00, 10.71it/s]
  0%|          | 0/6745 [00:00<?, ?it/s]

Epoch: 4


loss: 0.4351 ||: 100%|██████████| 6745/6745 [07:19<00:00, 15.35it/s]
loss: 0.2452 ||: 100%|██████████| 750/750 [01:06<00:00, 11.26it/s]
  0%|          | 0/6745 [00:00<?, ?it/s]

Epoch: 5


loss: 0.3685 ||: 100%|██████████| 6745/6745 [07:20<00:00, 15.31it/s]
loss: 0.2127 ||: 100%|██████████| 750/750 [01:07<00:00, 11.19it/s]
  0%|          | 0/6745 [00:00<?, ?it/s]

Epoch: 6


loss: 0.3284 ||: 100%|██████████| 6745/6745 [07:19<00:00, 15.35it/s]
loss: 0.1756 ||: 100%|██████████| 750/750 [01:07<00:00, 11.10it/s]
  0%|          | 0/6745 [00:00<?, ?it/s]

Epoch: 7


loss: 0.3003 ||: 100%|██████████| 6745/6745 [07:42<00:00, 14.57it/s]
loss: 0.1609 ||: 100%|██████████| 750/750 [01:21<00:00,  9.20it/s]
  0%|          | 0/6745 [00:00<?, ?it/s]

Epoch: 8


loss: 0.2797 ||: 100%|██████████| 6745/6745 [07:58<00:00, 14.09it/s]
loss: 0.1464 ||: 100%|██████████| 750/750 [01:15<00:00,  9.87it/s]
  0%|          | 0/6745 [00:00<?, ?it/s]

Epoch: 9


loss: 0.2641 ||: 100%|██████████| 6745/6745 [08:44<00:00, 12.85it/s]
loss: 0.1422 ||: 100%|██████████| 750/750 [01:19<00:00,  9.44it/s]
  0%|          | 0/6745 [00:00<?, ?it/s]

Epoch: 10


loss: 0.2494 ||: 100%|██████████| 6745/6745 [07:28<00:00, 15.05it/s]
loss: 0.1342 ||: 100%|██████████| 750/750 [00:56<00:00, 13.30it/s]
  0%|          | 0/6745 [00:00<?, ?it/s]

Epoch: 11


loss: 0.2379 ||: 100%|██████████| 6745/6745 [07:17<00:00, 15.42it/s]
loss: 0.1382 ||: 100%|██████████| 750/750 [01:00<00:00, 12.38it/s]
  0%|          | 0/6745 [00:00<?, ?it/s]

Epoch: 12


loss: 0.2299 ||: 100%|██████████| 6745/6745 [07:18<00:00, 15.39it/s]
loss: 0.1230 ||: 100%|██████████| 750/750 [00:58<00:00, 12.84it/s]
  0%|          | 0/6745 [00:00<?, ?it/s]

Epoch: 13


loss: 0.2209 ||: 100%|██████████| 6745/6745 [07:25<00:00, 15.13it/s]
loss: 0.1270 ||: 100%|██████████| 750/750 [01:01<00:00, 12.23it/s]
  0%|          | 0/6745 [00:00<?, ?it/s]

Epoch: 14


loss: 0.2163 ||: 100%|██████████| 6745/6745 [07:23<00:00, 15.19it/s]
loss: 0.1148 ||: 100%|██████████| 750/750 [01:01<00:00, 12.25it/s]
  0%|          | 0/6745 [00:00<?, ?it/s]

Epoch: 15


loss: 0.2084 ||: 100%|██████████| 6745/6745 [07:35<00:00, 14.81it/s]
loss: 0.1143 ||: 100%|██████████| 750/750 [01:01<00:00, 12.28it/s]
  0%|          | 0/6745 [00:00<?, ?it/s]

Epoch: 16


loss: 0.2052 ||: 100%|██████████| 6745/6745 [07:32<00:00, 14.90it/s]
loss: 0.1103 ||: 100%|██████████| 750/750 [01:01<00:00, 12.21it/s]
  0%|          | 0/6745 [00:00<?, ?it/s]

Epoch: 17


loss: 0.2001 ||: 100%|██████████| 6745/6745 [07:18<00:00, 15.37it/s]
loss: 0.1272 ||: 100%|██████████| 750/750 [00:59<00:00, 12.68it/s]
  0%|          | 0/6745 [00:00<?, ?it/s]

Epoch: 18


loss: 0.1956 ||: 100%|██████████| 6745/6745 [07:24<00:00, 15.19it/s]
loss: 0.1071 ||: 100%|██████████| 750/750 [00:53<00:00, 13.91it/s]
  0%|          | 0/6745 [00:00<?, ?it/s]

Epoch: 19


loss: 0.1914 ||: 100%|██████████| 6745/6745 [07:24<00:00, 15.16it/s]
loss: 0.1015 ||: 100%|██████████| 750/750 [00:55<00:00, 13.59it/s]
  0%|          | 0/6745 [00:00<?, ?it/s]

Epoch: 20


loss: 0.1886 ||: 100%|██████████| 6745/6745 [07:21<00:00, 15.29it/s]
loss: 0.0977 ||: 100%|██████████| 750/750 [00:58<00:00, 12.81it/s]


In [18]:
with open("Transformer.th", 'wb') as f:
    torch.save(model.state_dict(), f)
    
tqdm.tqdm.pandas()

predictor = SimpleSeq2SeqPredictor(model, reader)

del raw_dataset, df, train_dataset, trainer

  from pandas import Panel


In [19]:
kaggle_test = pd.read_csv('data/ru_test_2.csv')
kaggle_test['before'] = kaggle_test['before'].astype(str)

small_kaggle = pd.DataFrame(kaggle_test['before'].unique(), columns=['before'])

In [20]:
small_kaggle['after'] = small_kaggle['before'].progress_apply(lambda x: ''.join(predictor.predict(x.lower())['predicted_tokens']))

100%|██████████| 175991/175991 [1:08:20<00:00, 42.92it/s]


In [21]:
small_kaggle = small_kaggle.set_index('before')

In [22]:
kaggle_test = kaggle_test.join(small_kaggle, on='before')
kaggle_test['id'] = kaggle_test['sentence_id'].astype(str)+'_'+kaggle_test['token_id'].astype(str)

In [23]:
kaggle_test[['id', 'after']].to_csv('sub.csv', index=False)