In [9]:
import ast

In [23]:
ast.dump(ast.parse('if x >= 0:\n    x = 0'))

"Module(body=[If(test=Compare(left=Name(id='x', ctx=Load()), ops=[GtE()], comparators=[Num(n=0)]), body=[Assign(targets=[Name(id='x', ctx=Store())], value=Num(n=0))], orelse=[])])"

In [19]:
dir(ast.parse('test_var = test_list[0]').body[0])

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_attributes',
 '_fields',
 'col_offset',
 'lineno',
 'targets',
 'value']

## Импорты

In [2]:
%matplotlib inline
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import os
from tqdm import tqdm_notebook, trange, tqdm
import pickle
import numpy as np
from IPython.display import clear_output

from collections import defaultdict

In [3]:
base_path = '/Users/dentarasov/Yandex.Disk.localized/current/vkr'
data_path = base_path + '/after_preprocess/all_inputs'

with open(data_path, 'rb') as f:
    data = pickle.load(f)


In [4]:
data.shape

(695687, 100)

In [5]:
split_num = 600000
train, test = data[:split_num], data[split_num:]

In [6]:
N = 3

In [7]:
ngrams_dict = defaultdict(int)

In [8]:
for seq in tqdm_notebook(train):
    for i in range(len(seq) - N):
        ngrams_dict[tuple(seq[i:i+N])] += 1

HBox(children=(IntProgress(value=0, max=600000), HTML(value='')))




## Что стоит сделать
* Убрать `<PAD>`

In [34]:
def filter_most_popular_ngrams(ngrams_dict):
    # ngram_list: [[(1, 231, 12), 5], ...]
    ngrams_list = list(ngrams_dict.items())
    def sort_fun(ngram):
        return ngram[0][:2], ngram[1]
    ngrams_list = sorted(ngrams_list, key=sort_fun, reverse=True)
    prev = None
    new_ngrams_dict = {}
    for ngram in tqdm(ngrams_list):
        cur = ngram[0][:2]
        if prev != cur:
            new_ngrams_dict[tuple(cur)] = ngram[0][2]
        prev = cur
    return new_ngrams_dict

In [35]:
new_ngrams_dict = filter_most_popular_ngrams(ngrams_dict)

100%|██████████| 6566758/6566758 [00:06<00:00, 990502.79it/s] 


In [36]:
with open('./ngrams_dict.pickle', 'wb') as f:
    pickle.dump(new_ngrams_dict, f)

In [38]:
list(new_ngrams_dict.items())[:10]

[((79583, 14), 3),
 ((79583, 9), 2),
 ((79583, 5), 14),
 ((79583, 3), 2),
 ((79583, 2), 2),
 ((79581, 4), 1),
 ((79581, 3), 42671),
 ((79581, 2), 15),
 ((79580, 14), 7),
 ((79580, 9), 13)]

In [44]:
def calc_quality(ngrams_dict):
    tp = 0
    count = 0
    for seq in tqdm_notebook(test):
        for i in range(len(seq)-N):
            # to avoid considering empty tokens
            if seq[i+1] != 0 and seq[i+2] != 0:
                if (seq[i], seq[i+1]) in ngrams_dict \
                        and seq[i+2] == ngrams_dict[(seq[i], seq[i+1])]:
                    tp += 1
                count += 1
    return tp / count
        

In [45]:
calc_quality(new_ngrams_dict)

HBox(children=(IntProgress(value=0, max=95687), HTML(value='')))




0.43990673200777963

In [43]:
test.shape[0] * (test.shape[1] - N)

9281639

# Предобработка

In [28]:
def build_vocab():
    # load pre-computed vocab
    with open(base_path + '/mapping.map', 'rb') as f:
        word_to_id = pickle.load(f)
    id_to_word = dict([(v, k) for (k, v) in word_to_id.items()])
    return word_to_id, id_to_word

word_to_id, id_to_word = build_vocab()


In [98]:
base_path = '/Users/dentarasov/Yandex.Disk.localized/current/vkr'
code_batched = base_path + '/after_preprocess/after_preprocess.part0'
params = {
    'batch_size': 64,
    'emb_size': 150,
    'vocab_size': len(word_to_id),
}


In [4]:
def get_code_data(path):
    for _, _, files in os.walk(path):
        data = np.empty((0, 100))
        for file in tqdm_notebook(files):
            with open(path + file, 'rb') as f:
                array = pickle.load(f)
                for i in tqdm_notebook(range(len(array)), leave=False):
                    data = np.concatenate((data, array[i].inputs), axis=0)
    return data


In [5]:
data = get_code_data(base_path + '/after_preprocess/')

HBox(children=(IntProgress(value=0), HTML(value='')))

HBox(children=(IntProgress(value=0, max=135), HTML(value='')))

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

HBox(children=(IntProgress(value=0, max=129), HTML(value='')))

HBox(children=(IntProgress(value=0, max=92), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, max=79), HTML(value='')))

HBox(children=(IntProgress(value=0, max=332), HTML(value='')))

HBox(children=(IntProgress(value=0, max=268), HTML(value='')))

HBox(children=(IntProgress(value=0, max=248), HTML(value='')))

HBox(children=(IntProgress(value=0, max=84), HTML(value='')))

HBox(children=(IntProgress(value=0, max=9), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

HBox(children=(IntProgress(value=0, max=87), HTML(value='')))

HBox(children=(IntProgress(value=0, max=124), HTML(value='')))

HBox(children=(IntProgress(value=0, max=141), HTML(value='')))

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

HBox(children=(IntProgress(value=0, max=238), HTML(value='')))

HBox(children=(IntProgress(value=0, max=259), HTML(value='')))

HBox(children=(IntProgress(value=0, max=348), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1595), HTML(value='')))

HBox(children=(IntProgress(value=0, max=6956), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1255), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1392), HTML(value='')))

HBox(children=(IntProgress(value=0, max=6956), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1739), HTML(value='')))

HBox(children=(IntProgress(value=0, max=983), HTML(value='')))

HBox(children=(IntProgress(value=0, max=773), HTML(value='')))

HBox(children=(IntProgress(value=0, max=482), HTML(value='')))

HBox(children=(IntProgress(value=0, max=464), HTML(value='')))

HBox(children=(IntProgress(value=0, max=63), HTML(value='')))

HBox(children=(IntProgress(value=0, max=24), HTML(value='')))

HBox(children=(IntProgress(value=0, max=167), HTML(value='')))

HBox(children=(IntProgress(value=0, max=27), HTML(value='')))

HBox(children=(IntProgress(value=0, max=43), HTML(value='')))

HBox(children=(IntProgress(value=0, max=221), HTML(value='')))

HBox(children=(IntProgress(value=0, max=38), HTML(value='')))

HBox(children=(IntProgress(value=0, max=182), HTML(value='')))

HBox(children=(IntProgress(value=0, max=435), HTML(value='')))

HBox(children=(IntProgress(value=0, max=502), HTML(value='')))

HBox(children=(IntProgress(value=0, max=828), HTML(value='')))

HBox(children=(IntProgress(value=0, max=870), HTML(value='')))

HBox(children=(IntProgress(value=0, max=41), HTML(value='')))

HBox(children=(IntProgress(value=0, max=175), HTML(value='')))

HBox(children=(IntProgress(value=0, max=229), HTML(value='')))

HBox(children=(IntProgress(value=0, max=160), HTML(value='')))

HBox(children=(IntProgress(value=0, max=45), HTML(value='')))

HBox(children=(IntProgress(value=0, max=30), HTML(value='')))

HBox(children=(IntProgress(value=0, max=22), HTML(value='')))

HBox(children=(IntProgress(value=0, max=60), HTML(value='')))

HBox(children=(IntProgress(value=0, max=279), HTML(value='')))

HBox(children=(IntProgress(value=0, max=318), HTML(value='')))

HBox(children=(IntProgress(value=0, max=114), HTML(value='')))

HBox(children=(IntProgress(value=0, max=69), HTML(value='')))

HBox(children=(IntProgress(value=0, max=153), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=75), HTML(value='')))

HBox(children=(IntProgress(value=0, max=104), HTML(value='')))

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

HBox(children=(IntProgress(value=0, max=303), HTML(value='')))

HBox(children=(IntProgress(value=0, max=291), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0), HTML(value='')))

HBox(children=(IntProgress(value=0, max=73), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=108), HTML(value='')))

HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

HBox(children=(IntProgress(value=0, max=119), HTML(value='')))

HBox(children=(IntProgress(value=0, max=66), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3322), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2319), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2168), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3675), HTML(value='')))

HBox(children=(IntProgress(value=0, max=36), HTML(value='')))

HBox(children=(IntProgress(value=0, max=51), HTML(value='')))

HBox(children=(IntProgress(value=0, max=54), HTML(value='')))

HBox(children=(IntProgress(value=0, max=17), HTML(value='')))

HBox(children=(IntProgress(value=0, max=189), HTML(value='')))

HBox(children=(IntProgress(value=0, max=12), HTML(value='')))

HBox(children=(IntProgress(value=0, max=213), HTML(value='')))

HBox(children=(IntProgress(value=0, max=715), HTML(value='')))

HBox(children=(IntProgress(value=0, max=363), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1027), HTML(value='')))

HBox(children=(IntProgress(value=0, max=397), HTML(value='')))

HBox(children=(IntProgress(value=0, max=633), HTML(value='')))

HBox(children=(IntProgress(value=0, max=568), HTML(value='')))

HBox(children=(IntProgress(value=0, max=15), HTML(value='')))

HBox(children=(IntProgress(value=0, max=204), HTML(value='')))

HBox(children=(IntProgress(value=0, max=196), HTML(value='')))

HBox(children=(IntProgress(value=0, max=19), HTML(value='')))

HBox(children=(IntProgress(value=0, max=57), HTML(value='')))

HBox(children=(IntProgress(value=0, max=48), HTML(value='')))

HBox(children=(IntProgress(value=0, max=33), HTML(value='')))

HBox(children=(IntProgress(value=0, max=536), HTML(value='')))

HBox(children=(IntProgress(value=0, max=414), HTML(value='')))

HBox(children=(IntProgress(value=0, max=598), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1160), HTML(value='')))

HBox(children=(IntProgress(value=0, max=688), HTML(value='')))

HBox(children=(IntProgress(value=0, max=381), HTML(value='')))




In [7]:
data.shape

(695687, 100)

In [33]:
data.astype('int64')

array([[   22,    46,     2, ..., 14539,     6,     1],
       [    7,     8,     2, ...,    11,    10,     6],
       [14539,     6,   804, ...,    10,     5,     9],
       ...,
       [    6, 16827,     5, ...,     9,     2,    11],
       [ 3703,     6,   167, ...,  5240,     3,  2605],
       [    5,     2,  1001, ...,     0,     0,     0]])

In [34]:
# with open('./after_preprocess/all_inputs', 'wb') as f:
#     pickle.dump(data.astype('int64'), f)

In [53]:
data_path = base_path + '/after_preprocess/all_inputs'

In [113]:
class CodeDataset(Dataset):
    def __init__(self, path, params):
        with open(path, 'rb') as f:
            self.data = pickle.load(f)
#         self.batch_size = params['batch_size']

    def __len__(self):
#         # return number of batches
#         return int(np.ceil(self.data.shape[0] // self.batch_size))
        return self.data.shape[0]
    
    def __getitem__(self, index):
#         return self.data[index * self.batch_size : (index+1) * self.batch_size]
        return self.data[index]


In [114]:
train_loader = DataLoader(
    CodeDataset(data_path, params),
    batch_size=params['batch_size'],
    shuffle=True,
    num_workers=10
)

In [115]:
data[0:64].shape

(64, 100)

In [116]:
for i in train_loader:
    print(type(i))
    print(i.shape)
    break

<class 'torch.Tensor'>
torch.Size([64, 100])


# Инициализация модели

In [121]:
class LstmBaseline(nn.Module):
    def __init__(self, params):
        super().__init__()
        self.params = params
        self.embedding = nn.Embedding(
            self.params['vocab_size'],
            self.params['emb_size']
        )
        self.lstm = nn.LSTM(
            self.params['emb_size'],
            self.params['emb_size'],
            2
        )
    
    def init_hidden(self, batch_size):
        pass

    def forward(self, inputs, hidden):
        embs = self.embedding(inputs)
        print(hidden.shape)
        print(embs.shape)
        output = self.lstm(embs, hidden)
        return output


# Обучение

In [122]:
def train_epoch(model, optimizer, lr):
    loss_log = []
    model.train()
    
    for batch_num, x in zip(trange(len(train_loader)), train_loader):
        optimizer.zero_grad()
        print(x.shape)
#         x, y = x.to(device), y.to(device)
        params['batch_size'] = x.shape[0]
#         p = torch.Tensor(generate_mask(params)).to(device)
        hidden = torch.zeros(2, params['batch_size'], params["emb_size"]) #, device=device)
        output = model.forward(x, hidden)
        loss_value = 0
        loss = nn.CrossEntropyLoss()
        for i in range(x.shape[1]):
            loss_value += loss(output.float()[:, i], y[:, i])
        loss_value.backward()
        torch.nn.utils.clip_grad_norm(model.parameters(), 0.5)
        for p in model.parameters():
            p.data.add_(-lr, p.grad.data)
        optimizer.step()
        
        loss_value = loss_value.item()
        loss_log.append(loss_value / x.shape[1])
    perp = np.exp(np.mean(loss_log))
    return loss_log, perp

def test(model, test_batches):
    loss_log = []
    model.eval()
    for batch_num, x in zip(trange(len(train_loader)), val_loader):        
        hidden = model.init_hidden(batch.shape[0])
        loss = 0
        output = model.forward(x, y)
        loss = F.cross_entropy(output.float(), y.float())
        loss = loss.item()
        loss_log.append(loss / nums.shape[0])
    return loss_log

def plot_history(train_history, title='loss'):
    plt.figure()
    plt.title('{}'.format(title))
    plt.plot(train_history, label='train', zorder=1)
    plt.xlabel('train steps')
    plt.legend(loc='best')
    plt.grid()
    plt.show()
    
def train(model, opt, n_epochs):
    train_log = []
    val_log = []
    perp_log = []
    lr = 0.05
    lr_decay_base = 1 / 1.15
    m_flat_lr = 20.0
    for epoch in range(n_epochs):
        lr_decay = lr_decay_base ** max(epoch - m_flat_lr, 0)
        lr = lr * lr_decay
        train_loss, perp = train_epoch(model, opt, lr)
        train_log.extend(train_loss)
        perp_log.append(perp)
        if (epoch + 1) % 10 == 0:
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': opt.state_dict(),
                    'loss': train_log[-1]
                },
                './model_checkpoints/lstm_baseline_checkpoint_{}.pt'.format(epoch)
            )
        clear_output()
        print("Epoch:{}".format(epoch))
        plot_history(train_log)
        plot_history(perp_log)
    torch.save(
        {
            'epoch': n_epochs,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': opt.state_dict(),
            'loss': train_log[-1]
        },
        './model_checkpoints/lstm_baseline_checkpoint_{}.pt'.format(n_epochs)
    )
#     np.save("/home/.../model_checkpoints..._logs.npy", np.array(train_log))
#     np.save("/home/.../model_checkpoints..._logs_perp.npy", np.array(perp_log))
    

In [123]:
model = LstmBaseline(params)

In [124]:
%%time

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
train(model, optimizer, 10)

  0%|          | 0/10871 [00:00<?, ?it/s]

torch.Size([64, 100])
torch.Size([2, 64, 150])
torch.Size([64, 100, 150])





RuntimeError: Expected hidden[0] size (2, 100, 150), got (64, 150)