In [1]:
!cp drive/My\ Drive/query_popularity.csv.gz query_popularity.csv.gz

In [2]:
!gzip -d query_popularity.csv.gz

In [3]:
import pandas as pd
import numpy as np

In [4]:
data = pd.read_csv('query_popularity.csv')

In [5]:
# отбираем только самые частые запросы для избегания опечаток
train_data = data[data['query_popularity'] > 6]['query'].dropna()

In [6]:
train_data = train_data.apply(lambda x: x.lower()).apply(list)

In [7]:
# убираем данные со слишком большой и маленькой длиной
train_data = train_data[(train_data.apply(len) >= 3) & (train_data.apply(len) < 100)]

In [8]:
max_seq_len = 100

In [9]:
# составляем словари токенов
all_tokens = '1234567890абвгдеёжзийклмнопрстуфхцчшщъыьэяюabcdefghijkolmnpqrstuuvwxyz '
token_dict = {'<pad>' : 0, '<unk>': 1, '<eos>': 2, '<sos>': 3}
idx2token = {0: '<pad>', 1: '<unk>', 2: '<eos>', 3: '<sos>'}
for i, token in enumerate(all_tokens):
  token_dict[token] = i+4
  idx2token[i+4] = token

In [10]:
def process(x):
  return np.array([3]+[token_dict[i] if i in token_dict else 0 for i in x] + [2])

In [11]:
train_data = train_data.apply(process)

In [12]:
train_data0 = np.zeros((train_data.shape[0], 101), dtype=int)

In [13]:
train_data = np.array(train_data)

In [14]:
for i, seq in enumerate(train_data):
  train_data0[i][:len(seq)] += train_data[i]
train_data = train_data0

In [15]:
import torch
train_data = torch.tensor(train_data)

In [16]:
dataset = torch.utils.data.TensorDataset(train_data)

In [17]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [18]:
from torch import nn
class LSTM(nn.Module):
    def __init__(self, vocab_size, emb_size, hidden_size):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_size)
        self.lstm = nn.LSTM(emb_size, hidden_size, batch_first=True)
        self.pred = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, input_ids):
        embs = self.emb(input_ids)
        output, _ = self.lstm(embs)
        return self.pred(output)

In [19]:
import torch
device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')

In [20]:
model = LSTM(len(token_dict), 128, 128).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [21]:
criterion = nn.CrossEntropyLoss(ignore_index=0)

In [22]:
epochs = 35
for epoch in range(epochs):
  model.train()
  for [batch] in train_dataloader:
    optimizer.zero_grad()
    X = batch.to(device)
    predictions = model(X[:, :-1])
    loss = criterion(
    predictions.reshape(-1, predictions.size(-1)),
    X[:, 1:].reshape(-1))
    loss.backward()
    optimizer.step()
  model.eval()
  mean_loss = 0
  num_iter=0
  for [batch] in train_dataloader:
    X = batch.to(device)
    with torch.no_grad():
      predictions = model(X[:, :-1])
      loss = criterion(
          predictions.reshape(-1, predictions.size(-1)),
          X[:, 1:].reshape(-1)
      )
      mean_loss += loss.item()
      num_iter += 1
  mean_loss /= num_iter
  print(f"Epoch: {epoch}; mean loss: {mean_loss}; perplexity: {np.exp(mean_loss)}")            

Epoch: 0; mean loss: 1.6276372534580557; perplexity: 5.091829791353483
Epoch: 1; mean loss: 1.5029859511684556; perplexity: 4.49509117411035
Epoch: 2; mean loss: 1.4426678734700489; perplexity: 4.2319711332627135
Epoch: 3; mean loss: 1.4074723210624869; perplexity: 4.085615217708377
Epoch: 4; mean loss: 1.3819433818978502; perplexity: 3.982633890299581
Epoch: 5; mean loss: 1.3658097755206402; perplexity: 3.9188951928535
Epoch: 6; mean loss: 1.3507953709975848; perplexity: 3.8604948355456563
Epoch: 7; mean loss: 1.3408454711559608; perplexity: 3.8222737618454006
Epoch: 8; mean loss: 1.3297981005657762; perplexity: 3.780280074106926
Epoch: 9; mean loss: 1.3215261548190969; perplexity: 3.7491387795464814
Epoch: 10; mean loss: 1.312867822500916; perplexity: 3.71681761569235
Epoch: 11; mean loss: 1.307631321902964; perplexity: 3.697405368505781
Epoch: 12; mean loss: 1.302411126489875; perplexity: 3.6781544803759774
Epoch: 13; mean loss: 1.2979079632818018; perplexity: 3.6616283882184186
Epo

In [25]:
# Функция для дополнения слова
def predict(x, one_word = True):
  x0 = x
  x = [3] + [token_dict[i] for i in x]
  ans = x + [0]*100
  for i in range(len(x), 100):
    pred = model(torch.tensor(ans, dtype=int).to(device).unsqueeze(0)).argmax(2).squeeze(0)
    ans[i] += int(pred[i-1])
  ans = [idx2token[i] for i in ans]
  if '<eos>' in ans:
    eos = ans.index('<eos>')
    ans = ans[:eos]
  if one_word:
    init_len = len(x0.split())
    ans = ''.join(ans)
    ans = ans.split()
    ans = ans[:init_len]
    ans = ' '.join(ans)
  else:
    ans = ''.join(ans)
  return ans

In [26]:
predict('кур', False)

'<sos>куртка женская с капюшоном'

In [43]:
# Функция для генерации нескольких возможных слов (в финальном решении не использовалась)
def predict(x, one_word = True):
  x0 = x
  x = [3] + [token_dict[i] for i in x]
  ans = np.zeros((10, 100), dtype=int)
  ans [:, :len(x)] += x
  pred0 = model(torch.tensor(ans[0], dtype=int).to(device).unsqueeze(0)).squeeze(0)[len(x)-1, :]
  pred0 = pred0.argsort(descending=True)[:10]
  ans[:, len(x)] += pred0.detach().cpu().numpy()
  for i in range(len(x)+1, 100):
    pred = model(torch.tensor(ans, dtype=int).to(device)).argmax(2)
    ans[:, i] += pred[:, i-1].detach().cpu().numpy()
  new_ans = np.zeros(ans.shape, dtype=np.object)
  for i in range(len(ans)):
    for j in range(len(ans[i])):
      new_ans[i, j] = idx2token[ans[i, j]]
  new_ans = new_ans[:, 1:]
  new_ans = new_ans.tolist()
  for i in range(len(new_ans)):
    if '<eos>' in new_ans[i]:
      eos = new_ans[i].index('<eos>')
      ans = new_ans[i][:eos]
    if one_word:
      init_len = len(x0.split())
      new_ans[i] = ''.join(new_ans[i])
      new_ans[i] = new_ans[i].split()
      new_ans[i] = new_ans[i][:init_len]
      new_ans[i] = ' '.join(new_ans[i])
    else:
      new_ans[i] = ''.join(new_ans[i])
  return new_ans

In [45]:
predict('кур')

['куртка',
 'куральная',
 'куро',
 'куринатор',
 'курка',
 'курчка',
 'курбория',
 'курх',
 'курница',
 'курм']

In [46]:
predict('iph')

['iphone',
 'iphane',
 'iphin',
 'iphen',
 'iphland',
 'iphda',
 'iphune',
 'iphynex',
 'iph',
 'iph400<eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos><eos>']

In [48]:
predict('но')

['носки',
 'нож',
 'ночная',
 'новогодние',
 'норки',
 'ногтей',
 'номера',
 'ноутбук',
 'нотушка',
 'ноевочная']

In [50]:
torch.save(model.state_dict(), 'wildhack_lstm1.pth')