Skip to content

ginofft/BertClassifier

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

BERT Intent Classifier for CLINC150 Dataset

Pytorch and Huggingface implementation of a multi label intent classifier with BERT as the encoder and a MLP as the classification head.

This repo currently works for two dataset: CLINC150 and MixSNIPs.

User are recommended to write custom scripts to write your data into scr.dataset.SentenceLabelDataset, everything else should works just fine.

Features:

  • Multible LM as backbone: Bert, RoBerta and DistilBert
  • MLP With BCELoss : independent classifier for each classes.
  • A Multi Label evaluator : with an method to get thresholding value that maximize macro f1.
  • Dynamic Padding : for faster training and inference.

References

Quick Start

To train:

python main.py --mode train --batch_size 512 \
    --nEpochs 500 --saveEvery 10 \
    --datasetPath data/CLINC150
    --metrics 'marco f1' \
    --savePath output

To inference:

python main.py --mode inference \
    --datasetPath data/CLINC150 \
    --metrics 'macro f1' \
    --loadPath output/best.pth.tar

Custom stuff

You might need to change the dataloader to suit your need (or dataset), the code you need to modify is found in main:

    if opt.dataFormat.lower() == 'clinc150':
        dataDict = read_CLINC150_file(opt.datasetPath)
        trainList = dataDict['train'] + dataDict['oos_train']
        valList = dataDict['val'] + dataDict['oos_val']
        testList = dataDict['test'] + dataDict['oos_test']
        turn_single_label_to_multilabels(trainList, valList, testList)
    
    if opt.dataFormat.lower() == 'mixsnips':
        trainPath = opt.datasetPath + '/train.txt'
        valPath = opt.datasetPath + '/dev.txt'
        testPath = opt.datasetPath + '/test.txt'

        trainList = read_MixSNIPs_file(trainPath)
        valList = read_MixSNIPs_file(valPath)
        testList = read_MixSNIPs_file(testPath)

where read_CLINE150_file() and read_MixSNIPs_file() were my custom functions.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published