In [None]:
%run main.py
%load_ext autoreload
%autoreload 2

!mkdir -p {DATA_DIR} {NAVEC_DIR} {MODEL_DIR}
s3 = S3()

In [None]:
if not exists(NERUS):
    s3.download(S3_NERUS, NERUS)
    s3.download(S3_RELS_VOCAB, RELS_VOCAB)
    
if not exists(NAVEC):
    !wget {NAVEC_URL} -O {NAVEC}

In [None]:
navec = Navec.load(NAVEC)

In [None]:
words_vocab = Vocab(navec.vocab.words)
shapes_vocab = Vocab([PAD] + SHAPES)

In [None]:
# lines = load_gz_lines(NERUS)
# lines = log_progress(lines, total=NERUS_TOTAL)
# items = parse_jl(lines)
# markups = (SyntaxMarkup.from_json(_) for _ in items)

# rels = set()
# for markup in markups:
#     for token in markup.tokens:
#         rels.add(token.rel)
            
# rels = [PAD] + sorted(rels)
# rels_vocab = Vocab(rels)
# rels_vocab.dump(RELS_VOCAB)

rels_vocab = Vocab.load(RELS_VOCAB)

In [None]:
torch.manual_seed(SEED)
seed(SEED)

In [None]:
word = NavecEmbedding(navec)
shape = Embedding(
    vocab_size=len(shapes_vocab),
    dim=SHAPE_DIM,
    pad_id=shapes_vocab.pad_id
)
emb = SyntaxEmbedding(word, shape)
encoder = SyntaxEncoder(
    input_dim=emb.dim,
    layer_dims=LAYER_DIMS,
    kernel_size=KERNEL_SIZE,
)
head = SyntaxHead(
    input_dim=encoder.dim,
    hidden_dim=encoder.dim // 2,
)
rel = SyntaxRel(
    input_dim=encoder.dim,
    hidden_dim=encoder.dim // 2,
    rel_dim=len(rels_vocab)
)
model = Syntax(emb, encoder, head, rel)

model = model.to(DEVICE)

criterion = masked_flatten_cross_entropy

In [None]:
lines = load_gz_lines(NERUS)
lines = log_progress(lines, total=NERUS_TOTAL)
items = parse_jl(lines)
markups = (SyntaxMarkup.from_json(_) for _ in items)

encode = SyntaxTrainEncoder(
    words_vocab, shapes_vocab, rels_vocab,
    batch_size=64,
    sort_size=1000,
)
batches = encode(markups)
batches = [_.to(DEVICE) for _ in batches]

size = 25
batches = {
    TEST: batches[:size],
    TRAIN: batches[size:]
}

In [None]:
board = MultiBoard([
    TensorBoard(BOARD_NAME, RUNS_DIR),
    LogBoard()
])
boards = {
    TRAIN: board.section(TRAIN_BOARD),
    TEST: board.section(TEST_BOARD),
}

In [None]:
optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, LR_GAMMA)

In [None]:
meters = {
    TRAIN: SyntaxScoreMeter(),
    TEST: SyntaxScoreMeter(),
}

for epoch in log_progress(range(EPOCHS)):
    model.train()
    for batch in log_progress(batches[TRAIN], leave=False):
        optimizer.zero_grad()
        batch = process_batch(model, criterion, batch)
        batch.loss.backward()
        optimizer.step()
    
        score = score_syntax_batch(batch)
        meters[TRAIN].add(score)

    meters[TRAIN].write(boards[TRAIN])
    meters[TRAIN].reset()

    model.eval()
    with torch.no_grad():
        for batch in log_progress(batches[TEST], leave=False, desc=TEST):
            batch = process_batch(model, criterion, batch)
            score = score_syntax_batch(batch)
            meters[TEST].add(score)

        meters[TEST].write(boards[TEST])
        meters[TEST].reset()
    
    scheduler.step()
    board.step()

In [None]:
# [2020-04-23 17:03:43]    0 0.4045 01_train/01_loss
# [2020-04-23 17:03:43]    0 0.8770 01_train/02_uas
# [2020-04-23 17:03:43]    0 0.8595 01_train/03_las
# [2020-04-23 17:03:44]    0 0.2512 02_test/01_loss
# [2020-04-23 17:03:44]    0 0.9231 02_test/02_uas
# [2020-04-23 17:03:44]    0 0.9103 02_test/03_las
# [2020-04-23 17:39:45]    1 0.3287 01_train/01_loss
# [2020-04-23 17:39:45]    1 0.8975 01_train/02_uas
# [2020-04-23 17:39:45]    1 0.8827 01_train/03_las
# [2020-04-23 17:39:45]    1 0.2286 02_test/01_loss
# [2020-04-23 17:39:45]    1 0.9289 02_test/02_uas
# [2020-04-23 17:39:45]    1 0.9172 02_test/03_las
# [2020-04-23 18:15:48]    2 0.3106 01_train/01_loss
# [2020-04-23 18:15:48]    2 0.9025 01_train/02_uas
# [2020-04-23 18:15:48]    2 0.8883 01_train/03_las

# [2020-04-23 18:15:48]    2 0.2158 02_test/01_loss
# [2020-04-23 18:15:48]    2 0.9316 02_test/02_uas
# [2020-04-23 18:15:48]    2 0.9208 02_test/03_las

In [None]:
# model.emb.shape.dump(MODEL_SHAPE)
# model.encoder.dump(MODEL_ENCODER)
# model.head.dump(MODEL_HEAD)
# model.rel.dump(MODEL_REL)
# rels_vocab.dump(RELS_VOCAB)
        
# s3.upload(MODEL_SHAPE, S3_MODEL_SHAPE)
# s3.upload(MODEL_ENCODER, S3_MODEL_ENCODER)
# s3.upload(MODEL_HEAD, S3_MODEL_HEAD)
# s3.upload(MODEL_REL, S3_MODEL_REL)
# s3.upload(RELS_VOCAB, S3_RELS_VOCAB)