In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()
torch.rand(10, device=device)
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

Using device: cuda

Tesla T4
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [5]:
import pytorch_lightning as pl
import torch
from omegaconf import DictConfig, OmegaConf
# from models import PunctuationDomainModel
import hydra
from icecream import install
install()
def toString(obj):
    if isinstance(obj, np.ndarray):
        return 'array shape: '+obj.shape.__str__() + obj.__str__()
    if isinstance(obj, torch.Tensor):
        return 'tensor shape: '+obj.shape.__str__()
    if isinstance(obj, dict):
        return {_[0]:toString(_[1]) for _ in obj.items()}.__str__()
    return repr(obj)
ic.configureOutput(argToStringFunction=toString)

In [6]:
## Using hydra
from hydra.experimental import initialize, initialize_config_module, initialize_config_dir, compose
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path="core/config")
cfg=compose(
    config_name="config.yaml", 
)
cfg

{'base_path': '/home/nxingyu2/data', 'trainer': {'gpus': 1, 'num_nodes': 1, 'max_epochs': 3, 'max_steps': None, 'accumulate_grad_batches': 1, 'gradient_clip_val': 0.0, 'amp_level': 'O0', 'precision': 16, 'accelerator': 'ddp', 'checkpoint_callback': False, 'logger': True, 'log_every_n_steps': 1, 'val_check_interval': 1.0, 'resume_from_checkpoint': None}, 'exp_manager': {'exp_dir': None, 'name': 'Punctuation_with_Domain_discriminator', 'create_tensorboard_logger': True, 'create_checkpoint_callback': True}, 'model': {'transformer_path': '${base_path}/electra-base-discriminator', 'punct_label_ids': ['', '!', ',', '-', '.', ':', ';', '?', '—', '…'], 'dataset': {'data_dir': '${base_path}', 'labelled': ['${base_path}/ted_talks_processed'], 'unlabelled': ['${base_path}/open_subtitles_processed'], 'max_seq_length': 128, 'pad_label': '', 'ignore_extra_tokens': False, 'ignore_start_end': False, 'use_cache': True, 'num_workers': 2, 'pin_memory': False, 'drop_last': False}, 'train_ds': {'shuffle': 

In [8]:
from torch.utils.data import Dataset

from nemo.core.neural_types import ChannelType, LabelsType, MaskType, NeuralType
from transformers import AutoTokenizer
import numpy as np
from typing import List, Optional, Dict
import pandas as pd
import os
import torch
import subprocess

class PunctuationDomainDataset(Dataset):

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        """Returns definitions of module output ports.
               """
        return {
            "input_ids": NeuralType(('B', 'T'), ChannelType()),
            "attention_mask": NeuralType(('B', 'T'), ChannelType()),
            "labels": NeuralType(('B', 'T'), ChannelType()),
            "domain": NeuralType(('B'), ChannelType()),
        }

    def __init__(self, 
        csv_file:str, 
        tokenizer,
        num_samples:int=256,
        max_seq_length:int=256,
        punct_label_ids: Dict[str, int] = None,
        domain=0,
        labelled=True,
    ):
        if not (os.path.exists(csv_file)):
            raise FileNotFoundError(
                f'{csv_file} not found. The data should be joined in 1 csv file.\
                    Each line of the file contains the subword token ids, masks and class labels per row.'
            )

        data_dir = os.path.dirname(csv_file)
        filename = os.path.basename(csv_file)

        if not filename.endswith('.csv'):
            raise ValueError("{text_file} should have extension .csv")
        # filename = filename[:-4]
        
        self.csv_file = csv_file
        self.max_seq_length = max_seq_length
        self.set_num_samples(csv_file, num_samples)
        self.domain=domain
        self.labelled=labelled
        self.tokenizer=tokenizer

    def __getitem__(self, idx):
        x = next(
            pd.read_csv(
                self.csv_file,
                skiprows=(idx % self.len)*self.num_samples,
                chunksize=self.num_samples,
                header=None,
                delimiter=' '))
        x = torch.from_numpy(x.values).reshape(-1,3,self.max_seq_length) #x.shape[-1]//3
        return {'input_ids': torch.as_tensor(x[:,0,:], dtype=torch.long),
                'attention_mask': torch.as_tensor(x[:,1,:],dtype=torch.bool)if self.labelled else torch.zeros_like(x[:,1,:],dtype=torch.bool),
                'labels': torch.as_tensor(x[:,2,:],dtype=torch.long),
                'domain':self.domain*torch.ones(x.shape[0],1,dtype=torch.long)}

    def set_num_samples(self,csv_file,num_samples):
        self.num_samples = num_samples
        self.total_samples=int(subprocess.Popen(['wc', '-l', csv_file], stdout=subprocess.PIPE, stderr=subprocess.STDOUT).communicate()[0].split()[0])
        self.len = int(self.total_samples / self.num_samples)
        

    def __len__(self):
        return self.len
    
    def view(d)->list:
        """:param d(dictionary): returns readable format of single input_ids and labels in the form of readable text"""
        a,_,c=d.values()
        return [' '.join([_[0]+_[1] for _ in list(zip(self.tokenizer.convert_ids_to_tokens(_[0]),[id2tag[id] for id in _[1].tolist()]))]) for _ in zip(a,c)]
    
    def shuffle(self, sorted=False, seed=42):
        os.system('bash data/shuffle.sh -i {} -o {} -a {} -s {}'.format(self.csv_file, self.csv_file, ['false','true'][sorted], seed))

class PunctuationDomainDatasets(Dataset):
    
    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        """Returns definitions of module output ports. """
        return {
            "input_ids": NeuralType(('B', 'T'), ChannelType()),
            "attention_mask": NeuralType(('B', 'T'), ChannelType()),
            "labels": NeuralType(('B', 'T'), ChannelType()),
            "domain": NeuralType(('B'), ChannelType()),
        }

    def __init__(self, *datasets):
        self.datasets = datasets

    def __getitem__(self, i):
        b={k:torch.vstack([d[i][k] for d in self.datasets]) for k in ['input_ids','attention_mask','labels','domain']}
        rand=torch.randperm(b['labels'].size()[0])
        return {k:v[rand] for k,v in b.items()}

    def __len__(self):
        return max(len(d) for d in self.datasets)

class PunctuationInferenceDataset(Dataset):
    """
    Creates dataset to use during inference for punctuation and capitalization tasks with a pretrained model.
    For dataset to use during training with labels, see BertPunctuationCapitalizationDataset.
    Args:
        queries file to sequences, each line should a sentence, no header.
        max_seq_length: max sequence length minus 2 for [CLS] and [SEP]
        tokenizer: such as AutoTokenizer
    """

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        """Returns definitions of module output ports.
               """
        return {
            'input_ids': NeuralType(('B', 'T'), ChannelType()),
            'attention_mask': NeuralType(('B', 'T'), MaskType()),
        }

    def __init__(self, queries: List[str], max_seq_length: int, tokenizer):
        """ Initializes BertPunctuationInferDataset. """
        features = get_features(queries=queries, max_seq_length=max_seq_length, tokenizer=tokenizer)
        self.all_input_ids = features['input_ids']
        self.all_attention_mask = features['attention_mask']

    def __len__(self):
        return len(self.all_input_ids)

    def __getitem__(self, idx):
        return {'input_ids':self.all_input_ids[idx],
            'attention_mask':self.all_attention_mask[idx],}
        

def get_features(
    queries:str, 
    max_seq_length:int,
    tokenizer,
    punct_label_ids: dict = None,):

    def flatten(list_of_lists):
        for list in list_of_lists:
            for item in list:
                yield item

    def pad_to_len(max_length,ids):
        o=np.zeros(max_length, dtype=np.int)
        o[:len(ids)]=np.array(ids)
        return o

    def position_to_mask(max_length,indices):
        o=np.zeros(max_length,dtype=np.int)
        o[indices%(max_length-2)+1]=1
        return o

    batch_ids=[]
    batch_masks=[]
    for query in queries:
        wordlist=re.split('[^a-zA-Z0-9]+',query)
        subwords=list(map(tokenizer.tokenize,wordlist))
        subword_lengths=list(map(len,subwords))
        subwords=list(flatten(subwords))
        token_end_idxs=np.cumsum([0]+subword_lengths[:-1])+np.array(subword_lengths)-1
        teim=token_end_idxs%(max_seq_length-2)
        split_token_end_idxs=np.array_split(token_end_idxs,(np.argwhere((teim[1:])<teim[:-1]).flatten()+1).tolist())
        split_subwords=np.array_split(subwords,np.arange(max_length-2,len(subwords),max_seq_length-2)) 
        ids=torch.tensor([pad_to_len(max_seq_length,tokenizer.convert_tokens_to_ids(['[CLS]']+list(_)+['[SEP]'])) for _ in split_subwords], dtype=torch.long)
        masks=[position_to_mask(max_length,_) for _ in split_token_end_idxs]
        batch_ids.append(ids)
        batch_masks.append(masks)
    return {'input_ids': torch.as_tensor(batch_ids, dtype=torch.long),
            'attention_mask': torch.as_tensor(batch_masks,dtype=torch.bool)}

In [23]:
cfg.model.punct_label_ids=OmegaConf.create(sorted(cfg.model.punct_label_ids))
ids_to_labels = {_[0]:_[1] for _ in enumerate(cfg.model.punct_label_ids)}
labels_to_ids = {_[1]:_[0] for _ in enumerate(cfg.model.punct_label_ids)}
cfg.base_path='/home/nxingyu2/data' #/home/nxingyu/data
# cfg.base_path
cfg

{'base_path': '/home/nxingyu2/data', 'trainer': {'gpus': 1, 'num_nodes': 1, 'max_epochs': 3, 'max_steps': None, 'accumulate_grad_batches': 1, 'gradient_clip_val': 0.0, 'amp_level': 'O0', 'precision': 16, 'accelerator': 'ddp', 'checkpoint_callback': False, 'logger': True, 'log_every_n_steps': 1, 'val_check_interval': 1.0, 'resume_from_checkpoint': None}, 'exp_manager': {'exp_dir': None, 'name': 'Punctuation_with_Domain_discriminator', 'create_tensorboard_logger': True, 'create_checkpoint_callback': True}, 'model': {'transformer_path': '${base_path}/electra-base-discriminator', 'punct_label_ids': ['', '!', ',', '-', '.', ':', ';', '?', '—', '…'], 'dataset': {'data_dir': '${base_path}', 'labelled': ['${base_path}/ted_talks_processed'], 'unlabelled': ['${base_path}/open_subtitles_processed'], 'max_seq_length': 128, 'pad_label': '', 'ignore_extra_tokens': False, 'ignore_start_end': False, 'use_cache': True, 'num_workers': 2, 'pin_memory': False, 'drop_last': False}, 'train_ds': {'shuffle': 

In [332]:
class PunctuationDomainDataset(Dataset):

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        """Returns definitions of module output ports.
               """
        return {
            "input_ids": NeuralType(('B', 'T'), ChannelType()),
            "attention_mask": NeuralType(('B', 'T'), ChannelType()),
            "labels": NeuralType(('B', 'T'), ChannelType()),
            "domain": NeuralType(('B'), ChannelType()),
        }

    def __init__(self, 
        csv_file:str, 
        tokenizer,
        num_samples:int=256,
        max_seq_length:int=256,
        punct_label_ids: Dict[str, int] = None,
        domain=0,
        labelled=True,
    ):
        if not (os.path.exists(csv_file)):
            raise FileNotFoundError(
                f'{csv_file} not found. The data should be joined in 1 csv file.\
                    Each line of the file contains the subword token ids, masks and class labels per row.'
            )

        data_dir = os.path.dirname(csv_file)
        filename = os.path.basename(csv_file)

        if not filename.endswith('.csv'):
            raise ValueError("{text_file} should have extension .csv")
        # filename = filename[:-4]
        
        self.csv_file = ic(  csv_file)
        self.max_seq_length = ic(  max_seq_length)
        self.set_num_samples(csv_file, num_samples)
        self.domain= ic( domain)
        self.labelled= ic( labelled)
        self.tokenizer= ic( tokenizer)

    def __getitem__(self, idx):
        x = next(
            pd.read_csv(
                self.csv_file,
                skiprows=(idx % self.len)*self.num_samples,
                header=None,
                dtype=str,
                chunksize=self.num_samples,
                ))[1]
        chunked=chunk_examples_with_degree(0)(x)
        batched=chunk_to_len_batch(self.max_seq_length,self.tokenizer,chunked['texts'],chunked['tags'],self.labelled)
        batched['domain']=self.domain*torch.ones(batched['input_ids'].shape[0],1,dtype=torch.long)
        rand=torch.randperm(batched['domain'].size()[0])
        return {k:v[rand] for k,v in batched.items()}
#        {'input_ids': torch.as_tensor(x[:,0,:], dtype=torch.long),
#         'attention_mask': torch.as_tensor(x[:,1,:],dtype=torch.bool)if self.labelled else torch.zeros_like(x[:,1,:],dtype=torch.bool),
#         'labels': torch.as_tensor(x[:,2,:],dtype=torch.long),
#         'domain':self.domain*torch.ones(x.shape[0],1,dtype=torch.long)}

    def set_num_samples(self,csv_file,num_samples):
        self.num_samples = num_samples
        self.total_samples=int(subprocess.Popen(['wc', '-l', csv_file], stdout=subprocess.PIPE, stderr=subprocess.STDOUT).communicate()[0].split()[0])
        self.len = int(self.total_samples / self.num_samples)
        

    def __len__(self):
        return self.len
    
    def view(d)->list:
        """:param d(dictionary): returns readable format of single input_ids and labels in the form of readable text"""
        a,_,c=d.values()
        return [' '.join([_[0]+_[1] for _ in list(zip(self.tokenizer.convert_ids_to_tokens(_[0]),[id2tag[id] for id in _[1].tolist()]))]) for _ in zip(a,c)]
    
    def shuffle(self, sorted=False, seed=42):
        os.system('bash data/shuffle.sh -i {} -o {} -a {} -s {}'.format(self.csv_file, self.csv_file, ['false','true'][sorted], seed))


In [333]:
ds=PunctuationDomainDataset( 
    csv_file=cfg.model.dataset.labelled[0]+'.train.csv', 
    tokenizer=AutoTokenizer.from_pretrained(cfg.model.transformer_path),
    num_samples=16,
    max_seq_length=128,
    punct_label_ids=labels_to_ids,
    domain=0,
    labelled=True,
)
# ds.shuffle(sorted=True)
# ds.shuffle()

ic| csv_file: '/home/nxingyu2/data/ted_talks_processed.train.csv'
ic| max_seq_length: 128
ic| domain: 0
ic| labelled: True
ic| tokenizer: PreTrainedTokenizerFast(name_or_path='/home/nxingyu2/data/electra-base-discriminator', vocab_size=30522, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})


In [352]:
# ds0=ds[0]
ds0['domain'].size()

torch.Size([596, 1])

In [340]:
class PunctuationDomainDatasets(Dataset):
    
    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        """Returns definitions of module output ports. """
        return {
            "input_ids": NeuralType(('B', 'T'), ChannelType()),
            "attention_mask": NeuralType(('B', 'T'), ChannelType()),
            "subtoken_mask": NeuralType(('B', 'T'), ChannelType()),
            "labels": NeuralType(('B', 'T'), ChannelType()),
            "domain": NeuralType(('B'), ChannelType()),
        }

    def __init__(self, 
                 split:str,
                 num_samples:int,
                 max_seq_length:int,
                 punct_label_ids: Dict[str, int],
                 labelled: List[str],
                 unlabelled: List[str],
                 tokenizer):
        
        self.datasets = []
        for i,path in enumerate(labelled):
            self.datasets.append(PunctuationDomainDataset(
                    csv_file=f'{path}.{split}.csv', tokenizer=tokenizer,
                    num_samples=num_samples,max_seq_length=max_seq_length,
                    punct_label_ids=punct_label_ids,domain=i,labelled=True,))
            
        for i,path in enumerate(unlabelled):
            self.datasets.append(PunctuationDomainDataset(
                    csv_file=f'{path}.{split}.csv', tokenizer=tokenizer,
                    num_samples=num_samples,max_seq_length=max_seq_length,
                    punct_label_ids=punct_label_ids,domain=len(labelled)+i,labelled=False,))

    def __getitem__(self, i):
        ds=[d[i] for d in self.datasets]
        min_batch=1000000
        for d in ds:
            size=d['domain'].size()[0]
            if size<min_batch:
                min_batch=size
        #Ensure all domains are evenly represented
        b={k:torch.vstack([d[k][:min_batch] for d in ds]) for k in ['input_ids','attention_mask','subtoken_mask','labels','domain']}
        rand=torch.randperm(b['labels'].size()[0])
        return {k:v[rand] for k,v in b.items()}

    def __len__(self):
        return max(len(d) for d in self.datasets)

In [301]:
cfg.model.dataset

{'data_dir': '${base_path}', 'labelled': ['${base_path}/ted_talks_processed'], 'unlabelled': ['${base_path}/open_subtitles_processed'], 'max_seq_length': 128, 'pad_label': '', 'ignore_extra_tokens': False, 'ignore_start_end': False, 'use_cache': True, 'num_workers': 2, 'pin_memory': False, 'drop_last': False}

In [348]:
dstrain=PunctuationDomainDatasets(
        split='train',
        tokenizer=AutoTokenizer.from_pretrained(cfg.model.transformer_path),
        num_samples=8,
        max_seq_length=128,
        punct_label_ids=labels_to_ids,
        labelled=list(cfg.model.dataset.labelled),
        unlabelled=list(cfg.model.dataset.unlabelled)
    )


ic| csv_file: '/home/nxingyu2/data/ted_talks_processed.train.csv'
ic| max_seq_length: 128
ic| domain: 0
ic| labelled: True
ic| tokenizer: PreTrainedTokenizerFast(name_or_path='/home/nxingyu2/data/electra-base-discriminator', vocab_size=30522, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})
ic| csv_file: '/home/nxingyu2/data/open_subtitles_processed.train.csv'
ic| max_seq_length: 128
ic| domain: 1
ic| labelled: False
ic| tokenizer: PreTrainedTokenizerFast(name_or_path='/home/nxingyu2/data/electra-base-discriminator', vocab_size=30522, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})


In [349]:
dstrain0=dstrain[0]
# dstrain0['input_ids'].shape
# r=torch.randperm(414)
sum(dstrain0['domain'])

tensor([199])

In [None]:
from pytorch_lightning import LightningDataModule
from torch import dtype
from data import PunctuationDomainDataset, PunctuationDomainDatasets
from typing import List
import pandas as pd
import os
import torch
from nemo.utils import logging

class PunctuationDataModule(LightningDataModule):
    def __init__(self, 
            tokenizer,
            labelled: List[str], 
            unlabelled: List[str], 
            train_batch_size: int,
            max_seq_length:int = 256,
            val_batch_size:int = 256, 
            num_workers:int = 1,
            pin_memory:bool = False,
            drop_last:bool = False
            ):
        #unlabelled=[], batch_size = 256, max_seq_length = 256, num_workers=1):
        super().__init__()
        self.labelled=labelled
        self.tokenizer=tokenizer
        self.unlabelled=unlabelled
        self.num_domains=len(labelled)+len(unlabelled)
        self.train_batch_size = max(1,train_batch_size//self.num_domains)
        logging.info(f"using training batch_size of {self.train_batch_size} for each domain")
        self.val_batch_size = max(1,val_batch_size//self.num_domains)
        logging.info(f"using dev batch_size of {self.train_batch_size} for each domain")
        self.max_seq_length = max_seq_length
        self.num_workers=num_workers
        self.pin_memory = pin_memory
        self.drop_last = drop_last

        self.train_dataset={}
        self.dev_dataset={}
        self.test_dataset={}

    def setup(self, stage=None):
        for unlabelled,l in enumerate([self.labelled,self.unlabelled]):
            for i,p in enumerate(l):
                domain=i+unlabelled*len(self.labelled) #unlabelled domain is increasing after labelled
                try:
                    with open("{}.train-stride.csv".format(p),'r') as f:
                        s=len(f.readline().split(' '))//3
                except IOError:
                    s=0
                if (s!=self.max_seq_length):
                    logging.info(f"copying train file from {p}.train-batched.csv to {p}.train-stride.csv")
                    os.system("cp {} {}".format(p+'.train-batched.csv',p+'.train-stride.csv'))
                    if (self.max_seq_length!=256):
                        logging.info(f'generating training strides: {self.max_seq_length}')
                        n=np.loadtxt(open(p+".train-stride.csv", "rb"))
                        np.savetxt(p+".train-stride.csv", self.with_stride_split(n,self.max_seq_length),fmt='%d')

                if stage=='fit' or None:
                    self.train_dataset[domain] = PunctuationDomainDataset(p+'.train-stride.csv', num_samples=self.train_batch_size, max_seq_length=self.max_seq_length, domain = domain, labelled=bool(1-unlabelled), tokenizer=self.tokenizer)
                    self.dev_dataset[domain] =  PunctuationDomainDataset(p+'.dev-batched.csv', num_samples=self.val_batch_size, max_seq_length=self.max_seq_length, domain = domain, labelled=bool(1-unlabelled), tokenizer=self.tokenizer)
                    ic(self.train_dataset[domain].shuffle(sorted=True))
                    ic(self.train_dataset[domain].shuffle())

                if stage == 'test' or stage is None:
                    self.test_dataset[domain] =  PunctuationDomainDataset(p+'.test-batched.csv', num_samples=self.val_batch_size, max_seq_length=self.max_seq_length, domain = domain, labelled=bool(1-unlabelled), tokenizer=self.tokenizer)

    def shuffle(self):
        for dataset in self.train_dataset.values():
            dataset.shuffle()

    def train_dataloader(self):
        return DataLoader(PunctuationDomainDatasets(*self.train_dataset.values()),batch_size=None,num_workers=self.num_workers,pin_memory=self.pin_memory,drop_last=self.drop_last)

    def val_dataloader(self):
        return DataLoader(PunctuationDomainDatasets(*self.dev_dataset.values()),batch_size=None,num_workers=self.num_workers,pin_memory=self.pin_memory,drop_last=self.drop_last)

    def test_dataloader(self):
        return DataLoader(PunctuationDomainDatasets(*self.test_dataset.values()),batch_size=None,num_workers=self.num_workers,pin_memory=self.pin_memory,drop_last=self.drop_last)

    def with_stride_split(n,l):
        def with_stride(t,l):
            a=t[0,0]
            z=t[0,-1]
            t=t[:,1:-1].flatten()
            t=np.trim_zeros(t,'b')
            s=t.shape[0]
            nh=-(-s//(l-2))
            f=np.zeros((nh*(l-2),1))  
            f[:s,0]=t
            return np.hstack([np.ones((nh,1))*a,np.reshape(f,(-1,l-2)),np.ones((nh,1))*z])
        s=n.shape[1]
        a,b,c=n[:,:s//3],n[:,s//3:2*s//3],n[:,2*s//3:]
        a,b,c=with_stride(a,l), with_stride(b,l), with_stride(c,l)
        c1=np.zeros(a.shape)
        c1[:c.shape[0],:]=c
        return np.hstack([a,b,c1])



In [30]:
#helper functions
def flatten(list_of_lists):
    for l in list_of_lists:
        for item in l:
            yield item

def pad_to_len(max_seq_length,ids):
    '''[0, 1, 2] -> array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0])'''
    o=np.zeros(max_seq_length, dtype=np.int)
    o[:len(ids)]=np.array(ids)
    return o

def position_to_mask(max_seq_length:int,indices:list):
    '''[0, 2, 5] -> array([0, 1, 0, 1, 0, 0, 1, 0, 0, 0])'''
    o=np.zeros(max_seq_length,dtype=np.int)
    o[np.array(indices)%(max_seq_length-2)+1]=1
    return o

def align_labels_to_mask(mask,labels):
    '''[0,1,0],[2] -> [0,2,0]'''
    assert(sum(mask)==len(labels))
    mask[mask>0]=torch.tensor(labels)
    return mask.tolist()


In [357]:
class PunctuationInferenceDataset(Dataset):
    """
    Creates dataset to use during inference for punctuation and capitalization tasks with a pretrained model.
    For dataset to use during training with labels, see BertPunctuationCapitalizationDataset.
    Args:
        queries file to sequences, each line should a sentence, no header.
        max_seq_length: max sequence length minus 2 for [CLS] and [SEP]
        tokenizer: such as AutoTokenizer
    """

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        """Returns definitions of module output ports.
               """
        return {
            'input_ids': NeuralType(('B', 'T'), ChannelType()),
            'attention_mask': NeuralType(('B', 'T'), MaskType()),
            'subtoken_mask': NeuralType(('B', 'T'), MaskType()),
            "labels": NeuralType(('B', 'T'), ChannelType()),
        }

    def __init__(self, tokenizer, queries: List[str], max_seq_length: int, degree:int=0,):
        """ Initializes BertPunctuationInferDataset. """
        self.degree=degree
        chunked=chunk_examples_with_degree(self.degree)(queries)
        features = chunk_to_len_batch(max_seq_length=max_seq_length, tokenizer=tokenizer,tokens=chunked['texts'],labelled=False)
        self.all_input_ids = ic(features['input_ids'])
        self.all_attention_mask = ic(features['attention_mask'])
        self.all_subtoken_mask = ic(features['subtoken_mask'])

    def __len__(self):
        return len(self.all_input_ids)

    def __getitem__(self, idx):
        return {'input_ids':self.all_input_ids[idx],
            'attention_mask':self.all_attention_mask[idx],
               'subtoken_mask':self.all_subtoken_mask[idx]}


In [27]:
import regex as re
def text2masks(n):
    def text2masks(text):
        '''Converts single paragraph of text into a list of words and corresponding punctuation based on the degree requested.'''
        if n==0: 
            refilter="(?<=[.?!,;:\-—… ])(?=[^.?!,;:\-—… ])|$"
        else:
            refilter="[.?!,;:\-—…]{1,%d}(?= *[^.?!,;:\-—…]+|$)|(?<=[^.?!,;:\-—…]) +(?=[^.?!,;:\-—…])"%(n)
        text=re.sub(r'^[_\W]*','',text)
        word=re.split(refilter,text, flags=re.V1)
        punct=re.findall(refilter,text, flags=re.V1)
        wordlist,punctlist=([] for _ in range(2))
        if word[-1]=='': # ensures text aligns
            word.pop()
        else:
            punct.append('')
        
        for i in zip(word,punct): #+[''] to correspond to the last word or '' after the last punctuation.
            w,p=i[0].strip(),i[1].strip()
            if w!='':
                wordlist.append(re.sub(r'[.?!,;:\-—… ]','',w))
                punctlist.append(0 if not w[-1] in '.?!,;:-—…' else labels_to_ids[w[-1]])
            if p!='':
                wordlist.append(p)
                punctlist.append(0)
        return(wordlist,punctlist)
    return text2masks
def chunk_examples_with_degree(n):
    '''Ensure batched=True if using dataset.map or ensure the examples are wrapped in lists.'''
    def chunk_examples(examples):
        output={}
        output['texts']=[]
        output['tags']=[]
        for sentence in examples:
            text,tag=text2masks(n)(sentence)
            output['texts'].append(text)
            output['tags'].append(tag)
            # output['tags'].append([0]+tag if text[0]!='' else tag) # [0]+tag so that in all case, the first tag refers to [CLS]
            # not necessary since all the leading punctuations are stripped
        return output
    return chunk_examples
assert(chunk_examples_with_degree(0)(['Hello!Bye…'])=={'texts': [['Hello', 'Bye']], 'tags': [[1, 9]]})

def subword_tokenize(tokenizer,tokens):
    subwords = list(map(tokenizer.tokenize, tokens))
    subword_lengths = list(map(len, subwords))
    subwords = list(flatten(subwords))
    token_end_idxs = np.cumsum([0]+subword_lengths[:-1])+np.array(subword_lengths)-1
    return subwords, token_end_idxs

def chunk_to_len(max_seq_length,tokenizer,tokens,labels=None):
    subwords,token_end_idxs = subword_tokenize(tokenizer,tokens)
    teim=token_end_idxs%(max_seq_length-2)
    breakpoints=(np.argwhere(teim[1:]<teim[:-1]).flatten()+1).tolist()
    split_token_end_idxs=np.array_split(token_end_idxs,breakpoints)
    split_subwords=np.array_split(subwords,np.arange(max_seq_length-2,len(subwords),max_seq_length-2))
    ids=[pad_to_len(max_seq_length,tokenizer.convert_tokens_to_ids(['[CLS]']+list(_)+['[SEP]'])) for _ in split_subwords]
    masks=[position_to_mask(max_seq_length,_) for _ in split_token_end_idxs]
    padded_labels=None
    if labels!=None:
        split_labels=np.array_split(labels,breakpoints)
        padded_labels=[pad_to_len(max_seq_length,align_labels_to_mask(*_)) for _ in zip(masks,split_labels)]
    return ids,masks,padded_labels
    
def chunk_to_len_batch(max_seq_length,tokenizer,tokens,labels=None,labelled=True):
    batch_ids=[]
    batch_masks=[]
    batch_labels=[]
    for i,_ in enumerate(zip(tokens,tokens) if labels==None else zip(tokens,labels)):
        a,b,c=chunk_to_len(max_seq_length,tokenizer,*_) if labels else chunk_to_len(max_seq_length,tokenizer,_[0])
        batch_ids.extend(a)
        batch_masks.extend(b)
        if labelled==True:
            batch_labels.extend(c)
    output = {'input_ids': torch.as_tensor(batch_ids, dtype=torch.long),
              'attention_mask': torch.as_tensor(batch_ids, dtype=torch.bool),
              'subtoken_mask': torch.as_tensor(batch_masks,dtype=torch.bool)*labelled}
    output['labels']=torch.as_tensor(batch_labels,dtype=torch.short) if labelled==True else torch.zeros_like(output['input_ids'],dtype=torch.short)
    return output

In [363]:
# # split='train'
# # o=pd.read_csv(f'{cfg.model.dataset.labelled[0]}.{split}.csv',
# #                   dtype='str',
# #                   header=None,
# #                   chunksize=10)
# # t=next(iter(o))
# sample_out = chunk_examples_with_degree(0)(t[1])
# # tokenizer=AutoTokenizer.from_pretrained(cfg.model.transformer_path)
# # subword_tokenize(sample_out['texts'][0])
# sample_out
# sample_out['texts'],sample_out['tags']
# chunk_to_len_batch(1000,tokenizer,sample_out['texts'][:10],sample_out['tags'][:10])
# chunk_examples_with_degree(0)(t[1])
# chunk_examples_with_degree(0)(['!!Hellooooo! Yay! Bye Enddd.',"Hello"])
inferData=PunctuationInferenceDataset(tokenizer=tokenizer, queries=['!!Hellooooo! Yay! Bye Enddd.',"Hello"], max_seq_length=5,degree=1)
inferData[:]

ic| features['input_ids']: tensor shape: torch.Size([5, 5])
ic| features['attention_mask']: tensor shape: torch.Size([5, 5])
ic| features['subtoken_mask']: tensor shape: torch.Size([5, 5])


{'input_ids': tensor([[  101,  7592,  9541,  9541,   102],
         [  101,   999,  8038,  2100,   102],
         [  101,   999,  9061,  2203,   102],
         [  101, 14141,  1012,   102,     0],
         [  101,  7592,   102,     0,     0]]),
 'attention_mask': tensor([[ True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True],
         [ True,  True,  True,  True, False],
         [ True,  True,  True, False, False]]),
 'subtoken_mask': tensor([[False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False]])}

In [364]:
inferData=PunctuationInferenceDataset(tokenizer=tokenizer, queries=['!!Hellooooo! Yay! Bye Enddd.',"Hello"], max_seq_length=5,degree=0)
inferData[:]

ic| features['input_ids']: tensor shape: torch.Size([4, 5])
ic| features['attention_mask']: tensor shape: torch.Size([4, 5])
ic| features['subtoken_mask']: tensor shape: torch.Size([4, 5])


{'input_ids': tensor([[  101,  7592,  9541,  9541,   102],
         [  101,  8038,  2100,  9061,   102],
         [  101,  2203, 14141,   102,     0],
         [  101,  7592,   102,     0,     0]]),
 'attention_mask': tensor([[ True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True],
         [ True,  True,  True,  True, False],
         [ True,  True,  True, False, False]]),
 'subtoken_mask': tensor([[False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False]])}

In [50]:
class PunctuationInferDataset(Dataset):
    """
    Creates dataset to use during inference for punctuation and capitalization tasks with a pretrained model.
    For dataset to use during training with labels, see BertPunctuationCapitalizationDataset.
    Args:
        queries file to sequences, each line should a sentence, no header.
        max_seq_length: max sequence length minus 2 for [CLS] and [SEP]
        tokenizer: such as AutoTokenizer
    """

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        """Returns definitions of module output ports.
               """
        return {
            'input_ids': NeuralType(('B', 'T'), ChannelType()),
            'attention_mask': NeuralType(('B', 'T'), MaskType()),
        }

    def __init__(self, queries: List[str], max_seq_length: int, tokenizer):
        """ Initializes BertPunctuationInferDataset. """
        features = ic(get_features(queries=queries, max_seq_length=max_seq_length, tokenizer=tokenizer))
        self.all_input_ids = ic(features['input_ids'])
        self.all_attention_mask = ic(features['attention_mask'])

    def __len__(self):
        return len(self.all_input_ids)

    def __getitem__(self, idx):
        return {'input_ids':self.all_input_ids[idx],
            'attention_mask':self.all_attention_mask[idx]}

def get_features(
    queries:str, 
    tokenizer,
    max_seq_length:int,
    degree:int=0,
    punct_label_ids: dict = None,):

    batch_ids=[]
    batch_masks=[]
    for query in queries: #
        #'Hellooooo! Yay! Bye Endd.'
        wordlist=ic(re.split('[^a-zA-Z0-9]+',query,flags=re.V1)) #If end with punctuation, this includes a trailing ''
        if wordlist[-1]=='': #Not necessary since the masks would ignore repeated end idxs.
            wordlist=wordlist[:-1] 
        #['Hellooooo', 'Yay', 'Bye', 'Endd', '']
        subwords=ic(list(map(tokenizer.tokenize,wordlist))) # [['hello', '##oo', '##oo'], ['ya', '##y'], ['bye'], ['end', '##d']]
        subword_lengths=ic(list(map(len,subwords))) # [3, 2, 1, 1]
        subwords=ic(list(flatten(subwords))) # ['hello', '##oo', '##oo', 'ya', '##y', 'bye', 'end', '##d']
        token_end_idxs=ic(np.cumsum([0]+subword_lengths[:-1])+np.array(subword_lengths)-1) #'[2 4 5 6]'
        teim=ic(token_end_idxs%(max_seq_length-2)) #'[2 0 1 2]'
        ic(np.argwhere(teim[1:]<teim[:-1]).flatten()) #[0] returns last labels for each chunk.
        split_token_end_idxs=np.array_split(token_end_idxs,(np.argwhere(teim[1:]<teim[:-1]).flatten()+1).tolist())
        #[array([2]), array([4, 5, 6])]
        ic(split_token_end_idxs)
        split_subwords=ic(np.array_split(subwords,np.arange(max_seq_length-2,len(subwords),max_seq_length-2)))
        #[array(['hello', '##oo', '##oo', 'ya'], dtype='<U5'), array(['##y', 'bye', 'end'], dtype='<U5')]
        ids=ic([pad_to_len(max_seq_length,tokenizer.convert_tokens_to_ids(['[CLS]']+list(_)+['[SEP]'])) for _ in split_subwords])
        #[array([ 101, 7592, 9541, 9541, 8038,  102]), array([ 101, 2100, 9061, 2203,  102,    0])]
        masks=ic([position_to_mask(max_seq_length,_) for _ in split_token_end_idxs])
        batch_ids.append(ids) #[[array([ 101, 7592, 9541, 9541, 8038,  102]), array([ 101, 2100, 9061, 2203, 2094,  102])]]
        batch_masks.append(masks) #[[array([0, 0, 0, 1, 0, 0]), array([0, 1, 1, 1, 0, 0])]]
    
    return ic({'input_ids': torch.as_tensor(batch_ids, dtype=torch.long),
            'attention_mask': torch.as_tensor(batch_masks,dtype=torch.bool)})

In [236]:
import regex as re
ifds=PunctuationInferDataset(queries=['Hellooooo! Yay! Bye Enddd.'], max_seq_length=6, tokenizer=AutoTokenizer.from_pretrained(cfg.model.transformer_path))

ic| re.split('[^a-zA-Z0-9]+',query,flags=re.V1): ['Hellooooo', 'Yay', 'Bye', 'Enddd', '']
ic| list(map(tokenizer.tokenize,wordlist)): [['hello', '##oo', '##oo'], ['ya', '##y'], ['bye'], ['end', '##dd']]
ic| list(map(len,subwords)): [3, 2, 1, 2]
ic| list(flatten(subwords)): ['hello', '##oo', '##oo', 'ya', '##y', 'bye', 'end', '##dd']
ic| np.cumsum([0]+subword_lengths[:-1])+np.array(subword_lengths)-1: array shape: (4,)[2 4 5 7]
ic| token_end_idxs%(max_seq_length-2): array shape: (4,)[2 0 1 3]
ic| np.argwhere(teim[1:]<teim[:-1]).flatten(): array shape: (1,)[0]
ic| split_token_end_idxs: [array([2]), array([4, 5, 7])]
ic| np.array_split(subwords,np.arange(max_seq_length-2,len(subwords),max_seq_length-2)): [array(['hello', '##oo', '##oo', 'ya'], dtype='<U5'), array(['##y', 'bye', 'end', '##dd'], dtype='<U5')]
ic| [pad_ids_to_len(max_seq_length,tokenizer.convert_tokens_to_ids(['[CLS]']+list(_)+['[SEP]'])) for _ in split_subwords]: [array([ 101, 7592, 9541, 9541, 8038,  102]), array([  101,  

In [36]:
def view_aligned(texts,tags,tokenizer,labels_to_ids):
        return [re.sub(' ##','',' '.join([_[0]+_[1] for _ in list(zip(tokenizer.convert_ids_to_tokens(ic(_[0])),
                                                      [labels_to_ids[id] for id in _[1].tolist()
                                                      ]
                                                     )
                                                 )
                         ]
                        )
                      ) for _ in zip(texts,tags)
               ]

In [37]:
# labels_to_ids = {_[1]:_[0] for _ in enumerate(cfg.model.punct_label_ids)}
ids_to_labels = {_[0]:_[1] for _ in enumerate(cfg.model.punct_label_ids)}
t=chunk_examples_with_degree(0)(['Hellooooo!Bye…'])
tokenizer=AutoTokenizer.from_pretrained(cfg.model.transformer_path)
t=chunk_to_len_batch(5,tokenizer,t['texts'],t['tags'])
view_aligned(t['input_ids'],np.array(t['labels']),tokenizer,ids_to_labels)

ic| _[0]: tensor shape: torch.Size([5])
ic| _[0]: tensor shape: torch.Size([5])


['[CLS] hellooooo! [SEP]', '[CLS] bye… [SEP] [PAD] [PAD]']

In [188]:
def flatten(list_of_lists):
    for l in list_of_lists:
        for item in l:
            yield item
list(flatten([[0],[0,1],[0,1,2],[],[],[1]]))

[0, 0, 1, 0, 1, 2, 1]