In [1]:
import os
import json
from data.dataset import NERDataset
from models.networks import GlobalContextualDeepTransition
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback

In [2]:
base = 'data/conll03'
sourceName = f'{base}/eng.train.src'
targetName = f'{base}/eng.train.trg'
gloveFile = f'{base}/trimmed.300d.Cased.txt'
symbFile = f'{base}/sym.glove'
prevCheckpoint = None#'lightning_logs/version_7/epoch=502-step=24938.ckpt'
data = NERDataset(sourceName, targetName, gloveFile, symbFile)
loader = data.getLoader(4096)

In [3]:
with open('config_small.json', 'r') as file:
    kwargs = json.load(file)
print("Init model params =", json.dumps(kwargs, indent=4))
model = GlobalContextualDeepTransition(**kwargs)
model.init_weights(data.embeddingWeights)

numParams = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {numParams:,}") # 7,313,34

Init model params = {
    "numChars": 100,
    "charEmbedding": 128,
    "numWords": 21388,
    "wordEmbedding": 300,
    "contextOutputUnits": 128,
    "contextTransitionNumber": 2,
    "encoderUnits": 256,
    "decoderUnits": 256,
    "transitionNumber": 4,
    "numTags": 17
}
Trainable parameters: 7,114,173


In [4]:
class SaveEachEpoch(Callback):
    def __init__(self, dirpath, filename, period):
        super().__init__()
        self.dirpath = dirpath
        self.filename = filename
        self.period = period
        
    def on_epoch_end(self, trainer, pl_module):
        if trainer.current_epoch % self.period == 0:
            path = os.path.join(self.dirpath, self.filename).format(epoch=trainer.current_epoch)
            print("Saving at", path)
            trainer.save_checkpoint(path)

In [None]:
ckpt = SaveEachEpoch(
    dirpath='lightning_logs/backup/',
    filename='ckpt-small{epoch:02d}',
    period=5
)

trainer = pl.Trainer(resume_from_checkpoint=prevCheckpoint, callbacks=[ckpt],
                        gradient_clip_val=5., gpus=1, max_epochs=500)
trainer.fit(model, loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name             | Type                           | Params
--------------------------------------------------------------------
0 | contextEncoder   | GlobalContextualEncoder        | 7 M   
1 | sequenceLabeller | SequenceLabelingEncoderDecoder | 6 M   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…



Saving at lightning_logs/backup/ckpt-small00
Saving at lightning_logs/backup/ckpt-small05
Saving at lightning_logs/backup/ckpt-small10
Saving at lightning_logs/backup/ckpt-small15
Saving at lightning_logs/backup/ckpt-small20
Saving at lightning_logs/backup/ckpt-small25
Saving at lightning_logs/backup/ckpt-small30
Saving at lightning_logs/backup/ckpt-small35
Saving at lightning_logs/backup/ckpt-small40
Saving at lightning_logs/backup/ckpt-small45
Saving at lightning_logs/backup/ckpt-small50
Saving at lightning_logs/backup/ckpt-small55
Saving at lightning_logs/backup/ckpt-small60
Saving at lightning_logs/backup/ckpt-small65
Saving at lightning_logs/backup/ckpt-small70
Saving at lightning_logs/backup/ckpt-small75
Saving at lightning_logs/backup/ckpt-small80
Saving at lightning_logs/backup/ckpt-small85
