In [1]:
from petastorm import make_batch_reader, TransformSpec
from petastorm.pytorch import DataLoader as PetaDataLoader
from torch.utils.data import TensorDataset, DataLoader as TorchDataLoader, IterableDataset
from sklearn import preprocessing
from collections import OrderedDict
from torch import tensor
import math
import os
import re

In [2]:
FILE_PREFIX = 'file:'

pre_open_fds = None
def patch_leaking_fd():
    global pre_open_fds
    from pyarrow.parquet import ParquetFile, ParquetReader
    def _patched_init(self, source, **kwargs):
        self.source = source
        return ParquetFile.__old_init__(self, source, **kwargs)

    def _exit(self, *args, **kwargs):
        if hasattr(self.source, 'close'):
            self.source.close()
            del self.source

    def _bopen(fn):    
        return open(fn, 'rb')

    pre_open_fds = _bopen
    if not hasattr(ParquetFile, '__old_init__'):
        print("Patching")
        ParquetFile.__old_init__ = ParquetFile.__init__

        ParquetFile.__init__ = _patched_init
        ParquetFile.__exit__ = _exit
        ParquetFile.__del__ = _exit

    else:
        print("Already patched")

patch_leaking_fd()



class MyIterableDataset(IterableDataset):
    def __init__(self, filename, rex=None):
        super(MyIterableDataset).__init__()
        self._filename_param = filename
        self.filename = self._init_filenames(filename, rex)

    def _init_filenames(self, filename, rex):
        if rex is None:
            return filename
        
        filename = filename[len(FILE_PREFIX):]
        if not os.path.isdir(filename):
            raise ValueError(f"Filtering only possible for dirs, {filename} is not a one")
        paths = [os.path.join(dp, f) for dp, dn, fn in os.walk(filename) for f in fn]
        res = list(map(
            lambda f: FILE_PREFIX + f,
            filter(lambda f: re.match(rex, f) is not None, paths)
        ))
        if (len(res) == 0):
            raise ValueError(f"0 files remained out ot {len(paths)} - seems regex is too restrictive")

        return res;

    def _init_petaloader(self):
        def _transform_row(df_batch):
            return df_batch

        transform = TransformSpec(_transform_row, removed_fields=['cat_id', 'store_id', 'state_id'])
        reader = make_batch_reader(self.filename,
                 schema_fields=['id', 'item_id', 'dept_id', 'cat_id', 'day_id',
               'sales', 'day_date_str', 'month_id', 'date', 'wm_yr_wk',
               'snap_flag', 'sell_price', 'sales_dollars', 'store_id', 'state_id'],
                workers_count=1
                #,transform_spec = transform
        )
        return PetaDataLoader(reader=reader, batch_size=10000, shuffling_queue_capacity=100000)
        
    def __len__(self):
        return 1300000

    def __iter__(self):
        print(f"Iterator created on {self._filename_param}")
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            with self._init_petaloader() as loader:
                if pre_open_fds:
                    loader.reader.dataset.fs.open = pre_open_fds
                for batch in loader:
                    for price, sales_dollars in zip(batch['sell_price'], batch['sales_dollars']):
                        price_is_nan = math.isnan(price)
                        price_or_zero = 0. if price_is_nan else price
                        yield {'features': tensor([price_or_zero, price_is_nan]),
                               'targets': tensor([sales_dollars])}
        else:
            raise ValueError("Not implemented for multithreading")

Patching


In [3]:
import torch
from torch import nn
import torch.nn.functional as F

from catalyst.dl import SupervisedRunner
from catalyst.utils import set_global_seed

  from pandas import Panel

numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject



In [4]:
SEED=42
set_global_seed(SEED)

In [5]:
batch = 128

train_ds = MyIterableDataset('file:./trn.parquet', rex='.*/parquet_partition=2/.*')
valid_ds = MyIterableDataset('file:./trn.parquet/parquet_partition=1')

train_dl = TorchDataLoader(train_ds, batch_size=batch, shuffle=False, num_workers=0, drop_last=False)
valid_dl = TorchDataLoader(valid_ds, batch_size=batch, shuffle=False, num_workers=0, drop_last=False)

data = OrderedDict()
data["train"] = train_dl
data["valid"] = valid_dl

In [6]:
class Net(nn.Module):
    def __init__(self, num_features):
        super(Net,self).__init__()
        layers = [40, 20]
        self.L1 = nn.Linear(num_features, layers[0])
        torch.nn.init.xavier_uniform_(self.L1.weight) 
        torch.nn.init.zeros_(self.L1.bias)
        
        self.L2 = nn.Linear(layers[0], layers[1])
        torch.nn.init.xavier_uniform_(self.L2.weight) 
        torch.nn.init.zeros_(self.L2.bias)
        
        self.L3 = nn.Linear(layers[1], 1)
        torch.nn.init.xavier_uniform_(self.L3.weight) 
        torch.nn.init.zeros_(self.L3.bias)
    def forward(self, x):
        x = F.relu(self.L1(x))
        x = F.relu(self.L2(x))
        x = F.relu(self.L3(x))
        return x

class MyLoss(nn.MSELoss):
    def __init__(self):
        super(MyLoss, self).__init__()

    def forward(self, inp, target):
        return super().forward(inp, target)

In [7]:
model = Net(num_features=2)

In [8]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
criterion = MyLoss()

In [9]:
runner = SupervisedRunner()
# /usr/local/lib/python3.7/site-packages/petastorm/arrow_reader_worker.py:53
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    loaders=data,
    logdir="run",
    load_best_on_end=True,
    num_epochs=1)

Iterator created on file:./trn.parquet
./trn.parquet/parquet_partition=2/0f9c844fe0ec46cca3a4b7a53ab2c9b3.parquet (<class 'str'>)
./trn.parquet/parquet_partition=2/0f9c844fe0ec46cca3a4b7a53ab2c9b3.parquet (<class 'str'>)
./trn.parquet/parquet_partition=2/d1144e0650e746ce9532427ab42da435.parquet (<class 'str'>)
./trn.parquet/parquet_partition=2/78472bf1717044c597caf14f92f4b59e.parquet (<class 'str'>)
./trn.parquet/parquet_partition=2/9c903d22fdb146959b44f651a6214189.parquet (<class 'str'>)
./trn.parquet/parquet_partition=2/c2d359ff86a3440791bbad1f86cab185.parquet (<class 'str'>)
./trn.parquet/parquet_partition=2/20c4cc252048478c82090e525e265b77.parquet (<class 'str'>)
./trn.parquet/parquet_partition=2/d90c33a6296c491e88402308dc3f831d.parquet (<class 'str'>)
<_io.BufferedReader name='./trn.parquet/parquet_partition=2/78472bf1717044c597caf14f92f4b59e.parquet'> (<class '_io.BufferedReader'>)
<_io.BufferedReader name='./trn.parquet/parquet_partition=2/78472bf1717044c597caf14f92f4b59e.parque


Calling .data on ChunkedArray is provided for compatibility after Column was removed, simply drop this attribute



<_io.BufferedReader name='./trn.parquet/parquet_partition=2/20c4cc252048478c82090e525e265b77.parquet'> (<class '_io.BufferedReader'>)
<_io.BufferedReader name='./trn.parquet/parquet_partition=2/9c903d22fdb146959b44f651a6214189.parquet'> (<class '_io.BufferedReader'>)
<_io.BufferedReader name='./trn.parquet/parquet_partition=2/9c903d22fdb146959b44f651a6214189.parquet'> (<class '_io.BufferedReader'>)
<_io.BufferedReader name='./trn.parquet/parquet_partition=2/d1144e0650e746ce9532427ab42da435.parquet'> (<class '_io.BufferedReader'>)
<_io.BufferedReader name='./trn.parquet/parquet_partition=2/d1144e0650e746ce9532427ab42da435.parquet'> (<class '_io.BufferedReader'>)
<_io.BufferedReader name='./trn.parquet/parquet_partition=2/0f9c844fe0ec46cca3a4b7a53ab2c9b3.parquet'> (<class '_io.BufferedReader'>)
<_io.BufferedReader name='./trn.parquet/parquet_partition=2/0f9c844fe0ec46cca3a4b7a53ab2c9b3.parquet'> (<class '_io.BufferedReader'>)
<_io.BufferedReader name='./trn.parquet/parquet_partition=2/d9


isAlive() is deprecated, use is_alive() instead



Iterator created on file:./trn.parquet/parquet_partition=1
./trn.parquet/parquet_partition=1/35730d08c7134d2b8a5af25394cd5278.parquet (<class 'str'>)
./trn.parquet/parquet_partition=1/35730d08c7134d2b8a5af25394cd5278.parquet (<class 'str'>)./trn.parquet/parquet_partition=1/379b0072a72a43cb8d5b4797d58f3d86.parquet (<class 'str'>)

./trn.parquet/parquet_partition=1/8727bd4f99bb4ec3b5e914164844e676.parquet (<class 'str'>)
./trn.parquet/parquet_partition=1/b20c30ebd03145e19ac0604d6760bd44.parquet (<class 'str'>)
./trn.parquet/parquet_partition=1/c25924cd9c2e45a081db3820d93e08bd.parquet (<class 'str'>)./trn.parquet/parquet_partition=1/c541778b11574df5949271d3e55b912d.parquet (<class 'str'>)

./trn.parquet/parquet_partition=1/de5475c8c74d42828022cd21a5fc8da4.parquet (<class 'str'>)
<_io.BufferedReader name='./trn.parquet/parquet_partition=1/8727bd4f99bb4ec3b5e914164844e676.parquet'> (<class '_io.BufferedReader'>)
<_io.BufferedReader name='./trn.parquet/parquet_partition=1/8727bd4f99bb4ec3b5e

In [10]:
import pyarrow; pyarrow.__version__

'0.16.0'