# DBert train

From <https://github.com/huggingface/transformers/blob/master/examples/run_tf_glue.py>

## Commands

In [None]:
# Titanv tf install:
# !pip freeze | grep flow
# !pip install --upgrade pip
# !pip uninstall --y tensorboard tensorflow-estimator tensorflow tensorflow-gpu
# !pip install --upgrade tensorflow==2.0.0
# !pip install --upgrade tensorflow-gpu==2.0.0
# !pip install --upgrade transformers==2.4.1
# !pip freeze | grep flow

In [None]:
# titanv 1:
# screen -S dbert-train1
# source ~/.bash_profile ; source ~/.bash_aliases ; cd ~/dbert-train-logs
# DOCKER_PORT=9961 nn -o nohup-dbert-train-$HOSTNAME-1.out ~/docker/keras/run-jupython.sh ~/notebooks/asa/train/dbert-train.ipynb titanv
# observe ~/dbert-train-logs/nohup-dbert-train-$HOSTNAME-1.out

In [None]:
# titanv 2:
# screen -S dbert-train2
# source ~/.bash_profile ; source ~/.bash_aliases ; cd ~/dbert-train-logs
# DOCKER_PORT=9962 nn -o nohup-dbert-train-$HOSTNAME-2.out ~/docker/keras/run-jupython.sh ~/notebooks/asa/train/dbert-train.ipynb titanv
# observe ~/dbert-train-logs/nohup-dbert-train-$HOSTNAME-2.out

In [None]:
# cd ; archive-notebooks ; cd ~/logs ; ./mv-old-logs.sh # optionnel
# sbatch ~/slurm/run-notebook.sh ~/tmp/archives/notebooks/asa/train/dbert-train.ipynb
# observe ~/logs/*.out

In [None]:
# cd ; archive-notebooks ; cd ~/logs
# sbatch ~/slurm/run-notebook.sh ~/tmp/archives/notebooks/asa/train/dbert-train.ipynb
# observe ~/logs/*.out

## Imports

In [None]:
isNotebook = '__file__' not in locals()

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

In [None]:
import logging
import math
import tensorflow as tf
from tensorflow.keras import callbacks
from transformers import \
(
    DistilBertConfig,
    DistilBertTokenizer,
    TFDistilBertForSequenceClassification,
)

## Functions

In [None]:
def ksetGen\
(
    train=True,
    ksetRoot=dataDir() + '/Asa2/detok-kset' if lri() else homeDir() + "/asa/asa2-data/detok-kset",
    maxFiles=None,
    **kwargs,
):
    # We find files:
    if train:
        files = sortedGlob(ksetRoot + '/train/*.bz2')
    else:
        files = sortedGlob(ksetRoot + '/validation/*.bz2')
    if maxFiles is not None:
        files = files[:maxFiles]
    # we return the generator:
    return genFunct(files, ksetRoot=ksetRoot, **kwargs)

In [None]:
def genFunct\
(
    files,
    
    ksetRoot=dataDir() + '/Asa2/detok-kset' if lri() else homeDir() + "/asa/asa2-data/detok-kset",
    dataCol="filtered_detokenized_sentences",
    labelField='label',
    
    labelEncoding='index',
    labelEncoder=None,
    
    maxSamples=None,
    maxSentences=None,
    
    preventTokenizerWarnings=True,
    loggerName="transformers.tokenization_utils",
    
    logger=None,
    verbose=True,
    
    showProgress=False,
    
    multiSamplage=False,
    **encodeKwargs,
):
    # Handling unique file:
    if not isinstance(files, list):
        files = [files]
    # Misc init:
    samplesCount = 0
    # We set the logger level:
    if preventTokenizerWarnings:
        previousLoggerLevel = logging.getLogger(loggerName).level
        logging.getLogger(loggerName).setLevel(logging.ERROR)
    if showProgress:
        pbar = ProgressBar(len(files), logger=logger, verbose=verbose)
    # We get labels and encode labels:
    if labelEncoder is None:
        labels = sorted(list(deserialize(ksetRoot + '/validation/labels.pickle')))
        (classes, labels) = encodeMulticlassLabels(labels, encoding=labelEncoding)
        labelEncoder = dict()
        assert len(classes) == len(labels)
        for i in range(len(classes)):
            labelEncoder[classes[i]] = labels[i]
    # For each file:
    for file in files:
        for row in NDJson(file):
            # We get sentences:
            sentences = row[dataCol]
            if not (isinstance(sentences, list) and len(sentences) > 1 and isinstance(sentences[0], str)):
                raise Exception("All row[dataCol] must be a list of strings (sentences)")
            if maxSentences is not None:
                sentences = sentences[:maxSentences]
            # We encode the document:
            parts = tf2utils.distilBertEncode\
            (
                sentences,
                multiSamplage=multiSamplage,
                preventTokenizerWarnings=False,
                proxies=proxies,
                logger=logger, verbose=verbose,
                **encodeKwargs,
            )
            if not multiSamplage:
                parts = [parts]
            # We yield all parts:
            for part in parts:
                yield (np.array(part), labelEncoder[row[labelField]])
                # yield (np.array([np.array(part), np.array(part)]), np.array([labelEncoder[row[labelField]], labelEncoder[row[labelField]]]))
                samplesCount += 1
            if maxSamples is not None and samplesCount >= maxSamples:
                break
        if showProgress:
            pbar.tic(file)
        if maxSamples is not None and samplesCount >= maxSamples:
            break
    # We reset the logger:
    if preventTokenizerWarnings:
        logging.getLogger(loggerName).setLevel(previousLoggerLevel)

In [None]:
def saveFunct(model, directory, **kwargs):
    model.save_pretrained(directory)

In [None]:
def getSamplesCount(logger=None, verbose=True):
    samplesCountCache = None
    (user, password, host) = getOctodsMongoAuth()
    samplesCountCache = SerializableDict('samples-count', user=user, host=host, password=password, useMongodb=True)
    samplesCountParams = \
    {
        'maxFiles': config['maxFiles'],
        'maxSamples': config['maxSamples'],
        'multiSamplage': config['multiSamplage'],
        'maxLength': config['maxLength'],
        'dataCol': config['dataCol'],
    }
    trainSamplesCountParams = mergeDicts(samplesCountParams, {'train': True})
    trainSamplesCountHash = objectToHash(trainSamplesCountParams)
    validationSamplesCountParams = mergeDicts(samplesCountParams, {'train': False})
    validationSamplesCountHash = objectToHash(validationSamplesCountParams)
    if samplesCountCache is not None and trainSamplesCountHash in samplesCountCache:
        trainSamplesCount = samplesCountCache[trainSamplesCountHash]
    else:
        log("Starting to count batches in the train set...", logger, verbose=verbose)
        trainSamplesCount = 0
        for row in ksetGen\
        (
            train=True,
            **samplesCountParams,
            showProgress=True,
            logger=logger,
            verbose=True,
        ):
            trainSamplesCount += 1
        if samplesCountCache is not None:
            samplesCountCache[trainSamplesCountHash] = trainSamplesCount
    if samplesCountCache is not None and validationSamplesCountHash in samplesCountCache:
        validationSamplesCount = samplesCountCache[validationSamplesCountHash]
    else:
        log("Starting to count batches in the validation set...", logger, verbose=verbose)
        validationSamplesCount = 0
        for row in ksetGen\
        (
            train=False,
            **samplesCountParams,
            showProgress=True,
            logger=logger,
            verbose=True,
        ):
            validationSamplesCount += 1
        if samplesCountCache is not None:
            samplesCountCache[validationSamplesCountHash] = validationSamplesCount
    return (trainSamplesCount, validationSamplesCount)

## Config

In [None]:
config = \
{
    'dataCol': 'filtered_detokenized_sentences',
    'ksetRoot': dataDir() + '/Asa2/detok-kset' if lri() else homeDir() + "/asa/asa2-data/detok-kset",
    'multiSamplage': True,
    'maxFiles': 30 if isNotebook else None,
    'maxSamples': 5000 if isNotebook else None,
    'maxLength': 512,
    'batchSize': 16,
    
    'learningRate': 3e-5,
    'epsilon': 1e-08,
    'clipnorm': 1.0,
    
    'trainStepDivider': 2 if isNotebook else 30,
    'shuffle': 0 if isNotebook else 100,
    'queueSize': 100,
    
    'useMLIterator': True,
}

In [None]:
ksetRoot = config['ksetRoot']

In [None]:
outputDirRoot = homeDir() + '/asa/dbert-train'
outputDir = outputDirRoot + '/' + objectToHash(config)[:5]
mkdir(outputDir)

In [None]:
if False:
    assert config['maxFiles'] == 3
    assert isNotebook
    remove(outputDir)

In [None]:
logger = Logger(outputDir + '/dbert-train.log')
log("outputDir: " + str(outputDir), logger)

In [None]:
trainFiles = sortedGlob(ksetRoot + '/train/*.bz2')
validationFiles = sortedGlob(ksetRoot + '/validation/*.bz2')
if config['maxFiles'] is not None:
    log("Reducing amount of train files from " + str(len(trainFiles)) + " to " + str(config['maxFiles']), logger)
    trainFiles = trainFiles[:config['maxFiles']]
    log("Reducing amount of validation files from " + str(len(validationFiles)) + " to " + str(config['maxFiles']), logger)
    validationFiles = validationFiles[:config['maxFiles']]
bp(trainFiles, logger)
bp(validationFiles, logger)

## Model

In [None]:
# In case we reume a previous train:
batchesPassed = 0
initialEpoch = 0
lastEpochPath = None
if len(sortedGlob(outputDir + "/epochs/ep*")) > 0:
    lastEpochPath = sortedGlob(outputDir + "/epochs/ep*")[-1]
    batchesPassedPath = lastEpochPath + "/batchesPassed.txt"
    assert isFile(batchesPassedPath)
    assert not isFile(outputDir + "/finished")
    assert not isFile(outputDir + "/stop")
    initialEpoch = getFirstNumber(decomposePath(lastEpochPath)[1]) + 1
    batchesPassed = int(fileToStr(batchesPassedPath))
    log("We found an epoch to resume: " + lastEpochPath, logger)
    logWarning("We will skip " + str(batchesPassed) + " batches because we resume a previous train", logger)

In [None]:
if lastEpochPath is not None:
    log("Loading previous model...", logger)
    dbertConfig = DistilBertConfig.from_pretrained(lastEpochPath + '/config.json')
    model = TFDistilBertForSequenceClassification.from_pretrained\
    (
        lastEpochPath + '/tf_model.h5',
        config=dbertConfig,
    )
else:
    log("Loading a new model from distilbert-base-uncased...", logger)
    # Labels count:
    numLabels = len(deserialize(ksetRoot + '/validation/labels.pickle'))
    # Config:
    dbertConfig = DistilBertConfig.from_pretrained\
    (
        "distilbert-base-uncased",
        num_labels=numLabels,
        max_length=config['maxLength'],
        proxies=proxies,
    )
    # Model:
    model = TFDistilBertForSequenceClassification.from_pretrained\
    (
        "distilbert-base-uncased",
        config=dbertConfig,
        proxies=proxies,
    )
log("Model loaded.", logger)

In [None]:
# Optimizer:
optKwargs = dict()
if dictContains(config, 'clipnorm'): optKwargs['clipnorm'] = config['clipnorm']
if dictContains(config, 'learningRate'): optKwargs['learning_rate'] = config['learningRate']
if dictContains(config, 'epsilon'): optKwargs['epsilon'] = config['epsilon']
opt = tf.keras.optimizers.Adam(**optKwargs)
# Loss:
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# Metric:
metric = tf.keras.metrics.SparseCategoricalAccuracy("accuracy")
# Compilation:
model.compile(optimizer=opt, loss=loss, metrics=[metric])

In [None]:
model.summary()

## Training

In [None]:
(trainSamplesCount, validationSamplesCount) = getSamplesCount(logger=logger)
log('trainSamplesCount: ' + str(trainSamplesCount) + ', validationSamplesCount: ' + str(validationSamplesCount), logger)

In [None]:
trainBatchesAmount = math.ceil(trainSamplesCount / config['batchSize'])
validationBatchesAmount = math.ceil(validationSamplesCount / config['batchSize'])
trainSteps = math.ceil(trainBatchesAmount / config["trainStepDivider"])
validationSteps = validationBatchesAmount
log('trainBatchesAmount: ' + str(trainBatchesAmount), logger)
log('validationBatchesAmount: ' + str(validationBatchesAmount), logger)
log('trainSteps: ' + str(trainSteps), logger)
log('validationSteps: ' + str(validationSteps), logger)

In [None]:
callback = tf2utils.KerasCallback\
(
    model,
    outputDir,
    saveFunct=saveFunct,
    showGraphs=isNotebook,
    earlyStopMonitor=
    {
        'val_loss': {'patience': 10, 'mode': 'auto'},
        'val_accuracy': {'patience': 10, 'mode': 'auto'},
        'val_top_k_categorical_accuracy': {'patience': 10, 'mode': 'auto'},
    },
    initialEpoch=initialEpoch,
    batchesAmount=trainBatchesAmount,
    batchesPassed=batchesPassed,
    removeEpochs=True,
    logger=logger,
)

In [None]:
ksetGenKwargs = \
{
    'ksetRoot': ksetRoot,
    'dataCol': config['dataCol'],
    'maxLength': config['maxLength'],
    'multiSamplage': config['multiSamplage'],
    'maxSamples': config['maxSamples'],
}
ksetGenTrainKwargs = mergeDicts(ksetGenKwargs, {'train': True})
ksetGenValidationKwargs = mergeDicts(ksetGenKwargs, {'train': False})

In [None]:
 if config['useMLIterator']:
    train = IteratorToGenerator\
    (
        InfiniteBatcher\
        (
            AgainAndAgain\
            (
                MLIterator,
                trainFiles,
                genFunct,
                genKwargs=ksetGenKwargs,
                queuesMaxSize=100,
                parallelProcesses=cpuCount(),
                useFlushTimer=False,
                flushTimeout=300,
                logger=logger,
            ),
            batchSize=config['batchSize'],
            shuffle=config['shuffle'],
            queueSize=config['queueSize'],
            skip=batchesPassed,
            logger=logger,
        )
    )
    validation = IteratorToGenerator\
    (
        InfiniteBatcher\
        (
            AgainAndAgain\
            (
                MLIterator,
                validationFiles,
                genFunct,
                genKwargs=ksetGenKwargs,
                queuesMaxSize=100,
                parallelProcesses=cpuCount(),
                useFlushTimer=False,
                flushTimeout=300,
                logger=logger,
            ),
            batchSize=config['batchSize'],
            shuffle=config['shuffle'],
            queueSize=config['queueSize'],
            skip=batchesPassed,
            logger=logger,
        )
    )
else:
    train = IteratorToGenerator(InfiniteBatcher\
    (
        AgainAndAgain(ksetGen, **ksetGenTrainKwargs),
        batchSize=config['batchSize'],
        shuffle=config['shuffle'],
        queueSize=config['queueSize'],
        skip=batchesPassed,
        logger=logger,
    ))
    validation = IteratorToGenerator(InfiniteBatcher\
    (
        AgainAndAgain(ksetGen, **ksetGenValidationKwargs),
        batchSize=config['batchSize'],
        shuffle=0,
        queueSize=100,
        skip=0,
        logger=logger,
    ))

In [None]:
history = model.fit\
(
    x=train,
    epochs=100 * config["trainStepDivider"],
    validation_data=validation,
    callbacks=[callback, callbacks.TerminateOnNaN()],
    initial_epoch=initialEpoch,
    steps_per_epoch=trainSteps,
    validation_steps=validationSteps,
)