In [197]:
from sqlalchemy import create_engine
from datetime import datetime, timedelta
import pandas as pd
import numpy as np
import feather
import torch


# connect to db
user = 'root'
pswd = 'Curry5566'
host = '127.0.0.1'
port = '3306'
db = 'transport'
engine = create_engine(f"mysql+pymysql://{user}:{pswd}@{host}:{port}/{db}?charset=utf8")


def getEndDate(startDate: str, days: int) -> str:
    startDate += ' 00:00:00'
    end = str((datetime.strptime(startDate, '%Y-%m-%d %H:%M:%S') + timedelta(days=days)).replace(microsecond=0))
    return end

def getRollingMean(startDate: str, endDate: str) -> pd.DataFrame:
    """ Get the rolling mean from db
        ```text
        @Params
        startDate: The date for start, format='%Y-%m-%d'
        endDate: The date for end, format='%Y-%m-%d'

        @Returns
        DataFrame
        ```
    """
    sql  = " SELECT "
    sql += " 	STAC.VDID, STAC.RoadName, STAC.`Start`, STAC.`End`, "
    sql += " 	STAC.RoadDirection, DYMC.Speed, DYMC.Occupancy, DYMC.Volume, "
    sql += " 	STAC.LocationMile, DYMC.DataCollectTime "
    sql += " FROM ( "
    sql += " 	SELECT "
    sql += " 		VDSTC.id, VDSTC.VDID, ROAD.RoadName, SEC.`Start`, SEC.`End`, "
    sql += " 		VDSTC.RoadDirection, VDSTC.LocationMile "
    sql += " 	FROM vd_static_n5 VDSTC "
    sql += " 	JOIN road_info ROAD ON VDSTC.RoadInfoID = ROAD.id "
    sql += " 	JOIN section_info SEC ON ROAD.id = SEC.RoadInfoID "
    sql += " 	AND VDSTC.LocationMile >= SEC.StartKM "
    sql += " 	AND VDSTC.LocationMile <= SEC.EndKM "
    sql += " 	WHERE VDSTC.Mainlane = 1 "
    sql += " ) STAC JOIN ( "
    sql += " 	SELECT "
    sql += " 		VdStaticID, "
    sql += " 		CASE "
    sql += " 			WHEN MIN(Speed) = -99 THEN -99 "
    sql += " 			ELSE AVG(Speed) "
    sql += " 		END AS Speed,  "
    sql += " 		CASE "
    sql += " 			WHEN MIN(Occupancy) = -99 THEN -99 "
    sql += " 			ELSE AVG(Occupancy) "
    sql += " 		END AS Occupancy,  "
    sql += " 		CASE "
    sql += " 			WHEN MIN(Volume) = -99 THEN -99 "
    sql += " 			ELSE AVG(Volume) "
    sql += " 		END AS Volume, "
    sql += " 		MAX(DataCollectTime) AS DataCollectTime, "
    sql += " 		(UNIX_TIMESTAMP(DataCollectTime)-UNIX_TIMESTAMP(%(start)s)) DIV 300 "
    sql += " 	FROM vd_dynamic_detail_n5_202301 "
    sql += " 	WHERE id BETWEEN ( "
    sql += " 		SELECT id FROM vd_dynamic_detail_n5_202301 "
    sql += " 		WHERE DataCollectTime = %(start)s "
    sql += " 		ORDER BY id LIMIT 1 "
    sql += " 	) AND ( "
    sql += " 		SELECT id FROM vd_dynamic_detail_n5_202301 "
    sql += " 		WHERE DataCollectTime < %(end)s "
    sql += " 		ORDER BY id DESC LIMIT 1 "
    sql += " 	) "
    sql += " 	GROUP BY VdStaticID, (UNIX_TIMESTAMP(DataCollectTime)-UNIX_TIMESTAMP(%(start)s)) DIV 300 "
    sql += " ) DYMC ON STAC.id = DYMC.VdStaticID "
    sql += " ORDER BY STAC.RoadDirection, STAC.LocationMile, DYMC.DataCollectTime; "

    df = pd.read_sql(sql, con=engine, params={'start': startDate, 'end': endDate})
    engine.dispose()
    return df.sort_values(by=['RoadDirection','DataCollectTime','LocationMile']).reset_index(drop=True)

def groupVDs(df: pd.DataFrame, each: int) -> dict:
    """ Get the dict of VD groups
        ```text
        @Params
        df: DataFrame which is referenced by.
        each: The quantity of VDs would be considered as a group.

        @Returns
        vdGroups: The keys are the VDs we focus on, and the values are the collections of VDs which are correlated corresponding to the keys.
        ```
    """
    vdGroups = {}
    lb = each // 2
    ub = each - (each // 2)
    for vdid in df['VDID'].unique():
        vdGroups.setdefault(f"{vdid}", [])
    for no, vdid in enumerate(df['VDID'].unique()):
        startIdx = max(no-lb, 0)
        endIdx = min(no+ub, len(df['VDID'].unique())-1)
        vdGroups[f"{vdid}"] += list(df['VDID'].unique()[startIdx:no]) + list(df['VDID'].unique()[no:endIdx])

    delList = []
    for k in vdGroups.keys():
        if (len(vdGroups[k]) != each):
            delList.append(k)
    for k in delList:
        del vdGroups[k]
    
    return vdGroups

def genArrLists(df: pd.DataFrame, startDate: str, endDate: str, vdGroups: dict, groupKey: str,
                each: int, timeWindow: int = 30) -> tuple:
    """ Generate array lists for each traffic flow data (speed, volume, and occupancy)
        ```text
        @Params
        startDate: The date for start, format='%Y-%m-%d'
        endDate: The date for end, format='%Y-%m-%d'
        vdGroups: Can get it from groupVDs(),
        groupKey: The key of vdGroups,
        each: The quantity of VDs would be considered as a group,
        timeWindow: The length of period we consider, and the default value is 30 (minutes).

        @Returns
        speeds: list,
        vols: list,
        occs: list
        ```
    """
    freq5 = pd.date_range(startDate, endDate, freq='5min')
    speeds, vols, occs = [], [], []
    speed, vol, occ = [], [], []
    for dtStart, dtEnd in zip(freq5[:-1], freq5[1:]):
        tmpDf = df.loc[(df['VDID'].isin(vdGroups[f"{groupKey}"])) &\
                       (df['DataCollectTime']>dtStart) &\
                       (df['DataCollectTime']<dtEnd)].sort_values(by='LocationMile')
        if (len(speed) < timeWindow//5) and (len(vol) < timeWindow//5):
            if (tmpDf[['Speed']].shape[0]>0) and (tmpDf[['Volume']].shape[0]>0):
                speed.append(tmpDf[['Speed']].to_numpy())
                vol.append(tmpDf[['Volume']].to_numpy())
                occ.append(tmpDf[['Occupancy']].to_numpy())
            else:
                speed.append(np.array([[-99.] for _ in range(each)]))
                vol.append(np.array([[-99.] for _ in range(each)]))
                occ.append(np.array([[-99.] for _ in range(each)]))
        else:
            speeds.append(np.concatenate(speed, axis=1))
            vols.append(np.concatenate(vol, axis=1))
            occs.append(np.concatenate(occ, axis=1))
            
            speed.clear()
            vol.clear()
            occ.clear()
    
    return speeds, vols, occs

def genTensors(speeds: list, vols: list) -> list:
    """ Generate torch.Tensors """
    dataCollection = []
    for s, v in zip(speeds, vols):
        s = torch.tensor(s, dtype=torch.float).unsqueeze(0).unsqueeze(0)
        v = torch.tensor(v, dtype=torch.float).unsqueeze(0).unsqueeze(0)
        dataCollection.append(torch.concat([s, v], dim=1))
    return dataCollection

In [None]:
# TODO: main
if __name__ == '__main__':
    startDate = '2023-01-01'
    endDate = getEndDate(startDate, days=10)
    # df = getRollingMean(startDate, endDate)
    df = feather.read_dataframe('./20230101-20230110.feather').sort_values(by=['RoadDirection','DataCollectTime','LocationMile']).reset_index(drop=True)
    northDf = df.loc[df['RoadDirection']=='N'].reset_index(drop=True)
    
    each = 3
    vdGroups = groupVDs(northDf, each)
    # TODO: Declare lists for collecting speeds, vols, and occs
    for groupKey in vdGroups.keys():
        speeds, vols, occs = genArrLists(northDf, startDate, endDate, vdGroups, groupKey, each)
        tensors = genTensors(speeds, vols)
    pass

In [206]:
speeds[0]

array([[ 94.6, -99. ,  87.3,  89.1,  89.9,  75.3],
       [ 90.8, -99. ,  86.5,  87.4,  85. ,  75. ],
       [-99. ,  84.1,  90.5,  88.3,  75.8,  86.6]])

In [202]:
vols[0]

array([[  5.3, -99. ,   4.1,   3.5,   3.7,   3.2],
       [  5.1, -99. ,   2.7,   4.2,   3.7,   3.2],
       [-99. ,   3.2,   4.4,   4. ,   2.2,   4.3]])

In [203]:
occs[0]

array([[  4.1, -99. ,   3.2,   2.5,   2.7,   2.5],
       [  4.8, -99. ,   2.1,   3.2,   2.7,   2.3],
       [-99. ,   2.4,   3.2,   2.7,   1.6,   3.1]])

In [201]:
vdGroups['VD-N5-N-3.198-M-LOOP']

['VD-N5-N-2.068-M-PS-LOOP', 'VD-N5-N-3.198-M-LOOP', 'VD-N5-N-3.943-M-LOOP']

In [200]:
df = feather.read_dataframe('./20230101-20230110.feather').sort_values(by=['RoadDirection','DataCollectTime','LocationMile']).reset_index(drop=True)
df.head(50)

Unnamed: 0,VDID,RoadName,Start,End,RoadDirection,Speed,Occupancy,Volume,LocationMile,DataCollectTime
0,VD-N5-N-0.178-M-LOOP,國道5號,南港系統交流道,石碇交流道,N,-99.0,-99.0,-99.0,0.178,2023-01-01 00:04:00
1,VD-N5-N-0.706-M-LOOP,國道5號,南港系統交流道,石碇交流道,N,-99.0,-99.0,-99.0,0.706,2023-01-01 00:04:00
2,VD-N5-N-1.068-M-LOOP,國道5號,南港系統交流道,石碇交流道,N,-99.0,-99.0,-99.0,1.068,2023-01-01 00:04:00
3,VD-N5-N-2.068-M-PS-LOOP,國道5號,南港系統交流道,石碇交流道,N,94.6,4.1,5.3,2.068,2023-01-01 00:04:00
4,VD-N5-N-3.198-M-LOOP,國道5號,南港系統交流道,石碇交流道,N,90.8,4.8,5.1,3.198,2023-01-01 00:04:00
5,VD-N5-N-3.943-M-LOOP,國道5號,南港系統交流道,石碇交流道,N,-99.0,-99.0,-99.0,3.943,2023-01-01 00:04:00
6,VD-N5-N-5.883-M-LOOP,國道5號,石碇交流道,坪林交控交流道,N,-99.0,-99.0,-99.0,5.883,2023-01-01 00:04:00
7,VD-N5-N-7.107-M-LOOP,國道5號,石碇交流道,坪林交控交流道,N,-99.0,-99.0,-99.0,7.107,2023-01-01 00:04:00
8,VD-N5-N-8.011-M-LOOP,國道5號,石碇交流道,坪林交控交流道,N,-99.0,-99.0,-99.0,8.011,2023-01-01 00:04:00
9,VD-N5-N-9.840-M-LOOP,國道5號,石碇交流道,坪林交控交流道,N,86.1,2.7,3.8,9.84,2023-01-01 00:04:00


In [117]:
# This cell is actually same as df.sort_values()

# groupDf = df.groupby(['RoadDirection','DataCollectTime','LocationMile']).agg({
#     'VDID': 'max',
#     'RoadName': 'max',
#     'Start': 'max',
#     'End': 'max',
#     'Speed': 'max',
#     'Occupancy': 'max',
#     'Volume': 'max',
# }).reset_index().sort_values(by=['RoadDirection','DataCollectTime','LocationMile'])
# groupDf

In [118]:
groupDf = df.sort_values(by=['RoadDirection','DataCollectTime','LocationMile']).reset_index(drop=True)
groupDf

Unnamed: 0,VDID,RoadName,Start,End,RoadDirection,Speed,Occupancy,Volume,LocationMile,DataCollectTime
0,VD-N5-N-0.178-M-LOOP,國道5號,南港系統交流道,石碇交流道,N,-99.000,-99.000,-99.000,0.178,2023-01-01 00:04:00
1,VD-N5-N-0.706-M-LOOP,國道5號,南港系統交流道,石碇交流道,N,-99.000,-99.000,-99.000,0.706,2023-01-01 00:04:00
2,VD-N5-N-1.068-M-LOOP,國道5號,南港系統交流道,石碇交流道,N,-99.000,-99.000,-99.000,1.068,2023-01-01 00:04:00
3,VD-N5-N-2.068-M-PS-LOOP,國道5號,南港系統交流道,石碇交流道,N,94.600,4.100,5.300,2.068,2023-01-01 00:04:00
4,VD-N5-N-3.198-M-LOOP,國道5號,南港系統交流道,石碇交流道,N,90.800,4.800,5.100,3.198,2023-01-01 00:04:00
...,...,...,...,...,...,...,...,...,...,...
269681,VD-N5-S-41.298-M-LOOP,國道5號,宜蘭交流道,羅東交流道,S,76.875,1.250,1.500,41.298,2023-01-10 23:59:00
269682,VD-N5-S-44.202-M-LOOP,國道5號,宜蘭交流道,羅東交流道,S,90.875,2.250,3.250,44.202,2023-01-10 23:59:00
269683,VD-N5-S-46.566-M-LOOP,國道5號,宜蘭交流道,羅東交流道,S,35.750,0.375,0.500,46.566,2023-01-10 23:59:00
269684,VD-N5-S-48.040-M-LOOP,國道5號,羅東交流道,蘇澳交流道,S,57.750,0.625,0.625,48.040,2023-01-10 23:59:00


In [153]:
vdGroups = {}
for vdid in df['VDID'].unique():
    vdGroups.setdefault(f"{vdid}", [])
for no, vdid in enumerate(df['VDID'].unique()):
    startIdx = max(no-2, 0)
    endIdx = min(no+3, len(df['VDID'].unique())-1)
    vdGroups[f"{vdid}"] += list(df['VDID'].unique()[startIdx:no]) + list(df['VDID'].unique()[no:endIdx])
vdGroups

{'VD-N5-N-0.178-M-LOOP': ['VD-N5-N-0.178-M-LOOP',
  'VD-N5-N-0.706-M-LOOP',
  'VD-N5-N-1.068-M-LOOP'],
 'VD-N5-N-0.706-M-LOOP': ['VD-N5-N-0.178-M-LOOP',
  'VD-N5-N-0.706-M-LOOP',
  'VD-N5-N-1.068-M-LOOP',
  'VD-N5-N-2.068-M-PS-LOOP'],
 'VD-N5-N-1.068-M-LOOP': ['VD-N5-N-0.178-M-LOOP',
  'VD-N5-N-0.706-M-LOOP',
  'VD-N5-N-1.068-M-LOOP',
  'VD-N5-N-2.068-M-PS-LOOP',
  'VD-N5-N-3.198-M-LOOP'],
 'VD-N5-N-2.068-M-PS-LOOP': ['VD-N5-N-0.706-M-LOOP',
  'VD-N5-N-1.068-M-LOOP',
  'VD-N5-N-2.068-M-PS-LOOP',
  'VD-N5-N-3.198-M-LOOP',
  'VD-N5-N-3.943-M-LOOP'],
 'VD-N5-N-3.198-M-LOOP': ['VD-N5-N-1.068-M-LOOP',
  'VD-N5-N-2.068-M-PS-LOOP',
  'VD-N5-N-3.198-M-LOOP',
  'VD-N5-N-3.943-M-LOOP',
  'VD-N5-N-5.883-M-LOOP'],
 'VD-N5-N-3.943-M-LOOP': ['VD-N5-N-2.068-M-PS-LOOP',
  'VD-N5-N-3.198-M-LOOP',
  'VD-N5-N-3.943-M-LOOP',
  'VD-N5-N-5.883-M-LOOP',
  'VD-N5-N-7.107-M-LOOP'],
 'VD-N5-N-5.883-M-LOOP': ['VD-N5-N-3.198-M-LOOP',
  'VD-N5-N-3.943-M-LOOP',
  'VD-N5-N-5.883-M-LOOP',
  'VD-N5-N-7.107-M-LOOP',
  

In [114]:
northDf = groupDf.loc[groupDf['RoadDirection']=='N'].reset_index(drop=True)
northDf

Unnamed: 0,RoadDirection,DataCollectTime,LocationMile,VDID,RoadName,Start,End,Speed,Occupancy,Volume
0,N,2023-01-01 00:04:00,0.178,VD-N5-N-0.178-M-LOOP,國道5號,南港系統交流道,石碇交流道,-99.000,-99.000,-99.000
1,N,2023-01-01 00:04:00,0.706,VD-N5-N-0.706-M-LOOP,國道5號,南港系統交流道,石碇交流道,-99.000,-99.000,-99.000
2,N,2023-01-01 00:04:00,1.068,VD-N5-N-1.068-M-LOOP,國道5號,南港系統交流道,石碇交流道,-99.000,-99.000,-99.000
3,N,2023-01-01 00:04:00,2.068,VD-N5-N-2.068-M-PS-LOOP,國道5號,南港系統交流道,石碇交流道,94.600,4.100,5.300
4,N,2023-01-01 00:04:00,3.198,VD-N5-N-3.198-M-LOOP,國道5號,南港系統交流道,石碇交流道,90.800,4.800,5.100
...,...,...,...,...,...,...,...,...,...,...
140576,N,2023-01-10 23:59:00,37.225,VD-N5-N-37.225-M-LOOP,國道5號,頭城交流道,宜蘭交流道,74.500,1.125,1.750
140577,N,2023-01-10 23:59:00,42.359,VD-N5-N-42.359-M-LOOP,國道5號,宜蘭交流道,羅東交流道,94.125,1.875,2.375
140578,N,2023-01-10 23:59:00,45.230,VD-N5-N-45.230-M-LOOP,國道5號,宜蘭交流道,羅東交流道,63.250,1.000,1.250
140579,N,2023-01-10 23:59:00,49.070,VD-N5-N-49.070-M-LOOP,國道5號,羅東交流道,蘇澳交流道,56.500,0.875,1.000


In [162]:
vdGroups = groupVDs(northDf, each=5)
vdGroups

{'VD-N5-N-1.068-M-LOOP': ['VD-N5-N-0.178-M-LOOP',
  'VD-N5-N-0.706-M-LOOP',
  'VD-N5-N-1.068-M-LOOP',
  'VD-N5-N-2.068-M-PS-LOOP',
  'VD-N5-N-3.198-M-LOOP'],
 'VD-N5-N-2.068-M-PS-LOOP': ['VD-N5-N-0.706-M-LOOP',
  'VD-N5-N-1.068-M-LOOP',
  'VD-N5-N-2.068-M-PS-LOOP',
  'VD-N5-N-3.198-M-LOOP',
  'VD-N5-N-3.943-M-LOOP'],
 'VD-N5-N-3.198-M-LOOP': ['VD-N5-N-1.068-M-LOOP',
  'VD-N5-N-2.068-M-PS-LOOP',
  'VD-N5-N-3.198-M-LOOP',
  'VD-N5-N-3.943-M-LOOP',
  'VD-N5-N-5.883-M-LOOP'],
 'VD-N5-N-3.943-M-LOOP': ['VD-N5-N-2.068-M-PS-LOOP',
  'VD-N5-N-3.198-M-LOOP',
  'VD-N5-N-3.943-M-LOOP',
  'VD-N5-N-5.883-M-LOOP',
  'VD-N5-N-7.107-M-LOOP'],
 'VD-N5-N-5.883-M-LOOP': ['VD-N5-N-3.198-M-LOOP',
  'VD-N5-N-3.943-M-LOOP',
  'VD-N5-N-5.883-M-LOOP',
  'VD-N5-N-7.107-M-LOOP',
  'VD-N5-N-8.011-M-LOOP'],
 'VD-N5-N-7.107-M-LOOP': ['VD-N5-N-3.943-M-LOOP',
  'VD-N5-N-5.883-M-LOOP',
  'VD-N5-N-7.107-M-LOOP',
  'VD-N5-N-8.011-M-LOOP',
  'VD-N5-N-9.840-M-LOOP'],
 'VD-N5-N-8.011-M-LOOP': ['VD-N5-N-5.883-M-LOOP',
  'VD

In [172]:
groupVDs(northDf, each=3)

{'VD-N5-N-0.706-M-LOOP': ['VD-N5-N-0.178-M-LOOP',
  'VD-N5-N-0.706-M-LOOP',
  'VD-N5-N-1.068-M-LOOP'],
 'VD-N5-N-1.068-M-LOOP': ['VD-N5-N-0.706-M-LOOP',
  'VD-N5-N-1.068-M-LOOP',
  'VD-N5-N-2.068-M-PS-LOOP'],
 'VD-N5-N-2.068-M-PS-LOOP': ['VD-N5-N-1.068-M-LOOP',
  'VD-N5-N-2.068-M-PS-LOOP',
  'VD-N5-N-3.198-M-LOOP'],
 'VD-N5-N-3.198-M-LOOP': ['VD-N5-N-2.068-M-PS-LOOP',
  'VD-N5-N-3.198-M-LOOP',
  'VD-N5-N-3.943-M-LOOP'],
 'VD-N5-N-3.943-M-LOOP': ['VD-N5-N-3.198-M-LOOP',
  'VD-N5-N-3.943-M-LOOP',
  'VD-N5-N-5.883-M-LOOP'],
 'VD-N5-N-5.883-M-LOOP': ['VD-N5-N-3.943-M-LOOP',
  'VD-N5-N-5.883-M-LOOP',
  'VD-N5-N-7.107-M-LOOP'],
 'VD-N5-N-7.107-M-LOOP': ['VD-N5-N-5.883-M-LOOP',
  'VD-N5-N-7.107-M-LOOP',
  'VD-N5-N-8.011-M-LOOP'],
 'VD-N5-N-8.011-M-LOOP': ['VD-N5-N-7.107-M-LOOP',
  'VD-N5-N-8.011-M-LOOP',
  'VD-N5-N-9.840-M-LOOP'],
 'VD-N5-N-9.840-M-LOOP': ['VD-N5-N-8.011-M-LOOP',
  'VD-N5-N-9.840-M-LOOP',
  'VD-N5-N-10.866-M-PS-LOOP'],
 'VD-N5-N-10.866-M-PS-LOOP': ['VD-N5-N-9.840-M-LOOP',
  '

In [180]:
start = '2023-01-01'
end = '2023-01-11'
each = 3
vdGroups = groupVDs(northDf, each)
groupKey = 'VD-N5-N-3.198-M-LOOP'

speeds, vols, occs = genArrLists(northDf, start, end, vdGroups, groupKey, each)

# freq5 = pd.date_range(start, end, freq='5min')
# speeds, vols, occs = [], [], []
# speed, vol, occ = [], [], []
# for dtStart, dtEnd in zip(freq5[:-1], freq5[1:]):
#     tmpDf = df.loc[(df['VDID'].isin(vdGroups[f"{groupKey}"])) &\
#                    (df['DataCollectTime']>dtStart) &\
#                    (df['DataCollectTime']<dtEnd)].sort_values(by='LocationMile')
#     if (len(speed) < 6) and (len(vol) < 6):
#         if (tmpDf[['Speed']].shape[0]>0) and (tmpDf[['Volume']].shape[0]>0):
#             speed.append(tmpDf[['Speed']].to_numpy())
#             vol.append(tmpDf[['Volume']].to_numpy())
#             occ.append(tmpDf[['Occupancy']].to_numpy())
#         else:
#             speed.append(np.array([[-99.] for _ in range(each)]))
#             vol.append(np.array([[-99.] for _ in range(each)]))
#             occ.append(np.array([[-99.] for _ in range(each)]))
#     else:
#         speeds.append(np.concatenate(speed, axis=1))
#         vols.append(np.concatenate(vol, axis=1))
#         occs.append(np.concatenate(occ, axis=1))
        
#         speed.clear()
#         vol.clear()
#         occ.clear()

In [188]:
speeds

[array([[ 94.6, -99. ,  87.3,  89.1,  89.9,  75.3],
        [ 90.8, -99. ,  86.5,  87.4,  85. ,  75. ],
        [-99. ,  84.1,  90.5,  88.3,  75.8,  86.6]]),
 array([[ 87.4,  79.3,  85.3,  88.3,  89.4,  87. ],
        [ 86.4,  88.1,  83.1,  87. ,  96.9,  88.7],
        [ 90.5,  74.5, -99. ,  90.1,  90.6,  91.8]]),
 array([[86.4, 89.4, 92. , 91.4, 90. , 90.6],
        [87.9, 85.7, 89.5, 91.9, 75.9, 90.1],
        [69.9, 88.7, 88. , 80.9, 86.2, 93.1]]),
 array([[88.2, 87.4, 88.8, 92.3, 81.8, 89.3],
        [86.6, 86.8, 93.9, 91. , 91. , 86.8],
        [89. , 89.9, 93.6, 87.5, 88.6, 86.2]]),
 array([[90.6, 91.9, 83. , 87.8, 78.3, 81.1],
        [90.4, 90.6, 83.6, 86.1, 78.3, 71.3],
        [93. , 89.8, 82.9, 83. , 78.5, 80.4]]),
 array([[97. , 94.5, 79. , 91.7, 82.5, 91.1],
        [95.6, 94.6, 90. , 89.5, 94.1, 90.1],
        [97.5, 98.8, 79.9, 90.2, 95.2, 88.9]]),
 array([[ 78.2,  82. ,  91.5,  83. , -99. ,  91.4],
        [ 89.9,  90.3,  81.7,  88.2,  77.2,  87.9],
        [ 83.7,  95.

In [189]:
tensorList = genTensors(speeds, vols)
tensorList[0].shape

torch.Size([1, 2, 3, 6])