In [22]:
from petastorm import make_batch_reader, TransformSpec
from petastorm.pytorch import DataLoader as PetaDataLoader
from torch.utils.data import TensorDataset, DataLoader as TorchDataLoader, IterableDataset
from sklearn import preprocessing
from collections import OrderedDict
from torch import tensor
import math

In [23]:
class MyIterableDataset(IterableDataset):
    def __init__(self, filename):
        super(MyIterableDataset).__init__()
        self.filename = filename

    def _init_petaloader(self):
        def _transform_row(df_batch):
            return df_batch

        transform = TransformSpec(_transform_row, removed_fields=['cat_id', 'store_id', 'state_id'])
        reader = make_batch_reader(self.filename,
                 schema_fields=['id', 'item_id', 'dept_id', 'cat_id', 'day_id',
               'sales', 'day_date_str', 'month_id', 'date', 'wm_yr_wk',
               'snap_flag', 'sell_price', 'sales_dollars', 'store_id', 'state_id']
                #,transform_spec = transform
        )
        return PetaDataLoader(reader=reader, batch_size=10000, shuffling_queue_capacity=100000)
        
    def __len__(self):
        return 1300000

    def __iter__(self):
        print("Iterator created")
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            for batch in self._init_petaloader():
                for price, sales_dollars in zip(batch['sell_price'], batch['sales_dollars']):
                    price_is_nan = math.isnan(price)
                    price_or_zero = 0. if price_is_nan else price
                    yield {'features': tensor([price_or_zero, price_is_nan]),
                           'targets': tensor([sales_dollars])}
        else:
            raise ValueError("Not implemented for multithreading")

In [25]:
import torch
from torch import nn
import torch.nn.functional as F

from catalyst.dl import SupervisedRunner
from catalyst.utils import set_global_seed

In [26]:
SEED=42
set_global_seed(SEED)

In [27]:
batch = 128

train_ds = MyIterableDataset('file:./trn.parquet')

train_dl = TorchDataLoader(train_ds, batch_size=batch, shuffle=False, num_workers=0, drop_last=False)
valid_dl = TorchDataLoader(train_ds, batch_size=batch, shuffle=False, num_workers=0, drop_last=False)

data = OrderedDict()
data["train"] = train_dl
data["valid"] = valid_dl

In [30]:
class Net(nn.Module):
    def __init__(self, num_features):
        super(Net,self).__init__()
        layers = [40, 20]
        self.L1 = nn.Linear(num_features, layers[0])
        torch.nn.init.xavier_uniform_(self.L1.weight) 
        torch.nn.init.zeros_(self.L1.bias)
        
        self.L2 = nn.Linear(layers[0], layers[1])
        torch.nn.init.xavier_uniform_(self.L2.weight) 
        torch.nn.init.zeros_(self.L2.bias)
        
        self.L3 = nn.Linear(layers[1], 1)
        torch.nn.init.xavier_uniform_(self.L3.weight) 
        torch.nn.init.zeros_(self.L3.bias)
    def forward(self, x):
        x = F.relu(self.L1(x))
        x = F.relu(self.L2(x))
        x = F.relu(self.L3(x))
        return x

class MyLoss(nn.MSELoss):
    def __init__(self):
        super(MyLoss, self).__init__()

    def forward(self, inp, target):
        return super().forward(inp, target)

In [31]:
model = Net(num_features=2)

In [32]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
criterion = MyLoss()

In [33]:
runner = SupervisedRunner()
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    loaders=data,
    logdir="run",
    load_best_on_end=True,
    num_epochs=1)

Iterator created
Iterator created
[2020-05-31 10:35:32,477] 
1/1 * Epoch 1 (_base): lr=0.0100 | momentum=0.9000
1/1 * Epoch 1 (train): loss=215.8209
1/1 * Epoch 1 (valid): loss=202.4572


INFO:metrics_logger:
1/1 * Epoch 1 (_base): lr=0.0100 | momentum=0.9000
1/1 * Epoch 1 (train): loss=215.8209
1/1 * Epoch 1 (valid): loss=202.4572


Top best models:
run/checkpoints/train.1.pth	202.4572
=> Loading checkpoint run/checkpoints/best_full.pth
loaded state checkpoint run/checkpoints/best_full.pth (global epoch 1, epoch 1, stage train)


In [34]:
make_batch_reader??