In [1]:
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
import json
import argparse
from torch.utils.data import DataLoader, IterableDataset
import torch
import torch.nn as nn

In [117]:
class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, epoch):
        super(MyIterableDataset).__init__()
        self.header = header = pd.read_csv('data/header.tsv', sep='\t')
        self.init_reader()
        self.epoch = epoch
    
    def init_reader(self):
        self.dfiter = iter(pd.read_csv('data/sample2.tsv', sep='\t', iterator=True, chunksize=48))

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        loop = self.epoch
        while loop > 0:
            try:
                cur_chunk = next(self.dfiter)
            except:
                cur_chunk = None
                loop -= 1
                self.init_reader()
            
            if cur_chunk is None:
                continue
            
            cur_chunk = cur_chunk.values
            if not worker_info is None:  # single-process data loading, return the full iterator
                assert(worker_info.num_workers == 1)
            
            yield cur_chunk
        
        return

In [118]:
def collate_fn(batch):
    return np.concatenate(batch, axis=0)

In [119]:
dataset = MyIterableDataset(2)

In [120]:
data_loader = DataLoader(dataset, batch_size=4, num_workers=1, pin_memory=True, collate_fn=collate_fn)

In [121]:
for data in data_loader:
    print(data, data.shape)

[[    0 17307]
 [    1 10685]
 [    2 17929]
 [    3 13744]
 [    4 14940]
 [    5 16733]
 [    6 19340]
 [    7 17724]
 [    8 17912]
 [    9 12115]
 [   10 16601]
 [   11 11839]
 [   12 17908]
 [   13 14802]
 [   14 16551]
 [   15 16418]
 [   16 19340]
 [   17 18807]
 [   18 16137]
 [   19 18797]
 [   20 15484]
 [   21 17620]
 [   22 18552]
 [   23 18418]
 [   24 18797]
 [   25 20203]
 [   26 13247]
 [   27 20203]
 [   28  5872]
 [   29 17027]
 [   30 17385]
 [   31 19459]
 [   32 17225]
 [   33 16611]
 [   34 17154]
 [   35 16613]
 [   36 15526]
 [   37 14084]
 [   38  8792]
 [   39 18418]
 [   40 14413]
 [   41 16955]
 [   42 18807]
 [   43 14950]
 [   44 13334]
 [   45 11898]
 [   46  9576]
 [   47 10544]
 [   48 16066]
 [   49 15580]
 [   50 16766]
 [   51 17749]
 [   52 17154]
 [   53 17724]
 [   54 16601]
 [   55 18707]
 [   56 20203]
 [   57 18773]
 [   58 18521]
 [   59 15002]
 [   60 18864]
 [   61 17972]
 [   62 18747]
 [   63 16727]
 [   64 15154]
 [   65 17189]
 [   66 17

[[ 8256 16073]
 [ 8257 17972]
 [ 8258 18436]
 [ 8259 17494]
 [ 8260 19340]
 [ 8261 18864]
 [ 8262 14404]
 [ 8263 16178]
 [ 8264 15390]
 [ 8265 18773]
 [ 8266 17592]
 [ 8267 15655]
 [ 8268 16568]
 [ 8269 19459]
 [ 8270 17592]
 [ 8271 18432]
 [ 8272 16303]
 [ 8273 18818]
 [ 8274 17413]
 [ 8275 17154]
 [ 8276 17313]
 [ 8277 20203]
 [ 8278 16814]
 [ 8279 17912]
 [ 8280 18436]
 [ 8281 16592]
 [ 8282 16479]
 [ 8283 19459]
 [ 8284 19340]
 [ 8285 13556]
 [ 8286 15576]
 [ 8287 16568]
 [ 8288 14559]
 [ 8289 13981]
 [ 8290 19459]
 [ 8291 17762]
 [ 8292 17971]
 [ 8293 17971]
 [ 8294 13245]
 [ 8295 18818]
 [ 8296 10798]
 [ 8297 15265]
 [ 8298 13197]
 [ 8299 17045]
 [ 8300 19451]
 [ 8301 17762]
 [ 8302 16784]
 [ 8303 15892]
 [ 8304 18432]
 [ 8305 17718]
 [ 8306 15496]
 [ 8307  9148]
 [ 8308 16251]
 [ 8309 16238]
 [ 8310 16300]
 [ 8311 17032]
 [ 8312 15179]
 [ 8313 17963]
 [ 8314 18747]
 [ 8315 20203]
 [ 8316 18818]
 [ 8317 13752]
 [ 8318  8879]
 [ 8319 17963]
 [ 8320 10332]
 [ 8321 20203]
 [ 8322 16

[[ 4944 16530]
 [ 4945 18521]
 [ 4946 20203]
 [ 4947 15059]
 [ 4948 17589]
 [ 4949 17032]
 [ 4950 16601]
 [ 4951 14386]
 [ 4952 17189]
 [ 4953 18218]
 [ 4954 18797]
 [ 4955 12071]
 [ 4956 14044]
 [ 4957 12674]
 [ 4958 15644]
 [ 4959 18457]
 [ 4960 17553]
 [ 4961 12509]
 [ 4962 16841]
 [ 4963 16348]
 [ 4964 17718]
 [ 4965 20203]
 [ 4966 13810]
 [ 4967 17247]
 [ 4968 15149]
 [ 4969 17749]
 [ 4970 14432]
 [ 4971 17448]
 [ 4972 18658]
 [ 4973 15446]
 [ 4974 17313]
 [ 4975 16097]
 [ 4976 14954]
 [ 4977 18436]
 [ 4978 17971]
 [ 4979 18460]
 [ 4980 16814]
 [ 4981 18828]
 [ 4982 19459]
 [ 4983 10169]
 [ 4984 17592]
 [ 4985 15296]
 [ 4986  8266]
 [ 4987 18521]
 [ 4988 15240]
 [ 4989 14871]
 [ 4990 18818]
 [ 4991 17452]
 [ 4992 15522]
 [ 4993 17963]
 [ 4994 18552]
 [ 4995 18864]
 [ 4996 19459]
 [ 4997 18864]
 [ 4998 17718]
 [ 4999 12926]
 [ 5000 18432]
 [ 5001 17386]
 [ 5002 20203]
 [ 5003 17718]
 [ 5004 16592]
 [ 5005 11887]
 [ 5006 12436]
 [ 5007 16140]
 [ 5008 17908]
 [ 5009 17589]
 [ 5010 17

In [6]:
header = pd.read_csv('data/header.tsv', sep='\t')
cdf = pd.read_csv('data/sample.tsv', sep='\t', names=header.columns)
cdf = cdf.reset_index()
cdf[['index', 'Feature_7604_Impressions_log']].to_csv('data/sample2.tsv', sep='\t', index=None)

In [None]:
class MyIterableDataset2(torch.utils.data.IterableDataset):
    def __init__(self, epoch):
        super(MyIterableDataset).__init__()
        self.header = header = pd.read_csv('data/header.tsv', sep='\t')
        self.init_reader()
        self.epoch = epoch
    
    def init_reader(self):
        self.dfiter = iter(pd.read_csv('data/sample2.tsv', sep='\t', iterator=True, chunksize=48))

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        loop = self.epoch
        while loop > 0:
            try:
                cur_chunk = next(self.dfiter)
            except:
                cur_chunk = None
                loop -= 1
                self.init_reader()
            
            if cur_chunk is None:
                continue
            
            cur_chunk = cur_chunk.values
            if not worker_info is None:  # single-process data loading, return the full iterator
                assert(worker_info.num_workers == 1)
            
            yield cur_chunk
        
        return