In [None]:
%run main.py
%load_ext autoreload
%autoreload 2

# data

In [None]:
# !mkdir -p data/raw
# !wget https://storage.yandexcloud.net/natasha-corus/taiga/Fontanka.tar.gz -P data/raw
# !wget https://storage.yandexcloud.net/natasha-corus/ods/gazeta_v1.csv.zip -P data/raw
# !wget https://storage.yandexcloud.net/natasha-corus/ods/interfax_v1.csv.zip -P data/raw
# !wget https://storage.yandexcloud.net/natasha-corus/lenta-ru-news.csv.gz -P data/raw
# !wget https://storage.yandexcloud.net/natasha-corus/buriy/news-articles-2014.tar.bz2 -P data/raw
# !wget https://storage.yandexcloud.net/natasha-corus/buriy/news-articles-2015-part1.tar.bz2 -P data/raw
# !wget https://storage.yandexcloud.net/natasha-corus/buriy/news-articles-2015-part2.tar.bz2 -P data/raw

In [None]:
# LOADS = {
#     'gazeta_v1.csv.zip': load_ods_gazeta,
#     'interfax_v1.csv.zip': load_ods_interfax,
#     'Fontanka.tar.gz': load_taiga_fontanka,
#     'lenta-ru-news.csv.gz': load_lenta,
#     'news-articles-2015-part1.tar.bz2': load_buriy_news,
#     'news-articles-2015-part2.tar.bz2': load_buriy_news,
#     'news-articles-2014.tar.bz2': load_buriy_news,
# }


# lines = []  # Requires 15Gb RAM
# for name in listdir('data/raw'):
#     path = 'data/raw/' + name
#     records = LOADS[name](path)
#     for record in log_progress(records, desc=name):
#         line = re.sub('\s+', ' ', record.text)
#         lines.append(line)

In [None]:
# seed(1)
# shuffle(lines)

In [None]:
# cap = 1000
# dump_lines(lines[:cap], 'data/test.txt')
# dump_lines(log_progress(lines[cap:]), 'data/train.txt')

In [None]:
# upload('data/test.txt')
# upload('data/train.txt')

In [None]:
if not exists('data/test.txt'):
    download('data/test.txt')
    download('data/train.txt')

# model

In [None]:
if not exists('rubert/vocab.txt'):
    for name in ['vocab.txt', 'emb.pt', 'encoder.pt', 'mlm.pt']:
        download('rubert/' + name)

In [None]:
device = get_device()

In [None]:
items = list(load_lines('rubert/vocab.txt'))
vocab = BERTVocab(items)

In [None]:
config = BERTConfig(
    vocab_size=50106,
    seq_len=512,
    emb_dim=768,
    layers_num=12,
    heads_num=12,
    hidden_dim=3072,
    dropout=0.1,
    norm_eps=1e-12
)
emb = BERTEmbedding(
    config.vocab_size, config.seq_len, config.emb_dim,
    config.dropout, config.norm_eps
)
emb.position.requires_grad = False  # fix pos emb to train on short seqs
encoder = BERTEncoder(
    config.layers_num, config.emb_dim, config.heads_num, config.hidden_dim,
    config.dropout, config.norm_eps
)
mlm = BERTMLMHead(config.emb_dim, config.vocab_size)
model = BERTMLM(emb, encoder, mlm)

load_model(model, 'rubert')
model = model.to(device)

criterion = flatten_cross_entropy

In [None]:
torch.manual_seed(1)
seed(1)

In [None]:
encode = BERTMLMEncoder(
    vocab,
    seq_len=128,
    batch_size=32,
    shuffle_size=10000
)

lines = load_lines('data/test.txt')
batches = encode(lines)
test_batches = [_.to(device) for _ in batches]

lines = load_lines('data/train.txt')
batches = encode(lines)
train_batches = (_.to(device) for _ in batches)

In [None]:
board = Board('01', 'runs')
train_board = board.section('01_train')
test_board = board.section('02_test')

In [None]:
optimizer = optim.Adam(model.parameters(), lr=0.0001)
model, optimizer = amp.initialize(model, optimizer, opt_level='O2')
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 1)

In [None]:
train_meter = MLMScoreMeter()
test_meter = MLMScoreMeter()

accum_steps = 64  # 2K batch
log_steps = 256
eval_steps = 512
save_steps = eval_steps * 10

model.train()
optimizer.zero_grad()

for step, batch in log_progress(enumerate(train_batches)):
    batch = process_batch(model, criterion, batch)
    batch.loss /= accum_steps
    
    with amp.scale_loss(batch.loss, optimizer) as scaled:
        scaled.backward()

    score = score_batch(batch, ks=())
    train_meter.add(score)

    if every(step, log_steps):
        train_meter.write(train_board)
        train_meter.reset()

    if every(step, accum_steps):
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        if every(step, eval_steps):
            batches = infer_batches(model, criterion, test_batches)
            scores = score_batches(batches)
            test_meter.extend(scores)
            test_meter.write(test_board)
            test_meter.reset()
    
    if every(step, save_steps):
        dump_model(model, 'model')
        for name in ['emb.pt', 'encoder.pt', 'mlm.pt']:
            upload('model/' + name)
            
    board.step()