In [1]:
import os
import json
from data.dataset import NERDataset
from models.networks import GlobalContextualDeepTransition
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback

In [2]:
with open('config.json', 'r') as file:
    kwargs = json.load(file)
print("Init model params =", json.dumps(kwargs, indent=4))
model = GlobalContextualDeepTransition(**kwargs)

Init model params = {
    "numChars": 100,
    "charEmbedding": 128,
    "numWords": 21388,
    "wordEmbedding": 300,
    "contextOutputUnits": 128,
    "contextTransitionNumber": 4,
    "encoderUnits": 256,
    "decoderUnits": 256,
    "transitionNumber": 4,
    "numTags": 17
}


In [3]:
numParams = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {numParams:,}") # 7,443,753

Trainable parameters: 7,443,797


In [4]:
for k, v in model.sequenceLabeller.output.state_dict().items():
    print(k.ljust(63).replace('.', '/'), 'shape', str(v.numpy().shape).ljust(12), v.numel())

0/weight                                                        shape (256, 556)   142336
0/bias                                                          shape (256,)       256
3/weight                                                        shape (17, 256)    4352
3/bias                                                          shape (17,)        17
