In [1]:
from sqlalchemy import create_engine
from datetime import datetime
from datetime import timedelta
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from dotenv import load_dotenv
from toolkits.datapreparing import download_monthly_tables, collect_data
from toolkits.datasets import CNNDataset, train_test_split, load_next_5min
from tqdm.auto import tqdm
import pandas as pd
import numpy as np
import feather
import os
import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


load_dotenv('./.env')
engine = create_engine(os.getenv('DB_ENGINE'))


# def getEndDate(startDate: str, days: int) -> str:
#     startDate += ' 00:00:00'
#     end = str((datetime.strptime(startDate, '%Y-%m-%d %H:%M:%S') + timedelta(days=days)).replace(microsecond=0))
#     return end

# def getRollingMeanDaily(selectDate: str) -> pd.DataFrame:
#     sql  = " SELECT "
#     sql += " 	STAC.VDID, STAC.RoadName, STAC.`Start`, STAC.`End`, STAC.RoadDirection, "
#     sql += "    CASE "
#     sql += "        WHEN DYMC.Occupancy = 0 AND DYMC.Volume = 0 THEN 100 "
#     sql += "        ELSE DYMC.Speed "
#     sql += " 	END AS Speed, "
#     sql += "    DYMC.Occupancy, DYMC.Volume, "
#     sql += " 	STAC.ActualLaneNum, STAC.LocationMile, STAC.isTunnel, DYMC.DataCollectTime "
#     sql += " FROM ( "
#     sql += " 	SELECT "
#     sql += " 		VDSTC.id, VDSTC.VDID, ROAD.RoadName, SEC.`Start`, SEC.`End`, "
#     sql += " 		VDSTC.ActualLaneNum, VDSTC.RoadDirection, VDSTC.LocationMile, "
#     sql += "        CASE "
#     sql += " 	        WHEN VDSTC.RoadDirection = 'S' AND VDSTC.LocationMile BETWEEN 0.238 AND 0.694 THEN 1 "
#     sql += " 	        WHEN VDSTC.RoadDirection = 'N' AND VDSTC.LocationMile BETWEEN 0.235 AND 0.690 THEN 1 "
#     sql += " 	        WHEN VDSTC.RoadDirection = 'S' AND VDSTC.LocationMile BETWEEN 0.694 AND 3.481 THEN 1 "
#     sql += " 	        WHEN VDSTC.RoadDirection = 'N' AND VDSTC.LocationMile BETWEEN 0.795 AND 3.515 THEN 1 "
#     sql += " 	        WHEN VDSTC.RoadDirection = 'S' AND VDSTC.LocationMile BETWEEN 7.677 AND 7.893 THEN 1 "
#     sql += " 	        WHEN VDSTC.RoadDirection = 'N' AND VDSTC.LocationMile BETWEEN 7.646 AND 7.894 THEN 1 "
#     sql += " 	        WHEN VDSTC.RoadDirection = 'S' AND VDSTC.LocationMile BETWEEN 9.442 AND 13.303 THEN 1 "
#     sql += " 	        WHEN VDSTC.RoadDirection = 'N' AND VDSTC.LocationMile BETWEEN 9.457 AND 13.263 THEN 1 "
#     sql += " 	        WHEN VDSTC.RoadDirection = 'S' AND VDSTC.LocationMile BETWEEN 15.203 AND 28.128 THEN 1 "
#     sql += " 	        WHEN VDSTC.RoadDirection = 'N' AND VDSTC.LocationMile BETWEEN 15.179 AND 28.134 THEN 1 "
#     sql += " 	        ELSE 0 "
#     sql += "        END AS isTunnel "
#     sql += " 	FROM fwy_n5.vd_static_2023 VDSTC "
#     sql += " 	JOIN transport.road_info ROAD ON VDSTC.RoadInfoID = ROAD.id "
#     sql += " 	JOIN transport.section_info SEC ON ROAD.id = SEC.RoadInfoID "
#     sql += " 	AND VDSTC.LocationMile >= SEC.StartKM "
#     sql += " 	AND VDSTC.LocationMile <= SEC.EndKM "
#     sql += " 	WHERE VDSTC.Mainlane = 1 "
#     sql += " ) STAC JOIN ( "
#     sql += " 	SELECT "
#     sql += " 		VdStaticID, "
#     sql += " 		CASE "
#     sql += " 			WHEN MIN(Speed) = -99 THEN -99 "
#     sql += " 			ELSE AVG(Speed) "
#     sql += " 		END AS Speed,  "
#     sql += " 		CASE "
#     sql += " 			WHEN MIN(Occupancy) = -99 THEN -99 "
#     sql += " 			ELSE AVG(Occupancy) "
#     sql += " 		END AS Occupancy,  "
#     sql += " 		CASE "
#     sql += " 			WHEN MIN(Volume) = -99 THEN -99 "
#     sql += " 			ELSE AVG(Volume) "
#     sql += " 		END AS Volume, "
#     sql += " 		MAX(DataCollectTime) AS DataCollectTime, "
#     sql += " 		(UNIX_TIMESTAMP(DataCollectTime)-UNIX_TIMESTAMP(%(selectDate)s)) DIV 300 "
#     sql += " 	FROM fwy_n5.vd_dynamic_detail_{} ".format(selectDate.replace('-',''))
#     sql += " 	GROUP BY VdStaticID, (UNIX_TIMESTAMP(DataCollectTime)-UNIX_TIMESTAMP(%(selectDate)s)) DIV 300 "
#     sql += " ) DYMC ON STAC.id = DYMC.VdStaticID "
#     sql += " ORDER BY STAC.RoadDirection, STAC.LocationMile, DYMC.DataCollectTime; "

#     df = pd.read_sql(sql, con=engine, params={'selectDate': selectDate})
#     engine.dispose()
#     return df.sort_values(by=['RoadDirection','DataCollectTime','LocationMile']).reset_index(drop=True)

# def groupVDs(df: pd.DataFrame, each: int) -> dict:
#     """ Get the dict of VD groups
#         ```text
#         ---
#         @Params
#         df: DataFrame which is referenced by.
#         each: The quantity of VDs would be considered as a group.

#         ---
#         @Returns
#         vdGroups: The keys are the VDs we focus on, and the values are the collections of VDs which are correlated corresponding to the keys.
#         ```
#     """
#     vdGroups = {}
#     lb = each // 2
#     ub = each - (each // 2)
#     for vdid in df['VDID'].unique():
#         vdGroups.setdefault(f"{vdid}", [])
#     for no, vdid in enumerate(df['VDID'].unique()):
#         startIdx = max(no-lb, 0)
#         endIdx = min(no+ub, len(df['VDID'].unique())-1)
#         vdGroups[f"{vdid}"] += list(df['VDID'].unique()[startIdx:no]) + list(df['VDID'].unique()[no:endIdx])

#     delList = []
#     for k in vdGroups.keys():
#         if (len(vdGroups[k]) != each):
#             delList.append(k)
#     for k in delList:
#         del vdGroups[k]
    
#     return vdGroups

# def genSamples(df: pd.DataFrame, vdGroups: dict, groupKey: str, each: int, timeWindow: int = 30) -> tuple:
#     """ Generate samples for each traffic data (speed, volume, and occupancy)
#         ```text
#         ---
#         @Params
#         df: 
#         vdGroups: The outpur of groupVDs(),
#         groupKey: The key of vdGroups,
#         each: The quantity of VDs would be considered as a group,
#         timeWindow: The length of period we consider, and the default value is 30 (minutes).

#         ---
#         @Returns
#         speeds: list with each item as a tuple, all of them are represented (X,y).
#         vols: list with each item as a tuple, all of them are represented (X,y).
#         occs: list with each item as a tuple, all of them are represented (X,y).
#         ```
#     """
#     speeds, vols, occs, lanes, tunnels = [], [], [], [], []
#     tmpDf = df.loc[(df['VDID'].isin(vdGroups[f"{groupKey}"]))].sort_values(by=['LocationMile', 'DataCollectTime'])

#     indices = [x for x in range(0, tmpDf.shape[0]+1, tmpDf.shape[0]//each)]
#     speedMatx = np.zeros((each, tmpDf.shape[0]//each))
#     volMatx = np.zeros((each, tmpDf.shape[0]//each))
#     occMatx = np.zeros((each, tmpDf.shape[0]//each))
#     laneMatx = np.zeros((each, tmpDf.shape[0]//each))
#     tunnelMatx = np.zeros((each, tmpDf.shape[0]//each))
#     for i, j, k in zip(range(each), indices[:-1], indices[1:]):
#         speedMatx[i] += tmpDf.iloc[j:k,:]['Speed'].to_numpy()
#         volMatx[i] += tmpDf.iloc[j:k,:]['Volume'].to_numpy()
#         occMatx[i] += tmpDf.iloc[j:k,:]['Occupancy'].to_numpy()
#         laneMatx[i] += tmpDf.iloc[j:k,:]['ActualLaneNum'].to_numpy()
#         tunnelMatx[i] += tmpDf.iloc[j:k,:]['isTunnel'].to_numpy()

#     sliceLen = int((timeWindow / 5) + 1)
#     for x in range(speedMatx.shape[1]//sliceLen*sliceLen-(sliceLen-1)):
#         speeds.append((speedMatx[:,x:x+sliceLen][:,:-1], speedMatx[:,x:x+sliceLen][:,[-1]]))
#         vols.append((volMatx[:,x:x+sliceLen][:,:-1], volMatx[:,x:x+sliceLen][:,[-1]]))
#         occs.append((occMatx[:,x:x+sliceLen][:,:-1], occMatx[:,x:x+sliceLen][:,[-1]]))
#         lanes.append((laneMatx[:,x:x+sliceLen][:,:-1], laneMatx[:,x:x+sliceLen][:,[-1]]))
#         tunnels.append((occMatx[:,x:x+sliceLen][:,:-1], tunnelMatx[:,x:x+sliceLen][:,[-1]]))
    
#     return speeds, vols, occs, lanes, tunnels

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
## main

# download_monthly_tables(start='2023-01-01', end='2023-12-31', dest_dir='./nfb2023', file_format='feather')
speedCollection, volCollection, occCollection, laneCollection, tunnelCollection = collect_data()

trainSpeed, trainVol, trainOcc, trainNumLane, trainTunnel,\
testSpeed, testVol, testOcc, testNumLane, testTunnel =\
    train_test_split(speedCollection, volCollection, occCollection, laneCollection, tunnelCollection, test_size=0.2)

trainDataset = CNNDataset(speed_data=trainSpeed, volume_data=trainVol, occupy_data=trainOcc,
                          lane_data=trainNumLane, tunnel_data=trainTunnel, load_ckpt=False, mode='train')
testDataset = CNNDataset(speed_data=testSpeed, volume_data=testVol, occupy_data=testOcc,
                         lane_data=testNumLane, tunnel_data=testTunnel, load_ckpt=False, mode='test')

In [None]:
pd.read_feather('./nfb2023/202301.feather')

In [4]:
trainDataset[0][0].shape

torch.Size([5, 3, 6])

In [None]:
# 取得一年份資料
firstDate = list(map(lambda x: datetime.strftime(x, '%Y-%m-%d'), list(pd.date_range('2023-01-01', '2023-12-31', freq='MS'))))
lastDate = list(map(lambda x: datetime.strftime(x, '%Y-%m-%d'), list(pd.date_range('2023-01-01', '2023-12-31', freq='ME'))))
for first, last in zip(firstDate, lastDate):
    dataframes = []
    dateList = list(map(lambda x: datetime.strftime(x, '%Y-%m-%d'), list(pd.date_range(first, last))))
    for date in dateList:
        print(date)
        dataframes.append(getRollingMeanDaily(date))
    dataframes = pd.concat(dataframes).reset_index(drop=True)
    # display(dataframes)
    feather.write_dataframe(dataframes, dest=f"./nfb2023/{date[:7].replace('-','')}.feather")

In [None]:
monthlyStarts = list(map(lambda x: datetime.strftime(x, '%Y-%m-%d'), list(pd.date_range('2023-02-01', '2023-02-28', freq='MS'))))
monthlyEnds = list(map(lambda x: datetime.strftime(x, '%Y-%m-%d'), list(pd.date_range('2023-02-01', '2023-02-28', freq='ME'))))

for start, end in zip(monthlyStarts, monthlyEnds):
    dataframes = []
    dateList = list(map(lambda x: datetime.strftime(x, '%Y-%m-%d'), list(pd.date_range(start, end))))
    print(start[:7].replace('-',''))
    for date in dateList:
        print(date)
        dataframes.append(getRollingMeanDaily(date))
    dataframes = pd.concat(dataframes).reset_index(drop=True)
    # display(dataframes)
    feather.write_dataframe(dataframes, dest=f"./nfb2023/{start[:7].replace('-','')}.feather")

In [None]:
os.listdir('./nfb2023')

In [None]:
df = []
for filename in os.listdir('./nfb2023'):
    monthlyDf = feather.read_dataframe(f"./nfb2023/{filename}")
    if (len(df) == 0):
        df.append(monthlyDf)
    else:
        currDf = pd.concat(df).reset_index(drop=True)
        monthlyDf = monthlyDf.loc[monthlyDf['VDID'].isin(set(currDf['VDID']))]
        df.append(monthlyDf)
df = pd.concat(df).reset_index(drop=True)
df

In [None]:
df.to_csv('./df.csv', index=False, encoding='utf_8_sig')

In [None]:
# # TODO: main
# if __name__ == '__main__':
#     # # read feather files to get dataframes
#     # startDate = '2023-01-01'
#     # endDate = getEndDate(startDate, days=10)
#     # # df = getRollingMean(startDate, endDate)
#     # df = feather.read_dataframe('./20230101-20230110.feather').sort_values(by=['RoadDirection','DataCollectTime','LocationMile']).reset_index(drop=True)
    
#     # Northbound data
#     northDf = df.loc[df['RoadDirection']=='N'].reset_index(drop=True)
#     each = 3
#     vdGroups = groupVDs(northDf, each)    
#     speedDataset, volDataset, occDataset = [], [], []
#     for groupKey in vdGroups.keys():
#         speeds, vols, occs = genSamples(northDf, vdGroups, groupKey, each, timeWindow=30)
#         speedDataset.append(speeds)
#         volDataset.append(vols)
#         occDataset.append(occs)

#     # Southbound data
#     southDf = df.loc[df['RoadDirection']=='S'].reset_index(drop=True)
#     each = 3
#     vdGroups = groupVDs(southDf, each)    
#     speedDataset, volDataset, occDataset = [], [], []
#     for groupKey in vdGroups.keys():
#         speeds, vols, occs = genSamples(southDf, vdGroups, groupKey, each, timeWindow=30)
#         speedDataset.append(speeds)
#         volDataset.append(vols)
#         occDataset.append(occs)

In [None]:
## test cell for missing data
# This part is genSamples()
each = 3
timeWindow = 30

# df = feather.read_dataframe("./nfb2023/202305.feather")
northDf = df.loc[df['RoadDirection']=='N'].reset_index(drop=True)
vdGroups = groupVDs(northDf, each)
groupKey = 'VD-N5-N-1.068-M-LOOP'



speeds, vols, occs = [], [], []
tmpDf = df.loc[(df['VDID'].isin(vdGroups[f"{groupKey}"]))].sort_values(by=['LocationMile', 'DataCollectTime'])

indices = [x for x in range(0, tmpDf.shape[0]+1, tmpDf.shape[0]//each)]
mileMatx = np.zeros((each, tmpDf.shape[0]//each))
speedMatx = np.zeros((each, tmpDf.shape[0]//each))
volMatx = np.zeros((each, tmpDf.shape[0]//each))
occMatx = np.zeros((each, tmpDf.shape[0]//each))
for i, j, k in zip(range(each), indices[:-1], indices[1:]):
    mileMatx[i] += tmpDf.iloc[j:k,:]['LocationMile'].to_numpy()
    speedMatx[i] += tmpDf.iloc[j:k,:]['Speed'].to_numpy()
    volMatx[i] += tmpDf.iloc[j:k,:]['Volume'].to_numpy()
    occMatx[i] += tmpDf.iloc[j:k,:]['Occupancy'].to_numpy()

# sliceLen = int((timeWindow / 5) + 1)
# for x in range(speedMatx.shape[1]//sliceLen*sliceLen-(sliceLen-1)):
#     speeds.append((speedMatx[:,x:x+sliceLen][:,:-1], speedMatx[:,x:x+sliceLen][:,[-1]]))
#     vols.append((volMatx[:,x:x+sliceLen][:,:-1], volMatx[:,x:x+sliceLen][:,[-1]]))
#     occs.append((occMatx[:,x:x+sliceLen][:,:-1], occMatx[:,x:x+sliceLen][:,[-1]]))

## Prepare dataset

### Create class `CNNDataset` inherited from `torch.utils.data.Dataset`

In [None]:
# class CNNDataset(Dataset):
#     def __init__(
#             self,
#             speed_data: list = None,
#             volume_data: list = None,
#             load_ckpt: bool = None,
#             mode: str = None,
#             ckpt_dir: str = './datasets/cnndataset'
#     ) -> None:
#         if (speed_data):
#             self.speedFeature = [speed_data[x][0] for x in range(len(speed_data))]
#             self.volFeature = [volume_data[x][0] for x in range(len(volume_data))]
#             self.speedLabels = [speed_data[x][1][[1],:] for x in range(len(speed_data))]
#             self.volLabels = [volume_data[x][1][[1],:] for x in range(len(volume_data))]
        
#         else:
#             if (load_ckpt) and (mode == 'train'):
#                 with h5py.File(f"{ckpt_dir}/{mode}/{mode}_speed_feature.h5", 'r') as file:
#                     self.speedFeature = file[f"{mode}_speed_feature"][:]
#                 with h5py.File(f"{ckpt_dir}/{mode}/{mode}_volume_feature.h5", 'r') as file:
#                     self.volFeature = file[f"{mode}_volume_feature"][:]
#                 with h5py.File(f"{ckpt_dir}/{mode}/{mode}_speed_label.h5", 'r') as file:
#                     self.speedLabels = file[f"{mode}_speed_label"][:]
#                 with h5py.File(f"{ckpt_dir}/{mode}/{mode}_volume_label.h5", 'r') as file:
#                     self.volLabels = file[f"{mode}_volume_label"][:]
            
#             elif (load_ckpt) and (mode == 'test'):
#                 with h5py.File(f"{ckpt_dir}/{mode}/{mode}_speed_feature.h5", 'r') as file:
#                     self.speedFeature = file[f"{mode}_speed_feature"][:]
#                 with h5py.File(f"{ckpt_dir}/{mode}/{mode}_volume_feature.h5", 'r') as file:
#                     self.volFeature = file[f"{mode}_volume_feature"][:]
#                 with h5py.File(f"{ckpt_dir}/{mode}/{mode}_speed_label.h5", 'r') as file:
#                     self.speedLabels = file[f"{mode}_speed_label"][:]
#                 with h5py.File(f"{ckpt_dir}/{mode}/{mode}_volume_label.h5", 'r') as file:
#                     self.volLabels = file[f"{mode}_volume_label"][:]

#     def __len__(self) -> int:
#         return len(self.speedFeature)
    
#     def __getitem__(self, idx: int) -> torch.Tensor:
#         f1 = torch.tensor(self.speedFeature[idx], dtype=torch.float).unsqueeze(0)
#         f2 = torch.tensor(self.volFeature[idx], dtype=torch.float).unsqueeze(0)
#         l1 = torch.tensor(self.speedLabels[idx], dtype=torch.float).squeeze(0)
#         l2 = torch.tensor(self.volLabels[idx], dtype=torch.float).squeeze(0)
#         feature = torch.cat([f1, f2])
#         label = torch.cat([l1, l2])
#         return feature, label

### Create datasets

First, we have to collect data from the rawdata dataframe.

In [None]:
EACH = 3
speedCollection, volCollection, occCollection, laneCollection, tunnelCollection =\
      [], [], [], [], []

# Northbound data
northDf = df.loc[df['RoadDirection']=='N'].reset_index(drop=True)
print(f"northDf start grouping: {datetime.strftime(datetime.now(), '%Y-%m-%d %H:%M:%S')}")
northVDGrps = groupVDs(northDf, each=EACH)
print(f"northDf end grouping: {datetime.strftime(datetime.now(), '%Y-%m-%d %H:%M:%S')}")
for groupKey in northVDGrps.keys():
    print(groupKey)
    speeds, vols, occs, lanes, tunnels = genSamples(northDf, northVDGrps, groupKey, each=EACH, timeWindow=30)
    speedCollection += speeds
    volCollection += vols
    occCollection += occs
    laneCollection += lanes
    tunnelCollection += tunnels

# Southbound data
southDf = df.loc[df['RoadDirection']=='S'].reset_index(drop=True)
print(f"southDf start grouping: {datetime.strftime(datetime.now(), '%Y-%m-%d %H:%M:%S')}")
southVDGrps = groupVDs(southDf, each=EACH)
print(f"southDf end grouping: {datetime.strftime(datetime.now(), '%Y-%m-%d %H:%M:%S')}")
for groupKey in southVDGrps.keys():
    print(groupKey)
    speeds, vols, occs, lanes, tunnels = genSamples(northDf, northVDGrps, groupKey, each=EACH, timeWindow=30)
    speedCollection += speeds
    volCollection += vols
    occCollection += occs
    laneCollection += lanes
    tunnelCollection += tunnels

In [None]:
trainSpeed, trainVol, trainOcc, trainNumLane, trainTunnel,\
 testSpeed, testVol, testOcc, testNumLane, testTunnel =\
    train_test_split(speedCollection, volCollection, occCollection, laneCollection, tunnelCollection, test_size=0.2)

In [None]:
len(trainSpeed), len(trainVol), len(testSpeed), len(testVol)

In [None]:
def min_max_scaler(arr: np.ndarray, feature: str) -> np.ndarray:
    if (feature == 'speed') or (feature == 'occ'):
        arr = np.where(arr>=100, 100, arr)
        return np.where(arr<0, -1, arr/100)
    elif (feature == 'volume'):
        arr = np.where(arr>=600, 600, arr)
        return np.where(arr<0, -1, arr/600)
    else:
        raise ValueError(f"'{feature}'")


class CNNDataset(Dataset):
    def __init__(
            self,
            speed_data: list = None,
            volume_data: list = None,
            occupy_data: list = None,
            load_ckpt: bool = None,
            mode: str = None,
            ckpt_dir: str = 'C:/Users/Home/PythonProjects/hwyTrafficPred/toolkits/cnndataset'
    ) -> None:
        if (load_ckpt):
            with h5py.File(f"{ckpt_dir}/{mode}/{mode}_speed_feature.h5", 'r') as file:
                self.speedFeature = file[f"{mode}_speed_feature"][:]
            with h5py.File(f"{ckpt_dir}/{mode}/{mode}_volume_feature.h5", 'r') as file:
                self.volFeature = file[f"{mode}_volume_feature"][:]
            with h5py.File(f"{ckpt_dir}/{mode}/{mode}_occupancy_feature.h5", 'r') as file:
                self.occFeature = file[f"{mode}_occupancy_feature"][:]
            
            with h5py.File(f"{ckpt_dir}/{mode}/{mode}_speed_label.h5", 'r') as file:
                self.speedLabels = file[f"{mode}_speed_label"][:]
            with h5py.File(f"{ckpt_dir}/{mode}/{mode}_volume_label.h5", 'r') as file:
                self.volLabels = file[f"{mode}_volume_label"][:]
            with h5py.File(f"{ckpt_dir}/{mode}/{mode}_occupancy_label.h5", 'r') as file:
                self.occLabels = file[f"{mode}_occupancy_label"][:]
        else:
            self.speedFeature, self.volFeature, self.occFeature,\
                self.speedLabels, self.volLabels, self.occLabels = [], [], [], [], [], []
            for x in range(len(speed_data)):
                # Labels must be valid (>=0), or it will be dropped.
                if (speed_data[x][1][1][0] >= 0) and (volume_data[x][1][1][0] >= 0):                
                    self.speedFeature.append(speed_data[x][0])
                    self.volFeature.append(volume_data[x][0])
                    self.occFeature.append(occupy_data[x][0])
                    self.speedLabels.append(speed_data[x][1][[1],:])
                    self.volLabels.append(volume_data[x][1][[1],:])
                    self.occLabels.append(occupy_data[x][1][[1],:])

            with h5py.File(f"{ckpt_dir}/{mode}/{mode}_speed_feature.h5", 'w') as f:
                f.create_dataset(f"{mode}_speed_feature", data=self.speedFeature)
            with h5py.File(f"{ckpt_dir}/{mode}/{mode}_volume_feature.h5", 'w') as f:
                f.create_dataset(f"{mode}_volume_feature", data=self.volFeature)
            with h5py.File(f"{ckpt_dir}/{mode}/{mode}_speed_label.h5", 'w') as f:
                f.create_dataset(f"{mode}_speed_label", data=self.speedLabels)
            with h5py.File(f"{ckpt_dir}/{mode}/{mode}_volume_label.h5", 'w') as f:
                f.create_dataset(f"{mode}_volume_label", data=self.volLabels)

    def __len__(self) -> int:
        return len(self.speedFeature)
    
    def __getitem__(self, idx: int) -> torch.Tensor:
        f1 = min_max_scaler(self.speedFeature[idx], 'speed')
        f2 = min_max_scaler(self.volFeature[idx], 'volume')
        l1 = min_max_scaler(self.speedLabels[idx], 'speed')
        l2 = min_max_scaler(self.volLabels[idx], 'volume')

        f1 = torch.tensor(f1, dtype=torch.float).unsqueeze(0)
        f2 = torch.tensor(f2, dtype=torch.float).unsqueeze(0)
        l1 = torch.tensor(l1, dtype=torch.float).squeeze(0)
        l2 = torch.tensor(l2, dtype=torch.float).squeeze(0)
        feature = torch.cat([f1, f2])
        label = torch.cat([l1, l2])
        return feature, label

In [None]:
trainDataset = CNNDataset(speed_data=trainSpeed, volume_data=trainVol, load_ckpt=False, mode='train')
testDataset = CNNDataset(speed_data=testSpeed, volume_data=testVol, load_ckpt=False, mode='test')

In [None]:
np.array(trainDataset.volLabels).argmax()

In [None]:
speedCollection[0][1][[1],:]

In [None]:
np.array([[100.]])

In [None]:
trainDataset.speedLabels[5006179]

In [None]:
trainDataset.volLabels[5006179]

In [None]:
np.array(trainDataset.speedLabels).argmin()

In [None]:
trainDataset.speedLabels[457543], trainDataset.volLabels[457543]

### Directly load dataset

You can also load datasets from `.h5` file if you have saved them.

In [None]:
trainDataset = CNNDataset(load_ckpt=True, mode='train', ckpt_dir='./toolkits/cnndataset')
testDataset = CNNDataset(load_ckpt=True, mode='test', ckpt_dir='./toolkits/cnndataset')

Create dataloaders

## Define Neural Network Architecture

In [None]:
class CNNRegression(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.cnnLayer = nn.Sequential(
            nn.Conv2d(in_channels=2, out_channels=16, kernel_size=(2,2), stride=1, padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(1, 1, 0),
            
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(2,2), stride=1, padding=0),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(1, 1, 0),
        )
        
        self.fcLayer = nn.Sequential(
            nn.Linear(32 * 1 * 4, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 2),
        )

        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.to(self.device)

    def forward(self, x) -> torch.Tensor:
        x = self.cnnLayer(x)
        x = x.flatten(1)
        x = self.fcLayer(x)
        return x

In [None]:
# Hyperparams for training
batch_size = 256
lr = 1e-3
n_epochs = 200

# Prepare datasets and dataloaders
trainDataset, testDataset = load_next_5min()
trainLoader = DataLoader(trainDataset, batch_size=batch_size, shuffle=True)
testLoader = DataLoader(testDataset, batch_size=batch_size, shuffle=False)

model = CNNRegression()
optimizer = optim.Adam(params=model.parameters(), lr=lr, weight_decay=1e-8)

In [None]:
for epoch in range(n_epochs):
    # Switch to train mode
    model.train()
    
    # Record Info in training
    train_loss = []

    for batch in tqdm(trainLoader):
        X, y = batch
        logits = model(X.to(model.device))
        loss = F.mse_loss(logits, y.to(model.device))
        
        # Compute gradients and update model params
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss.append(loss.item())

    # Compute the average train_loss
    train_loss = sum(train_loss) / len(train_loss)
    print(f"[ Train | {epoch + 1:d}/{n_epochs:d} ] loss = {train_loss:.5f}")