In [1]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()
torch.rand(10, device=device)
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

Using device: cuda

GeForce GTX 1080 Ti
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [2]:
import hydra
import numpy as np
import pytorch_lightning as pl
import torch
import snoop
from omegaconf import DictConfig, OmegaConf
from transformers import AutoTokenizer

from data import PunctuationDataModule, PunctuationInferenceDataset

snoop.install()

from hydra.experimental import initialize, initialize_config_module, initialize_config_dir, compose
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize()
cfg=compose(
    config_name="config.yaml", 
)
cfg.model.punct_label_ids=OmegaConf.create(sorted(cfg.model.punct_label_ids))
labels_to_ids = {_[1]:_[0] for _ in enumerate(cfg.model.punct_label_ids)}
ids_to_labels = {_[0]:_[1] for _ in enumerate(cfg.model.punct_label_ids)}
cfg


{'seed': 42, 'trainer': {'gpus': 1, 'num_nodes': 1, 'max_epochs': 10, 'max_steps': None, 'accumulate_grad_batches': 4, 'gradient_clip_val': 0, 'amp_level': 'O1', 'precision': 16, 'accelerator': 'ddp', 'checkpoint_callback': False, 'logger': False, 'log_every_n_steps': 1, 'val_check_interval': 1.0, 'resume_from_checkpoint': None}, 'exp_manager': {'exp_dir': '/home/nxingyu/project/', 'name': 'Punctuation_with_Domain_discriminator', 'create_tensorboard_logger': True, 'create_checkpoint_callback': True}, 'base_path': '/home/nxingyu/data', 'tmp_path': '/home/nxingyu/data/tmp', 'model': {'nemo_path': None, 'transformer_path': 'google/electra-small-discriminator', 'unfrozen': 0, 'maximum_unfrozen': 1, 'unfreeze_step': 1, 'punct_label_ids': ['', '!', ',', '-', '.', ':', ';', '?', '—', '…'], 'punct_class_weights': False, 'dataset': {'data_dir': '/home/nxingyu/data', 'labelled': ['${base_path}/ted_talks_processed'], 'unlabelled': ['${base_path}/open_subtitles_processed'], 'max_seq_length': 128, 

In [3]:
from core.losses import FocalDiceLoss, CrossEntropyLoss, LinearChainCRF, AggregatorLoss, FocalLoss

# inp = pp(torch.randn(3, 4, 5, requires_grad=True))
# tar = pp(torch.empty(3, 4, dtype=torch.long).random_(5))
inp = torch.tensor([[[0,1,0],[0,0,1],[0,0,1],[1,0,0]]],dtype=torch.float, requires_grad=True)
tar = torch.tensor([[0,1,2,0]],dtype=torch.long)
mask=torch.tensor([[1,1,1,1]],dtype=torch.bool)

loss = CrossEntropyLoss(reduction='none')
output = loss(inp, tar)
# output.backward()
pp('cel none',output)

loss = CrossEntropyLoss(reduction='mean')
output = loss(inp, tar)
# output.backward()
pp('cel mean',output)

loss = CrossEntropyLoss(reduction='sum')
output = loss(inp, tar)
# output.backward()
pp('cel sum',output)

loss = FocalLoss(reduction='sum')
output = loss(inp, tar)
# output.backward()
pp('focal sum',loss(inp, tar))

loss = FocalLoss(reduction='mean')
output = loss(inp, tar)
# output.backward()
pp('focal mean',loss(inp, tar))

loss = FocalLoss(reduction='none')
output = loss(inp, tar)
# output.backward()
pp('focal none',loss(inp, tar))

loss = FocalLoss(reduction='none', gamma=5)
output = loss(inp, tar)
# output.backward()
pp('focal none',loss(inp, tar))

loss = LinearChainCRF(num_labels=3,reduction='none')
output = loss(inp, tar,mask)
# output.backward()
pp('crf,none',output)

loss = LinearChainCRF(num_labels=3,reduction='mean')
output = loss(inp, tar,mask)
# output.backward()
pp('crf,mean',output)

loss = LinearChainCRF(num_labels=3,reduction='sum')
output = loss(inp, tar,mask)
# output.backward()
pp('crf,sum',output)

loss = LinearChainCRF(num_labels=3,reduction='token_mean')
output = loss(inp, tar,mask)
# output.backward()
pp('crf,token_mean',output)

loss = FocalDiceLoss(reduction='none', macro_average=False)
output = loss(inp, tar)
# output.backward()
pp('dice none,micro',output)

loss = FocalDiceLoss(reduction='mean', macro_average=False)
output = loss(inp, tar)
# output.backward()
pp('dice mean,micro',output)

loss = FocalDiceLoss(reduction='sum', macro_average=False)
output = loss(inp, tar)
# output.backward()
pp('dice sum,micro',output)

loss = FocalDiceLoss(reduction='sum', alpha=3, macro_average=False)
output = loss(inp, tar)
# output.backward()
pp('dice sum,micro',output)

loss = FocalDiceLoss(reduction='none',macro_average=True,alpha=5.0, log_softmax=True)
output = loss(inp, tar)
# output.backward()
pp('dice none,macro',loss(inp, tar))

loss = FocalDiceLoss(reduction='mean',macro_average=True)
output = loss(inp, tar)
# output.backward()
pp('dice mean,macro',loss(inp, tar))

loss = FocalDiceLoss(reduction='sum',macro_average=True)
output = loss(inp, tar)
# output.backward()
pp('dice sum,macro',loss(inp, tar))

loss = FocalDiceLoss(reduction='none',macro_average=True,alpha=1.0)
output = loss(inp, tar)
# output.backward()
pp('dice none,macro',loss(inp, tar))

loss = FocalDiceLoss(reduction='none',macro_average=True, alpha=3)
output = loss(inp, tar)
# output.backward()
pp('dice none,macro',loss(inp, tar))

inp = torch.tensor([[[0,1,0],[1,0,1],[0,0,1],[0,1,0]]],dtype=torch.float, requires_grad=True)
tar = torch.tensor([[0,1,2,0]],dtype=torch.long)
mask=torch.tensor([[1,1,1,1]],dtype=torch.bool)

loss = FocalDiceLoss(reduction='none', alpha=1, macro_average=True)
output = loss(inp, tar)
# output.backward()
pp('dice sum,macro',output)

loss = FocalDiceLoss(reduction='none', alpha=3, macro_average=True)
output = loss(inp, tar)
# output.backward()
pp('dice sum,macro',output)

10:11:13.98 LOG:
10:11:14.02 .... 'cel none' = 'cel none'
10:11:14.02 .... output = tensor([1.5514, 1.5514, 0.5514, 0.5514], grad_fn=<NllLossBackward>)
10:11:14.02 LOG:
10:11:14.02 .... 'cel mean' = 'cel mean'
10:11:14.02 .... output = tensor(1.0514, grad_fn=<NllLossBackward>)
10:11:14.03 LOG:
10:11:14.03 .... 'cel sum' = 'cel sum'
10:11:14.03 .... output = tensor(4.2058, grad_fn=<NllLossBackward>)
10:11:14.08 LOG:
10:11:14.08 .... 'focal sum' = 'focal sum'
10:11:14.08 .... loss(inp, tar) = tensor(6.7352, grad_fn=<SumBackward0>)
10:11:14.08 LOG:
10:11:14.08 .... 'focal mean' = 'focal mean'
10:11:14.09 .... loss(inp, tar) = tensor(0.4210, grad_fn=<MeanBackward0>)
10:11:14.09 LOG:
10:11:14.09 .... 'focal none' = 'focal none'
10:11:14.10 .... loss(inp, tar) = tensor([0.9635, 0.9635, 0.0991, 0.0991], grad_fn=<MulBackward0>)
10:11:14.10 LOG:
10:11:14.10 .... 'focal none' = 'focal none'
10:11:14.10 .... loss(inp, tar) = tensor([0.4716, 0.4716, 0.0075, 0.0075], grad_fn=<MulBackward0>)
10:11:1

('dice sum,macro', tensor([0.2148, 0.4559, 0.0140], grad_fn=<MulBackward0>))

In [4]:
data_config = cfg.model.dataset
cfg.model.punct_label_ids=OmegaConf.create(sorted(cfg.model.punct_label_ids))
labels_to_ids = {_[1]:_[0] for _ in enumerate(cfg.model.punct_label_ids)}
data_config.num_labels=len(cfg.model.punct_label_ids)
data_config.labelled = OmegaConf.create([] if data_config.labelled==None else data_config.labelled)
data_config.unlabelled = OmegaConf.create([] if data_config.unlabelled==None else data_config.unlabelled)
data_config.num_domains = len(data_config.labelled)+len(data_config.unlabelled)
dm=PunctuationDataModule(
    tokenizer= cfg.model.transformer_path,
    labelled= list(data_config.labelled),
    unlabelled= list(data_config.unlabelled),
    punct_label_ids= labels_to_ids,
    train_batch_size= data_config.train_ds.batch_size,
    max_seq_length= data_config.max_seq_length,
    val_batch_size= data_config.validation_ds.batch_size,
    num_workers= data_config.num_workers,
    pin_memory= data_config.pin_memory,
    train_shuffle= data_config.train_ds.shuffle,
    val_shuffle= data_config.validation_ds.shuffle,
    seed=cfg.seed,
    data_id='0'
)
dm.setup()
len(dm.train_dataset)

KeyboardInterrupt: 

In [None]:
# it=dm.train_dataset
# ni=next(it)
# it=dm.train_dataset.datasets[0]
dm.train_dataset.__len__()#determine_class_weights()
# ct=torch.zeros(10)
# for _ in range(64):
#     print('.',end='')
#     ni=next(it)
#     ct+=torch.bincount(ni['labels'].view(-1))
# return ct/sum(ct)

In [19]:
# from core.optim import get_optimizer
optimizer=get_optimizer('adam')
optimizer=optimizer([torch.tensor(5)],**{'lr':0.01})

In [3]:
loss = FocalDiceLoss(reduction='mean',macro_average=True)
output1 = loss(inp, tar)
loss1 = LinearChainCRF(num_labels=5,reduction='token_mean')
output2 = loss1(inp, tar)
agg=AggregatorLoss(num_inputs=2,weights=[0.5,1])
agg(a=output1,b=output2)

tensor(2.4794, grad_fn=<AddBackward0>)

In [9]:
torch.nn.functional.one_hot(LinearChainCRF(num_labels=5,reduction='sum').decode(inp).flatten().long(),5)

# inp

tensor([[1, 0, 0, 0, 0],
        [1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 1, 0, 0],
        [1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 1, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 1, 0, 0]])

In [77]:
dm=PunctuationDataModule(
            tokenizer= cfg.model.transformer_path,
            labelled= list(cfg.model.dataset.labelled),
            unlabelled= list(cfg.model.dataset.unlabelled),
            punct_label_ids= labels_to_ids,
            train_batch_size= cfg.model.train_ds.batch_size,
            max_seq_length= cfg.model.dataset.max_seq_length,
            val_batch_size= cfg.model.validation_ds.batch_size,
            num_workers= cfg.model.dataset.num_workers,
            pin_memory= cfg.model.dataset.pin_memory,
            train_shuffle= cfg.model.train_ds.shuffle,
            val_shuffle= cfg.model.validation_ds.shuffle,
)
dm.setup('fit')
dl=dm.train_dataloader()

In [79]:
next(iter(dl))

{'input_ids': tensor([[  101,  1005,  1056,  ...,  2748,  2008,   102],
         [  101,  3818,  2000,  ...,  2202,  2115,   102],
         [  101,  2599, 19366,  ...,  8491, 23161,   102],
         ...,
         [  101,  2037, 15451,  ...,  2035,  1997,   102],
         [  101,  2041,  1997,  ...,  1997, 15451,   102],
         [  101,  2028,  2182,  ...,  2518,  7910,   102]]),
 'attention_mask': tensor([[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]]),
 'subtoken_mask': tensor([[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False,  True, False],
         ...,
         [False, False, False,  ..., False, False, False]

In [10]:
# from transformers import AutoModel
# transformer=AutoModel.from_pretrained('google/electra-small-discriminator')
transformer.encoder.layer.__len__()

12

In [59]:
# torch.cat([torch.ones(10,1),torch.ones(10,1)],dim=-2)
# torch.cat([torch.ones(50,25,10),torch.ones(50,25,10)],dim=0).shape

torch.Size([100, 25, 10])

In [38]:
# import transformers
# t=transformers.AutoModel.from_pretrained('distilbert-base-uncased')
t.embeddings

Embeddings(
  (word_embeddings): Embedding(30522, 768, padding_idx=0)
  (position_embeddings): Embedding(512, 768)
  (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [None]:

class PunctuationDomainModel(pl.LightningModule):

    @property
    def input_types(self) -> Optional[Dict[str, NeuralType]]:
        return {
            "input_ids": NeuralType(('B', 'T'), ChannelType()),
            "attention_mask": NeuralType(('B', 'T'), ChannelType()),
            "subtoken_mask": NeuralType(('B', 'T'), ChannelType()),
            "labels": NeuralType(('B', 'T'), ChannelType()),
            "domain": NeuralType(('B'), ChannelType()),
        }
    
    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        return {
            "punct_logits": NeuralType(('B', 'T', 'D'), LogitsType()),
            "domain_logits": NeuralType(('B', 'D'), LogitsType()),
        }

    def __init__(self, cfg: DictConfig): 
        # num_labels: int = 10, 
        # embedding_dim: int = 768, 
        # lossfn: str = '', 
        # hidden_dropout_prob:float=0.1, 
        # base_model_path:str='google/electra-base-discriminator', 
        # reduction:str='mean',
        # stride:int=256,
        # unfrozen_layers=0,
        # alpha='0.8',
        # gamma='2',
        # lbd=1, #coefficient of gradient reversal.
        # domains: int = 1):
        # self.setup_tokenizer(cfg.tokenizer)
        super().__init__(label)

        self._cfg.punct_label_ids=OmegaConf.create(sorted(self._cfg.punct_label_ids))
        self.labels_to_ids = {_[0]:_[1] for _ in enumerate(self._cfg.punct_label_ids)}
        self.ids_to_labels = {_[1]:_[0] for _ in enumerate(self._cfg.punct_label_ids)}
        self.num_domains = len(self._cfg.dataset.labelled)+len(self._cfg.dataset.unlabelled)
        
        self.bert_model = get_lm_model(
            pretrained_model_name=cfg.language_model.pretrained_model_name,
            config_file=cfg.language_model.config_file,
            config_dict=OmegaConf.to_container(cfg.language_model.config) if cfg.language_model.config else None,
            checkpoint_file=cfg.language_model.lm_checkpoint,
        )

        self.punct_classifier = TokenClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=len(self._cfg.punct_label_ids),
            activation=cfg.punct_head.activation,
            log_softmax=False,
            dropout=cfg.punct_head.fc_dropout,
            num_layers=cfg.punct_head.punct_num_fc_layers,
            use_transformer_init=cfg.punct_head.use_transformer_init,
        )

        self.domain_classifier = SequenceClassifier(
            hidden_size=self.bert_model.config.hidden_size,
            num_classes=self.num_domains,
            num_layers=cfg.domain_head.domain_num_fc_layers,
            activation=cfg.domain_head.activation,
            log_softmax=False,
            dropout=cfg.domain_head.fc_dropout,
            use_transformer_init=cfg.domain_head.use_transformer_init,
        )

        self.punctuation_loss = CrossEntropyLoss(logits_ndim=3)
        self.domain_loss = CrossEntropyLoss(logits_ndim=2)
        self.agg_loss = AggregatorLoss(num_inputs=2)

        self.punct_class_report = ClassificationReport(
            num_classes=len(self._cfg.punct_label_ids),
            label_ids=self.labels_to_ids,
            mode='macro',
            dist_sync_on_step=True,
        )
        self.domain_class_report = ClassificationReport(
            num_classes=self.num_domains,
            label_ids=list(range(self.num_domains)),
            mode='macro',
            dist_sync_on_step=True,
        )


    @typecheck()
    def forward(self, input_ids, attention_mask, token_type_ids=None, domain_ids=None):
        """
        No special modification required for Lightning, define it as you normally would
        in the `nn.Module` in vanilla PyTorch.
        """
        hidden_states = self.bert_model(
            input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask
        )
        punct_logits = self.punct_classifier(hidden_states=hidden_states)
        reverse_grad_hidden_states = self.grad_reverse.apply(hidden_states)
        domain_logits = self.domain_classifier(hidden_states=reverse_grad_hidden_states)
        return punct_logits, domain_logits

    def _make_step(self, batch):
        input_ids=batch['input_ids']
        attention_mask=batch['attention_mask']
        punct_labels=batch['labels']
        domain_labels=batch['domain'][:,0,:]
        # input_ids, input_type_ids, input_mask, subtokens_mask, loss_mask, punct_labels, domain_labels = batch
        punct_logits, domain_logits = self(
            input_ids=input_ids, attention_mask=attention_mask
        )

        punct_loss = self.punct_loss(logits=punct_logits, labels=punct_labels, loss_mask=attention_mask)
        domain_loss = self.domain_loss(logits=domain_logits, labels=domain_labels)
        loss = self.agg_loss(loss_1=punct_loss, loss_2=domain_loss)
        return loss, punct_logits, domain_logits
    
    def training_step(self, batch, batch_idx):
        """
        Lightning calls this inside the training loop with the data from the training dataloader
        passed in as `batch`.
        """
        loss, _, _ = self._make_step(batch)
        lr = self._optimizer.param_groups[0]['lr']

        self.log('lr', lr, prog_bar=True)
        self.log('train_loss', loss)

        return {'loss': loss, 'lr': lr}
    
    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        input_ids=batch['input_ids']
        attention_mask=batch['attention_mask']
        punct_labels=batch['labels']
        domain_labels=batch['domain'][:,0,:]

        val_loss, punct_logits, domain_logits = self._make_step(batch)

        # attention_mask = attention_mask > 0.5
        punct_preds = torch.argmax(punct_logits, axis=-1)[attention_mask]
        punct_labels = punct_labels[attention_mask]
        self.punct_class_report.update(punct_preds, punct_labels)

        domain_preds = torch.argmax(domain_logits, axis=-1)[attention_mask]
        domain_labels = domain_labels[attention_mask]
        self.domain_class_report.update(domain_preds, domain_labels)

        return {
            'val_loss': val_loss,
            'punct_tp': self.punct_class_report.tp,
            'punct_fn': self.punct_class_report.fn,
            'punct_fp': self.punct_class_report.fp,
            'domain_tp': self.domain_class_report.tp,
            'domain_fn': self.domain_class_report.fn,
            'domain_fp': self.domain_class_report.fp,
        }

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        """
        Lightning calls this inside the validation loop with the data from the validation dataloader
        passed in as `batch`.
        """
        input_ids=batch['input_ids']
        attention_mask=batch['attention_mask']
        punct_labels=batch['labels']
        domain_labels=batch['domain'][:,0,:]

        test_loss, punct_logits, domain_logits = self._make_step(batch)

        # attention_mask = attention_mask > 0.5
        punct_preds = torch.argmax(punct_logits, axis=-1)[attention_mask]
        punct_labels = punct_labels[attention_mask]
        self.punct_class_report.update(punct_preds, punct_labels)

        domain_preds = torch.argmax(domain_logits, axis=-1)[attention_mask]
        domain_labels = domain_labels[attention_mask]
        self.domain_class_report.update(domain_preds, domain_labels)

        return {
            'test_loss': test_loss,
            'punct_tp': self.punct_class_report.tp,
            'punct_fn': self.punct_class_report.fn,
            'punct_fp': self.punct_class_report.fp,
            'domain_tp': self.domain_class_report.tp,
            'domain_fn': self.domain_class_report.fn,
            'domain_fp': self.domain_class_report.fp,
        }
    
    def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
        """
        Called at the end of validation to aggregate outputs.
        outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        # calculate metrics and log classification report for Punctuation task
        punct_precision, punct_recall, punct_f1, punct_report = self.punct_class_report.compute()
        logging.info(f'Punctuation report: {punct_report}')

        # calculate metrics and log classification report for domainalization task
        domain_precision, domain_recall, domain_f1, domain_report = self.domain_class_report.compute()
        logging.info(f'Domain report: {domain_report}')

        self.log('val_loss', avg_loss, prog_bar=True)
        self.log('punct_precision', punct_precision)
        self.log('punct_f1', punct_f1)
        self.log('punct_recall', punct_recall)
        self.log('domain_precision', domain_precision)
        self.log('domain_f1', domain_f1)
        self.log('domain_recall', domain_recall)

    def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
        """
            Called at the end of test to aggregate outputs.
            outputs: list of individual outputs of each validation step.
        """
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()

        # calculate metrics and log classification report for Punctuation task
        punct_precision, punct_recall, punct_f1, punct_report = self.punct_class_report.compute()
        logging.info(f'Punctuation report: {punct_report}')

        # calculate metrics and log classification report for domainalization task
        domain_precision, domain_recall, domain_f1, domain_report = self.domain_class_report.compute()
        logging.info(f'Domain report: {domain_report}')

        self.log('test_loss', avg_loss, prog_bar=True)
        self.log('punct_precision', punct_precision)
        self.log('punct_f1', punct_f1)
        self.log('punct_recall', punct_recall)
        self.log('domain_precision', domain_precision)
        self.log('domain_f1', domain_f1)
        self.log('domain_recall', domain_recall)
    
    def update_data_dir(self, data_dir: str) -> None:
        """
        Update data directory
        Args:
            data_dir: path to data directory
        """
        if os.path.exists(data_dir):
            logging.info(f'Setting model.dataset.data_dir to {data_dir}.')
            self._cfg.dataset.data_dir = data_dir
        else:
            raise ValueError(f'{data_dir} not found')

    def setup_datamodule(self, cfg: Optional[DictConfig] = None):
        if cfg is None:
            cfg = self._cfg.train_ds
            
        self.data_module = PunctuationDataModule(
            labelled=list(cfg.model.dataset.labelled),
            unlabeled=list(cfg.model.dataset.unlabelled),
            train_batch_size=cfg.model.train_ds.batch_size,
            val_batch_size=cfg.model.validation_ds.batch_size,
            max_seq_length=self._cfg.max_seq_length,
            num_workers=cfg.model.dataset.num_workers,
            pin_memory=cfg.model.dataset.pin_memory,
            drop_last=cfg.model.dataset.drop_last,
            tokenizer=self.tokenizer,
        )
        self._train_dl=self.data_module.train_dataloader
        self._validation_dl=self.data_module.dev_dataloader
        self._test_dl=self.data_module.test_dataloader
    
    def _setup_infer_dataloader(self, queries: List[str], batch_size: int) -> 'torch.utils.data.DataLoader':
        """
        Setup function for a infer data loader.
        Args:
            queries: lower cased text without punctuation
            batch_size: batch size to use during inference
        Returns:
            A pytorch DataLoader.
        """

        dataset = BertPunctuationInferDataset(
            tokenizer=self.tokenizer, queries=queries, max_seq_length=self._cfg.dataset.max_seq_length
        )

        return torch.utils.data.DataLoader(
            dataset=dataset,
            collate_fn=dataset.collate_fn,
            batch_size=batch_size,
            shuffle=False,
            num_workers=self._cfg.dataset.num_workers,
            pin_memory=self._cfg.dataset.pin_memory,
            drop_last=False,
        )

    def add_punctuation_capitalization(self, queries: List[str], batch_size: int = None) -> List[str]:
        """
        Adds punctuation and capitalization to the queries. Use this method for debugging and prototyping.
        Args:
            queries: Text
            batch_size: batch size to use during inference
        Returns:
            result: text with punctuation
        """

        if queries is None or len(queries) == 0:
            return []

        if batch_size is None:
            batch_size = len(queries)
            logging.info(f'Using batch size {batch_size} for inference')

        # We will store the output here
        result = []

        # Model's mode and device
        mode = self.training
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        try:
            # Switch model to evaluation mode
            self.eval()
            self = self.to(device)
            infer_datalayer = self._setup_infer_dataloader(queries, batch_size)

            # store predictions for all queries in a single list
            all_punct_preds = []

            for batch in infer_datalayer:
                input_ids = batch['input_ids']
                attention_mask = batch['attention_mask']

                punct_logits, _ = self.forward(
                    input_ids=input_ids.to(device),
                    attention_mask=input_mask.to(device),
                )
                punct_preds = tensor2list(torch.argmax(punct_logits, axis=-1)[subtokens_mask])
                all_punct_preds.extend(punct_preds)
            id2tag = {v: k for k, v in self._cfg.punct_label_ids.items()}
            result.extend([' '.join([_[0]+_[1] for _ in \
                list(zip(self.tokenizer.convert_ids_to_tokens(_[0]),
                            [id2tag[id] for id in _[1].tolist()])
                    )]) for _ in zip(infer_datalayer['input_ids'],all_punct_preds)])
        finally:
            # set mode back to its original value
            self.train(mode=mode)
        return result


#%%
'''
        self.num_labels=num_labels
        self.embedding_dim=embedding_dim
        self.domains = domains
        self.reduction=reduction
        self.unfrozen_layers=unfrozen_layers
        self.alpha=alpha
        self.gamma=gamma
        self.lossfn=lossfn
        self.stride=stride
        self.grad_reverse=GradientReverse
        self.grad_reverse.scale=lbd
        self.dropout = torch.nn.Dropout(hidden_dropout_prob)
        self.transformer = transformers.ElectraModel.from_pretrained(base_model_path)
        self.freeze()
        self.fcl = torch.nn.Linear(self.embedding_dim, self.num_labels)
        if lossfn == 'crf':
            self.loss=DiceCRF(self.num_labels,reduction=self.reduction)
        elif lossfn == 'dice':
            self.loss=DiceLoss(gamma=self.gamma,alpha=self.alpha, num_classes=self.num_labels, reduction=self.reduction)
        else:
            self.loss=CrossEntropyLoss(reduction=self.reduction)
        if self.domains>1:
            self.domainfcl=torch.nn.Linear(self.embedding_dim, self.domains)
            self.domain_loss=CrossEntropyLoss(reduction=self.reduction, punct_classifier=False)
            self.agg_loss=AggregatorLoss(weights=[1,0.5])
            
        self.punct_class_report = ClassificationReport(
            num_classes=self.num_labels,
            label_ids={'': 0, '!': 1, ',': 2, '-': 3, '.': 4, ':': 5, ';': 6, '?': 7, '—': 8, '…': 9},
            mode='macro',
            dist_sync_on_step=True,
        )
        if self.domains>1:
            self.domain_class_report = ClassificationReport(
                num_classes=self.domains,
                mode='macro',
                dist_sync_on_step=True)
        
        

    def forward(self, x):
        o1 = self.transformer(x['input_ids'],x['attention_mask'])[0]
        d1 = self.dropout(o1)
        p = self.fcl(d1)
        if self.domains>1: ##relook
            d1r= self.grad_reverse.apply(d1)
            d= self.domainfcl(d1r[:,0,:])
            return p, d
        return p

    def _make_step(self, batch):
        punct_logits, domain_logits = self(batch)
        print('make_step',punct_logits.shape,domain_logits.shape)
        punct_loss = self.loss(punct_logits, batch['labels'], batch['attention_mask'])
        if self.domains>1:
            domain_loss = self.domain_loss(domain_logits, batch['domain'])
        loss = punct_loss if self.domains==1 else self.agg_loss(loss_1=punct_loss, loss_2=domain_loss)
        return loss, punct_logits, domain_logits

    def training_step(self, batch, batch_idx):
        loss,_,_ = self._make_step(batch)
        lr = self._optimizer.param_groups[0]['lr']
        self.log('lr', lr, prog_bar=True)
        self.log('train_loss', loss)
        return {'loss': loss, 'lr': lr}

    def validation_step(self, batch, batch_idx):
        val_loss, punct_logits, domain_logits = self._make_step(batch)
        punct_preds = F.one_hot(self.loss.decode(punct_logits, batch['attention_mask']).flatten(),self.num_labels).to(device) if self.lossfn=='crf' else punct_logits.view(-1,self.num_labels)
        punct_labels = F.one_hot(batch['labels'].flatten(),self.num_labels)
        print('punct pred, labels',punct_preds.shape,punct_labels.shape)
        self.punct_class_report.update(punct_preds, punct_labels)
        if self.domains>1:
            domain_labels=F.one_hot(batch['domain'].flatten(),self.domains)
            domain_preds = domain_logits.view(-1,self.domains)
            print('domain pred,label,logits',domain_preds.shape,domain_labels.shape, domain_logits.shape)
            self.domain_class_report.update(domain_preds, domain_labels)
            return {
                'val_loss': val_loss,
                'punct_tp': self.punct_class_report.tp,
                'punct_fn': self.punct_class_report.fn,
                'punct_fp': self.punct_class_report.fp,
                'domain_tp': self.domain_class_report.tp,
                'domain_fn': self.domain_class_report.fn,
                'domain_fp': self.domain_class_report.fp,
            }
        return {
            'val_loss': val_loss,
            'punct_tp': self.punct_class_report.tp,
            'punct_fn': self.punct_class_report.fn,
            'punct_fp': self.punct_class_report.fp,
        }

    def test_step(self, batch, batch_idx):
        test_loss, punct_logits, domain_logits = self._make_step(batch)
        punct_preds = F.one_hot(self.loss.decode(punct_logits, batch['attention_mask']).flatten(),self.num_labels).to(device) if self.loss_fn=='crf' else punct_logits.view(-1,self.num_labels)
        punct_labels = F.one_hot(batch['labels'].flatten(),self.num_labels)
        self.punct_class_report.update(punct_preds, punct_labels)
        if self.domains>1:
            domain_labels=F.one_hot(batch['domain'],self.domains)
            domain_preds = domain_logits.view(-1,self.domains)
            self.domain_class_report.update(domain_preds, domain_labels)
            return {
                'test_loss': test_loss,
                'punct_tp': self.punct_class_report.tp,
                'punct_fn': self.punct_class_report.fn,
                'punct_fp': self.punct_class_report.fp,
                'domain_tp': self.domain_class_report.tp,
                'domain_fn': self.domain_class_report.fn,
                'domain_fp': self.domain_class_report.fp,
            }
        return {
                'test_loss': test_loss,
                'punct_tp': self.punct_class_report.tp,
                'punct_fn': self.punct_class_report.fn,
                'punct_fp': self.punct_class_report.fp,
            }
    #https://github.com/NVIDIA/NeMo/blob/bb86f88143c89231f970e5b6bd9f78999fc45a90/nemo/collections/nlp/models/token_classification/punctuation_domainalization_model.py#L42
    def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()

        # calculate metrics and log classification report for Punctuation task
        punct_precision, punct_recall, punct_f1, punct_report = self.punct_class_report.compute()
        logging.info(f'Punctuation report: {punct_report}')
        self.log('val_loss', avg_loss, prog_bar=True)
        self.log('punct_precision', punct_precision)
        self.log('punct_f1', punct_f1)
        self.log('punct_recall', punct_recall)
        if self.domains>1:
            # calculate metrics and log classification report for domainalization task
            domain_precision, domain_recall, domain_f1, domain_report = self.domain_class_report.compute()
            logging.info(f'Domain report: {domain_report}')
            self.log('domain_precision', domain_precision)
            self.log('domain_f1', domain_f1)
            self.log('domain_recall', domain_recall)
    def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        # calculate metrics and log classification report for Punctuation task
        punct_precision, punct_recall, punct_f1, punct_report = self.punct_class_report.compute()
        logging.info(f'Punctuation report: {punct_report}')
        # calculate metrics and log classification report for domainalization task
        self.log('test_loss', avg_loss, prog_bar=True)
        self.log('punct_precision', punct_precision)
        self.log('punct_f1', punct_f1)
        self.log('punct_recall', punct_recall)
        if self.domains>1:
            domain_precision, domain_recall, domain_f1, domain_report = self.domain_class_report.compute()
            logging.info(f'Domain report: {domain_report}')
            self.log('domain_precision', domain_precision)
            self.log('domain_f1', domain_f1)
            self.log('domain_recall', domain_recall)
        
    def freeze_transformer_to(self, n:int, exclude_types=(torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d)) -> None:
        """Freeze layers up to layer group `n`.
        Look at each group, and freeze each paraemeter, except excluded types
        """
        print(f"freeze 1st {n} encoder layers of transformer")
        def set_requires_grad_for_module(module: torch.nn.Module, requires_grad: bool):
            "Sets each parameter in lthe module to the `requires_grad` value"
            params = list(module.parameters())
            for param in params: 
                param.requires_grad = requires_grad
            
        for layer in list(self.transformer.encoder.layer)[:n]:
            if not isinstance(layer, exclude_types): 
                set_requires_grad_for_module(layer, False)
        
        for layer in list(self.transformer.encoder.layer)[n:]:
            set_requires_grad_for_module(layer, True)

    def freeze(self) -> None:
        for param in self.transformer.embeddings.parameters():
            param.requires_grad=False

        self.frozen=len(self.transformer.encoder.layer)
        self.freeze_transformer_to(self.frozen)

    def unfreeze(self,i:int=1):
        self.freeze_transformer_to(max(0,self.frozen-i))
        self.frozen-=1;

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)
'''


# DataLoading

In [2]:
from torch.utils.data import Dataset

from nemo.core.neural_types import ChannelType, LabelsType, MaskType, NeuralType
from transformers import AutoTokenizer
import numpy as np
from typing import List, Optional, Dict
import pandas as pd
import os
import torch
import subprocess

class PunctuationDomainDataset(Dataset):

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        """Returns definitions of module output ports.
               """
        return {
            "input_ids": NeuralType(('B', 'T'), ChannelType()),
            "attention_mask": NeuralType(('B', 'T'), ChannelType()),
            "labels": NeuralType(('B', 'T'), ChannelType()),
            "domain": NeuralType(('B'), ChannelType()),
        }

    def __init__(self, 
        csv_file:str, 
        tokenizer,
        num_samples:int=256,
        max_seq_length:int=256,
        punct_label_ids: Dict[str, int] = None,
        domain=0,
        labelled=True,
    ):
        if not (os.path.exists(csv_file)):
            raise FileNotFoundError(
                f'{csv_file} not found. The data should be joined in 1 csv file.\
                    Each line of the file contains the subword token ids, masks and class labels per row.'
            )

        data_dir = os.path.dirname(csv_file)
        filename = os.path.basename(csv_file)

        if not filename.endswith('.csv'):
            raise ValueError("{text_file} should have extension .csv")
        # filename = filename[:-4]
        
        self.csv_file = csv_file
        self.max_seq_length = max_seq_length
        self.set_num_samples(csv_file, num_samples)
        self.domain=domain
        self.labelled=labelled
        self.tokenizer=tokenizer

    def __getitem__(self, idx):
        x = next(
            pd.read_csv(
                self.csv_file,
                skiprows=(idx % self.len)*self.num_samples,
                chunksize=self.num_samples,
                header=None,
                delimiter=' '))
        x = torch.from_numpy(x.values).reshape(-1,3,self.max_seq_length) #x.shape[-1]//3
        return {'input_ids': torch.as_tensor(x[:,0,:], dtype=torch.long),
                'attention_mask': torch.as_tensor(x[:,1,:],dtype=torch.bool)if self.labelled else torch.zeros_like(x[:,1,:],dtype=torch.bool),
                'labels': torch.as_tensor(x[:,2,:],dtype=torch.long),
                'domain':self.domain*torch.ones(x.shape[0],1,dtype=torch.long)}

    def set_num_samples(self,csv_file,num_samples):
        self.num_samples = num_samples
        self.total_samples=int(subprocess.Popen(['wc', '-l', csv_file], stdout=subprocess.PIPE, stderr=subprocess.STDOUT).communicate()[0].split()[0])
        self.len = int(self.total_samples / self.num_samples)
        

    def __len__(self):
        return self.len
    
    def view(d)->list:
        """:param d(dictionary): returns readable format of single input_ids and labels in the form of readable text"""
        a,_,c=d.values()
        return [' '.join([_[0]+_[1] for _ in list(zip(self.tokenizer.convert_ids_to_tokens(_[0]),[id2tag[id] for id in _[1].tolist()]))]) for _ in zip(a,c)]
    
    def shuffle(self, sorted=False, seed=42):
        os.system('bash data/shuffle.sh -i {} -o {} -a {} -s {}'.format(self.csv_file, self.csv_file, ['false','true'][sorted], seed))

class PunctuationDomainDatasets(Dataset):
    
    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        """Returns definitions of module output ports. """
        return {
            "input_ids": NeuralType(('B', 'T'), ChannelType()),
            "attention_mask": NeuralType(('B', 'T'), ChannelType()),
            "labels": NeuralType(('B', 'T'), ChannelType()),
            "domain": NeuralType(('B'), ChannelType()),
        }

    def __init__(self, *datasets):
        self.datasets = datasets

    def __getitem__(self, i):
        b={k:torch.vstack([d[i][k] for d in self.datasets]) for k in ['input_ids','attention_mask','labels','domain']}
        rand=torch.randperm(b['labels'].size()[0])
        return {k:v[rand] for k,v in b.items()}

    def __len__(self):
        return max(len(d) for d in self.datasets)

class PunctuationInferenceDataset(Dataset):
    """
    Creates dataset to use during inference for punctuation and capitalization tasks with a pretrained model.
    For dataset to use during training with labels, see BertPunctuationCapitalizationDataset.
    Args:
        queries file to sequences, each line should a sentence, no header.
        max_seq_length: max sequence length minus 2 for [CLS] and [SEP]
        tokenizer: such as AutoTokenizer
    """

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        """Returns definitions of module output ports.
               """
        return {
            'input_ids': NeuralType(('B', 'T'), ChannelType()),
            'attention_mask': NeuralType(('B', 'T'), MaskType()),
        }

    def __init__(self, queries: List[str], max_seq_length: int, tokenizer):
        """ Initializes BertPunctuationInferDataset. """
        features = get_features(queries=queries, max_seq_length=max_seq_length, tokenizer=tokenizer)
        self.all_input_ids = features['input_ids']
        self.all_attention_mask = features['attention_mask']

    def __len__(self):
        return len(self.all_input_ids)

    def __getitem__(self, idx):
        return {'input_ids':self.all_input_ids[idx],
            'attention_mask':self.all_attention_mask[idx],}
        

def get_features(
    queries:str, 
    max_seq_length:int,
    tokenizer,
    punct_label_ids: dict = None,):

    def flatten(list_of_lists):
        for list in list_of_lists:
            for item in list:
                yield item

    def pad_to_len(max_length,ids):
        o=np.zeros(max_length, dtype=np.int)
        o[:len(ids)]=np.array(ids)
        return o

    def position_to_mask(max_length,indices):
        o=np.zeros(max_length,dtype=np.int)
        o[indices%(max_length-2)+1]=1
        return o

    batch_ids=[]
    batch_masks=[]
    for query in queries:
        wordlist=re.split('[^a-zA-Z0-9]+',query)
        subwords=list(map(tokenizer.tokenize,wordlist))
        subword_lengths=list(map(len,subwords))
        subwords=list(flatten(subwords))
        token_end_idxs=np.cumsum([0]+subword_lengths[:-1])+np.array(subword_lengths)-1
        teim=token_end_idxs%(max_seq_length-2)
        split_token_end_idxs=np.array_split(token_end_idxs,(np.argwhere((teim[1:])<teim[:-1]).flatten()+1).tolist())
        split_subwords=np.array_split(subwords,np.arange(max_length-2,len(subwords),max_seq_length-2)) 
        ids=torch.tensor([pad_to_len(max_seq_length,tokenizer.convert_tokens_to_ids(['[CLS]']+list(_)+['[SEP]'])) for _ in split_subwords], dtype=torch.long)
        masks=[position_to_mask(max_length,_) for _ in split_token_end_idxs]
        batch_ids.append(ids)
        batch_masks.append(masks)
    return {'input_ids': torch.as_tensor(batch_ids, dtype=torch.long),
            'attention_mask': torch.as_tensor(batch_masks,dtype=torch.bool)}

In [16]:
cfg.model.punct_label_ids=OmegaConf.create(sorted(cfg.model.punct_label_ids))
ids_to_labels = {_[0]:_[1] for _ in enumerate(cfg.model.punct_label_ids)}
labels_to_ids = {_[1]:_[0] for _ in enumerate(cfg.model.punct_label_ids)}
cfg.base_path='/home/nxingyu2/data' #/home/nxingyu/data
# cfg.base_path
cfg

{'seed': 42, 'trainer': {'gpus': 1, 'num_nodes': 1, 'max_epochs': 3, 'max_steps': None, 'accumulate_grad_batches': 1, 'gradient_clip_val': 0.0, 'amp_level': 'O0', 'precision': 16, 'accelerator': 'ddp', 'checkpoint_callback': False, 'logger': False, 'log_every_n_steps': 1, 'val_check_interval': 1.0, 'resume_from_checkpoint': None}, 'exp_manager': {'exp_dir': None, 'name': 'Punctuation_with_Domain_discriminator', 'create_tensorboard_logger': True, 'create_checkpoint_callback': True}, 'base_path': '/home/nxingyu2/data', 'model': {'nemo_path': None, 'transformer_path': 'google/electra-small-discriminator', 'punct_label_ids': ['', '!', ',', '-', '.', ':', ';', '?', '—', '…'], 'dataset': {'data_dir': '${base_path}', 'labelled': ['${base_path}/ted_talks_processed'], 'unlabelled': ['${base_path}/open_subtitles_processed'], 'max_seq_length': 128, 'pad_label': '', 'ignore_extra_tokens': False, 'ignore_start_end': False, 'use_cache': True, 'num_workers': 4, 'pin_memory': False, 'drop_last': False

In [30]:
int(time())

1612232037

In [150]:
class PunctuationDomainDataset(Dataset):

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        """Returns definitions of module output ports.
               """
        return {
            "input_ids": NeuralType(('B', 'T'), ChannelType()),
            "attention_mask": NeuralType(('B', 'T'), ChannelType()),
            "labels": NeuralType(('B', 'T'), ChannelType()),
            "domain": NeuralType(('B'), ChannelType()),
        }

    def __init__(self, 
        csv_file:str, 
        tokenizer,
        num_samples:int=256,
        max_seq_length:int=256,
        punct_label_ids: Dict[str, int] = None,
        domain=0,
        labelled=True,
    ):
        if not (os.path.exists(csv_file)):
            raise FileNotFoundError(
                f'{csv_file} not found. The data should be joined in 1 csv file.\
                    Each line of the file contains the subword token ids, masks and class labels per row.'
            )

        data_dir = os.path.dirname(csv_file)
        filename = os.path.basename(csv_file)

        if not filename.endswith('.csv'):
            raise ValueError("{text_file} should have extension .csv")
        # filename = filename[:-4]
        
        self.csv_file = pp(  csv_file)
        self.max_seq_length = pp(  max_seq_length)
        self.set_num_samples(csv_file, num_samples)
        self.domain= pp( domain)
        self.labelled= pp( labelled)
        self.tokenizer= pp( tokenizer)

    def __getitem__(self, idx):
        x = next(
            pd.read_csv(
                self.csv_file,
                skiprows=(idx % self.len)*self.num_samples,
                header=None,
                dtype=str,
                chunksize=self.num_samples,
                ))[1]
        chunked=chunk_examples_with_degree(0)(x)
        batched=chunk_to_len_batch(self.max_seq_length,self.tokenizer,chunked['texts'],chunked['tags'],self.labelled)
        batched['domain']=self.domain*torch.ones(batched['input_ids'].shape[0],1,dtype=torch.long)
        rand=torch.randperm(batched['domain'].size()[0])
        return {k:v[rand] for k,v in batched.items()}
#        {'input_ids': torch.as_tensor(x[:,0,:], dtype=torch.long),
#         'attention_mask': torch.as_tensor(x[:,1,:],dtype=torch.bool)if self.labelled else torch.zeros_like(x[:,1,:],dtype=torch.bool),
#         'labels': torch.as_tensor(x[:,2,:],dtype=torch.long),
#         'domain':self.domain*torch.ones(x.shape[0],1,dtype=torch.long)}

    def set_num_samples(self,csv_file,num_samples):
        self.num_samples = num_samples
        self.total_samples=int(subprocess.Popen(['wc', '-l', csv_file], stdout=subprocess.PIPE, stderr=subprocess.STDOUT).communicate()[0].split()[0])
        self.len = int(self.total_samples / self.num_samples)
        

    def __len__(self):
        return self.len
    
    def view(d)->list:
        """:param d(dictionary): returns readable format of single input_ids and labels in the form of readable text"""
        a,_,c=d.values()
        return [' '.join([_[0]+_[1] for _ in list(zip(self.tokenizer.convert_ids_to_tokens(_[0]),[id2tag[id] for id in _[1].tolist()]))]) for _ in zip(a,c)]
    
    def shuffle(self, sorted=False, seed=42):
        os.system('bash data/shuffle.sh -i {} -o {} -a {} -s {}'.format(self.csv_file, self.csv_file, ['false','true'][sorted], seed))


In [130]:
from torch.utils.data import IterableDataset
from itertools import cycle
class PunctuationDomainDataset(IterableDataset):
    def __init__(self, 
        csv_file:str, 
        tokenizer,
        num_samples:int=256,
        max_seq_length:int=256,
        punct_label_ids: Dict[str, int] = None,
        domain=0,
        labelled=True,
    ):
        if not (os.path.exists(csv_file)):
            raise FileNotFoundError(
                f'{csv_file} not found. The data should be joined in 1 csv file.\
                    Each line of the file contains the subword token ids, masks and class labels per row.'
            )

        data_dir = os.path.dirname(csv_file)
        filename = os.path.basename(csv_file)

        if not filename.endswith('.csv'):
            raise ValueError("{text_file} should have extension .csv")
        # filename = filename[:-4]
        
        self.csv_file = csv_file
        self.max_seq_length = max_seq_length
        self.set_num_samples(csv_file, num_samples)
        self.domain= domain
        self.labelled= labelled
        self.tokenizer= tokenizer
    def __iter__(self):
        self.dataset=iter(pd.read_csv(
                self.csv_file,
                skiprows=(0 % self.len)*self.num_samples,
                header=None,
                dtype=str,
                chunksize=self.num_samples,
                ))
        return self
        
    
    def __next__(self):
        x = next(self.dataset)[1]
        chunked=chunk_examples_with_degree(0)(x)
        batched=chunk_to_len_batch(self.max_seq_length,self.tokenizer,chunked['texts'],chunked['tags'],self.labelled)
        batched['domain']=self.domain*torch.ones(batched['input_ids'].shape[0],1,dtype=torch.long)
        rand=torch.randperm(batched['domain'].size()[0])
        return {k:v[rand] for k,v in batched.items()}
#        {'input_ids': torch.as_tensor(x[:,0,:], dtype=torch.long),
#         'attention_mask': torch.as_tensor(x[:,1,:],dtype=torch.bool)if self.labelled else torch.zeros_like(x[:,1,:],dtype=torch.bool),
#         'labels': torch.as_tensor(x[:,2,:],dtype=torch.long),
#         'domain':self.domain*torch.ones(x.shape[0],1,dtype=torch.long)}

    def set_num_samples(self,csv_file,num_samples):
        self.num_samples = num_samples
        self.total_samples=int(subprocess.Popen(['wc', '-l', csv_file], stdout=subprocess.PIPE, stderr=subprocess.STDOUT).communicate()[0].split()[0])
        self.len = int(self.total_samples / self.num_samples)
        

    def __len__(self):
        return self.len
    
    def view(d)->list:
        """:param d(dictionary): returns readable format of single input_ids and labels in the form of readable text"""
        a,_,c=d.values()
        return [' '.join([_[0]+_[1] for _ in list(zip(self.tokenizer.convert_ids_to_tokens(_[0]),[id2tag[id] for id in _[1].tolist()]))]) for _ in zip(a,c)]
    
    def shuffle(self, sorted=False, seed=42):
        os.system('bash data/shuffle.sh -i {} -o {} -a {} -s {}'.format(self.csv_file, self.csv_file, ['false','true'][sorted], seed))

class PunctuationDomainDatasets(IterableDataset):
    
    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        """Returns definitions of module output ports. """
        return {
            "input_ids": NeuralType(('B', 'T'), ChannelType()),
            "attention_mask": NeuralType(('B', 'T'), ChannelType()),
            "subtoken_mask": NeuralType(('B', 'T'), ChannelType()),
            "labels": NeuralType(('B', 'T'), ChannelType()),
            "domain": NeuralType(('B'), ChannelType()),
        }

    def __init__(self, 
                 split:str,
                 num_samples:int,
                 max_seq_length:int,
                 punct_label_ids: Dict[str, int],
                 labelled: List[str],
                 unlabelled: List[str],
                 tokenizer):
        
        self.datasets = []
        self.iterators=[]
        for i,path in enumerate(labelled):
            dataset=PunctuationDomainDataset(
                    csv_file=f'{path}.{split}.csv', tokenizer=tokenizer,
                    num_samples=num_samples,max_seq_length=max_seq_length,
                    punct_label_ids=punct_label_ids,domain=i,labelled=True,)
            self.datasets.append(dataset)
            self.iterators.append(cycle(dataset))
            
        for i,path in enumerate(unlabelled):
            dataset=PunctuationDomainDataset(
                    csv_file=f'{path}.{split}.csv', tokenizer=tokenizer,
                    num_samples=num_samples,max_seq_length=max_seq_length,
                    punct_label_ids=punct_label_ids,domain=len(labelled)+i,labelled=False,)
            self.datasets.append(dataset)
            self.iterators.append(cycle(dataset))

    def __iter__(self):
        self.iterators=[]
        for dataset in self.datasets:
            self.iterators.append(cycle(dataset))
        return self
            
    def __next__(self):
        ds=[next(d) for d in self.datasets]
        min_batch=1000000
        for d in ds:
            size=d['domain'].size()[0]
            if size<min_batch:
                min_batch=size
        #Ensure all domains are evenly represented
        b={k:torch.cat([d[k][:min_batch] for d in ds],axis=0) for k in ['input_ids','attention_mask','subtoken_mask','labels','domain']}
        rand=torch.randperm(b['labels'].size()[0])
        return {k:v[rand] for k,v in b.items()}

    def __len__(self):
        return max(len(d) for d in self.datasets)

In [131]:
dstrain=PunctuationDomainDatasets(
        split='train',
        tokenizer=AutoTokenizer.from_pretrained(cfg.model.transformer_path),
        num_samples=1,
        max_seq_length=128,
        punct_label_ids=labels_to_ids,
        labelled=['/home/nxingyu/data/ted_talks_processed'],
        unlabelled=[]
    )

# ds=PunctuationDomainDataset( 
#     csv_file='/home/nxingyu/data/ted_talks_processed.train.csv', 
#     tokenizer=AutoTokenizer.from_pretrained(cfg.model.transformer_path),
#     num_samples=16,
#     max_seq_length=128,
#     punct_label_ids=labels_to_ids,
#     domain=0,
#     labelled=True,
# )
# ds.shuffle(sorted=True)
# ds.shuffle()

In [170]:
# it=iter(dstrain)
next(it)
# len(dstrain)
# next(it)
# next(it)
# next(it)
# next(iter(dstrain))
# dstrain.__len__()

{'input_ids': tensor([[  101,  2064,  1005,  ...,  2056,  2024,   102],
         [  101,  1998,  1996,  ...,  2128,  2183,   102],
         [  101,  2512,  6299,  ...,   102,     0,     0],
         ...,
         [  101,  2036,  2000,  ...,  2001,  1037,   102],
         [  101,  2000,  2828,  ...,  2009,  1005,   102],
         [  101,  2009, 11901,  ...,  2000,  2030,   102]]),
 'attention_mask': tensor([[ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True, False, False],
         ...,
         [ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True],
         [ True,  True,  True,  ...,  True,  True,  True]]),
 'subtoken_mask': tensor([[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, Fals

In [165]:
next(iter(ds))

{'input_ids': tensor([[  101,  5654,  7028,  ...,  2008,  1996,   102],
         [  101,  1040,  2185,  ...,  1045,  1005,   102],
         [  101,  2054,  2017,  ...,  1040,  3305,   102],
         ...,
         [  101, 13346,  2000,  ...,  2022,  2006,   102],
         [  101,  2054,  2065,  ...,  1005,  1055,   102],
         [  101,  1056,  2079,  ...,  1997,  2256,   102]]),
 'attention_mask': tensor([[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]]),
 'subtoken_mask': tensor([[False, False,  True,  ..., False, False, False],
         [False, False,  True,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ...,  True, False, False]

In [352]:
# ds0=ds[0]
ds0['domain'].size()

torch.Size([596, 1])

In [63]:
class PunctuationDomainDatasets(Dataset):
    
    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        """Returns definitions of module output ports. """
        return {
            "input_ids": NeuralType(('B', 'T'), ChannelType()),
            "attention_mask": NeuralType(('B', 'T'), ChannelType()),
            "subtoken_mask": NeuralType(('B', 'T'), ChannelType()),
            "labels": NeuralType(('B', 'T'), ChannelType()),
            "domain": NeuralType(('B'), ChannelType()),
        }

    def __init__(self, 
                 split:str,
                 num_samples:int,
                 max_seq_length:int,
                 punct_label_ids: Dict[str, int],
                 labelled: List[str],
                 unlabelled: List[str],
                 tokenizer):
        
        self.datasets = []
        for i,path in enumerate(labelled):
            self.datasets.append(PunctuationDomainDataset(
                    csv_file=f'{path}.{split}.csv', tokenizer=tokenizer,
                    num_samples=num_samples,max_seq_length=max_seq_length,
                    punct_label_ids=punct_label_ids,domain=i,labelled=True,))
            
        for i,path in enumerate(unlabelled):
            self.datasets.append(PunctuationDomainDataset(
                    csv_file=f'{path}.{split}.csv', tokenizer=tokenizer,
                    num_samples=num_samples,max_seq_length=max_seq_length,
                    punct_label_ids=punct_label_ids,domain=len(labelled)+i,labelled=False,))

    def __getitem__(self, i):
        ds=[d[i] for d in self.datasets]
        min_batch=1000000
        for d in ds:
            size=d['domain'].size()[0]
            if size<min_batch:
                min_batch=size
        #Ensure all domains are evenly represented
        b={k:torch.vstack([d[k][:min_batch] for d in ds]) for k in ['input_ids','attention_mask','subtoken_mask','labels','domain']}
        rand=torch.randperm(b['labels'].size()[0])
        return {k:v[rand] for k,v in b.items()}

    def __len__(self):
        return max(len(d) for d in self.datasets)

In [301]:
cfg.model.dataset

{'data_dir': '${base_path}', 'labelled': ['${base_path}/ted_talks_processed'], 'unlabelled': ['${base_path}/open_subtitles_processed'], 'max_seq_length': 128, 'pad_label': '', 'ignore_extra_tokens': False, 'ignore_start_end': False, 'use_cache': True, 'num_workers': 2, 'pin_memory': False, 'drop_last': False}

In [68]:
# d=pd.read_csv(f'{cfg.model.dataset.labelled[0]}.csv',header=None)[1].str.split().map(len)
# d.describe()
# x=torch.utils.data.DataLoader(dstrain)
next(iter(x))

TypeError: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.

In [64]:
dstrain=PunctuationDomainDatasets(
        split='train',
        tokenizer=AutoTokenizer.from_pretrained(cfg.model.transformer_path),
        num_samples=1,
        max_seq_length=128,
        punct_label_ids=labels_to_ids,
        labelled=list(cfg.model.dataset.labelled),
        unlabelled=list(cfg.model.dataset.unlabelled)
    )


In [349]:
dstrain0=dstrain[0]
# dstrain0['input_ids'].shape
# r=torch.randperm(414)
sum(dstrain0['domain'])

tensor([199])

In [None]:
from pytorch_lightning import LightningDataModule
from torch import dtype
from data import PunctuationDomainDataset, PunctuationDomainDatasets
from typing import List
import pandas as pd
import os
import torch
from nemo.utils import logging

class PunctuationDataModule(LightningDataModule):
    def __init__(self, 
            tokenizer,
            labelled: List[str], 
            unlabelled: List[str], 
            train_batch_size: int,
            max_seq_length:int = 256,
            val_batch_size:int = 256, 
            num_workers:int = 1,
            pin_memory:bool = False,
            drop_last:bool = False
            ):
        #unlabelled=[], batch_size = 256, max_seq_length = 256, num_workers=1):
        super().__init__()
        self.labelled=labelled
        self.tokenizer=tokenizer
        self.unlabelled=unlabelled
        self.num_domains=len(labelled)+len(unlabelled)
        self.train_batch_size = max(1,train_batch_size//self.num_domains)
        logging.info(f"using training batch_size of {self.train_batch_size} for each domain")
        self.val_batch_size = max(1,val_batch_size//self.num_domains)
        logging.info(f"using dev batch_size of {self.train_batch_size} for each domain")
        self.max_seq_length = max_seq_length
        self.num_workers=num_workers
        self.pin_memory = pin_memory
        self.drop_last = drop_last

        self.train_dataset={}
        self.dev_dataset={}
        self.test_dataset={}

    def setup(self, stage=None):
        for unlabelled,l in enumerate([self.labelled,self.unlabelled]):
            for i,p in enumerate(l):
                domain=i+unlabelled*len(self.labelled) #unlabelled domain is increasing after labelled
                try:
                    with open("{}.train-stride.csv".format(p),'r') as f:
                        s=len(f.readline().split(' '))//3
                except IOError:
                    s=0
                if (s!=self.max_seq_length):
                    logging.info(f"copying train file from {p}.train-batched.csv to {p}.train-stride.csv")
                    os.system("cp {} {}".format(p+'.train-batched.csv',p+'.train-stride.csv'))
                    if (self.max_seq_length!=256):
                        logging.info(f'generating training strides: {self.max_seq_length}')
                        n=np.loadtxt(open(p+".train-stride.csv", "rb"))
                        np.savetxt(p+".train-stride.csv", self.with_stride_split(n,self.max_seq_length),fmt='%d')

                if stage=='fit' or None:
                    self.train_dataset[domain] = PunctuationDomainDataset(p+'.train-stride.csv', num_samples=self.train_batch_size, max_seq_length=self.max_seq_length, domain = domain, labelled=bool(1-unlabelled), tokenizer=self.tokenizer)
                    self.dev_dataset[domain] =  PunctuationDomainDataset(p+'.dev-batched.csv', num_samples=self.val_batch_size, max_seq_length=self.max_seq_length, domain = domain, labelled=bool(1-unlabelled), tokenizer=self.tokenizer)
                    pp(self.train_dataset[domain].shuffle(sorted=True))
                    pp(self.train_dataset[domain].shuffle())

                if stage == 'test' or stage is None:
                    self.test_dataset[domain] =  PunctuationDomainDataset(p+'.test-batched.csv', num_samples=self.val_batch_size, max_seq_length=self.max_seq_length, domain = domain, labelled=bool(1-unlabelled), tokenizer=self.tokenizer)

    def shuffle(self):
        for dataset in self.train_dataset.values():
            dataset.shuffle()

    def train_dataloader(self):
        return DataLoader(PunctuationDomainDatasets(*self.train_dataset.values()),batch_size=None,num_workers=self.num_workers,pin_memory=self.pin_memory,drop_last=self.drop_last)

    def val_dataloader(self):
        return DataLoader(PunctuationDomainDatasets(*self.dev_dataset.values()),batch_size=None,num_workers=self.num_workers,pin_memory=self.pin_memory,drop_last=self.drop_last)

    def test_dataloader(self):
        return DataLoader(PunctuationDomainDatasets(*self.test_dataset.values()),batch_size=None,num_workers=self.num_workers,pin_memory=self.pin_memory,drop_last=self.drop_last)

    def with_stride_split(n,l):
        def with_stride(t,l):
            a=t[0,0]
            z=t[0,-1]
            t=t[:,1:-1].flatten()
            t=np.trim_zeros(t,'b')
            s=t.shape[0]
            nh=-(-s//(l-2))
            f=np.zeros((nh*(l-2),1))  
            f[:s,0]=t
            return np.hstack([np.ones((nh,1))*a,np.reshape(f,(-1,l-2)),np.ones((nh,1))*z])
        s=n.shape[1]
        a,b,c=n[:,:s//3],n[:,s//3:2*s//3],n[:,2*s//3:]
        a,b,c=with_stride(a,l), with_stride(b,l), with_stride(c,l)
        c1=np.zeros(a.shape)
        c1[:c.shape[0],:]=c
        return np.hstack([a,b,c1])



In [8]:
#helper functions
def flatten(list_of_lists):
    for l in list_of_lists:
        for item in l:
            yield item

def pad_to_len(max_seq_length,ids):
    '''[0, 1, 2] -> array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0])'''
    o=np.zeros(max_seq_length, dtype=np.int)
    o[:len(ids)]=np.array(ids)
    return o

def position_to_mask(max_seq_length:int,indices:list):
    '''[0, 2, 5] -> array([0, 1, 0, 1, 0, 0, 1, 0, 0, 0])'''
    o=np.zeros(max_seq_length,dtype=np.int)
    o[np.array(indices)%(max_seq_length-2)+1]=1
    return o

def align_labels_to_mask(mask,labels):
    '''[0,1,0],[2] -> [0,2,0]'''
    assert(sum(mask)==len(labels))
    mask[mask>0]=torch.tensor(labels)
    return mask.tolist()


In [28]:
import regex as re
def text2masks(n):
    def text2masks(text):
        '''Converts single paragraph of text into a list of words and corresponding punctuation based on the degree requested.'''
        if n==0: 
            refilter="(?<=[.?!,;:\-—… ])(?=[^.?!,;:\-—… ])|$"
        else:
            refilter="[.?!,;:\-—…]{1,%d}(?= *[^.?!,;:\-—…]+|$)|(?<=[^.?!,;:\-—…]) +(?=[^.?!,;:\-—…])"%(n)
        text=re.sub(r'^[_\W]*','',text)
        word=re.split(refilter,text, flags=re.V1)
        punct=re.findall(refilter,text, flags=re.V1)
        wordlist,punctlist=([] for _ in range(2))
        if word[-1]=='': # ensures text aligns
            word.pop()
        else:
            punct.append('')
        
        for i in zip(word,punct): #+[''] to correspond to the last word or '' after the last punctuation.
            w,p=i[0].strip(),i[1].strip()
            if w!='':
                wordlist.append(re.sub(r'[.?!,;:\-—… ]','',w))
                punctlist.append(0 if not w[-1] in '.?!,;:-—…' else labels_to_ids[w[-1]])
            if p!='':
                wordlist.append(p)
                punctlist.append(0)
        return(wordlist,punctlist)
    return text2masks
def chunk_examples_with_degree(n):
    '''Ensure batched=True if using dataset.map or ensure the examples are wrapped in lists.'''
    def chunk_examples(examples):
        output={}
        output['texts']=[]
        output['tags']=[]
        for sentence in examples:
            text,tag=text2masks(n)(sentence)
            output['texts'].append(text)
            output['tags'].append(tag)
            # output['tags'].append([0]+tag if text[0]!='' else tag) # [0]+tag so that in all case, the first tag refers to [CLS]
            # not necessary since all the leading punctuations are stripped
        return output
    return chunk_examples
assert(chunk_examples_with_degree(0)(['Hello!Bye…'])=={'texts': [['Hello', 'Bye']], 'tags': [[1, 9]]})

def subword_tokenize(tokenizer,tokens):
    subwords = list(map(tokenizer.tokenize, tokens))
    subword_lengths = list(map(len, subwords))
    subwords = list(flatten(subwords))
    token_end_idxs = np.cumsum([0]+subword_lengths[:-1])+np.array(subword_lengths)-1
    return subwords, token_end_idxs

def chunk_to_len(max_seq_length,tokenizer,tokens,labels=None):
    subwords,token_end_idxs = subword_tokenize(tokenizer,tokens)
    teim=token_end_idxs%(max_seq_length-2)
    breakpoints=(np.argwhere(teim[1:]<teim[:-1]).flatten()+1).tolist()
    split_token_end_idxs=np.array_split(token_end_idxs,breakpoints)
    split_subwords=np.array_split(subwords,np.arange(max_seq_length-2,len(subwords),max_seq_length-2))
    ids=[pad_to_len(max_seq_length,tokenizer.convert_tokens_to_ids(['[CLS]']+list(_)+['[SEP]'])) for _ in split_subwords]
    masks=[position_to_mask(max_seq_length,_) for _ in split_token_end_idxs]
    padded_labels=None
    if labels!=None:
        split_labels=np.array_split(labels,breakpoints)
        padded_labels=[pad_to_len(max_seq_length,align_labels_to_mask(*_)) for _ in zip(masks,split_labels)]
    return ids,masks,padded_labels
    
def chunk_to_len_batch(max_seq_length,tokenizer,tokens,labels=None,labelled=True, ignore_index=-100):
    batch_ids=[]
    batch_masks=[]
    batch_labels=[]
    for i,_ in enumerate(zip(tokens,tokens) if labels==None else zip(tokens,labels)):
        a,b,c=chunk_to_len(max_seq_length,tokenizer,*_) if labels else chunk_to_len(max_seq_length,tokenizer,_[0])
        batch_ids.extend(a)
        batch_masks.extend(b)
        if labelled==True:
            batch_labels.extend(c)
    output = {'input_ids': torch.as_tensor(batch_ids, dtype=torch.long),
              'attention_mask': torch.as_tensor(batch_ids, dtype=torch.bool),
              'subtoken_mask': torch.as_tensor(batch_masks,dtype=torch.bool)}#*labelled
    output['subtoken_mask']|=(output['input_ids']==101)|(output['input_ids']==102)
    output['subtoken_mask']&=labelled
#     output['input_ids']+=ignore_index*(~output['subtoken_mask'])
    
    output['labels']=torch.as_tensor(batch_labels,dtype=torch.short) if labelled==True else torch.zeros_like(output['input_ids'],dtype=torch.short)
    return output

In [29]:
class PunctuationInferenceDataset(Dataset):
    """
    Creates dataset to use during inference for punctuation and capitalization tasks with a pretrained model.
    For dataset to use during training with labels, see BertPunctuationCapitalizationDataset.
    Args:
        queries file to sequences, each line should a sentence, no header.
        max_seq_length: max sequence length minus 2 for [CLS] and [SEP]
        tokenizer: such as AutoTokenizer
    """

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        """Returns definitions of module output ports.
               """
        return {
            'input_ids': NeuralType(('B', 'T'), ChannelType()),
            'attention_mask': NeuralType(('B', 'T'), MaskType()),
            'subtoken_mask': NeuralType(('B', 'T'), MaskType()),
            "labels": NeuralType(('B', 'T'), ChannelType()),
        }

    def __init__(self, tokenizer, queries: List[str], max_seq_length: int, degree:int=0,):
        """ Initializes BertPunctuationInferDataset. """
        self.degree=degree
        chunked=chunk_examples_with_degree(self.degree)(queries)
        features = chunk_to_len_batch(max_seq_length=max_seq_length, tokenizer=tokenizer,tokens=chunked['texts'],labelled=False)
        self.all_input_ids = pp(features['input_ids'])
        self.all_attention_mask = pp(features['attention_mask'])
        self.all_subtoken_mask = pp(features['subtoken_mask'])

    def __len__(self):
        return len(self.all_input_ids)

    def __getitem__(self, idx):
        return {'input_ids':self.all_input_ids[idx],
            'attention_mask':self.all_attention_mask[idx],
               'subtoken_mask':self.all_subtoken_mask[idx]}


In [30]:
# # split='train'
# # o=pd.read_csv(f'{cfg.model.dataset.labelled[0]}.{split}.csv',
# #                   dtype='str',
# #                   header=None,
# #                   chunksize=10)
# # t=next(iter(o))
# sample_out = chunk_examples_with_degree(0)(t[1])
# tokenizer=AutoTokenizer.from_pretrained(cfg.model.transformer_path)
# # subword_tokenize(sample_out['texts'][0])
# sample_out
# sample_out['texts'],sample_out['tags']
# chunk_to_len_batch(1000,tokenizer,sample_out['texts'][:10],sample_out['tags'][:10])
# chunk_examples_with_degree(0)(t[1])
# chunk_examples_with_degree(0)(['!!Hellooooo! Yay! Bye Enddd.',"Hello"])
inferData=PunctuationInferenceDataset(tokenizer=tokenizer, queries=['!!Hellooooo! Yay! Bye Enddd.',"Hello"], max_seq_length=5,degree=1)
inferData[:]

ic| features['input_ids']: tensor shape torch.Size([5, 5]) type torch.int64
ic| features['attention_mask']: tensor shape torch.Size([5, 5]) type torch.bool
ic| features['subtoken_mask']: tensor shape torch.Size([5, 5]) type torch.bool


{'input_ids': tensor([[  101,  7592,  9541,  9541,   102],
         [  101,   999,  8038,  2100,   102],
         [  101,   999,  9061,  2203,   102],
         [  101, 14141,  1012,   102,     0],
         [  101,  7592,   102,     0,     0]]),
 'attention_mask': tensor([[ True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True],
         [ True,  True,  True,  True, False],
         [ True,  True,  True, False, False]]),
 'subtoken_mask': tensor([[False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False]])}

In [15]:
inferData=PunctuationInferenceDataset(tokenizer=tokenizer, queries=['!!Hellooooo! Yay! Bye Enddd.',"Hello"], max_seq_length=5,degree=0)
inferData[:]

ic| features['input_ids']: tensor tensor([[  101,  7592,  9541,  9541,   102],
                                   [  101,  8038,  2100,  9061,   102],
                                   [  101,  2203, 14141,   102,     0],
                                   [  101,  7592,   102,     0,     0]]) shape torch.Size([4, 5]) type torch.int64
ic| features['attention_mask']: tensor tensor([[ True,  True,  True,  True,  True],
                                        [ True,  True,  True,  True,  True],
                                        [ True,  True,  True,  True, False],
                                        [ True,  True,  True, False, False]]) shape torch.Size([4, 5]) type torch.bool
ic| features['subtoken_mask']: tensor tensor([[False, False, False,  True, False],
                                       [False, False,  True,  True, False],
                                       [False, False,  True, False, False],
                                       [False,  True, False, False, Fa

{'input_ids': tensor([[  101,  7592,  9541,  9541,   102],
         [  101,  8038,  2100,  9061,   102],
         [  101,  2203, 14141,   102,     0],
         [  101,  7592,   102,     0,     0]]),
 'attention_mask': tensor([[ True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True],
         [ True,  True,  True,  True, False],
         [ True,  True,  True, False, False]]),
 'subtoken_mask': tensor([[False, False, False,  True, False],
         [False, False,  True,  True, False],
         [False, False,  True, False, False],
         [False,  True, False, False, False]])}

In [50]:
class PunctuationInferDataset(Dataset):
    """
    Creates dataset to use during inference for punctuation and capitalization tasks with a pretrained model.
    For dataset to use during training with labels, see BertPunctuationCapitalizationDataset.
    Args:
        queries file to sequences, each line should a sentence, no header.
        max_seq_length: max sequence length minus 2 for [CLS] and [SEP]
        tokenizer: such as AutoTokenizer
    """

    @property
    def output_types(self) -> Optional[Dict[str, NeuralType]]:
        """Returns definitions of module output ports.
               """
        return {
            'input_ids': NeuralType(('B', 'T'), ChannelType()),
            'attention_mask': NeuralType(('B', 'T'), MaskType()),
        }

    def __init__(self, queries: List[str], max_seq_length: int, tokenizer):
        """ Initializes BertPunctuationInferDataset. """
        features = pp(get_features(queries=queries, max_seq_length=max_seq_length, tokenizer=tokenizer))
        self.all_input_ids = pp(features['input_ids'])
        self.all_attention_mask = pp(features['attention_mask'])

    def __len__(self):
        return len(self.all_input_ids)

    def __getitem__(self, idx):
        return {'input_ids':self.all_input_ids[idx],
            'attention_mask':self.all_attention_mask[idx]}

def get_features(
    queries:str, 
    tokenizer,
    max_seq_length:int,
    degree:int=0,
    punct_label_ids: dict = None,):

    batch_ids=[]
    batch_masks=[]
    for query in queries: #
        #'Hellooooo! Yay! Bye Endd.'
        wordlist=ic(re.split('[^a-zA-Z0-9]+',query,flags=re.V1)) #If end with punctuation, this includes a trailing ''
        if wordlist[-1]=='': #Not necessary since the masks would ignore repeated end idxs.
            wordlist=wordlist[:-1] 
        #['Hellooooo', 'Yay', 'Bye', 'Endd', '']
        subwords=ic(list(map(tokenizer.tokenize,wordlist))) # [['hello', '##oo', '##oo'], ['ya', '##y'], ['bye'], ['end', '##d']]
        subword_lengths=ic(list(map(len,subwords))) # [3, 2, 1, 1]
        subwords=ic(list(flatten(subwords))) # ['hello', '##oo', '##oo', 'ya', '##y', 'bye', 'end', '##d']
        token_end_idxs=ic(np.cumsum([0]+subword_lengths[:-1])+np.array(subword_lengths)-1) #'[2 4 5 6]'
        teim=ic(token_end_idxs%(max_seq_length-2)) #'[2 0 1 2]'
        pp(np.argwhere(teim[1:]<teim[:-1]).flatten()) #[0] returns last labels for each chunk.
        split_token_end_idxs=np.array_split(token_end_idxs,(np.argwhere(teim[1:]<teim[:-1]).flatten()+1).tolist())
        #[array([2]), array([4, 5, 6])]
        pp(split_token_end_idxs)
        split_subwords=ic(np.array_split(subwords,np.arange(max_seq_length-2,len(subwords),max_seq_length-2)))
        #[array(['hello', '##oo', '##oo', 'ya'], dtype='<U5'), array(['##y', 'bye', 'end'], dtype='<U5')]
        ids=ic([pad_to_len(max_seq_length,tokenizer.convert_tokens_to_ids(['[CLS]']+list(_)+['[SEP]'])) for _ in split_subwords])
        #[array([ 101, 7592, 9541, 9541, 8038,  102]), array([ 101, 2100, 9061, 2203,  102,    0])]
        masks=ic([position_to_mask(max_seq_length,_) for _ in split_token_end_idxs])
        batch_ids.append(ids) #[[array([ 101, 7592, 9541, 9541, 8038,  102]), array([ 101, 2100, 9061, 2203, 2094,  102])]]
        batch_masks.append(masks) #[[array([0, 0, 0, 1, 0, 0]), array([0, 1, 1, 1, 0, 0])]]
    
    return pp({'input_ids': torch.as_tensor(batch_ids, dtype=torch.long),
            'attention_mask': torch.as_tensor(batch_masks,dtype=torch.bool)})

In [236]:
import regex as re
ifds=PunctuationInferDataset(queries=['Hellooooo! Yay! Bye Enddd.'], max_seq_length=6, tokenizer=AutoTokenizer.from_pretrained(cfg.model.transformer_path))

ic| re.split('[^a-zA-Z0-9]+',query,flags=re.V1): ['Hellooooo', 'Yay', 'Bye', 'Enddd', '']
ic| list(map(tokenizer.tokenize,wordlist)): [['hello', '##oo', '##oo'], ['ya', '##y'], ['bye'], ['end', '##dd']]
ic| list(map(len,subwords)): [3, 2, 1, 2]
ic| list(flatten(subwords)): ['hello', '##oo', '##oo', 'ya', '##y', 'bye', 'end', '##dd']
ic| np.cumsum([0]+subword_lengths[:-1])+np.array(subword_lengths)-1: array shape: (4,)[2 4 5 7]
ic| token_end_idxs%(max_seq_length-2): array shape: (4,)[2 0 1 3]
ic| np.argwhere(teim[1:]<teim[:-1]).flatten(): array shape: (1,)[0]
ic| split_token_end_idxs: [array([2]), array([4, 5, 7])]
ic| np.array_split(subwords,np.arange(max_seq_length-2,len(subwords),max_seq_length-2)): [array(['hello', '##oo', '##oo', 'ya'], dtype='<U5'), array(['##y', 'bye', 'end', '##dd'], dtype='<U5')]
ic| [pad_ids_to_len(max_seq_length,tokenizer.convert_tokens_to_ids(['[CLS]']+list(_)+['[SEP]'])) for _ in split_subwords]: [array([ 101, 7592, 9541, 9541, 8038,  102]), array([  101,  

In [40]:
labels_to_ids

{'': 0, '!': 1, ',': 2, '-': 3, '.': 4, ':': 5, ';': 6, '?': 7, '—': 8, '…': 9}

In [39]:
# labels_to_ids = {_[1]:_[0] for _ in enumerate(cfg.model.punct_label_ids)}
ids_to_labels = {_[0]:_[1] for _ in enumerate(cfg.model.punct_label_ids)}
t=chunk_examples_with_degree(0)(['Hellooooo!Bye…'])
tokenizer=AutoTokenizer.from_pretrained(cfg.model.transformer_path)
t=chunk_to_len_batch(5,tokenizer,t['texts'],t['tags'])
view_aligned(t['input_ids'],np.array(t['labels']),tokenizer,ids_to_labels)

ic| _[0]: tensor shape: torch.Size([5])
ic| _[0]: tensor shape: torch.Size([5])


['[CLS] hellooooo! [SEP]', '[CLS] bye… [SEP] [PAD] [PAD]']

In [188]:
def flatten(list_of_lists):
    for l in list_of_lists:
        for item in l:
            yield item
list(flatten([[0],[0,1],[0,1,2],[],[],[1]]))

[0, 0, 1, 0, 1, 2, 1]