In [None]:
from seligator.common.params import MetadataEncoding, Seq2VecEncoderType, BasisVectorConfiguration
from seligator.simple_demo import prepare_model, train_and_get
from seligator.tests.evaluate import *
from seligator.models.siamese import SiameseClassifier

# Run tests on Vectors Categories
METADATA_CATS = ("Century", "Textgroup", "WrittenType", "CitationTypes")
BVC = BasisVectorConfiguration(
    categories=METADATA_CATS
)
model, reader, train, dev = prepare_model(
    input_features=("lemma_char", "lemma", "case", "numb", "gend", "mood", "tense", "voice", "person", "deg"),
    seq2vec_encoder_type=Seq2VecEncoderType.MetadataLSTM,
    basis_vector_configuration=BVC,
    agglomerate_msd=True,
    reader_kwargs={
        "batch_size": 4, 
        "metadata_encoding": MetadataEncoding.AS_CATEGORICAL,
        "metadata_tokens_categories": METADATA_CATS
    },
    model_embedding_kwargs=dict(
        keep_all_vocab=True,
        pretrained_embeddings={
            # "token": "~/Downloads/latin.embeddings",
        #    "token": "~/dev/these/notebooks/4 - Detection/data/embs_models/model.token.word2vec.kv",
            "lemma": "~/dev/these/notebooks/4 - Detection/data/embs_models/model.lemma.word2vec.kv.header"
        },
        trainable_embeddings={"token": False, "lemma": False},
        pretrained_emb_dims={"token": 200, "lemma": 200}
    ),
    #batches_per_epoch=100,
    # model_class=SiameseClassifier,
    use_bert_higway=True,
    additional_model_kwargs={
        "metadata_linear": False,
    }
)
model = train_and_get(
    model, train, dev,
    patience=2,
    num_epochs=20,
    lr=5e-4,
    optimizer="AdamW",
    #optimizer_params=dict(rho=0.9, eps=1e-6)
#    use_cpu=True
)
print(model)
data = run_tests(
    "dataset/split/test.txt",
    dataset_reader=reader, model=model, dump="test.lemma-msd-metadatacat-basis.csv"
)
print(data)

2021-07-28 17:53:55.689740: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
INFO:root:Dataset reader set with following categories: [msd], lemma_char, lemma
INFO:root:Indexer set for following categories: [msd], lemma_char, lemma
INFO:root:TSV READER uses following metadata encoding MetadataEncoding.AS_CATEGORICAL 
INFO:root:Reading data
INFO:root:Building the vocabulary
INFO:allennlp.data.vocabulary:Fitting token dictionary from dataset.


building vocab:   0%|          | 0/5427 [00:00<?, ?it/s]

INFO:root:Fitting the BasisVectorConfiguration
INFO:allennlp.modules.token_embedders.embedding:Reading pretrained embeddings from file
INFO:allennlp.modules.token_embedders.embedding:Recognized a header line in the embedding file with number of tokens: 155131


  0%|          | 0/155131 [00:00<?, ?it/s]

INFO:allennlp.modules.token_embedders.embedding:Initializing pre-trained embedding layer
INFO:allennlp.modules.token_embedders.embedding:Pretrained embeddings were found for 155131 out of 155381 tokens
INFO:allennlp.training.optimizers:Number of trainable parameters: 1787058
INFO:root:Current Optimizer: AdamWOptimizer (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.0005
    weight_decay: 0.01
) 
INFO:root:Num epochs: 20
INFO:root:Starting training
INFO:allennlp.training.gradient_descent_trainer:Beginning training.
INFO:allennlp.training.gradient_descent_trainer:Epoch 0/19
INFO:allennlp.training.gradient_descent_trainer:Worker 0 memory usage: 2.9G
INFO:allennlp.training.gradient_descent_trainer:GPU 0 memory usage: 125M
INFO:allennlp.training.gradient_descent_trainer:Training


---> Epochs:   20
---> Patience: 2


  0%|          | 0/1357 [00:00<?, ?it/s]

  probs = F.softmax(logits)
INFO:allennlp.training.callbacks.console_logger:Batch inputs
INFO:allennlp.training.callbacks.console_logger:batch_input/[msd] (Shape: 4 x 33 x 52)
tensor([[[0, 0, 1,  ..., 0, 0, 0],
         [0, 0, 1,  ..., 0, 0, 0],
         [0, 0, 1,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 1,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 1,  ..., 0, 0, 0],
         ...,
         [0, 0, 1,  ..., 0, 0, 0],
         [0, 0, 1,  ..., 0, 0, 0],
         [0, 0, 1,  ..., 0, 0, 0]],

        [[0, 0, 1,  ..., 0, 0, 0],
         [0, 0, 1,  ..., 0, 0, 0],
         [0, 0, 1,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 1,  ..., 0, 0, 0],
         [0, 0, 1,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0,

torch.Size([4, 33, 128])
torch.Size([4, 15, 128])
torch.Size([4, 17, 128])
torch.Size([4, 21, 128])
torch.Size([4, 20, 128])
torch.Size([4, 30, 128])
torch.Size([4, 29, 128])
torch.Size([4, 34, 128])
torch.Size([4, 70, 128])
torch.Size([4, 61, 128])
torch.Size([4, 69, 128])
torch.Size([4, 69, 128])
torch.Size([4, 28, 128])
torch.Size([4, 23, 128])
torch.Size([4, 18, 128])
torch.Size([4, 35, 128])
torch.Size([4, 61, 128])
torch.Size([4, 35, 128])
torch.Size([4, 25, 128])
torch.Size([4, 17, 128])
torch.Size([4, 25, 128])
torch.Size([4, 28, 128])
torch.Size([4, 24, 128])
torch.Size([4, 96, 128])
torch.Size([4, 41, 128])
torch.Size([4, 35, 128])
torch.Size([4, 29, 128])
torch.Size([4, 46, 128])
torch.Size([4, 32, 128])
torch.Size([4, 26, 128])
torch.Size([4, 13, 128])
torch.Size([4, 49, 128])
torch.Size([4, 47, 128])
torch.Size([4, 40, 128])
torch.Size([4, 18, 128])
torch.Size([4, 20, 128])
torch.Size([4, 29, 128])
torch.Size([4, 23, 128])
torch.Size([4, 23, 128])
torch.Size([4, 42, 128])
