In [1]:
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
import json
import argparse
from torch.utils.data import DataLoader, IterableDataset
import torch
from utils.config import NNConfig
import torch.nn as nn

In [2]:
cfg = NNConfig()

In [8]:
class MyIterableDataset2(torch.utils.data.IterableDataset):
    def __init__(self, cfg, headerp, filep, epoch=2):
        super(MyIterableDataset2).__init__()
        
        self.header = header = pd.read_csv(headerp, sep='\t')
        self.filep = filep
        self.epoch = epoch
        
        dicts = []
        to_sub = []
        to_div = []
        self.flist = cfg.flist
        self.idlist = cfg.idlist
        
        for fname in self.flist:
            findex = cfg.meta['all_features'].index(fname)
            to_sub.append(cfg.meta['to_minus'][findex])
            to_div.append(cfg.meta['to_div'][findex])
        
        for idname in self.idlist:
            dindex = cfg.meta['all_ids'].index(idname)
            dicts.append(cfg.meta['dicts'][dindex])
        
        self.to_sub = to_sub
        self.to_div = to_div
        self.dicts = dicts
    
    def init_reader(self):
        self.dfiter = iter(pd.read_csv(self.filep, sep='\t', names=self.header.columns, iterator=True, chunksize=48))

    def __iter__(self):
        self.init_reader()
        worker_info = torch.utils.data.get_worker_info()
        while True:
            try:
                cur_chunk = next(self.dfiter)
            except:
                break
            
            for findex, fname in enumerate(self.flist):
                cur_chunk[fname] = (cur_chunk[fname] - self.to_sub[findex]) / self.to_div[findex]
                
            for dindex, idname in enumerate(self.idlist):
                cur_chunk[idname] = cur_chunk[idname].apply(lambda x: self.dicts[dindex][str(x)])
                
            finputs = cur_chunk[self.flist].values
            idinputs = cur_chunk[self.idlist].values
            targets = cur_chunk["m:Click"].values
                
            if not worker_info is None:  # single-process data loading, return the full iterator
                assert(worker_info.num_workers == 1)
            
            yield {
                    "finputs": torch.FloatTensor(finputs),
                    "idinputs": torch.LongTensor(idinputs),
                    'labels': torch.LongTensor(targets)
            }
        
        return

In [9]:
def collate_fn2(batch):
    return {
        'finputs': torch.cat([x['finputs'] for x in batch], dim=0),
        "idinputs" : torch.cat([x['idinputs'] for x in batch], dim=0),
        'labels': torch.cat([x['labels'] for x in batch], dim=0)
    }

In [10]:
dataset2 = MyIterableDataset2(cfg, 'data/header.tsv', 'data/sample.tsv')

In [11]:
data_loader = DataLoader(dataset2, batch_size=4, num_workers=1, pin_memory=True, collate_fn=collate_fn2)

In [14]:
for data in data_loader:
    print(data['finputs'].size(), data['idinputs'].size(), data['labels'].size())

torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) to

In [13]:
for data in data_loader:
    print(data['finputs'].size(), data['idinputs'].size(), data['labels'].size())

torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) torch.Size([192, 8]) torch.Size([192])
torch.Size([192, 58]) to