In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
#Export
from uti.interface import *
from pytorch_pretrained_bert import BertTokenizer

In [3]:
#Export
class FastaiBertTokenizer(BaseTokenizer):
    '''wrapper for fastai tokenizer'''
    def __init__(self, tokenizer, max_seq=128, **kwargs):
        self._pretrained_tokenizer = tokenizer
        self.max_seq_length = max_seq
        
    def __call__(self,*args,**kwargs):
        return self
    
    def tokenizer(self,t):
        return ["[CLS]"] + self._pretrained_tokenizer.tokenize(t)[:self.max_seq_length - 2] + ['[SEP]']
    
class BERT_Interface(Interface):
    def __init__(self,model_tokenizer, path, eval_mode=True, max_seq=128):
        bert_tok = BertTokenizer.from_pretrained(
            model_tokenizer,
        )
        fastai_tokenizer = Tokenizer(tok_func=FastaiBertTokenizer(bert_tok,max_seq=128),
                                     pre_rules=[fix_html],
                                     post_rules=[]
                                    )
        self.fastai_bert_vocab = Vocab(list(bert_tok.vocab.keys()))
        self.processor = [OpenFileProcessor(),
                          TokenizeProcessor(tokenizer=fastai_tokenizer,include_bos=False,include_eos=False),
                          NumericalizeProcessor(vocab=self.fastai_bert_vocab)
                         ]
        super().__init__(path,eval_mode=eval_mode)
        
    def _get_individual_data(self,filepath):
        df = pd.read_csv(filepath)
        if df.shape[0]*0.7 < self.bs:
            bs = int(df.shape[0]*0.7)
        else: bs=self.bs
        data = (TextList
                .from_df(df=df, path=self.csv_path,cols='Body',
                         vocab=self.fastai_bert_vocab, processor=self.processor)
                .split_by_rand_pct(0.3,seed=42)
                .label_from_df(cols='Label')
                .databunch(bs=bs,num_workers=2)
               )
        return data
    
    def pre_processing(self,skip_convert_json=False,test=False,**kwargs):
        #if not self._check_data_format(): exit(0) #we can't handle non Basilica data
        if not skip_convert_json: self._convert_json_to_csv(test)
        if self.eval_mode:
            for file in self.csv_path.ls():
                file_extension = Path(file).name.split('.')
                if len(file_extension) < 2 or file_extension[1] != 'csv': continue
                self.data_list.append(self._get_individual_data(file))
                self.dataset_name.append(file_extension[0])
        else:
            print('Warnning: Does not support fine tune BERT language model')
            raise NotImplementedError

In [4]:
path = Path('/home/jupyter/insight_project/Project-M/data/preprocessed/')
path.ls()

[PosixPath('/home/jupyter/insight_project/Project-M/data/preprocessed/ed91c398-31c6-437f-a9d1-462e3ccfb6fa.json'),
 PosixPath('/home/jupyter/insight_project/Project-M/data/preprocessed/example.txt'),
 PosixPath('/home/jupyter/insight_project/Project-M/data/preprocessed/ff063ea9-62b8-4f29-9faa-04a09cb5fba2.json'),
 PosixPath('/home/jupyter/insight_project/Project-M/data/preprocessed/.ipynb_checkpoints'),
 PosixPath('/home/jupyter/insight_project/Project-M/data/preprocessed/delete_test'),
 PosixPath('/home/jupyter/insight_project/Project-M/data/preprocessed/fabab216-0767-4aa5-85fa-bb8852eb30d3.json'),
 PosixPath('/home/jupyter/insight_project/Project-M/data/preprocessed/csv'),
 PosixPath('/home/jupyter/insight_project/Project-M/data/preprocessed/f6d2081a-0f79-4abf-9021-c4d254859890.json')]

In [5]:
test = BERT_Interface("bert-base-uncased",path.ls())

In [6]:
test.path

[PosixPath('/home/jupyter/insight_project/Project-M/data/preprocessed/ed91c398-31c6-437f-a9d1-462e3ccfb6fa.json'),
 PosixPath('/home/jupyter/insight_project/Project-M/data/preprocessed/ff063ea9-62b8-4f29-9faa-04a09cb5fba2.json'),
 PosixPath('/home/jupyter/insight_project/Project-M/data/preprocessed/fabab216-0767-4aa5-85fa-bb8852eb30d3.json'),
 PosixPath('/home/jupyter/insight_project/Project-M/data/preprocessed/f6d2081a-0f79-4abf-9021-c4d254859890.json')]

In [7]:
test.pre_processing(test=True)

In [8]:
test.csv_path

PosixPath('/home/jupyter/insight_project/Project-M/data/preprocessed/delete_test')

In [9]:
from notebook2script import *
notebook2script('Test_BERT.ipynb', 'bert_interface')

Converted Test_BERT.ipynb to uti/bert_interface.py
