In [1]:
from sqlalchemy import create_engine
from datetime import datetime
from datetime import timedelta
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
import feather
import os
import torch
import h5py
import pickle


# connect to db
user = 'root'
pswd = 'Curry5566'
host = '127.0.0.1'
port = '3306'
db = 'transport'
engine = create_engine(f"mysql+pymysql://{user}:{pswd}@{host}:{port}/{db}?charset=utf8")


def getEndDate(startDate: str, days: int) -> str:
    startDate += ' 00:00:00'
    end = str((datetime.strptime(startDate, '%Y-%m-%d %H:%M:%S') + timedelta(days=days)).replace(microsecond=0))
    return end

def getRollingMean(startDate: str, endDate: str) -> pd.DataFrame:
    """ Get the rolling mean from db
        ```text
        ---
        @Params
        startDate: The date for start, format='%Y-%m-%d'
        endDate: The date for end, format='%Y-%m-%d'

        ---
        @Returns
        DataFrame
        ```
    """
    sql  = " SELECT "
    sql += " 	STAC.VDID, STAC.RoadName, STAC.`Start`, STAC.`End`, "
    sql += " 	STAC.RoadDirection, DYMC.Speed, DYMC.Occupancy, DYMC.Volume, "
    sql += " 	STAC.LocationMile, DYMC.DataCollectTime "
    sql += " FROM ( "
    sql += " 	SELECT "
    sql += " 		VDSTC.id, VDSTC.VDID, ROAD.RoadName, SEC.`Start`, SEC.`End`, "
    sql += " 		VDSTC.RoadDirection, VDSTC.LocationMile "
    sql += " 	FROM vd_static_n5 VDSTC "
    sql += " 	JOIN road_info ROAD ON VDSTC.RoadInfoID = ROAD.id "
    sql += " 	JOIN section_info SEC ON ROAD.id = SEC.RoadInfoID "
    sql += " 	AND VDSTC.LocationMile >= SEC.StartKM "
    sql += " 	AND VDSTC.LocationMile <= SEC.EndKM "
    sql += " 	WHERE VDSTC.Mainlane = 1 "
    sql += " ) STAC JOIN ( "
    sql += " 	SELECT "
    sql += " 		VdStaticID, "
    sql += " 		CASE "
    sql += " 			WHEN MIN(Speed) = -99 THEN -99 "
    sql += " 			ELSE AVG(Speed) "
    sql += " 		END AS Speed,  "
    sql += " 		CASE "
    sql += " 			WHEN MIN(Occupancy) = -99 THEN -99 "
    sql += " 			ELSE AVG(Occupancy) "
    sql += " 		END AS Occupancy,  "
    sql += " 		CASE "
    sql += " 			WHEN MIN(Volume) = -99 THEN -99 "
    sql += " 			ELSE AVG(Volume) "
    sql += " 		END AS Volume, "
    sql += " 		MAX(DataCollectTime) AS DataCollectTime, "
    sql += " 		(UNIX_TIMESTAMP(DataCollectTime)-UNIX_TIMESTAMP(%(start)s)) DIV 300 "
    sql += " 	FROM vd_dynamic_detail_n5_202301 "
    sql += " 	WHERE id BETWEEN ( "
    sql += " 		SELECT id FROM vd_dynamic_detail_n5_202301 "
    sql += " 		WHERE DataCollectTime = %(start)s "
    sql += " 		ORDER BY id LIMIT 1 "
    sql += " 	) AND ( "
    sql += " 		SELECT id FROM vd_dynamic_detail_n5_202301 "
    sql += " 		WHERE DataCollectTime < %(end)s "
    sql += " 		ORDER BY id DESC LIMIT 1 "
    sql += " 	) "
    sql += " 	GROUP BY VdStaticID, (UNIX_TIMESTAMP(DataCollectTime)-UNIX_TIMESTAMP(%(start)s)) DIV 300 "
    sql += " ) DYMC ON STAC.id = DYMC.VdStaticID "
    sql += " ORDER BY STAC.RoadDirection, STAC.LocationMile, DYMC.DataCollectTime; "

    df = pd.read_sql(sql, con=engine, params={'start': startDate, 'end': endDate})
    engine.dispose()
    return df.sort_values(by=['RoadDirection','DataCollectTime','LocationMile']).reset_index(drop=True)

def getRollingMeanDaily(selectDate: str) -> pd.DataFrame:
    sql  = " SELECT "
    sql += " 	STAC.VDID, STAC.RoadName, STAC.`Start`, STAC.`End`, "
    sql += " 	STAC.RoadDirection, DYMC.Speed, DYMC.Occupancy, DYMC.Volume, "
    sql += " 	STAC.ActualLaneNum, STAC.LocationMile, STAC.isTunnel, DYMC.DataCollectTime "
    sql += " FROM ( "
    sql += " 	SELECT "
    sql += " 		VDSTC.id, VDSTC.VDID, ROAD.RoadName, SEC.`Start`, SEC.`End`, "
    sql += " 		VDSTC.ActualLaneNum, VDSTC.RoadDirection, VDSTC.LocationMile, "
    sql += "        CASE "
    sql += " 	        WHEN VDSTC.RoadDirection = 'S' AND VDSTC.LocationMile BETWEEN 15.203 AND 28.128 THEN 1 "
    sql += " 	        WHEN VDSTC.RoadDirection = 'N' AND VDSTC.LocationMile BETWEEN 15.179 AND 28.134 THEN 1 "
    sql += " 	        ELSE 0 "
    sql += "        END AS isTunnel "
    sql += " 	FROM fwy_n5.vd_static_2023 VDSTC "
    sql += " 	JOIN transport.road_info ROAD ON VDSTC.RoadInfoID = ROAD.id "
    sql += " 	JOIN transport.section_info SEC ON ROAD.id = SEC.RoadInfoID "
    sql += " 	AND VDSTC.LocationMile >= SEC.StartKM "
    sql += " 	AND VDSTC.LocationMile <= SEC.EndKM "
    sql += " 	WHERE VDSTC.Mainlane = 1 "
    sql += " ) STAC JOIN ( "
    sql += " 	SELECT "
    sql += " 		VdStaticID, "
    sql += " 		CASE "
    sql += " 			WHEN MIN(Speed) = -99 THEN -99 "
    sql += " 			ELSE AVG(Speed) "
    sql += " 		END AS Speed,  "
    sql += " 		CASE "
    sql += " 			WHEN MIN(Occupancy) = -99 THEN -99 "
    sql += " 			ELSE AVG(Occupancy) "
    sql += " 		END AS Occupancy,  "
    sql += " 		CASE "
    sql += " 			WHEN MIN(Volume) = -99 THEN -99 "
    sql += " 			ELSE AVG(Volume) "
    sql += " 		END AS Volume, "
    sql += " 		MAX(DataCollectTime) AS DataCollectTime, "
    sql += " 		(UNIX_TIMESTAMP(DataCollectTime)-UNIX_TIMESTAMP(%(selectDate)s)) DIV 300 "
    sql += " 	FROM fwy_n5.vd_dynamic_detail_{} ".format(selectDate.replace('-',''))
    sql += " 	GROUP BY VdStaticID, (UNIX_TIMESTAMP(DataCollectTime)-UNIX_TIMESTAMP(%(selectDate)s)) DIV 300 "
    sql += " ) DYMC ON STAC.id = DYMC.VdStaticID "
    sql += " ORDER BY STAC.RoadDirection, STAC.LocationMile, DYMC.DataCollectTime; "

    df = pd.read_sql(sql, con=engine, params={'selectDate': selectDate})
    engine.dispose()
    return df.sort_values(by=['RoadDirection','DataCollectTime','LocationMile']).reset_index(drop=True)

def groupVDs(df: pd.DataFrame, each: int) -> dict:
    """ Get the dict of VD groups
        ```text
        ---
        @Params
        df: DataFrame which is referenced by.
        each: The quantity of VDs would be considered as a group.

        ---
        @Returns
        vdGroups: The keys are the VDs we focus on, and the values are the collections of VDs which are correlated corresponding to the keys.
        ```
    """
    vdGroups = {}
    lb = each // 2
    ub = each - (each // 2)
    for vdid in df['VDID'].unique():
        vdGroups.setdefault(f"{vdid}", [])
    for no, vdid in enumerate(df['VDID'].unique()):
        startIdx = max(no-lb, 0)
        endIdx = min(no+ub, len(df['VDID'].unique())-1)
        vdGroups[f"{vdid}"] += list(df['VDID'].unique()[startIdx:no]) + list(df['VDID'].unique()[no:endIdx])

    delList = []
    for k in vdGroups.keys():
        if (len(vdGroups[k]) != each):
            delList.append(k)
    for k in delList:
        del vdGroups[k]
    
    return vdGroups

def genSamples(df: pd.DataFrame, vdGroups: dict, groupKey: str, each: int, timeWindow: int = 30) -> tuple:
    """ Generate samples for each traffic data (speed, volume, and occupancy)
        ```text
        ---
        @Params
        df: 
        vdGroups: The outpur of groupVDs(),
        groupKey: The key of vdGroups,
        each: The quantity of VDs would be considered as a group,
        timeWindow: The length of period we consider, and the default value is 30 (minutes).

        ---
        @Returns
        speeds: list with each item as a tuple, all of them are represented (X,y).
        vols: list with each item as a tuple, all of them are represented (X,y).
        occs: list with each item as a tuple, all of them are represented (X,y).
        ```
    """
    speeds, vols, occs = [], [], []
    tmpDf = df.loc[(df['VDID'].isin(vdGroups[f"{groupKey}"]))].sort_values(by=['LocationMile', 'DataCollectTime'])

    indices = [x for x in range(0, tmpDf.shape[0]+1, tmpDf.shape[0]//each)]
    speedMatx = np.zeros((each, tmpDf.shape[0]//each))
    volMatx = np.zeros((each, tmpDf.shape[0]//each))
    occMatx = np.zeros((each, tmpDf.shape[0]//each))
    for i, j, k in zip(range(each), indices[:-1], indices[1:]):
        speedMatx[i] += tmpDf.iloc[j:k,:]['Speed'].to_numpy()
        volMatx[i] += tmpDf.iloc[j:k,:]['Volume'].to_numpy()
        occMatx[i] += tmpDf.iloc[j:k,:]['Occupancy'].to_numpy()

    sliceLen = int((timeWindow / 5) + 1)
    for x in range(speedMatx.shape[1]//sliceLen*sliceLen-(sliceLen-1)):
        speeds.append((speedMatx[:,x:x+sliceLen][:,:-1], speedMatx[:,x:x+sliceLen][:,[-1]]))
        vols.append((volMatx[:,x:x+sliceLen][:,:-1], volMatx[:,x:x+sliceLen][:,[-1]]))
        occs.append((occMatx[:,x:x+sliceLen][:,:-1], occMatx[:,x:x+sliceLen][:,[-1]]))
    
    return speeds, vols, occs

def genTensors(speeds: list, vols: list) -> list:
    """ Generate torch.Tensors.
        The sizes of the tensors are `[batch, 2, each, 6]`, and `each` depends on how many VDs regarded as a group.
    """
    dataCollection = []
    for s, v in zip(speeds, vols):
        s = torch.tensor(s, dtype=torch.float).unsqueeze(0).unsqueeze(0)
        v = torch.tensor(v, dtype=torch.float).unsqueeze(0).unsqueeze(0)
        dataCollection.append(torch.concat([s, v], dim=1))
    return dataCollection

def train_test_split(speedCollection, volCollection, train_size=None, test_size=None, random_number=42):
    np.random.seed(random_number)
    if train_size:
        trainDataIdx = np.random.choice(
            len(speedCollection),
            int(train_size * len(speedCollection)),
            replace=False
        )
        testDataIdx = set([i for i in range(len(speedCollection))]) -\
                      set(trainDataIdx)
    
    elif test_size:
        testDataIdx = np.random.choice(
            len(speedCollection),
            int(test_size * len(speedCollection)),
            replace=False
        )
        trainDataIdx = set([i for i in range(len(speedCollection))]) -\
                       set(testDataIdx)
        
    trainSpeed = list(pd.Series(speedCollection)[list(trainDataIdx)])
    trainVol = list(pd.Series(volCollection)[list(trainDataIdx)])
    testSpeed = list(pd.Series(speedCollection)[list(testDataIdx)])
    testVol = list(pd.Series(volCollection)[list(testDataIdx)])

    return trainSpeed, trainVol, testSpeed, testVol

In [None]:
# # 取得一年份資料
# firstDate = list(map(lambda x: datetime.strftime(x, '%Y-%m-%d'), list(pd.date_range('2023-01-01', '2023-12-31', freq='MS'))))
# lastDate = list(map(lambda x: datetime.strftime(x, '%Y-%m-%d'), list(pd.date_range('2023-01-01', '2023-12-31', freq='ME'))))
# for first, last in zip(firstDate, lastDate):
#     dataframes = []
#     dateList = list(map(lambda x: datetime.strftime(x, '%Y-%m-%d'), list(pd.date_range(first, last))))
#     for date in dateList:
#         print(date)
#         dataframes.append(getRollingMeanDaily(date))
#     dataframes = pd.concat(dataframes).reset_index(drop=True)
#     display(dataframes)
#     feather.write_dataframe(dataframes, dest=f"./nfb2023/{date[:7].replace('-','')}.feather")

In [None]:
monthlyStarts = list(map(lambda x: datetime.strftime(x, '%Y-%m-%d'), list(pd.date_range('2023-02-01', '2023-02-28', freq='MS'))))
monthlyEnds = list(map(lambda x: datetime.strftime(x, '%Y-%m-%d'), list(pd.date_range('2023-02-01', '2023-02-28', freq='ME'))))

for start, end in zip(monthlyStarts, monthlyEnds):
    dataframes = []
    dateList = list(map(lambda x: datetime.strftime(x, '%Y-%m-%d'), list(pd.date_range(start, end))))
    print(start[:7].replace('-',''))
    for date in dateList:
        print(date)
        dataframes.append(getRollingMeanDaily(date))
    dataframes = pd.concat(dataframes).reset_index(drop=True)
    # display(dataframes)
    feather.write_dataframe(dataframes, dest=f"./nfb2023/{start[:7].replace('-','')}.feather")

In [None]:
os.listdir('./nfb2023')

In [2]:
df = []
for filename in os.listdir('./nfb2023'):
    monthlyDf = feather.read_dataframe(f"./nfb2023/{filename}")
    if (len(df) == 0):
        df.append(monthlyDf)
    else:
        currDf = pd.concat(df).reset_index(drop=True)
        monthlyDf = monthlyDf.loc[monthlyDf['VDID'].isin(set(currDf['VDID']))]
        df.append(monthlyDf)
df = pd.concat(df).reset_index(drop=True)
df

Unnamed: 0,VDID,RoadName,Start,End,RoadDirection,Speed,Occupancy,Volume,ActualLaneNum,LocationMile,DataCollectTime
0,VD-N5-N-0.178-M-LOOP,國道5號,南港系統交流道,石碇交流道,N,-99.0000,-99.0000,-99.0000,2,0.178,2023-01-01 00:04:00
1,VD-N5-N-0.706-M-LOOP,國道5號,南港系統交流道,石碇交流道,N,-99.0000,-99.0000,-99.0000,2,0.706,2023-01-01 00:04:00
2,VD-N5-N-1.068-M-LOOP,國道5號,南港系統交流道,石碇交流道,N,-99.0000,-99.0000,-99.0000,2,1.068,2023-01-01 00:04:00
3,VD-N5-N-2.068-M-PS-LOOP,國道5號,南港系統交流道,石碇交流道,N,94.6000,4.1000,5.3000,2,2.068,2023-01-01 00:04:00
4,VD-N5-N-3.198-M-LOOP,國道5號,南港系統交流道,石碇交流道,N,90.8000,4.8000,5.1000,2,3.198,2023-01-01 00:04:00
...,...,...,...,...,...,...,...,...,...,...,...
9819329,VD-N5-S-41.298-M-LOOP,國道5號,宜蘭交流道,羅東交流道,S,92.5000,1.6667,2.6667,2,41.298,2023-12-31 23:57:00
9819330,VD-N5-S-44.202-M-LOOP,國道5號,宜蘭交流道,羅東交流道,S,84.3333,2.3333,3.3333,2,44.202,2023-12-31 23:57:00
9819331,VD-N5-S-46.566-M-LOOP,國道5號,宜蘭交流道,羅東交流道,S,65.8333,1.8333,1.5000,2,46.566,2023-12-31 23:57:00
9819332,VD-N5-S-48.040-M-LOOP,國道5號,羅東交流道,蘇澳交流道,S,62.6667,0.6667,1.0000,2,48.040,2023-12-31 23:57:00


In [3]:
df.to_csv('./df.csv', index=False, encoding='utf_8_sig')

In [None]:
# # TODO: main
# if __name__ == '__main__':
#     # # read feather files to get dataframes
#     # startDate = '2023-01-01'
#     # endDate = getEndDate(startDate, days=10)
#     # # df = getRollingMean(startDate, endDate)
#     # df = feather.read_dataframe('./20230101-20230110.feather').sort_values(by=['RoadDirection','DataCollectTime','LocationMile']).reset_index(drop=True)
    
#     # Northbound data
#     northDf = df.loc[df['RoadDirection']=='N'].reset_index(drop=True)
#     each = 3
#     vdGroups = groupVDs(northDf, each)    
#     speedDataset, volDataset, occDataset = [], [], []
#     for groupKey in vdGroups.keys():
#         speeds, vols, occs = genSamples(northDf, vdGroups, groupKey, each, timeWindow=30)
#         speedDataset.append(speeds)
#         volDataset.append(vols)
#         occDataset.append(occs)

#     # Southbound data
#     southDf = df.loc[df['RoadDirection']=='S'].reset_index(drop=True)
#     each = 3
#     vdGroups = groupVDs(southDf, each)    
#     speedDataset, volDataset, occDataset = [], [], []
#     for groupKey in vdGroups.keys():
#         speeds, vols, occs = genSamples(southDf, vdGroups, groupKey, each, timeWindow=30)
#         speedDataset.append(speeds)
#         volDataset.append(vols)
#         occDataset.append(occs)

In [None]:
## test cell for missing data
# This part is genSamples()
each = 3
timeWindow = 30

# df = feather.read_dataframe("./nfb2023/202305.feather")
northDf = df.loc[df['RoadDirection']=='N'].reset_index(drop=True)
vdGroups = groupVDs(northDf, each)
groupKey = 'VD-N5-N-1.068-M-LOOP'



speeds, vols, occs = [], [], []
tmpDf = df.loc[(df['VDID'].isin(vdGroups[f"{groupKey}"]))].sort_values(by=['LocationMile', 'DataCollectTime'])

indices = [x for x in range(0, tmpDf.shape[0]+1, tmpDf.shape[0]//each)]
mileMatx = np.zeros((each, tmpDf.shape[0]//each))
speedMatx = np.zeros((each, tmpDf.shape[0]//each))
volMatx = np.zeros((each, tmpDf.shape[0]//each))
occMatx = np.zeros((each, tmpDf.shape[0]//each))
for i, j, k in zip(range(each), indices[:-1], indices[1:]):
    mileMatx[i] += tmpDf.iloc[j:k,:]['LocationMile'].to_numpy()
    speedMatx[i] += tmpDf.iloc[j:k,:]['Speed'].to_numpy()
    volMatx[i] += tmpDf.iloc[j:k,:]['Volume'].to_numpy()
    occMatx[i] += tmpDf.iloc[j:k,:]['Occupancy'].to_numpy()

# sliceLen = int((timeWindow / 5) + 1)
# for x in range(speedMatx.shape[1]//sliceLen*sliceLen-(sliceLen-1)):
#     speeds.append((speedMatx[:,x:x+sliceLen][:,:-1], speedMatx[:,x:x+sliceLen][:,[-1]]))
#     vols.append((volMatx[:,x:x+sliceLen][:,:-1], volMatx[:,x:x+sliceLen][:,[-1]]))
#     occs.append((occMatx[:,x:x+sliceLen][:,:-1], occMatx[:,x:x+sliceLen][:,[-1]]))

In [None]:
mileMatx

In [None]:
testDf['LocationMile'].nunique()

In [None]:
testDf = northDf.sort_values(by=['LocationMile', 'DataCollectTime'])
n = testDf['LocationMile'].nunique()

indices__ = [x for x in range(0, testDf.shape[0]+1, testDf.shape[0]//n)]
mileMatx__ = np.zeros((n, testDf.shape[0]//n))
speedMatx__ = np.zeros((n, testDf.shape[0]//n))
volMatx__ = np.zeros((n, testDf.shape[0]//n))
occMatx__ = np.zeros((n, testDf.shape[0]//n))
for i, j, k in zip(range(n), indices__[:-1], indices__[1:]):
    mileMatx__[i] += testDf.iloc[j:k,:]['LocationMile'].to_numpy()
    speedMatx__[i] += testDf.iloc[j:k,:]['Speed'].to_numpy()
    volMatx__[i] += testDf.iloc[j:k,:]['Volume'].to_numpy()
    occMatx__[i] += testDf.iloc[j:k,:]['Occupancy'].to_numpy()

In [None]:
m = mileMatx__[:5,:5].copy()
m

In [None]:
m - np.ones(m.shape) * m[0]

In [None]:
speedMatx__[:5,:5]

In [None]:
volMatx__[:5,:5]

In [None]:
speedMatx[:,:5]

In [None]:
volMatx[:,:5]

In [None]:
occMatx[:,:5]

## Create Dataset for CNN (Short-term Prediction)

In [3]:
EACH = 3
speedCollection, volCollection, occCollection = [], [], []

# Northbound data
northDf = df.loc[df['RoadDirection']=='N'].reset_index(drop=True)
print(f"northDf start grouping: {datetime.strftime(datetime.now(), '%Y-%m-%d %H:%M:%S')}")
northVDGrps = groupVDs(northDf, each=EACH)
print(f"northDf end grouping: {datetime.strftime(datetime.now(), '%Y-%m-%d %H:%M:%S')}")
for groupKey in northVDGrps.keys():
    print(groupKey)
    speeds, vols, occs = genSamples(northDf, northVDGrps, groupKey, each=EACH, timeWindow=30)
    speedCollection += speeds
    volCollection += vols
    occCollection += occs

# Southbound data
southDf = df.loc[df['RoadDirection']=='S'].reset_index(drop=True)
print(f"southDf start grouping: {datetime.strftime(datetime.now(), '%Y-%m-%d %H:%M:%S')}")
southVDGrps = groupVDs(southDf, each=EACH)
print(f"southDf end grouping: {datetime.strftime(datetime.now(), '%Y-%m-%d %H:%M:%S')}")
# speedDataset, volDataset, occDataset = [], [], []
for groupKey in southVDGrps.keys():
    print(groupKey)
    speeds, vols, occs = genSamples(southDf, southVDGrps, groupKey, each=EACH, timeWindow=30)
    speedCollection += speeds
    volCollection += vols
    occCollection += occs

northDf start grouping: 2024-03-27 23:09:53
northDf end grouping: 2024-03-27 23:11:12
VD-N5-N-0.706-M-LOOP
VD-N5-N-1.068-M-LOOP
VD-N5-N-2.068-M-PS-LOOP
VD-N5-N-3.198-M-LOOP
VD-N5-N-3.943-M-LOOP
VD-N5-N-5.883-M-LOOP
VD-N5-N-7.107-M-LOOP
VD-N5-N-8.011-M-LOOP
VD-N5-N-9.840-M-LOOP
VD-N5-N-10.866-M-PS-LOOP
VD-N5-N-11.903-M-PS-LOOP
VD-N5-N-12.922-M-LOOP
VD-N5-N-13.707-M-LOOP
VD-N5-N-14.550-M-LOOP
VD-N5-N-15.488-M-LOOP
VD-N5-N-16.196-M-LOOP
VD-N5-N-16.900-M-PS-LOOP
VD-N5-N-17.608-M-LOOP
VD-N5-N-18.313-M-PS-LOOP
VD-N5-N-19.012-M-LOOP
VD-N5-N-19.689-M-PS-LOOP
VD-N5-N-20.412-M-LOOP
VD-N5-N-21.055-M-PS-LOOP
VD-N5-N-21.808-M-LOOP
VD-N5-N-22.510-M-PS-LOOP
VD-N5-N-23.209-M-LOOP
VD-N5-N-23.911-M-PS-LOOP
VD-N5-N-24.677-M-LOOP
VD-N5-N-25.310-M-PS-LOOP
VD-N5-N-26.007-M-LOOP
VD-N5-N-26.705-M-PS-LOOP
VD-N5-N-27.468-M-LOOP
VD-N5-N-27.779-M-LOOP
VD-N5-N-28.420-M-LOOP
VD-N5-N-29.000-M-LOOP
VD-N5-N-30.100-M-LOOP
VD-N5-N-30.551-M-LOOP
VD-N5-N-31.540-M-LOOP
VD-N5-N-32.120-M-LOOP
VD-N5-N-32.743-M-LOOP
VD-N5-N-33

In [25]:
class CNNDataset(Dataset):
    def __init__(
            self,
            speed_data: list = None,
            volume_data: list = None,
            load_ckpt: bool = None,
            mode: str = None,
            ckpt_dir: str = './datasets/cnndataset'
    ) -> None:
        if (speed_data):
            self.speedFeature = [speed_data[x][0] for x in range(len(speed_data))]
            self.volFeature = [volume_data[x][0] for x in range(len(volume_data))]
            self.speedLabels = [speed_data[x][1][[1],:] for x in range(len(speed_data))]
            self.volLabels = [volume_data[x][1][[1],:] for x in range(len(volume_data))]
        
        else:
            if (load_ckpt) and (mode == 'train'):
                with h5py.File(f"{ckpt_dir}/{mode}/{mode}_speed_feature.h5", 'r') as file:
                    self.speedFeature = file[f"{mode}_speed_feature"][:]
                with h5py.File(f"{ckpt_dir}/{mode}/{mode}_volume_feature.h5", 'r') as file:
                    self.volFeature = file[f"{mode}_volume_feature"][:]
                with h5py.File(f"{ckpt_dir}/{mode}/{mode}_speed_label.h5", 'r') as file:
                    self.speedLabels = file[f"{mode}_speed_label"][:]
                with h5py.File(f"{ckpt_dir}/{mode}/{mode}_volume_label.h5", 'r') as file:
                    self.volLabels = file[f"{mode}_volume_label"][:]
            
            elif (load_ckpt) and (mode == 'test'):
                with h5py.File(f"{ckpt_dir}/{mode}/{mode}_speed_feature.h5", 'r') as file:
                    self.speedFeature = file[f"{mode}_speed_feature"][:]
                with h5py.File(f"{ckpt_dir}/{mode}/{mode}_volume_feature.h5", 'r') as file:
                    self.volFeature = file[f"{mode}_volume_feature"][:]
                with h5py.File(f"{ckpt_dir}/{mode}/{mode}_speed_label.h5", 'r') as file:
                    self.speedLabels = file[f"{mode}_speed_label"][:]
                with h5py.File(f"{ckpt_dir}/{mode}/{mode}_volume_label.h5", 'r') as file:
                    self.volLabels = file[f"{mode}_volume_label"][:]

    def __len__(self) -> int:
        return len(self.speedFeature)
    
    def __getitem__(self, idx: int) -> torch.Tensor:
        f1 = torch.tensor(self.speedFeature[idx], dtype=torch.float).unsqueeze(0).unsqueeze(0)
        f2 = torch.tensor(self.volFeature[idx], dtype=torch.float).unsqueeze(0).unsqueeze(0)
        l1 = torch.tensor(self.speedLabels[idx], dtype=torch.float).unsqueeze(0)
        l2 = torch.tensor(self.volLabels[idx], dtype=torch.float).unsqueeze(0)
        feature = torch.concat([f1, f2], dim=1)
        label = torch.concat([l1, l2], dim=1)
        return feature, label

### Create datasets

In [5]:
trainSpeed, trainVol, testSpeed, testVol =\
    train_test_split(speedCollection, volCollection, test_size=0.2)

In [6]:
trainDataset = CNNDataset(speed_data=trainSpeed, volume_data=trainVol)
testDataset = CNNDataset(speed_data=testSpeed, volume_data=testVol)

In [27]:
tedo = CNNDataset(speed_data=testSpeed, volume_data=testVol)

In [28]:
len(tedo)

1838408

### Save dataset as `.h5`

In [15]:
with h5py.File('./datasets/cnndataset/train_speed_feature.h5', 'w') as f:
    f.create_dataset('train_speed_feature', data=trainDataset.speedFeature)

with h5py.File('./datasets/cnndataset/train_volume_feature.h5', 'w') as f:
    f.create_dataset('train_volume_feature', data=trainDataset.volFeature)

with h5py.File('./datasets/cnndataset/train_speed_label.h5', 'w') as f:
    f.create_dataset('train_speed_label', data=trainDataset.speedLabels)

with h5py.File('./datasets/cnndataset/train_volume_label.h5', 'w') as f:
    f.create_dataset('train_volume_label', data=trainDataset.volLabels)

In [16]:
with h5py.File('./datasets/cnndataset/test_speed_feature.h5', 'w') as f:
    f.create_dataset('test_speed_feature', data=testDataset.speedFeature)

with h5py.File('./datasets/cnndataset/test_volume_feature.h5', 'w') as f:
    f.create_dataset('test_volume_feature', data=testDataset.volFeature)

with h5py.File('./datasets/cnndataset/test_speed_label.h5', 'w') as f:
    f.create_dataset('test_speed_label', data=testDataset.speedLabels)

with h5py.File('./datasets/cnndataset/test_volume_label.h5', 'w') as f:
    f.create_dataset('test_volume_label', data=testDataset.volLabels)

### Directly load dataset

You can also load datasets from `.h5` file if you have saved them.

In [26]:
trainDataset = CNNDataset(load_ckpt=True, mode='train', ckpt_dir='./datasets/cnndataset')
testDataset = CNNDataset(load_ckpt=True, mode='test', ckpt_dir='./datasets/cnndataset')