In [None]:
# default_exp tbptt
# default_cls_lvl 3

In [None]:
#hide
%load_ext line_profiler
%load_ext snakeviz

The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


In [None]:
#export
from seqdata.core import *
from seqdata.model import *
from fastai2.basics import *
from fastai2.callback.progress import *
from fastai2.callback.tracker import *

import math

## TBPTT Dataloader
> Pytorch Modules for Training Models for sequential data

The tbptt dataloader needs to split the minibatches that are created in several smaller minibatches that will be returned sequentially before the next minibatch may be created.

In [None]:
@delegates()
class TbpttDl(TfmdDL):
    def __init__(self, dataset, sub_seq_len=None, seq_len = None ,shuffle=True,num_workers=0, **kwargs):
        self.n_sub_seq = None
        super().__init__(dataset=dataset, shuffle=shuffle, num_workers=num_workers, **kwargs)
        if seq_len is None: seq_len = self.do_item(0)[0].shape[0]
        store_attr(self,'sub_seq_len,seq_len')
        
        if sub_seq_len is not None: self.n_sub_seq = math.ceil(seq_len / sub_seq_len)
        self.rnn_reset = False
        
    def __len__(self):
#         import pdb; pdb.set_trace()
        if self.n_sub_seq is None:
            return super().__len__()
        else:
            return super().__len__()*self.n_sub_seq
    
    def create_batches(self, samps):
        batch_iter = super().create_batches(samps)
        if self.n_sub_seq is None:
            return batch_iter
        else:
            return self._tbptt_generator(batch_iter)
        
    def _tbptt_generator(self,batch_iter):
        for b in batch_iter:
            for i in range(self.n_sub_seq):
                self.rnn_reset = i == 0
#                 import pdb; pdb.set_trace()
                trunc_b = tuple([x[:,i*self.sub_seq_len:(i+1)*self.sub_seq_len] for x in b])
                yield trunc_b

In [None]:
seq = DataBlock(blocks=(SequenceBlock.from_hdf(['current','voltage'],TensorSequencesInput,clm_shift=[-1,-1]),
                        SequenceBlock.from_hdf(['voltage'],TensorSequencesOutput,clm_shift=[1])),
                 get_items=CreateDict([DfHDFCreateWindows(win_sz=100+1,stp_sz=1000,clm='current')]),
                 splitter=ApplyToDict(ParentSplitter()))
seq.dl_type=TbpttDl
db = seq.dataloaders(get_hdf_files('test_data/'),sub_seq_len=90)

In [None]:
db.one_batch()[0].shape

torch.Size([64, 90, 2])

num_workers has to be 0. If there are parallel workers, the order of minibatches will be corrupted

In [None]:
for x in db.train:
    print(x[0].shape)

torch.Size([64, 90, 2])
torch.Size([64, 10, 2])
torch.Size([64, 90, 2])
torch.Size([64, 10, 2])
torch.Size([64, 90, 2])
torch.Size([64, 10, 2])
torch.Size([64, 90, 2])
torch.Size([64, 10, 2])
torch.Size([64, 90, 2])
torch.Size([64, 10, 2])
torch.Size([64, 90, 2])
torch.Size([64, 10, 2])
torch.Size([64, 90, 2])
torch.Size([64, 10, 2])
torch.Size([64, 90, 2])
torch.Size([64, 10, 2])


## TBPTT_Reset_Callback
The stateful model needs to reset its hidden state, when a new sequence begins. The callback reads the reset flag and acts accordingly.

In [None]:
class TbpttResetCB(Callback):
    "`Callback` resets the rnn model with every new sequence for tbptt"
        
    def begin_batch(self):
        dl = self.learn.data.train if self.training else self.learn.data.valid
        if hasattr(dl,'rnn_reset')and dl.rnn_reset: self.model.reset()

In [None]:
#hide
from nbdev.export import *
notebook2script()

Converted 00_core.ipynb.
Converted 01_model.ipynb.
Converted 02_learner.ipynb.
Converted 03_tbptt_dl.ipynb.
Converted 11_ProDiag.ipynb.
Converted 12_TensorQuaternions.ipynb.
Converted 13_PBT.ipynb.
Converted index.ipynb.
