# T-GCN GSL

**Graph Structure Learning for Traffic Prediction**

This repository provides the implementation of the paper "Graph Structure Learning for Traffic Prediction" by Mahmood Amintoosi.

The adjacency matrix should be estimated using GSL. We have already computed them (for different datasets and different prediction lengths) and saved them in the `data` folder. If the estimated matrices are deleted, they will be computed automatically before prediction by the proposed method.

In [1]:
import os
if 'google.colab' in str(get_ipython()) and not os.path.exists('/content/TGCN-PyTorch'):
    !git clone https://github.com/mamintoosi-papers-codes/TGCN-PyTorch.git
    !pip install -q torchmetrics
    %cd TGCN-PyTorch

In [2]:
%%time

datasets = ['sz', 'los'] #  sz=shenzhen, los=losloop
pred_list = [1, 2, 3, 4]

for  dataset in datasets:
    for pre_len in pred_list:
        %run main.py --config configs/tgcn-{dataset}-pre_len{pre_len}.yaml
        %run main.py --config configs/tgcn-{dataset}-gsl-pre_len{pre_len}.yaml
        %run main.py --config configs/tgcn-{dataset}-gsl-adj-pre_len{pre_len}.yaml

  from .autonotebook import tqdm as notebook_tqdm
[31m[2025-03-06 14:20:29,993 INFO][0mLoaded config from configs/tgcn-sz-pre_len1.yaml: {'fit': {'trainer': {'max_epochs': 50, 'accelerator': 'cuda', 'devices': 1}, 'data': {'dataset_name': 'shenzhen', 'batch_size': 64, 'seq_len': 12, 'pre_len': 1}, 'model': {'model': {'class_path': 'models.TGCN', 'init_args': {'hidden_dim': 100, 'use_gsl': 0}}, 'learning_rate': 0.001, 'weight_decay': 0, 'loss': 'mse_with_regularizer'}}}
[31m[2025-03-06 14:20:30,229 INFO][0mUsing device: cuda


TGCN


[31m[2025-03-06 14:20:32,228 INFO][0mStarting training for 50 epochs...
[31m[2025-03-06 14:20:33,939 INFO][0m[Epoch 1/50] Train Loss: 109.665068, Val Loss: 1853920.250000, RMSE: 6.385045, MAE: 5.018596, Accuracy: 0.5551, R2: 0.6289, Expl.Var: 0.6332
[31m[2025-03-06 14:20:47,919 INFO][0m[Epoch 11/50] Train Loss: 36.290581, Val Loss: 1486394.125000, RMSE: 5.717226, MAE: 4.225228, Accuracy: 0.6016, R2: 0.7005, Expl.Var: 0.7007
[31m[2025-03-06 14:21:02,033 INFO][0m[Epoch 21/50] Train Loss: 35.463298, Val Loss: 1449891.750000, RMSE: 5.646589, MAE: 4.167087, Accuracy: 0.6066, R2: 0.7079, Expl.Var: 0.7082
[31m[2025-03-06 14:21:17,958 INFO][0m[Epoch 31/50] Train Loss: 34.118202, Val Loss: 1378880.250000, RMSE: 5.506576, MAE: 4.082464, Accuracy: 0.6163, R2: 0.7222, Expl.Var: 0.7226
[31m[2025-03-06 14:21:33,602 INFO][0m[Epoch 41/50] Train Loss: 31.722393, Val Loss: 1255893.625000, RMSE: 5.255267, MAE: 3.927185, Accuracy: 0.6338, R2: 0.7471, Expl.Var: 0.7476
[31m[2025-03-06 14:21:47,

File data/W_est_shenzhen_pre_len1.npy found. Loading existing adjacency matrix estimated by GSL from training data.
GSL computed: Only GSL


[31m[2025-03-06 14:21:49,776 INFO][0m[Epoch 1/50] Train Loss: 90.770370, Val Loss: 838165.687500, RMSE: 4.293221, MAE: 3.054961, Accuracy: 0.7009, R2: 0.8319, Expl.Var: 0.8361
[31m[2025-03-06 14:22:05,490 INFO][0m[Epoch 11/50] Train Loss: 22.155054, Val Loss: 779778.812500, RMSE: 4.140989, MAE: 2.733996, Accuracy: 0.7115, R2: 0.8429, Expl.Var: 0.8429
[31m[2025-03-06 14:22:21,162 INFO][0m[Epoch 21/50] Train Loss: 21.780855, Val Loss: 770396.000000, RMSE: 4.116000, MAE: 2.724192, Accuracy: 0.7132, R2: 0.8447, Expl.Var: 0.8447
[31m[2025-03-06 14:22:36,603 INFO][0m[Epoch 31/50] Train Loss: 21.686204, Val Loss: 769832.250000, RMSE: 4.114494, MAE: 2.729474, Accuracy: 0.7133, R2: 0.8448, Expl.Var: 0.8449
[31m[2025-03-06 14:22:52,245 INFO][0m[Epoch 41/50] Train Loss: 21.667007, Val Loss: 776797.187500, RMSE: 4.133065, MAE: 2.809196, Accuracy: 0.7120, R2: 0.8437, Expl.Var: 0.8449
[31m[2025-03-06 14:23:06,172 INFO][0m[Epoch 50/50] Train Loss: 21.632922, Val Loss: 769739.312500, RMSE:

File data/W_est_shenzhen_pre_len1.npy found. Loading existing adjacency matrix estimated by GSL from training data.
GSL computed: GSL+Adj


[31m[2025-03-06 14:23:08,127 INFO][0m[Epoch 1/50] Train Loss: 109.665068, Val Loss: 1853920.250000, RMSE: 6.385045, MAE: 5.018596, Accuracy: 0.5551, R2: 0.6289, Expl.Var: 0.6332
[31m[2025-03-06 14:23:23,665 INFO][0m[Epoch 11/50] Train Loss: 36.290581, Val Loss: 1486394.125000, RMSE: 5.717226, MAE: 4.225228, Accuracy: 0.6016, R2: 0.7005, Expl.Var: 0.7007
[31m[2025-03-06 14:23:39,103 INFO][0m[Epoch 21/50] Train Loss: 35.463298, Val Loss: 1449891.750000, RMSE: 5.646589, MAE: 4.167087, Accuracy: 0.6066, R2: 0.7079, Expl.Var: 0.7082
[31m[2025-03-06 14:23:54,756 INFO][0m[Epoch 31/50] Train Loss: 34.118202, Val Loss: 1378880.250000, RMSE: 5.506576, MAE: 4.082464, Accuracy: 0.6163, R2: 0.7222, Expl.Var: 0.7226
[31m[2025-03-06 14:24:10,327 INFO][0m[Epoch 41/50] Train Loss: 31.722393, Val Loss: 1255893.625000, RMSE: 5.255267, MAE: 3.927185, Accuracy: 0.6338, R2: 0.7471, Expl.Var: 0.7476
[31m[2025-03-06 14:24:23,919 INFO][0m[Epoch 50/50] Train Loss: 28.078746, Val Loss: 1076578.625000

TGCN


[31m[2025-03-06 14:24:25,573 INFO][0m[Epoch 1/50] Train Loss: 155.166981, Val Loss: 3551141.500000, RMSE: 6.254033, MAE: 4.811000, Accuracy: 0.5642, R2: 0.6447, Expl.Var: 0.6504
[31m[2025-03-06 14:24:39,704 INFO][0m[Epoch 11/50] Train Loss: 72.975424, Val Loss: 2988711.500000, RMSE: 5.737441, MAE: 4.250663, Accuracy: 0.6002, R2: 0.6983, Expl.Var: 0.6984
[31m[2025-03-06 14:24:53,299 INFO][0m[Epoch 21/50] Train Loss: 70.632839, Val Loss: 2871645.500000, RMSE: 5.623952, MAE: 4.179071, Accuracy: 0.6081, R2: 0.7102, Expl.Var: 0.7105
[31m[2025-03-06 14:25:06,736 INFO][0m[Epoch 31/50] Train Loss: 64.776980, Val Loss: 2567981.750000, RMSE: 5.318291, MAE: 4.001699, Accuracy: 0.6294, R2: 0.7409, Expl.Var: 0.7412
[31m[2025-03-06 14:25:20,157 INFO][0m[Epoch 41/50] Train Loss: 53.897459, Val Loss: 2046158.000000, RMSE: 4.747289, MAE: 3.493054, Accuracy: 0.6692, R2: 0.7936, Expl.Var: 0.7939
[31m[2025-03-06 14:25:32,353 INFO][0m[Epoch 50/50] Train Loss: 49.402700, Val Loss: 1843754.250000

File data/W_est_shenzhen_pre_len2.npy found. Loading existing adjacency matrix estimated by GSL from training data.
GSL computed: Only GSL


[31m[2025-03-06 14:25:33,870 INFO][0m[Epoch 1/50] Train Loss: 119.366915, Val Loss: 1802368.750000, RMSE: 4.455516, MAE: 2.962124, Accuracy: 0.6895, R2: 0.8191, Expl.Var: 0.8236
[31m[2025-03-06 14:25:47,528 INFO][0m[Epoch 11/50] Train Loss: 44.558536, Val Loss: 1580794.625000, RMSE: 4.172668, MAE: 2.809899, Accuracy: 0.7092, R2: 0.8405, Expl.Var: 0.8408
[31m[2025-03-06 14:26:01,153 INFO][0m[Epoch 21/50] Train Loss: 44.104408, Val Loss: 1575103.500000, RMSE: 4.165151, MAE: 2.822248, Accuracy: 0.7098, R2: 0.8411, Expl.Var: 0.8419
[31m[2025-03-06 14:26:14,953 INFO][0m[Epoch 31/50] Train Loss: 44.015043, Val Loss: 1569953.000000, RMSE: 4.158335, MAE: 2.765785, Accuracy: 0.7102, R2: 0.8415, Expl.Var: 0.8416
[31m[2025-03-06 14:26:28,453 INFO][0m[Epoch 41/50] Train Loss: 43.997246, Val Loss: 1577810.375000, RMSE: 4.168728, MAE: 2.807904, Accuracy: 0.7095, R2: 0.8409, Expl.Var: 0.8417
[31m[2025-03-06 14:26:40,622 INFO][0m[Epoch 50/50] Train Loss: 43.974694, Val Loss: 1572582.750000

File data/W_est_shenzhen_pre_len2.npy found. Loading existing adjacency matrix estimated by GSL from training data.
GSL computed: GSL+Adj


[31m[2025-03-06 14:26:42,204 INFO][0m[Epoch 1/50] Train Loss: 155.166981, Val Loss: 3551141.500000, RMSE: 6.254033, MAE: 4.811000, Accuracy: 0.5642, R2: 0.6447, Expl.Var: 0.6504
[31m[2025-03-06 14:26:55,840 INFO][0m[Epoch 11/50] Train Loss: 72.975424, Val Loss: 2988711.500000, RMSE: 5.737441, MAE: 4.250663, Accuracy: 0.6002, R2: 0.6983, Expl.Var: 0.6984
[31m[2025-03-06 14:27:09,293 INFO][0m[Epoch 21/50] Train Loss: 70.632839, Val Loss: 2871645.500000, RMSE: 5.623952, MAE: 4.179071, Accuracy: 0.6081, R2: 0.7102, Expl.Var: 0.7105
[31m[2025-03-06 14:27:22,793 INFO][0m[Epoch 31/50] Train Loss: 64.776980, Val Loss: 2567981.750000, RMSE: 5.318291, MAE: 4.001699, Accuracy: 0.6294, R2: 0.7409, Expl.Var: 0.7412
[31m[2025-03-06 14:27:36,716 INFO][0m[Epoch 41/50] Train Loss: 53.897459, Val Loss: 2046158.000000, RMSE: 4.747289, MAE: 3.493054, Accuracy: 0.6692, R2: 0.7936, Expl.Var: 0.7939
[31m[2025-03-06 14:27:48,940 INFO][0m[Epoch 50/50] Train Loss: 49.402700, Val Loss: 1843754.250000

TGCN


[31m[2025-03-06 14:27:50,437 INFO][0m[Epoch 1/50] Train Loss: 247.953554, Val Loss: 5588889.500000, RMSE: 6.411605, MAE: 5.065490, Accuracy: 0.5532, R2: 0.6309, Expl.Var: 0.6438
[31m[2025-03-06 14:28:04,070 INFO][0m[Epoch 11/50] Train Loss: 110.309745, Val Loss: 4517411.500000, RMSE: 5.764330, MAE: 4.265648, Accuracy: 0.5983, R2: 0.6956, Expl.Var: 0.6957
[31m[2025-03-06 14:28:17,563 INFO][0m[Epoch 21/50] Train Loss: 107.294537, Val Loss: 4367923.500000, RMSE: 5.668152, MAE: 4.197238, Accuracy: 0.6050, R2: 0.7058, Expl.Var: 0.7063
[31m[2025-03-06 14:28:31,219 INFO][0m[Epoch 31/50] Train Loss: 99.921669, Val Loss: 4040442.000000, RMSE: 5.451530, MAE: 4.064967, Accuracy: 0.6201, R2: 0.7279, Expl.Var: 0.7286
[31m[2025-03-06 14:28:44,569 INFO][0m[Epoch 41/50] Train Loss: 89.787479, Val Loss: 3584244.750000, RMSE: 5.134555, MAE: 3.869759, Accuracy: 0.6422, R2: 0.7596, Expl.Var: 0.7632
[31m[2025-03-06 14:28:56,821 INFO][0m[Epoch 50/50] Train Loss: 78.863149, Val Loss: 2983710.5000

File data/W_est_shenzhen_pre_len3.npy found. Loading existing adjacency matrix estimated by GSL from training data.
GSL computed: Only GSL


[31m[2025-03-06 14:28:58,380 INFO][0m[Epoch 1/50] Train Loss: 194.376562, Val Loss: 2768185.750000, RMSE: 4.512338, MAE: 3.168136, Accuracy: 0.6856, R2: 0.8159, Expl.Var: 0.8271
[31m[2025-03-06 14:29:12,076 INFO][0m[Epoch 11/50] Train Loss: 68.065120, Val Loss: 2403940.750000, RMSE: 4.204999, MAE: 2.839746, Accuracy: 0.7070, R2: 0.8380, Expl.Var: 0.8382
[31m[2025-03-06 14:29:25,663 INFO][0m[Epoch 21/50] Train Loss: 67.464025, Val Loss: 2400594.250000, RMSE: 4.202071, MAE: 2.865467, Accuracy: 0.7072, R2: 0.8384, Expl.Var: 0.8392
[31m[2025-03-06 14:29:39,111 INFO][0m[Epoch 31/50] Train Loss: 67.363994, Val Loss: 2394406.500000, RMSE: 4.196651, MAE: 2.823927, Accuracy: 0.7076, R2: 0.8387, Expl.Var: 0.8392
[31m[2025-03-06 14:29:52,397 INFO][0m[Epoch 41/50] Train Loss: 67.213229, Val Loss: 2387370.250000, RMSE: 4.190481, MAE: 2.769196, Accuracy: 0.7080, R2: 0.8391, Expl.Var: 0.8393
[31m[2025-03-06 14:30:04,654 INFO][0m[Epoch 50/50] Train Loss: 67.137550, Val Loss: 2390015.750000

File data/W_est_shenzhen_pre_len3.npy found. Loading existing adjacency matrix estimated by GSL from training data.
GSL computed: GSL+Adj


[31m[2025-03-06 14:30:06,221 INFO][0m[Epoch 1/50] Train Loss: 247.953554, Val Loss: 5588889.500000, RMSE: 6.411605, MAE: 5.065490, Accuracy: 0.5532, R2: 0.6309, Expl.Var: 0.6438
[31m[2025-03-06 14:30:19,654 INFO][0m[Epoch 11/50] Train Loss: 110.309745, Val Loss: 4517411.500000, RMSE: 5.764330, MAE: 4.265648, Accuracy: 0.5983, R2: 0.6956, Expl.Var: 0.6957
[31m[2025-03-06 14:30:33,221 INFO][0m[Epoch 21/50] Train Loss: 107.294537, Val Loss: 4367923.500000, RMSE: 5.668152, MAE: 4.197238, Accuracy: 0.6050, R2: 0.7058, Expl.Var: 0.7063
[31m[2025-03-06 14:30:46,621 INFO][0m[Epoch 31/50] Train Loss: 99.921669, Val Loss: 4040442.000000, RMSE: 5.451530, MAE: 4.064967, Accuracy: 0.6201, R2: 0.7279, Expl.Var: 0.7286
[31m[2025-03-06 14:31:00,124 INFO][0m[Epoch 41/50] Train Loss: 89.787479, Val Loss: 3584244.750000, RMSE: 5.134555, MAE: 3.869759, Accuracy: 0.6422, R2: 0.7596, Expl.Var: 0.7632
[31m[2025-03-06 14:31:12,266 INFO][0m[Epoch 50/50] Train Loss: 78.863149, Val Loss: 2983710.5000

TGCN


[31m[2025-03-06 14:31:13,757 INFO][0m[Epoch 1/50] Train Loss: 490.803538, Val Loss: 8192973.500000, RMSE: 6.728673, MAE: 5.409897, Accuracy: 0.5311, R2: 0.5863, Expl.Var: 0.5879
[31m[2025-03-06 14:31:27,122 INFO][0m[Epoch 11/50] Train Loss: 148.926267, Val Loss: 6059755.000000, RMSE: 5.786770, MAE: 4.304471, Accuracy: 0.5967, R2: 0.6933, Expl.Var: 0.6938
[31m[2025-03-06 14:31:40,593 INFO][0m[Epoch 21/50] Train Loss: 145.812555, Val Loss: 5929333.000000, RMSE: 5.724158, MAE: 4.245355, Accuracy: 0.6011, R2: 0.6999, Expl.Var: 0.7002
[31m[2025-03-06 14:31:54,322 INFO][0m[Epoch 31/50] Train Loss: 140.221965, Val Loss: 5676330.500000, RMSE: 5.600703, MAE: 4.164546, Accuracy: 0.6097, R2: 0.7127, Expl.Var: 0.7131
[31m[2025-03-06 14:32:07,905 INFO][0m[Epoch 41/50] Train Loss: 130.101241, Val Loss: 5150568.500000, RMSE: 5.335022, MAE: 4.021696, Accuracy: 0.6282, R2: 0.7394, Expl.Var: 0.7401
[31m[2025-03-06 14:32:19,971 INFO][0m[Epoch 50/50] Train Loss: 114.482389, Val Loss: 4405340.0

File data/W_est_shenzhen_pre_len4.npy found. Loading existing adjacency matrix estimated by GSL from training data.
GSL computed: Only GSL


[31m[2025-03-06 14:32:21,546 INFO][0m[Epoch 1/50] Train Loss: 423.295415, Val Loss: 3908408.750000, RMSE: 4.647385, MAE: 3.432168, Accuracy: 0.6761, R2: 0.8024, Expl.Var: 0.8035
[31m[2025-03-06 14:32:35,132 INFO][0m[Epoch 11/50] Train Loss: 92.869057, Val Loss: 3254922.000000, RMSE: 4.241105, MAE: 2.832589, Accuracy: 0.7044, R2: 0.8352, Expl.Var: 0.8352
[31m[2025-03-06 14:32:48,455 INFO][0m[Epoch 21/50] Train Loss: 91.545186, Val Loss: 3228456.750000, RMSE: 4.223827, MAE: 2.810040, Accuracy: 0.7057, R2: 0.8365, Expl.Var: 0.8366
[31m[2025-03-06 14:33:02,155 INFO][0m[Epoch 31/50] Train Loss: 91.181948, Val Loss: 3222035.750000, RMSE: 4.219625, MAE: 2.794372, Accuracy: 0.7059, R2: 0.8369, Expl.Var: 0.8370
[31m[2025-03-06 14:33:15,646 INFO][0m[Epoch 41/50] Train Loss: 91.163458, Val Loss: 3217517.750000, RMSE: 4.216666, MAE: 2.780404, Accuracy: 0.7062, R2: 0.8371, Expl.Var: 0.8371
[31m[2025-03-06 14:33:27,658 INFO][0m[Epoch 50/50] Train Loss: 90.995489, Val Loss: 3217259.500000

File data/W_est_shenzhen_pre_len4.npy found. Loading existing adjacency matrix estimated by GSL from training data.
GSL computed: GSL+Adj


[31m[2025-03-06 14:33:29,189 INFO][0m[Epoch 1/50] Train Loss: 490.803538, Val Loss: 8192973.500000, RMSE: 6.728673, MAE: 5.409897, Accuracy: 0.5311, R2: 0.5863, Expl.Var: 0.5879
[31m[2025-03-06 14:33:42,658 INFO][0m[Epoch 11/50] Train Loss: 148.926267, Val Loss: 6059755.000000, RMSE: 5.786770, MAE: 4.304471, Accuracy: 0.5967, R2: 0.6933, Expl.Var: 0.6938
[31m[2025-03-06 14:33:56,270 INFO][0m[Epoch 21/50] Train Loss: 145.812555, Val Loss: 5929333.000000, RMSE: 5.724158, MAE: 4.245355, Accuracy: 0.6011, R2: 0.6999, Expl.Var: 0.7002
[31m[2025-03-06 14:34:09,915 INFO][0m[Epoch 31/50] Train Loss: 140.221965, Val Loss: 5676330.500000, RMSE: 5.600703, MAE: 4.164546, Accuracy: 0.6097, R2: 0.7127, Expl.Var: 0.7131
[31m[2025-03-06 14:34:23,589 INFO][0m[Epoch 41/50] Train Loss: 130.101241, Val Loss: 5150568.500000, RMSE: 5.335022, MAE: 4.021696, Accuracy: 0.6282, R2: 0.7394, Expl.Var: 0.7401
[31m[2025-03-06 14:34:35,806 INFO][0m[Epoch 50/50] Train Loss: 114.482389, Val Loss: 4405340.0

TGCN


[31m[2025-03-06 14:34:36,857 INFO][0m[Epoch 1/50] Train Loss: 1112.484312, Val Loss: 9420844.000000, RMSE: 15.257605, MAE: 11.646675, Accuracy: 0.7403, R2: 0.2511, Expl.Var: 0.4069
[31m[2025-03-06 14:34:45,847 INFO][0m[Epoch 11/50] Train Loss: 109.476194, Val Loss: 3683299.250000, RMSE: 9.540255, MAE: 6.738713, Accuracy: 0.8376, R2: 0.5267, Expl.Var: 0.5279
[31m[2025-03-06 14:34:54,783 INFO][0m[Epoch 21/50] Train Loss: 80.579441, Val Loss: 2612263.500000, RMSE: 8.034335, MAE: 5.562919, Accuracy: 0.8633, R2: 0.6637, Expl.Var: 0.6638
[31m[2025-03-06 14:35:03,839 INFO][0m[Epoch 31/50] Train Loss: 68.158829, Val Loss: 2192631.500000, RMSE: 7.360787, MAE: 5.131146, Accuracy: 0.8747, R2: 0.7178, Expl.Var: 0.7181
[31m[2025-03-06 14:35:12,911 INFO][0m[Epoch 41/50] Train Loss: 60.225385, Val Loss: 1933874.375000, RMSE: 6.912825, MAE: 4.765222, Accuracy: 0.8823, R2: 0.7511, Expl.Var: 0.7516
[31m[2025-03-06 14:35:21,106 INFO][0m[Epoch 50/50] Train Loss: 55.029498, Val Loss: 1756643.50

File data/W_est_losloop_pre_len1.npy found. Loading existing adjacency matrix estimated by GSL from training data.
GSL computed: Only GSL


[31m[2025-03-06 14:35:22,144 INFO][0m[Epoch 1/50] Train Loss: 1146.168615, Val Loss: 11481282.000000, RMSE: 16.843666, MAE: 11.732319, Accuracy: 0.7133, R2: -0.0914, Expl.Var: -0.1238
[31m[2025-03-06 14:35:30,921 INFO][0m[Epoch 11/50] Train Loss: 57.479963, Val Loss: 1864368.750000, RMSE: 6.787461, MAE: 4.454289, Accuracy: 0.8845, R2: 0.7600, Expl.Var: 0.7601
[31m[2025-03-06 14:35:39,725 INFO][0m[Epoch 21/50] Train Loss: 43.044410, Val Loss: 1411129.250000, RMSE: 5.905066, MAE: 3.752765, Accuracy: 0.8995, R2: 0.8183, Expl.Var: 0.8184
[31m[2025-03-06 14:35:48,555 INFO][0m[Epoch 31/50] Train Loss: 34.213752, Val Loss: 1131582.750000, RMSE: 5.287916, MAE: 3.317061, Accuracy: 0.9100, R2: 0.8543, Expl.Var: 0.8545
[31m[2025-03-06 14:35:57,410 INFO][0m[Epoch 41/50] Train Loss: 30.017911, Val Loss: 995363.625000, RMSE: 4.959436, MAE: 3.127564, Accuracy: 0.9156, R2: 0.8718, Expl.Var: 0.8719
[31m[2025-03-06 14:36:05,323 INFO][0m[Epoch 50/50] Train Loss: 28.278445, Val Loss: 939243.37

File data/W_est_losloop_pre_len1.npy found. Loading existing adjacency matrix estimated by GSL from training data.
GSL computed: GSL+Adj


[31m[2025-03-06 14:36:06,408 INFO][0m[Epoch 1/50] Train Loss: 1113.143479, Val Loss: 9483019.000000, RMSE: 15.307870, MAE: 11.642679, Accuracy: 0.7395, R2: 0.2441, Expl.Var: 0.3943
[31m[2025-03-06 14:36:15,352 INFO][0m[Epoch 11/50] Train Loss: 110.741545, Val Loss: 3703649.750000, RMSE: 9.566573, MAE: 6.777622, Accuracy: 0.8372, R2: 0.5240, Expl.Var: 0.5251
[31m[2025-03-06 14:36:24,233 INFO][0m[Epoch 21/50] Train Loss: 79.833123, Val Loss: 2576744.750000, RMSE: 7.979527, MAE: 5.568526, Accuracy: 0.8642, R2: 0.6682, Expl.Var: 0.6682
[31m[2025-03-06 14:36:33,123 INFO][0m[Epoch 31/50] Train Loss: 67.526821, Val Loss: 2171137.500000, RMSE: 7.324620, MAE: 5.126070, Accuracy: 0.8753, R2: 0.7206, Expl.Var: 0.7210
[31m[2025-03-06 14:36:42,124 INFO][0m[Epoch 41/50] Train Loss: 59.867563, Val Loss: 1920271.750000, RMSE: 6.888470, MAE: 4.765722, Accuracy: 0.8828, R2: 0.7529, Expl.Var: 0.7534
[31m[2025-03-06 14:36:50,165 INFO][0m[Epoch 50/50] Train Loss: 54.434653, Val Loss: 1736057.50

TGCN


[31m[2025-03-06 14:36:51,256 INFO][0m[Epoch 1/50] Train Loss: 2551.107982, Val Loss: 22101220.000000, RMSE: 16.545910, MAE: 12.872609, Accuracy: 0.7184, R2: 0.2063, Expl.Var: 0.3704
[31m[2025-03-06 14:37:00,026 INFO][0m[Epoch 11/50] Train Loss: 235.458262, Val Loss: 8016692.000000, RMSE: 9.965065, MAE: 7.084861, Accuracy: 0.8304, R2: 0.4837, Expl.Var: 0.4843
[31m[2025-03-06 14:37:08,856 INFO][0m[Epoch 21/50] Train Loss: 173.613949, Val Loss: 5665826.500000, RMSE: 8.377494, MAE: 5.836537, Accuracy: 0.8574, R2: 0.6347, Expl.Var: 0.6347
[31m[2025-03-06 14:37:17,690 INFO][0m[Epoch 31/50] Train Loss: 147.563591, Val Loss: 4777217.500000, RMSE: 7.692544, MAE: 5.333383, Accuracy: 0.8691, R2: 0.6921, Expl.Var: 0.6922
[31m[2025-03-06 14:37:26,656 INFO][0m[Epoch 41/50] Train Loss: 131.010524, Val Loss: 4241721.000000, RMSE: 7.248590, MAE: 4.976371, Accuracy: 0.8766, R2: 0.7267, Expl.Var: 0.7270
[31m[2025-03-06 14:37:34,642 INFO][0m[Epoch 50/50] Train Loss: 120.788007, Val Loss: 39104

File data/W_est_losloop_pre_len2.npy found. Loading existing adjacency matrix estimated by GSL from training data.
GSL computed: Only GSL


[31m[2025-03-06 14:37:35,731 INFO][0m[Epoch 1/50] Train Loss: 2593.785552, Val Loss: 25614992.000000, RMSE: 17.812695, MAE: 13.286761, Accuracy: 0.6968, R2: -0.0906, Expl.Var: -0.1372
[31m[2025-03-06 14:37:44,528 INFO][0m[Epoch 11/50] Train Loss: 139.387148, Val Loss: 4480797.000000, RMSE: 7.450067, MAE: 4.830361, Accuracy: 0.8732, R2: 0.7119, Expl.Var: 0.7138
[31m[2025-03-06 14:37:53,373 INFO][0m[Epoch 21/50] Train Loss: 102.562290, Val Loss: 3377663.250000, RMSE: 6.468307, MAE: 4.020167, Accuracy: 0.8899, R2: 0.7825, Expl.Var: 0.7834
[31m[2025-03-06 14:38:02,210 INFO][0m[Epoch 31/50] Train Loss: 83.732713, Val Loss: 2778798.500000, RMSE: 5.866931, MAE: 3.646541, Accuracy: 0.9001, R2: 0.8209, Expl.Var: 0.8209
[31m[2025-03-06 14:38:11,056 INFO][0m[Epoch 41/50] Train Loss: 73.292917, Val Loss: 2485625.250000, RMSE: 5.548816, MAE: 3.349482, Accuracy: 0.9055, R2: 0.8400, Expl.Var: 0.8414
[31m[2025-03-06 14:38:19,055 INFO][0m[Epoch 50/50] Train Loss: 68.910686, Val Loss: 235448

File data/W_est_losloop_pre_len2.npy found. Loading existing adjacency matrix estimated by GSL from training data.
GSL computed: GSL+Adj


[31m[2025-03-06 14:38:20,145 INFO][0m[Epoch 1/50] Train Loss: 2553.527185, Val Loss: 22265304.000000, RMSE: 16.607216, MAE: 12.866159, Accuracy: 0.7173, R2: 0.1972, Expl.Var: 0.3525
[31m[2025-03-06 14:38:28,994 INFO][0m[Epoch 11/50] Train Loss: 240.868618, Val Loss: 8131115.500000, RMSE: 10.035929, MAE: 7.154685, Accuracy: 0.8292, R2: 0.4763, Expl.Var: 0.4768
[31m[2025-03-06 14:38:37,911 INFO][0m[Epoch 21/50] Train Loss: 173.264291, Val Loss: 5622172.500000, RMSE: 8.345159, MAE: 5.824179, Accuracy: 0.8579, R2: 0.6376, Expl.Var: 0.6377
[31m[2025-03-06 14:38:46,762 INFO][0m[Epoch 31/50] Train Loss: 146.576822, Val Loss: 4726560.500000, RMSE: 7.651650, MAE: 5.325992, Accuracy: 0.8698, R2: 0.6953, Expl.Var: 0.6954
[31m[2025-03-06 14:38:55,559 INFO][0m[Epoch 41/50] Train Loss: 130.041123, Val Loss: 4195291.500000, RMSE: 7.208810, MAE: 4.979330, Accuracy: 0.8773, R2: 0.7296, Expl.Var: 0.7299
[31m[2025-03-06 14:39:03,511 INFO][0m[Epoch 50/50] Train Loss: 119.444212, Val Loss: 3860

TGCN


[31m[2025-03-06 14:39:04,611 INFO][0m[Epoch 1/50] Train Loss: 5975.527014, Val Loss: 32249282.000000, RMSE: 16.340090, MAE: 12.598639, Accuracy: 0.7218, R2: 0.1883, Expl.Var: 0.3220
[31m[2025-03-06 14:39:13,473 INFO][0m[Epoch 11/50] Train Loss: 373.395498, Val Loss: 12700180.000000, RMSE: 10.254142, MAE: 7.378126, Accuracy: 0.8254, R2: 0.4535, Expl.Var: 0.4535
[31m[2025-03-06 14:39:22,373 INFO][0m[Epoch 21/50] Train Loss: 294.635658, Val Loss: 9718177.000000, RMSE: 8.969883, MAE: 6.308595, Accuracy: 0.8473, R2: 0.5818, Expl.Var: 0.5818
[31m[2025-03-06 14:39:31,365 INFO][0m[Epoch 31/50] Train Loss: 244.597288, Val Loss: 8024870.000000, RMSE: 8.151047, MAE: 5.598724, Accuracy: 0.8612, R2: 0.6547, Expl.Var: 0.6547
[31m[2025-03-06 14:39:40,356 INFO][0m[Epoch 41/50] Train Loss: 217.218102, Val Loss: 7140357.000000, RMSE: 7.688725, MAE: 5.342566, Accuracy: 0.8691, R2: 0.6934, Expl.Var: 0.6949
[31m[2025-03-06 14:39:48,333 INFO][0m[Epoch 50/50] Train Loss: 200.972167, Val Loss: 654

File data/W_est_losloop_pre_len3.npy found. Loading existing adjacency matrix estimated by GSL from training data.
GSL computed: Only GSL


[31m[2025-03-06 14:39:49,411 INFO][0m[Epoch 1/50] Train Loss: 6043.095129, Val Loss: 43058568.000000, RMSE: 18.880964, MAE: 13.684561, Accuracy: 0.6786, R2: -0.2729, Expl.Var: -0.3973
[31m[2025-03-06 14:39:58,106 INFO][0m[Epoch 11/50] Train Loss: 271.827138, Val Loss: 8445623.000000, RMSE: 8.362001, MAE: 5.437482, Accuracy: 0.8576, R2: 0.6368, Expl.Var: 0.6373
[31m[2025-03-06 14:40:06,946 INFO][0m[Epoch 21/50] Train Loss: 177.040565, Val Loss: 5871422.500000, RMSE: 6.972139, MAE: 4.330657, Accuracy: 0.8813, R2: 0.7474, Expl.Var: 0.7478
[31m[2025-03-06 14:40:16,007 INFO][0m[Epoch 31/50] Train Loss: 145.135429, Val Loss: 4925343.000000, RMSE: 6.385761, MAE: 3.883577, Accuracy: 0.8913, R2: 0.7883, Expl.Var: 0.7892
[31m[2025-03-06 14:40:24,775 INFO][0m[Epoch 41/50] Train Loss: 126.569740, Val Loss: 4375548.000000, RMSE: 6.018810, MAE: 3.731156, Accuracy: 0.8975, R2: 0.8118, Expl.Var: 0.8125
[31m[2025-03-06 14:40:32,875 INFO][0m[Epoch 50/50] Train Loss: 118.852611, Val Loss: 412

File data/W_est_losloop_pre_len3.npy found. Loading existing adjacency matrix estimated by GSL from training data.
GSL computed: GSL+Adj


[31m[2025-03-06 14:40:33,975 INFO][0m[Epoch 1/50] Train Loss: 5979.250416, Val Loss: 32744192.000000, RMSE: 16.464991, MAE: 12.663162, Accuracy: 0.7197, R2: 0.1735, Expl.Var: 0.2959
[31m[2025-03-06 14:40:42,993 INFO][0m[Epoch 11/50] Train Loss: 386.297914, Val Loss: 13017356.000000, RMSE: 10.381396, MAE: 7.477658, Accuracy: 0.8233, R2: 0.4398, Expl.Var: 0.4398
[31m[2025-03-06 14:40:51,965 INFO][0m[Epoch 21/50] Train Loss: 294.756357, Val Loss: 9638410.000000, RMSE: 8.932995, MAE: 6.317088, Accuracy: 0.8479, R2: 0.5852, Expl.Var: 0.5853
[31m[2025-03-06 14:41:01,044 INFO][0m[Epoch 31/50] Train Loss: 242.084971, Val Loss: 7939920.500000, RMSE: 8.107790, MAE: 5.569749, Accuracy: 0.8620, R2: 0.6584, Expl.Var: 0.6585
[31m[2025-03-06 14:41:10,078 INFO][0m[Epoch 41/50] Train Loss: 215.930045, Val Loss: 7111023.000000, RMSE: 7.672915, MAE: 5.370074, Accuracy: 0.8694, R2: 0.6949, Expl.Var: 0.6969
[31m[2025-03-06 14:41:18,165 INFO][0m[Epoch 50/50] Train Loss: 199.531939, Val Loss: 649

TGCN


[31m[2025-03-06 14:41:19,276 INFO][0m[Epoch 1/50] Train Loss: 6778.211318, Val Loss: 47248048.000000, RMSE: 17.150465, MAE: 13.151768, Accuracy: 0.7080, R2: 0.1443, Expl.Var: 0.2576
[31m[2025-03-06 14:41:28,297 INFO][0m[Epoch 11/50] Train Loss: 499.725946, Val Loss: 17182340.000000, RMSE: 10.342490, MAE: 7.362521, Accuracy: 0.8239, R2: 0.4450, Expl.Var: 0.4452
[31m[2025-03-06 14:41:37,435 INFO][0m[Epoch 21/50] Train Loss: 379.989653, Val Loss: 12657210.000000, RMSE: 8.876729, MAE: 6.145161, Accuracy: 0.8489, R2: 0.5909, Expl.Var: 0.5909
[31m[2025-03-06 14:41:46,359 INFO][0m[Epoch 31/50] Train Loss: 330.156089, Val Loss: 10905120.000000, RMSE: 8.239469, MAE: 5.664052, Accuracy: 0.8597, R2: 0.6476, Expl.Var: 0.6478
[31m[2025-03-06 14:41:55,389 INFO][0m[Epoch 41/50] Train Loss: 299.326721, Val Loss: 9868780.000000, RMSE: 7.838189, MAE: 5.322256, Accuracy: 0.8665, R2: 0.6811, Expl.Var: 0.6812
[31m[2025-03-06 14:42:03,528 INFO][0m[Epoch 50/50] Train Loss: 279.900671, Val Loss: 9

File data/W_est_losloop_pre_len4.npy found. Loading existing adjacency matrix estimated by GSL from training data.
GSL computed: Only GSL


[31m[2025-03-06 14:42:04,589 INFO][0m[Epoch 1/50] Train Loss: 6929.411973, Val Loss: 61636232.000000, RMSE: 19.588537, MAE: 14.194361, Accuracy: 0.6665, R2: -0.3686, Expl.Var: -0.5365
[31m[2025-03-06 14:42:13,725 INFO][0m[Epoch 11/50] Train Loss: 387.919924, Val Loss: 12234973.000000, RMSE: 8.727412, MAE: 5.769040, Accuracy: 0.8514, R2: 0.6048, Expl.Var: 0.6051
[31m[2025-03-06 14:42:22,776 INFO][0m[Epoch 21/50] Train Loss: 248.716331, Val Loss: 8405551.000000, RMSE: 7.233809, MAE: 4.468880, Accuracy: 0.8768, R2: 0.7287, Expl.Var: 0.7297
[31m[2025-03-06 14:42:31,760 INFO][0m[Epoch 31/50] Train Loss: 203.998822, Val Loss: 7065628.000000, RMSE: 6.632227, MAE: 4.090346, Accuracy: 0.8871, R2: 0.7717, Expl.Var: 0.7717
[31m[2025-03-06 14:42:40,763 INFO][0m[Epoch 41/50] Train Loss: 186.509908, Val Loss: 6552608.500000, RMSE: 6.386914, MAE: 3.811663, Accuracy: 0.8913, R2: 0.7883, Expl.Var: 0.7886
[31m[2025-03-06 14:42:48,826 INFO][0m[Epoch 50/50] Train Loss: 179.453769, Val Loss: 62

File data/W_est_losloop_pre_len4.npy found. Loading existing adjacency matrix estimated by GSL from training data.
GSL computed: GSL+Adj


[31m[2025-03-06 14:42:49,909 INFO][0m[Epoch 1/50] Train Loss: 6786.746448, Val Loss: 48238924.000000, RMSE: 17.329370, MAE: 13.221812, Accuracy: 0.7050, R2: 0.1199, Expl.Var: 0.2125
[31m[2025-03-06 14:42:58,936 INFO][0m[Epoch 11/50] Train Loss: 525.462556, Val Loss: 17812280.000000, RMSE: 10.530373, MAE: 7.492685, Accuracy: 0.8207, R2: 0.4245, Expl.Var: 0.4246
[31m[2025-03-06 14:43:07,783 INFO][0m[Epoch 21/50] Train Loss: 375.161494, Val Loss: 12433444.000000, RMSE: 8.797914, MAE: 6.113583, Accuracy: 0.8502, R2: 0.5983, Expl.Var: 0.5984
[31m[2025-03-06 14:43:16,910 INFO][0m[Epoch 31/50] Train Loss: 328.410177, Val Loss: 10836886.000000, RMSE: 8.213651, MAE: 5.681320, Accuracy: 0.8602, R2: 0.6500, Expl.Var: 0.6504
[31m[2025-03-06 14:43:26,080 INFO][0m[Epoch 41/50] Train Loss: 298.577363, Val Loss: 9834196.000000, RMSE: 7.824442, MAE: 5.324487, Accuracy: 0.8668, R2: 0.6822, Expl.Var: 0.6824
[31m[2025-03-06 14:43:34,146 INFO][0m[Epoch 50/50] Train Loss: 278.112495, Val Loss: 9

CPU times: total: 20min 59s
Wall time: 23min 8s


In [3]:
import pandas as pd
import matplotlib.pyplot as plt
import os
import shutil

model_cls_name = 'TGCN'
# Define the prediction lengths and measures
pred_list = [1, 2, 3, 4]
measures = ["RMSE", "MAE", "Accuracy", "R2"]
datasets = ['shenzhen', 'losloop']

# Define a color palette for each method
colors = plt.cm.tab10.colors  # Using a predefined color palette (10 distinct colors)
method_colors = {
    "T-GCN": colors[0],           # T-GCN (Default)
    "T-GCN (GSL Only)": colors[1],  # T-GCN (GSL Only)
    "T-GCN (GSL + Adj)": colors[2], # T-GCN (GSL + Adj)
}

# Define line styles for each method
line_styles = {
    "T-GCN": "--",                # Dashed for T-GCN (Default)
    "T-GCN (GSL Only)": "-",       # Solid for T-GCN (GSL Only)
    "T-GCN (GSL + Adj)": ":",      # Dotted for T-GCN (GSL + Adj)
}

# Define markers for each method
markers = {
    "T-GCN": "o",                 # Circle for T-GCN (Default)
    "T-GCN (GSL Only)": "s",       # Square for T-GCN (GSL Only)
    "T-GCN (GSL + Adj)": "D",      # Diamond for T-GCN (GSL + Adj)
}

# Ensure the results directory exists
os.makedirs("results", exist_ok=True)

# Loop through each dataset
for dataset in datasets:
    # Loop through each measure
    for measure in measures:
        # Loop through each pre_len in pred_list
        for pre_len in pred_list:
            # Create a new figure for each pre_len and measure
            plt.figure(figsize=(12, 8))

            # Load the metrics for T-GCN (Default) for comparison
            metrics_file_tgcn = f"results/{model_cls_name}/metrics_{dataset}_seq12_pre{pre_len}_gsl0.csv"
            metrics_df_tgcn = pd.read_csv(metrics_file_tgcn)

            # Loop through each method
            for method, color in method_colors.items():
                # Map method names to their corresponding file suffixes
                if method == "T-GCN":
                    file_suffix = "gsl0"
                elif method == "T-GCN (GSL Only)":
                    file_suffix = "gsl1"
                elif method == "T-GCN (GSL + Adj)":
                    file_suffix = "gsl2"

                # Generate the file path dynamically based on pre_len and method
                metrics_file = f"results/{model_cls_name}/metrics_{dataset}_seq12_pre{pre_len}_{file_suffix}.csv"

                # Load the metrics CSV file
                metrics_df = pd.read_csv(metrics_file)

                # Plot the current measure for the current method
                plt.plot(
                    metrics_df["Epoch"],
                    metrics_df[measure],
                    linestyle=line_styles[method],
                    marker=markers[method],
                    color=color,
                    label=f"{method}", # (pre_len={pre_len})
                )

                # Highlight the first epoch where the current method is better than T-GCN (Default)
                if method != "T-GCN":  # Skip comparison for T-GCN (Default)
                    better_epoch = None
                    for epoch in range(len(metrics_df)):
                        tgcn_value = metrics_df_tgcn.loc[epoch, measure]
                        current_value = metrics_df.loc[epoch, measure]

                        # Check if the current method is better than T-GCN (Default)
                        if measure in ["RMSE", "MAE"]:
                            if current_value < tgcn_value:  # Lower is better
                                better_epoch = epoch + 1  # Epochs are 1-indexed
                                break
                        else:
                            if current_value > tgcn_value:  # Higher is better
                                better_epoch = epoch + 1
                                break

                    # Highlight the first epoch where the current method is better
                    if better_epoch is not None:
                        plt.axvline(
                            x=better_epoch,
                            color=color,
                            linestyle=":",
                            alpha=0.5,
                            # label=f"{method} better at epoch {better_epoch}", # (pre_len={pre_len})
                        )

            # Add labels, title, legend, and grid
            plt.xlabel("Epoch")
            plt.ylabel(measure)
            plt.title(f"{measure} over Epochs for {dataset} (pre_len={pre_len})")
            plt.legend()
            # plt.grid(True)

            # Save the figure as a PNG file
            filename = f"results/{model_cls_name}/{dataset}_{measure}_pre{pre_len}.png"
            plt.savefig(filename, bbox_inches="tight", dpi=300)
            plt.close()  # Close the figure to free up memory


In [4]:
import pandas as pd
import os

# Define the prediction lengths and measures
pred_list = [1, 2, 3, 4]
measures = ["RMSE", "MAE", "Accuracy", "R2"]
datasets = [ 'shenzhen', 'losloop']

# Ensure the results directory exists
os.makedirs("results", exist_ok=True)

# Loop through each dataset
for dataset in datasets:
    # Initialize a dictionary to store the results of the last epoch
    results = {}

    # Loop through each pre_len and measure to collect the last epoch results
    for pre_len in pred_list:
        # Generate the file paths dynamically based on pre_len and method
        metrics_file_gsl0 = f"results/{model_cls_name}/metrics_{dataset}_seq12_pre{pre_len}_gsl0.csv"
        metrics_file_gsl1 = f"results/{model_cls_name}/metrics_{dataset}_seq12_pre{pre_len}_gsl1.csv"
        metrics_file_gsl2 = f"results/{model_cls_name}/metrics_{dataset}_seq12_pre{pre_len}_gsl2.csv"

        # Load the metrics CSV files
        metrics_df_gsl0 = pd.read_csv(metrics_file_gsl0)
        metrics_df_gsl1 = pd.read_csv(metrics_file_gsl1)
        metrics_df_gsl2 = pd.read_csv(metrics_file_gsl2)

        # Get the last epoch results for all methods
        last_epoch_gsl0 = metrics_df_gsl0.iloc[-1]  # Last row for gsl0 (T-GCN)
        last_epoch_gsl1 = metrics_df_gsl1.iloc[-1]  # Last row for gsl1 (T-GCN GSL Only)
        last_epoch_gsl2 = metrics_df_gsl2.iloc[-1]  # Last row for gsl2 (T-GCN GSL + Adj)

        # Store the results in the dictionary with pre_len as part of the key
        results[f"T-GCN (pre_len={pre_len})"] = last_epoch_gsl0[measures]
        results[f"T-GCN (GSL Only) (pre_len={pre_len})"] = last_epoch_gsl1[measures]
        results[f"T-GCN (GSL + Adj) (pre_len={pre_len})"] = last_epoch_gsl2[measures]

    # Convert the results dictionary to a DataFrame
    results_df = pd.DataFrame(results).T

    # Add a column for pre_len to results_df
    results_df["pre\\_len"] = [pre_len for pre_len in pred_list for _ in range(3)]

    # Function to highlight the winner for each pre_len and measure
    def highlight_winner(df):
        highlighted_df = df.copy()
        for pre_len in pred_list:
            for measure in measures:
                # Get the values for the three methods for the current pre_len
                tgcn_value = df.loc[f"T-GCN (pre_len={pre_len})", measure]
                gsl_only_value = df.loc[f"T-GCN (GSL Only) (pre_len={pre_len})", measure]
                gsl_adj_value = df.loc[f"T-GCN (GSL + Adj) (pre_len={pre_len})", measure]

                # Determine the winner based on the measure
                if measure in ["RMSE", "MAE"]:
                    winner_value = min(tgcn_value, gsl_only_value, gsl_adj_value)  # Lower is better
                else:
                    winner_value = max(tgcn_value, gsl_only_value, gsl_adj_value)  # Higher is better

                # Highlight the winner
                if tgcn_value == winner_value:
                    highlighted_df.loc[f"T-GCN (pre_len={pre_len})", measure] = f"\\textbf{{{tgcn_value:.4f}}}"
                else:
                    highlighted_df.loc[f"T-GCN (pre_len={pre_len})", measure] = f"{tgcn_value:.4f}"

                if gsl_only_value == winner_value:
                    highlighted_df.loc[f"T-GCN (GSL Only) (pre_len={pre_len})", measure] = f"\\textbf{{{gsl_only_value:.4f}}}"
                else:
                    highlighted_df.loc[f"T-GCN (GSL Only) (pre_len={pre_len})", measure] = f"{gsl_only_value:.4f}"

                if gsl_adj_value == winner_value:
                    highlighted_df.loc[f"T-GCN (GSL + Adj) (pre_len={pre_len})", measure] = f"\\textbf{{{gsl_adj_value:.4f}}}"
                else:
                    highlighted_df.loc[f"T-GCN (GSL + Adj) (pre_len={pre_len})", measure] = f"{gsl_adj_value:.4f}"
        return highlighted_df

    # Apply the highlight function to the results DataFrame
    highlighted_results = highlight_winner(results_df)

    # Remove (pre_len={pre_len}) from method names
    highlighted_results.index = highlighted_results.index.str.replace(r" \(pre_len=\d+\)", "", regex=True)

    # Add a column for method names
    highlighted_results.insert(0, "Method", highlighted_results.index)

    # Reorder columns to make pre_len the first column and method-name the second column
    highlighted_results = highlighted_results[["pre\\_len", "Method"] + measures]

    # Generate the LaTeX table
    latex_table = highlighted_results.to_latex(
        escape=False,
        column_format="cl" + "c" * len(measures),  # Updated column format
        multicolumn_format="c",
        index=False  # Do not include the index in the LaTeX table
    )

    # Save the LaTeX table to a file
    table_filename = f"results/{model_cls_name}/{dataset}_results_table.tex"
    with open(table_filename, "w") as f:
        f.write(latex_table)

    # Print the LaTeX table
    print(f"LaTeX table for {dataset}:")
    print(latex_table)

LaTeX table for shenzhen:
\begin{tabular}{clcccc}
\toprule
 pre\_len &            Method &            RMSE &             MAE &        Accuracy &              R2 \\
\midrule
        1 &             T-GCN &          4.8657 &          3.6060 &          0.6610 &          0.7833 \\
        1 &  T-GCN (GSL Only) & \textbf{4.1142} & \textbf{2.7479} & \textbf{0.7133} & \textbf{0.8449} \\
        1 & T-GCN (GSL + Adj) &          4.8657 &          3.6060 &          0.6610 &          0.7833 \\
        2 &             T-GCN &          4.5064 &          3.2126 &          0.6860 &          0.8139 \\
        2 &  T-GCN (GSL Only) & \textbf{4.1618} & \textbf{2.8162} & \textbf{0.7100} & \textbf{0.8414} \\
        2 & T-GCN (GSL + Adj) &          4.5064 &          3.2126 &          0.6860 &          0.8139 \\
        3 &             T-GCN &          4.6847 &          3.4161 &          0.6735 &          0.7989 \\
        3 &  T-GCN (GSL Only) & \textbf{4.1928} & \textbf{2.8060} & \textbf{0.7078} & \textb

In [5]:
import os

# Define the prediction lengths and measures
pred_list = [1, 2, 3, 4]
measures = ["RMSE", "MAE", "Accuracy", "R2"]
datasets = [ 'shenzhen', 'losloop']

# Ensure the results directory exists
os.makedirs("results", exist_ok=True)

# Function to generate LaTeX table for a dataset
def generate_latex_table(dataset):
    latex_code = f"""
\\begin{{table}}[htbp]
\\centering
\\caption{{Results for {dataset}: RMSE, MAE, Accuracy, and R2 for pre\\_len=1, 2, 3, 4.}}
\\begin{{tabular}}{{|c|{'c' * len(measures)}|}}
\\hline
pre\\_len & {' & '.join(measures)} \\\\
\\hline
"""

    for pre_len in pred_list:
        row = f"{pre_len} "
        for measure in measures:
            image_path = f"{dataset}_{measure}_pre{pre_len}.png"
            row += f"& \\includegraphics[width=0.2\\textwidth]{{{image_path}}} "
        row += "\\\\ \\hline"
        latex_code += row + "\n"

    latex_code += """
\\end{tabular}
\\end{table}
"""
    return latex_code

# Generate LaTeX tables for each dataset
for dataset in datasets:
    latex_table = generate_latex_table(dataset)
    print(f"LaTeX table for {dataset}:")
    print(latex_table)

    # Save the LaTeX table to a file
    table_filename = f"results/{model_cls_name}_{dataset}_images_table.tex"
    with open(table_filename, "w") as f:
        f.write(latex_table)

LaTeX table for shenzhen:

\begin{table}[htbp]
\centering
\caption{Results for shenzhen: RMSE, MAE, Accuracy, and R2 for pre\_len=1, 2, 3, 4.}
\begin{tabular}{|c|cccc|}
\hline
pre\_len & RMSE & MAE & Accuracy & R2 \\
\hline
1 & \includegraphics[width=0.2\textwidth]{shenzhen_RMSE_pre1.png} & \includegraphics[width=0.2\textwidth]{shenzhen_MAE_pre1.png} & \includegraphics[width=0.2\textwidth]{shenzhen_Accuracy_pre1.png} & \includegraphics[width=0.2\textwidth]{shenzhen_R2_pre1.png} \\ \hline
2 & \includegraphics[width=0.2\textwidth]{shenzhen_RMSE_pre2.png} & \includegraphics[width=0.2\textwidth]{shenzhen_MAE_pre2.png} & \includegraphics[width=0.2\textwidth]{shenzhen_Accuracy_pre2.png} & \includegraphics[width=0.2\textwidth]{shenzhen_R2_pre2.png} \\ \hline
3 & \includegraphics[width=0.2\textwidth]{shenzhen_RMSE_pre3.png} & \includegraphics[width=0.2\textwidth]{shenzhen_MAE_pre3.png} & \includegraphics[width=0.2\textwidth]{shenzhen_Accuracy_pre3.png} & \includegraphics[width=0.2\textwidth]{s

In [6]:
if 'google.colab' in str(get_ipython()): 
    from google.colab import files
    # Zip the results folder
    shutil.make_archive('results', 'zip', 'results')
    # Download the zipped file
    files.download('results.zip')