In [1]:
import pandas as pd
import numpy as np
import torch
from torch.autograd import Variable
import sklearn.preprocessing
import sklearn.metrics
from data_split_tune_utils import X_y_site_split
from CNN_utils import split_sizes_site, split_data, pad_stack_splits, get_monitorData_indices, r2, get_nonConst_vars, train_CNN
from CNN_architecture import CNN1, CNN2

np.random.seed(1)
torch.manual_seed(1)

### read in train, val, and test
train = pd.read_csv('../data/trainV_ridgeImp.csv')
val = pd.read_csv('../data/valV_ridgeImp.csv')
test = pd.read_csv('../data/testV_ridgeImp.csv')

#################################################################
# for testing purposes
#train = train[~train['site'].isin([train['site'].values[-1]])]
#val = train[~train['site'].isin([train['site'].values[-1]])]
#test = train[~train['site'].isin([train['site'].values[-1]])]
#################################################################

### split train, val, and test into x, y, and sites
train_x, train_y, train_sites = X_y_site_split(train, y_var_name='MonitorData', site_var_name='site')
val_x, val_y, val_sites = X_y_site_split(val, y_var_name='MonitorData', site_var_name='site')
test_x, test_y, test_sites = X_y_site_split(test, y_var_name='MonitorData', site_var_name='site')

### get dataframes with non-constant features only
nonConst_vars = get_nonConst_vars(train, site_var_name='site', y_var_name='MonitorData', cutoff=1000)
train_x_nonConst = train_x.loc[:, nonConst_vars]
val_x_nonConst = val_x.loc[:, nonConst_vars]
test_x_nonConst = test_x.loc[:, nonConst_vars]

### standardize all features
standardizer_all = sklearn.preprocessing.StandardScaler(with_mean = True, with_std = True)
train_x_std_all = standardizer_all.fit_transform(train_x)
val_x_std_all = standardizer_all.transform(val_x)
test_x_std_all = standardizer_all.transform(test_x)

### standardize non-constant features
standardizer_nonConst = sklearn.preprocessing.StandardScaler(with_mean = True, with_std = True)
train_x_std_nonConst = standardizer_nonConst.fit_transform(train_x_nonConst)
val_x_std_nonConst = standardizer_nonConst.transform(val_x_nonConst)
test_x_std_nonConst = standardizer_nonConst.transform(test_x_nonConst)




### get split sizes for TRAIN data (splitting by site)
train_split_sizes = split_sizes_site(train_sites.values)

### get tuples by site
train_x_std_tuple_nonConst = split_data(torch.from_numpy(train_x_std_nonConst).float(), train_split_sizes, dim = 0)
train_x_std_tuple = split_data(torch.from_numpy(train_x_std_all).float(), train_split_sizes, dim = 0)
train_y_tuple = split_data(torch.from_numpy(train_y.values), train_split_sizes, dim = 0)

### get site sequences stacked into matrix to go through CNN
train_x_std_stack_nonConst = pad_stack_splits(train_x_std_tuple_nonConst, np.array(train_split_sizes), 'x')
train_x_std_stack_nonConst = Variable(torch.transpose(train_x_std_stack_nonConst, 1, 2))


### get split sizes for VALIDATION data (splitting by site)
val_split_sizes = split_sizes_site(val_sites.values)

### get tuples by site
val_x_std_tuple_nonConst = split_data(torch.from_numpy(val_x_std_nonConst).float(), val_split_sizes, dim = 0)
val_x_std_tuple = split_data(torch.from_numpy(val_x_std_all).float(), val_split_sizes, dim = 0)
val_y_tuple = split_data(torch.from_numpy(val_y.values), val_split_sizes, dim = 0)

### get site sequences stacked into matrix to go through CNN
val_x_std_stack_nonConst = pad_stack_splits(val_x_std_tuple_nonConst, np.array(val_split_sizes), 'x')
val_x_std_stack_nonConst = Variable(torch.transpose(val_x_std_stack_nonConst, 1, 2))


### get split sizes for TEST data (splitting by site)
test_split_sizes = split_sizes_site(test_sites.values)

### get tuples by site
test_x_std_tuple_nonConst = split_data(torch.from_numpy(test_x_std_nonConst).float(), test_split_sizes, dim = 0)
test_x_std_tuple = split_data(torch.from_numpy(test_x_std_all).float(), test_split_sizes, dim = 0)
test_y_tuple = split_data(torch.from_numpy(test_y.values), test_split_sizes, dim = 0)

### get site sequences stacked into matrix to go through CNN
test_x_std_stack_nonConst = pad_stack_splits(test_x_std_tuple_nonConst, np.array(test_split_sizes), 'x')
test_x_std_stack_nonConst = Variable(torch.transpose(test_x_std_stack_nonConst, 1, 2))

In [2]:
num_epochs = 2
batch_size = 32
input_size_conv = train_x_std_nonConst.shape[1]
input_size_full = train_x_std_all.shape[1]

# CNN and optimizer hyper-parameters to test
hidden_size_conv_list = [25, 50, 75]
kernel_size_list = [3, 5]
padding_list = [1, 2]
hidden_size_full_list = [50, 100, 150]
dropout_full_list = [0.1, 0.3, 0.5]
hidden_size_combo_list = [50, 100, 150]
dropout_combo_list = [0.1, 0.3, 0.5]
lr_list = [0.1, 0.01]
weight_decay_list = [0.001, 0.0001, 0.00001]

# Loss function
mse_loss = torch.nn.MSELoss(size_average=True)

In [3]:
best_val_r2 = -np.inf
for hidden_size_conv in hidden_size_conv_list:
    for kernel_size, padding in zip(kernel_size_list, padding_list):
        for hidden_size_full in hidden_size_full_list:
            for dropout_full in dropout_full_list:
                for hidden_size_combo in hidden_size_combo_list:
                    for dropout_combo in dropout_combo_list:
                        for lr in lr_list:
                            for weight_decay in weight_decay_list:
                                
                                # instantiate CNN
                                cnn = CNN1(input_size_conv, hidden_size_conv, kernel_size, padding, input_size_full, hidden_size_full, 
                                          dropout_full, hidden_size_combo, dropout_combo)
                                
                                # instantiate optimizer
                                optimizer = torch.optim.Adam(cnn.parameters(), lr=lr, weight_decay=weight_decay)
                                
                                print('Hidden size conv: ' + str(hidden_size_conv))
                                print('Kernel size: ' + str(kernel_size))
                                print('Hidden size full: ' + str(hidden_size_full))
                                print('Dropout full: ' + str(dropout_full))
                                print('Hidden size combo: ' + str(hidden_size_combo))
                                print('Dropout combo: ' + str(dropout_combo))
                                print('Learning rate: ' + str(lr))
                                print('Weight decay: ' + str(weight_decay))

                                train_CNN(train_x_std_stack_nonConst, train_x_std_tuple, train_y_tuple, cnn, optimizer, mse_loss, num_epochs, batch_size)
                                
                                val_r2 = r2(cnn, batch_size, val_x_std_stack_nonConst, val_x_std_tuple, val_y_tuple)
                                print('Validation R^2: ' + str(val_r2))
                                print()
                                print()
                                
                                if val_r2 > best_val_r2:
                                    best_val_r2 = val_r2
                                    best_hidden_size_conv = hidden_size_conv
                                    best_kernel_size = kernel_size
                                    best_hidden_size_full = hidden_size_full
                                    best_dropout_full = dropout_full
                                    best_hidden_size_combo = hidden_size_combo
                                    best_dropout_combo = dropout_combo
                                    best_lr = lr
                                    best_weight_decay = weight_decay
                                    
print('Best validation R^2: ' + str(best_val_r2))
print('Best hidden size conv: ' + str(best_hidden_size_conv))
print('Best kernel size: ' + str(best_kernel_size))
print('Best hidden size full: ' + str(best_hidden_size_full))
print('Best dropout full: ' + str(best_dropout_full))
print('Best hidden size combo: ' + str(best_hidden_size_combo))
print('Best dropout combo: ' + str(best_dropout_combo))
print('Best learning rate: ' + str(best_lr))
print('Best weight decay: ' + str(best_weight_decay))                   

Hidden size conv: 25
Kernel size: 3
Hidden size full: 50
Dropout full: 0.1
Hidden size combo: 50
Dropout combo: 0.1
Learning rate: 0.1
Weight decay: 0.001

 3590
 3591
 3592
  ⋮  
 6054
 6055
 6056
[torch.LongTensor of size 2367]


  366
  369
  372
  ⋮  
 6141
 6144
 6147
[torch.LongTensor of size 1440]


    0
    3
    6
  ⋮  
 6081
 6084
 6087
[torch.LongTensor of size 1247]


    0
    3
    6
    9
   12
   15
   18
   21
   24
   27
   30
   33
   36
   39
   42
   45
   48
   51
   54
   57
   60
   63
   66
   69
   72
   75
   78
   81
   84
   87
   90
   93
   96
  102
  105
  108
  111
  114
  117
  120
  123
  126
  129
  132
  135
  138
  141
  144
  147
  150
  153
  156
  159
  162
  165
  168
  171
  174
  177
  180
  183
  186
  189
  192
  195
  198
  201
  204
  207
  210
  213
  216
  219
  222
  225
  228
  231
  237
  240
  243
  246
  249
  252
  255
  258
  261
  264
  267
  270
  273
  276
  279
  282
  285
  288
  291
  294
  297
  300
  303
  306
  309
  31


 3708
 3709
 3710
  ⋮  
 6054
 6055
 6056
[torch.LongTensor of size 2292]


    0
    6
   12
  ⋮  
 6146
 6147
 6148
[torch.LongTensor of size 1294]


   57
   60
   63
  ⋮  
 6078
 6081
 6084
[torch.LongTensor of size 1915]


    0
    3
    6
  ⋮  
 6111
 6114
 6117
[torch.LongTensor of size 1826]


    0
    3
    6
  ⋮  
 6090
 6102
 6115
[torch.LongTensor of size 1080]


 1008
 1009
 1010
  ⋮  
 6146
 6147
 6148
[torch.LongTensor of size 3701]


 5876
 5877
 5878
 5879
 5880
 5881
 5882
 5883
 5884
 5885
 5886
 5887
 5888
 5889
 5890
 5891
 5892
 5893
 5894
 5895
 5896
 5897
 5898
 5899
 5900
 5901
 5902
 5903
 5904
 5905
 5906
 5907
 5908
 5909
 5910
 5911
 5912
 5913
 5914
 5915
 5916
 5917
 5918
 5919
 5920
 5921
 5922
 5923
 5924
 5925
 5926
 5927
 5928
 5929
 5930
 5931
 5932
 5933
 5962
 5963
 5964
 5965
 5966
 5967
 5968
 5969
 5970
 5971
 5972
 5973
 5974
 5975
 5976
 5977
 5978
 5979
 5980
 5981
 5982
 5983
 5984
 5985
 5986
 5987
 5988
 5989
 5990
 5991
 5992
 5993
 59


    0
    3
    6
  ⋮  
 6111
 6114
 6117
[torch.LongTensor of size 1618]


 2196
 2199
 2202
  ⋮  
 6023
 6024
 6025
[torch.LongTensor of size 2920]


 6047
 6048
 6049
 6050
 6051
 6052
 6053
 6054
 6055
 6056
 6060
 6089
 6090
 6091
 6092
 6093
 6094
 6095
 6096
 6097
 6098
 6099
 6100
 6101
 6102
 6103
 6104
 6105
 6106
 6107
 6108
 6109
 6110
 6111
 6112
 6113
 6114
 6115
 6116
 6117
 6118
 6119
 6120
 6121
 6122
 6123
 6124
 6125
 6126
 6127
 6128
 6129
 6130
 6131
 6132
 6133
 6134
 6135
 6136
 6137
 6138
 6139
 6140
 6141
 6142
 6143
 6144
 6145
 6146
 6147
 6148
[torch.LongTensor of size 71]


 3586
 3587
 3588
  ⋮  
 6115
 6116
 6117
[torch.LongTensor of size 2242]


 2550
 2553
 2556
 2559
 2562
 2565
 2568
 2574
 2577
 2580
 2583
 2586
 2592
 2595
 2598
 2601
 2604
 2607
 2610
 2613
 2616
 2619
 2622
 2625
 2628
 2631
 2634
 2637
 2640
 2643
 2646
 2652
 2658
 2664
 2670
 2679
 2682
 2688
 2694
 2700
 2706
 2712
 2718
 2724
 2730
 2736
 2742
 2748
 2754
 2760
 2766
 2772
 


    0
    3
    6
  ⋮  
 6176
 6177
 6178
[torch.LongTensor of size 2698]


 3356
 3357
 3359
  ⋮  
 6115
 6116
 6117
[torch.LongTensor of size 2395]


 4443
 4444
 4445
  ⋮  
 6085
 6086
 6087
[torch.LongTensor of size 1440]


 4383
 4384
 4385
  ⋮  
 6115
 6116
 6117
[torch.LongTensor of size 1723]


    0
    3
    6
  ⋮  
 6141
 6144
 6147
[torch.LongTensor of size 1999]


    0
    3
    6
  ⋮  
 6120
 6123
 6126
[torch.LongTensor of size 1989]


    0
    1
    2
  ⋮  
 6115
 6116
 6117
[torch.LongTensor of size 3790]


 1467
 1470
 1476
 1494
 1500
 1513
 1515
 1518
 1524
 1533
 1536
 1542
 1548
 1554
 1561
 1566
 1575
 1579
 1584
 1590
 1596
 1602
 1608
 1616
 1621
 1626
 1638
 1644
 1650
 1656
 1662
 1668
 1683
 1686
 1692
 1698
 1704
 1710
 1716
 1722
 1728
 1734
 1740
 1746
 1752
 1758
 1764
 1770
 1776
 1782
 1788
 1794
 1800
 1806
 1812
 1818
 1824
 1830
 1836
 1842
 1848
 1859
 1863
 1866
 1872
 1878
 1884
 1890
 1896
 1902
 1908
 1914
 1920
 1926
 1932
 1938
 1944
 1950


    0
    3
    6
  ⋮  
 6015
 6021
 6024
[torch.LongTensor of size 1782]


 1830
 1836
 1842
 1848
 1854
 1866
 1872
 1878
 1884
 1890
 1896
 1914
 1920
 1926
 1932
 1938
 1944
 1950
 1956
 1962
 1974
 1980
 1986
 1992
 2004
 2010
 2016
 2022
 2028
 2034
 2040
 2046
 2052
 2058
 2064
 2070
 2076
 2083
 2088
 2094
 2100
 2106
 2112
 2118
 2124
 2130
 2136
 2142
 2148
 2154
 2160
 2166
 2172
 2178
 2184
 2190
 2196
 2202
 2208
 2214
 2220
 2226
 2232
 2238
 2244
 2250
 2256
 2262
 2268
 2274
 2280
 2292
 2304
 2310
 2316
 2322
 2328
 2334
 2340
 2346
 2352
 2358
 2370
 2376
 2382
 2388
 2394
 2400
 2406
 2412
 2418
 2424
 2430
 2436
 2442
 2448
 2454
 2460
 2466
 2472
 2478
 2484
 2490
 2496
 2502
 2508
 2520
 2526
 2532
 2538
 2544
 2550
 2556
 2562
 2568
 2574
 2580
 2586
 2592
 2598
 2604
 2610
 2616
 2622
 2628
 2640
 2646
 2649
 2652
 2655
 2658
 2661
 2664
 2667
 2670
 2673
 2676
 2682
 2685
 2688
 2691
 2694
 2697
 2703
 2706
 2712
 2715
 2718
 2721
 2724
 2727
 2730
 2736
 2739



 5310
 5313
 5316
 5319
 5322
 5325
 5328
 5331
 5334
 5337
 5343
 5346
 5349
 5352
 5355
 5358
 5361
 5364
 5367
 5370
 5373
 5376
 5379
 5382
 5385
 5388
 5391
 5394
 5397
 5403
 5406
 5409
 5412
 5415
 5418
 5424
 5427
 5430
 5433
 5436
 5439
 5442
 5445
 5448
 5451
 5454
 5460
 5463
 5466
 5469
 5472
 5475
 5478
 5484
 5487
 5490
 5493
 5496
 5499
 5502
 5505
 5511
 5514
 5517
 5520
 5523
 5526
 5529
 5532
 5535
 5538
 5541
 5544
 5547
 5550
 5553
 5556
 5562
 5565
 5568
 5571
 5574
 5577
 5580
 5583
 5586
 5592
 5595
 5598
 5601
 5604
 5607
 5610
 5613
 5616
 5619
 5622
 5625
 5628
 5631
 5634
 5637
 5640
 5643
 5646
 5649
 5652
 5655
 5658
 5661
 5664
 5667
 5670
 5673
 5676
 5679
 5682
 5685
 5688
 5691
 5694
 5697
 5700
 5703
 5706
 5709
 5712
 5715
 5718
 5721
 5724
 5733
 5736
 5739
 5742
 5745
 5748
 5751
 5754
 5757
 5760
 5763
 5766
 5769
 5772
 5775
 5778
 5781
 5784
 5787
 5790
 5793
 5796
 5799
 5802
 5805
 5808
 5811
 5814
 5817
 5820
 5823
 5826
 5832
 5835
 5838
 5


 1827
 1830
 1833
  ⋮  
 6111
 6114
 6117
[torch.LongTensor of size 1403]


 5997
 6000
 6003
 6006
 6009
 6012
 6015
 6018
 6021
 6024
 6026
 6027
 6028
 6029
 6030
 6031
 6032
 6033
 6034
 6035
 6036
 6037
 6038
 6039
 6040
 6041
 6042
 6043
 6044
 6045
 6046
 6047
 6048
 6049
 6050
 6051
 6052
 6053
 6054
 6055
 6056
 6057
 6058
 6059
 6060
 6061
 6062
 6063
 6064
 6065
 6066
 6067
 6068
 6069
 6070
 6071
 6072
 6073
 6074
 6075
 6076
 6077
 6080
 6081
 6082
 6083
 6084
 6085
 6086
 6087
 6089
 6090
 6091
 6092
 6093
 6094
 6095
 6096
 6097
 6098
 6099
 6100
 6101
 6102
 6107
 6108
 6109
 6110
 6111
 6112
 6113
 6114
 6115
 6116
 6117
[torch.LongTensor of size 95]


   19
   33
   36
  ⋮  
 6115
 6116
 6117
[torch.LongTensor of size 2701]


    0
    3
    6
  ⋮  
 6018
 6021
 6024
[torch.LongTensor of size 1859]


 4730
 4731
 4732
  ⋮  
 6146
 6147
 6148
[torch.LongTensor of size 1379]


    0
    3
    6
  ⋮  
 6115
 6116
 6117
[torch.LongTensor of size 3717]


    0
    3
    7

  if y_ind.numpy() != None:


Validation R^2: 0.400932478949


Hidden size conv: 25
Kernel size: 3
Hidden size full: 50
Dropout full: 0.1
Hidden size combo: 50
Dropout combo: 0.1
Learning rate: 0.1
Weight decay: 0.0001

    0
    6
   12
  ⋮  
 6146
 6147
 6148
[torch.LongTensor of size 1294]


 2862
 2866
 2868
  ⋮  
 6146
 6147
 6148
[torch.LongTensor of size 2744]


    0
    1
    2
  ⋮  
 6115
 6116
 6117
[torch.LongTensor of size 3790]


    0
    3
    6
  ⋮  
 6146
 6147
 6148
[torch.LongTensor of size 2893]


    6
    9
   12
  ⋮  
 6018
 6021
 6024
[torch.LongTensor of size 1842]


 2562
 2568
 2578
  ⋮  
 6115
 6116
 6117
[torch.LongTensor of size 2590]


 3309
 3312
 3318
  ⋮  
 6111
 6114
 6117
[torch.LongTensor of size 1340]


    0
    3
    6
  ⋮  
 6102
 6105
 6108
[torch.LongTensor of size 2562]


    0
    3
    6
  ⋮  
 6015
 6021
 6024
[torch.LongTensor of size 1782]


   12
   13
   14
  ⋮  
 6132
 6138
 6144
[torch.LongTensor of size 1347]


 5310
 5316
 5322
 5328
 5345
 5346
 5352
 5358
 5


 4261
 4262
 4263
  ⋮  
 6085
 6086
 6087
[torch.LongTensor of size 1815]


    0
    3
    6
  ⋮  
 6111
 6114
 6117
[torch.LongTensor of size 1930]


 5752
 5757
 5758
 5759
 5760
 5761
 5762
 5763
 5764
 5765
 5766
 5767
 5768
 5769
 5770
 5771
 5772
 5778
 5779
 5780
 5781
 5782
 5784
 5785
 5786
 5787
 5788
 5789
 5790
 5791
 5792
 5793
 5794
 5795
 5796
 5797
 5798
 5799
 5800
 5801
 5802
 5803
 5804
 5805
 5806
 5807
 5808
 5809
 5810
 5811
 5812
 5813
 5814
 5815
 5816
 5817
 5818
 5819
 5820
 5821
 5822
 5823
 5824
 5825
 5826
 5827
 5828
 5829
 5830
 5831
 5832
 5833
 5834
 5835
 5836
 5837
 5838
 5839
 5840
 5841
 5842
 5843
 5844
 5845
 5846
 5847
 5848
 5849
 5850
 5851
 5852
 5853
 5854
 5855
 5856
 5857
 5858
 5859
 5860
 5861
 5862
 5863
 5864
 5865
 5866
 5867
 5868
 5869
 5870
 5871
 5872
 5873
 5874
 5875
 5876
 5877
 5878
 5879
 5880
 5881
 5882
 5883
 5884
 5885
 5886
 5887
 5888
 5889
 5890
 5891
 5892
 5893
 5894
 5895
 5896
 5897
 5898
 5899
 5900
 5901
 5902
 


    0
    1
    2
  ⋮  
 6143
 6144
 6147
[torch.LongTensor of size 3948]


    0
    4
    7
  ⋮  
 6066
 6069
 6072
[torch.LongTensor of size 1905]


 5479
 5480
 5481
 5482
 5483
 5484
 5485
 5487
 5488
 5489
 5490
 5491
 5492
 5493
 5499
 5500
 5501
 5502
 5503
 5504
 5505
 5506
 5507
 5508
 5509
 5510
 5511
 5512
 5513
 5514
 5515
 5516
 5517
 5518
 5519
 5520
 5521
 5522
 5523
 5524
 5525
 5526
 5527
 5528
 5529
 5530
 5531
 5532
 5533
 5534
 5535
 5536
 5537
 5538
 5539
 5540
 5541
 5542
 5544
 5545
 5546
 5548
 5549
 5550
 5551
 5552
 5553
 5554
 5555
 5556
 5557
 5558
 5559
 5560
 5561
 5562
 5563
 5564
 5565
 5566
 5567
 5568
 5569
 5570
 5571
 5572
 5573
 5574
 5575
 5576
 5577
 5578
 5579
 5580
 5581
 5582
 5583
 5584
 5585
 5586
 5587
 5588
 5589
 5590
 5591
 5592
 5593
 5594
 5595
 5596
 5597
 5598
 5599
 5600
 5601
 5602
 5603
 5604
 5605
 5606
 5607
 5608
 5609
 5610
 5611
 5612
 5613
 5614
 5615
 5616
 5617
 5618
 5619
 5620
 5621
 5622
 5623
 5624
 5625
 5626
 5627
 


    0
    3
    4
  ⋮  
 6081
 6084
 6087
[torch.LongTensor of size 2455]


   12
   13
   14
  ⋮  
 6132
 6138
 6144
[torch.LongTensor of size 1347]


    0
    1
    2
  ⋮  
 6141
 6144
 6147
[torch.LongTensor of size 4933]


 3708
 3709
 3710
  ⋮  
 6054
 6055
 6056
[torch.LongTensor of size 2292]


 4383
 4384
 4385
  ⋮  
 6115
 6116
 6117
[torch.LongTensor of size 1723]


    0
    3
    6
  ⋮  
 6141
 6144
 6147
[torch.LongTensor of size 1054]


 5813
 5814
 5815
 5816
 5817
 5822
 5823
 5824
 5825
 5826
 5827
 5828
 5829
 5830
 5831
 5832
 5834
 5835
 5836
 5837
 5838
 5839
 5840
 5841
 5843
 5844
 5845
 5846
 5847
 5848
 5849
 5850
 5851
 5852
 5853
 5854
 5855
 5856
 5858
 5859
 5860
 5861
 5862
 5863
 5864
 5865
 5866
 5867
 5868
 5869
 5870
 5871
 5872
 5873
 5874
 5875
 5876
 5877
 5878
 5879
 5880
 5881
 5882
 5883
 5884
 5885
 5886
 5887
 5888
 5889
 5890
 5891
 5892
 5893
 5894
 5895
 5896
 5897
 5898
 5899
 5900
 5901
 5902
 5903
 5904
 5905
 5906
 5907
 5908
 5909
 59


 5479
 5480
 5481
 5482
 5483
 5484
 5488
 5489
 5490
 5491
 5500
 5506
 5507
 5508
 5510
 5511
 5512
 5516
 5517
 5518
 5519
 5522
 5524
 5525
 5526
 5527
 5528
 5529
 5530
 5531
 5532
 5533
 5534
 5535
 5536
 5537
 5538
 5539
 5540
 5541
 5542
 5543
 5544
 5545
 5546
 5547
 5548
 5549
 5550
 5551
 5552
 5553
 5554
 5555
 5556
 5557
 5558
 5559
 5560
 5561
 5562
 5563
 5564
 5565
 5566
 5567
 5568
 5569
 5570
 5571
 5572
 5573
 5574
 5576
 5577
 5578
 5579
 5580
 5581
 5582
 5583
 5584
 5585
 5586
 5587
 5588
 5589
 5590
 5591
 5592
 5593
 5594
 5595
 5596
 5597
 5598
 5599
 5600
 5601
 5602
 5603
 5604
 5605
 5606
 5607
 5608
 5609
 5610
 5611
 5612
 5613
 5614
 5615
 5616
 5617
 5618
 5619
 5620
 5621
 5622
 5623
 5624
 5625
 5626
 5627
 5628
 5629
 5630
 5631
 5632
 5633
 5634
 5635
 5636
 5637
 5638
 5639
 5640
 5641
 5642
 5643
 5644
 5645
 5646
 5647
 5648
 5649
 5650
 5651
 5652
 5653
 5654
 5655
 5656
 5657
 5658
 5659
 5660
 5661
 5662
 5663
 5664
 5665
 5666
 5667
 5668
 56


    0
    3
    6
  ⋮  
 6126
 6129
 6132
[torch.LongTensor of size 1961]


    0
    3
    6
  ⋮  
 6141
 6144
 6147
[torch.LongTensor of size 1993]


    0
    4
    7
  ⋮  
 6066
 6069
 6072
[torch.LongTensor of size 1905]


    0
    1
    2
  ⋮  
 6146
 6147
 6148
[torch.LongTensor of size 4283]


    0
    6
    9
  ⋮  
 6123
 6126
 6132
[torch.LongTensor of size 1963]


 5234
 5235
 5236
 5237
 5238
 5239
 5240
 5241
 5242
 5243
 5244
 5245
 5246
 5247
 5248
 5249
 5250
 5251
 5252
 5253
 5254
 5255
 5256
 5257
 5258
 5259
 5260
 5261
 5262
 5263
 5264
 5268
 5269
 5270
 5271
 5272
 5273
 5274
 5275
 5276
 5277
 5278
 5279
 5280
 5281
 5282
 5283
 5284
 5285
 5286
 5287
 5288
 5289
 5290
 5291
 5292
 5293
 5294
 5295
 5296
 5297
 5298
 5299
 5300
 5301
 5302
 5303
 5304
 5305
 5306
 5307
 5308
 5309
 5310
 5311
 5312
 5313
 5314
 5315
 5316
 5317
 5318
 5319
 5320
 5321
 5322
 5323
 5324
 5325
 5326
 5327
 5328
 5329
 5330
 5331
 5332
 5333
 5334
 5335
 5336
 5337
 5338
 5339
 


    0
    3
    6
  ⋮  
 6090
 6102
 6115
[torch.LongTensor of size 1080]


    0
    6
    7
  ⋮  
 6141
 6144
 6147
[torch.LongTensor of size 1401]


 3356
 3357
 3359
  ⋮  
 6115
 6116
 6117
[torch.LongTensor of size 2395]


 1827
 1830
 1833
  ⋮  
 6079
 6080
 6081
[torch.LongTensor of size 2891]


    0
    3
    6
    9
   12
   15
   18
   21
   24
   27
   30
   33
   36
   39
   42
   45
   48
   51
   54
   57
   60
   63
   66
   69
   72
   75
   78
   81
   84
   87
   90
   93
   96
  102
  105
  108
  111
  114
  117
  120
  123
  126
  129
  132
  135
  138
  141
  144
  147
  150
  153
  156
  159
  162
  165
  168
  171
  174
  177
  180
  183
  186
  189
  192
  195
  198
  201
  204
  207
  210
  213
  216
  219
  222
  225
  228
  231
  237
  240
  243
  246
  249
  252
  255
  258
  261
  264
  267
  270
  273
  276
  279
  282
  285
  288
  291
  294
  297
  300
  303
  306
  309
  312
  315
  318
  321
  324
  327
  330
  333
  336
  339
  342
  345
  348
  351

Validation R^2: 0.422411917476


Hidden size conv: 25
Kernel size: 3
Hidden size full: 50
Dropout full: 0.1
Hidden size combo: 50
Dropout combo: 0.1
Learning rate: 0.1
Weight decay: 1e-05

 4839
 4842
 4843
  ⋮  
 6054
 6055
 6056
[torch.LongTensor of size 1207]


    0
    3
    9
  ⋮  
 6111
 6114
 6117
[torch.LongTensor of size 1841]


 4191
 4192
 4193
  ⋮  
 6018
 6021
 6024
[torch.LongTensor of size 1132]


 4653
 4654
 4656
 4659
 4662
 4665
 4668
 4671
 4674
 4677
 4680
 4683
 4686
 4689
 4692
 4695
 4698
 4701
 4704
 4707
 4710
 4713
 4716
 4719
 4722
 4725
 4728
 4731
 4734
 4735
 4740
 4743
 4746
 4749
 4752
 4755
 4758
 4761
 4764
 4767
 4770
 4773
 4776
 4779
 4782
 4785
 4788
 4791
 4794
 4797
 4800
 4803
 4806
 4809
 4812
 4815
 4818
 4821
 4824
 4827
 4830
 4833
 4836
 4839
 4842
 4845
 4848
 4851
 4854
 4857
 4860
 4863
 4866
 4869
 4872
 4875
 4878
 4881
 4884
 4887
 4890
 4893
 4896
 4899
 4902
 4905
 4908
 4911
 4914
 4917
 4920
 4923
 4926
 4929
 4932
 4935
 4938
 


    0
    3
   12
  ⋮  
 6111
 6114
 6117
[torch.LongTensor of size 1780]


    0
    1
    2
  ⋮  
 6115
 6116
 6117
[torch.LongTensor of size 5309]


    0
   15
   18
  ⋮  
 6146
 6147
 6148
[torch.LongTensor of size 3394]


 3245
 3246
 3247
  ⋮  
 6085
 6086
 6087
[torch.LongTensor of size 2773]


 2862
 2866
 2868
  ⋮  
 6146
 6147
 6148
[torch.LongTensor of size 2744]


 1802
 1803
 1804
  ⋮  
 6018
 6021
 6024
[torch.LongTensor of size 2737]


    0
    3
    6
  ⋮  
 6146
 6147
 6148
[torch.LongTensor of size 2893]


 1827
 1830
 1833
  ⋮  
 6079
 6080
 6081
[torch.LongTensor of size 2891]


 2022
 2025
 2028
  ⋮  
 6126
 6129
 6132
[torch.LongTensor of size 1307]


    0
    3
    6
  ⋮  
 6111
 6114
 6117
[torch.LongTensor of size 1789]


 4431
 4434
 4437
 4440
 4443
 4446
 4449
 4452
 4453
 4455
 4458
 4462
 4464
 4467
 4468
 4470
 4473
 4479
 4482
 4485
 4488
 4491
 4494
 4497
 4500
 4503
 4506
 4509
 4512
 4515
 4518
 4521
 4524
 4527
 4530
 4533
 4536
 4539
 4542
 4545


 2811
 2814
 2821
 2823
 2826
 2829
 2832
 2835
 2838
 2841
 2844
 2847
 2850
 2853
 2856
 2862
 2865
 2868
 2871
 2874
 2877
 2880
 2883
 2886
 2889
 2892
 2895
 2898
 2904
 2907
 2908
 2910
 2913
 2916
 2919
 2922
 2925
 2928
 2931
 2934
 2937
 2940
 2943
 2946
 2949
 2952
 2958
 2964
 2970
 2976
 2980
 2982
 2985
 2988
 2991
 2994
 2997
 3000
 3003
 3006
 3009
 3015
 3018
 3021
 3024
 3027
 3030
 3033
 3036
 3039
 3042
 3045
 3048
 3051
 3054
 3057
 3060
 3063
 3066
 3071
 3072
 3075
 3078
 3081
 3084
 3087
 3090
 3093
 3096
 3100
 3102
 3106
 3108
 3111
 3114
 3117
 3120
 3123
 3126
 3129
 3132
 3135
 3138
 3141
 3144
 3147
 3156
 3157
 3160
 3161
 3164
 3165
 3168
 3171
 3174
 3177
 3180
 3183
 3186
 3192
 3195
 3198
 3203
 3204
 3207
 3210
 3213
 3216
 3219
 3222
 3225
 3228
 3232
 3234
 3237
 3240
 3245
 3246
 3252
 3256
 3258
 3262
 3264
 3268
 3270
 3274
 3276
 3282
 3286
 3288
 3293
 3294
 3300
 3304
 3305
 3307
 3309
 3312
 3317
 3318
 3321
 3324
 3327
 3330
 3336
 3342
 33


 5752
 5757
 5758
 5759
 5760
 5761
 5762
 5763
 5764
 5765
 5766
 5767
 5768
 5769
 5770
 5771
 5772
 5778
 5779
 5780
 5781
 5782
 5784
 5785
 5786
 5787
 5788
 5789
 5790
 5791
 5792
 5793
 5794
 5795
 5796
 5797
 5798
 5799
 5800
 5801
 5802
 5803
 5804
 5805
 5806
 5807
 5808
 5809
 5810
 5811
 5812
 5813
 5814
 5815
 5816
 5817
 5818
 5819
 5820
 5821
 5822
 5823
 5824
 5825
 5826
 5827
 5828
 5829
 5830
 5831
 5832
 5833
 5834
 5835
 5836
 5837
 5838
 5839
 5840
 5841
 5842
 5843
 5844
 5845
 5846
 5847
 5848
 5849
 5850
 5851
 5852
 5853
 5854
 5855
 5856
 5857
 5858
 5859
 5860
 5861
 5862
 5863
 5864
 5865
 5866
 5867
 5868
 5869
 5870
 5871
 5872
 5873
 5874
 5875
 5876
 5877
 5878
 5879
 5880
 5881
 5882
 5883
 5884
 5885
 5886
 5887
 5888
 5889
 5890
 5891
 5892
 5893
 5894
 5895
 5896
 5897
 5898
 5899
 5900
 5901
 5902
 5903
 5904
 5905
 5906
 5907
 5908
 5909
 5910
 5911
 5912
 5913
 5914
 5915
 5916
 5917
 5918
 5919
 5920
 5921
 5922
 5923
 5924
 5925
 5926
 5927
 59


   12
   18
   24
  ⋮  
 6115
 6116
 6117
[torch.LongTensor of size 1996]


 1290
 1296
 1302
  ⋮  
 6023
 6024
 6025
[torch.LongTensor of size 1611]


 1830
 1836
 1842
 1848
 1854
 1860
 1866
 1872
 1878
 1884
 1890
 1896
 1902
 1908
 1914
 1920
 1926
 1932
 1938
 1944
 1950
 1956
 1962
 1968
 1974
 1980
 1986
 1992
 1998
 2004
 2010
 2022
 2028
 2034
 2040
 2049
 2052
 2058
 2064
 2070
 2076
 2082
 2088
 2094
 2100
 2106
 2112
 2118
 2124
 2130
 2136
 2142
 2148
 2154
 2160
 2166
 2172
 2178
 2184
 2190
 2196
 2202
 2208
 2214
 2220
 2226
 2232
 2238
 2244
 2250
 2256
 2262
 2268
 2274
 2280
 2286
 2292
 2298
 2304
 2310
 2316
 2322
 2328
 2334
 2340
 2346
 2352
 2358
 2364
 2370
 2376
 2382
 2388
 2394
 2400
 2406
 2412
 2418
 2424
 2430
 2436
 2442
 2448
 2454
 2460
 2466
 2472
 2478
 2484
 2490
 2496
 2508
 2514
 2520
 2526
 2538
 2544
 2550
 2556
 2562
 2568
 2580
 2586
 2592
 2598
 2604
 2610
 2616
 2622
 2628
 2634
 2640
 2658
 2664
 2673
 2677
 2682
 2688
 2694
 2700
 2706
 


    0
    3
    4
  ⋮  
 6170
 6171
 6172
[torch.LongTensor of size 5868]


 1827
 1830
 1833
  ⋮  
 6111
 6114
 6117
[torch.LongTensor of size 1403]


    0
    4
    5
  ⋮  
 6138
 6141
 6147
[torch.LongTensor of size 2166]


 5118
 5124
 5130
 5136
 5142
 5148
 5154
 5160
 5166
 5172
 5178
 5184
 5190
 5196
 5202
 5208
 5214
 5220
 5226
 5232
 5238
 5244
 5250
 5256
 5262
 5268
 5274
 5280
 5286
 5292
 5298
 5304
 5310
 5316
 5322
 5328
 5334
 5340
 5346
 5352
 5358
 5364
 5370
 5376
 5382
 5388
 5394
 5400
 5406
 5412
 5418
 5424
 5430
 5436
 5442
 5448
 5454
 5460
 5466
 5472
 5478
 5490
 5496
 5502
 5508
 5514
 5520
 5526
 5532
 5538
 5544
 5550
 5556
 5562
 5568
 5574
 5580
 5586
 5592
 5598
 5604
 5610
 5616
 5622
 5628
 5634
 5640
 5646
 5652
 5658
 5664
 5670
 5676
 5682
 5688
 5694
 5700
 5706
 5712
 5718
 5724
 5730
 5736
 5742
 5748
 5754
 5760
 5766
 5772
 5778
 5784
 5790
 5796
 5802
 5808
 5814
 5820
 5826
 5832
 5838
 5844
 5850
 5856
 5862
 5868
 5874
 5880
 5886
 58

Validation R^2: 0.405358964479


Hidden size conv: 25
Kernel size: 3
Hidden size full: 50
Dropout full: 0.1
Hidden size combo: 50
Dropout combo: 0.1
Learning rate: 0.01
Weight decay: 0.001

    0
    3
    6
  ⋮  
 6132
 6138
 6144
[torch.LongTensor of size 1800]


    0
    6
    9
  ⋮  
 6123
 6126
 6132
[torch.LongTensor of size 1963]


    0
    3
    6
  ⋮  
 6018
 6021
 6024
[torch.LongTensor of size 1859]


    0
    4
    7
  ⋮  
 6066
 6069
 6072
[torch.LongTensor of size 1905]


 6026
 6027
 6028
 6029
 6030
 6031
 6032
 6033
 6034
 6035
 6036
 6037
 6038
 6039
 6040
 6041
 6042
 6043
 6044
 6045
 6046
 6047
 6048
 6049
 6050
 6051
 6052
 6053
 6054
 6055
 6056
 6057
 6058
 6059
 6060
 6061
 6062
 6063
 6064
 6065
 6066
 6067
 6068
 6069
 6070
 6071
 6072
 6073
 6074
 6075
 6076
 6077
 6078
 6079
 6080
 6081
 6082
 6083
 6084
 6085
 6086
 6087
 6088
 6089
 6090
 6091
 6092
 6093
 6094
 6095
 6096
 6097
 6098
 6099
 6103
 6104
 6105
 6106
 6110
 6111
 6112
 6113
 6114
 6115
 6


 5114
 5115
 5116
  ⋮  
 6146
 6147
 6148
[torch.LongTensor of size 1023]


    0
    3
    6
  ⋮  
 6141
 6144
 6147
[torch.LongTensor of size 2108]


   48
   51
   60
  ⋮  
 6078
 6084
 6087
[torch.LongTensor of size 1830]


 5576
 5577
 5578
 5579
 5580
 5581
 5582
 5583
 5584
 5585
 5586
 5587
 5588
 5589
 5590
 5591
 5592
 5593
 5594
 5595
 5596
 5597
 5598
 5599
 5600
 5601
 5602
 5603
 5604
 5605
 5606
 5607
 5608
 5609
 5610
 5611
 5612
 5613
 5614
 5615
 5616
 5617
 5618
 5619
 5620
 5621
 5622
 5623
 5624
 5625
 5626
 5627
 5628
 5629
 5630
 5631
 5632
 5633
 5634
 5635
 5636
 5637
 5638
 5639
 5640
 5641
 5642
 5643
 5644
 5645
 5646
 5647
 5648
 5649
 5650
 5651
 5652
 5653
 5654
 5655
 5656
 5657
 5658
 5659
 5660
 5661
 5662
 5663
 5664
 5665
 5666
 5667
 5668
 5669
 5670
 5671
 5672
 5673
 5674
 5675
 5676
 5677
 5678
 5679
 5680
 5681
 5682
 5683
 5684
 5685
 5686
 5687
 5688
 5689
 5690
 5691
 5692
 5693
 5694
 5695
 5696
 5697
 5698
 5699
 5700
 5701
 5702
 5703
 57


 5904
 5908
 5909
 5910
 5911
 5912
 5913
 5914
 5915
 5916
 5917
 5918
 5919
 5920
 5921
 5922
 5924
 5925
 5927
 5928
 5929
 5930
 5933
 5934
 5935
 5936
 5937
 5938
 5939
 5940
 5941
 5942
 5943
 5944
 5945
 5946
 5947
 5948
 5949
 5951
 5952
 5953
 5954
 5955
 5956
 5957
 5959
 5960
 5961
 5962
 5963
 5964
 5966
 5967
 5968
 5969
 5970
 5971
 5972
 5973
 5974
 5975
 5976
 5977
 5978
 5980
 5981
 5982
 5983
 5985
 5986
 5987
 5988
 5989
 5990
 5991
 5992
 5993
 5994
 5995
 5996
 5997
 5998
 5999
 6000
 6001
 6002
 6003
 6004
 6005
 6006
 6007
 6008
 6009
 6010
 6011
 6012
 6013
 6014
 6015
 6016
 6017
 6018
 6019
 6020
 6021
 6022
 6023
 6024
 6025
[torch.LongTensor of size 110]


    0
    1
    2
  ⋮  
 6111
 6114
 6117
[torch.LongTensor of size 3172]


    0
    3
    6
  ⋮  
 6138
 6141
 6144
[torch.LongTensor of size 1151]


 1869
 1872
 1875
  ⋮  
 6111
 6114
 6117
[torch.LongTensor of size 1220]


    0
    1
    2
  ⋮  
 6115
 6116
 6117
[torch.LongTensor of size 5911]


  


 5630
 5631
 5632
 5633
 5634
 5635
 5636
 5637
 5638
 5639
 5640
 5641
 5642
 5643
 5644
 5645
 5646
 5647
 5648
 5649
 5650
 5651
 5652
 5653
 5654
 5655
 5656
 5657
 5658
 5659
 5660
 5661
 5662
 5663
 5664
 5665
 5666
 5667
 5668
 5669
 5670
 5671
 5672
 5673
 5674
 5675
 5676
 5677
 5678
 5679
 5680
 5681
 5682
 5683
 5684
 5685
 5686
 5687
 5688
 5689
 5690
 5691
 5692
 5693
 5694
 5695
 5696
 5697
 5698
 5699
 5700
 5701
 5702
 5703
 5704
 5705
 5706
 5707
 5708
 5709
 5710
 5711
 5712
 5713
 5714
 5715
 5716
 5717
 5718
 5719
 5720
 5721
 5722
 5723
 5724
 5725
 5726
 5727
 5728
 5729
 5730
 5731
 5732
 5733
 5734
 5735
 5736
 5737
 5738
 5739
 5740
 5741
 5742
 5743
 5744
 5745
 5746
 5747
 5748
 5749
 5750
 5751
 5752
 5753
 5754
 5755
 5756
 5757
 5758
 5759
 5760
 5761
 5762
 5763
 5764
 5765
 5766
 5767
 5768
 5769
 5770
 5771
 5772
 5773
 5774
 5775
 5776
 5777
 5778
 5779
 5780
 5781
 5782
 5783
 5784
 5785
 5786
 5787
 5788
 5789
 5790
 5791
 5792
 5793
 5794
 5795
 57


    0
    3
    6
  ⋮  
 6146
 6147
 6148
[torch.LongTensor of size 3421]


 5752
 5757
 5758
 5759
 5760
 5761
 5762
 5763
 5764
 5765
 5766
 5767
 5768
 5769
 5770
 5771
 5772
 5778
 5779
 5780
 5781
 5782
 5784
 5785
 5786
 5787
 5788
 5789
 5790
 5791
 5792
 5793
 5794
 5795
 5796
 5797
 5798
 5799
 5800
 5801
 5802
 5803
 5804
 5805
 5806
 5807
 5808
 5809
 5810
 5811
 5812
 5813
 5814
 5815
 5816
 5817
 5818
 5819
 5820
 5821
 5822
 5823
 5824
 5825
 5826
 5827
 5828
 5829
 5830
 5831
 5832
 5833
 5834
 5835
 5836
 5837
 5838
 5839
 5840
 5841
 5842
 5843
 5844
 5845
 5846
 5847
 5848
 5849
 5850
 5851
 5852
 5853
 5854
 5855
 5856
 5857
 5858
 5859
 5860
 5861
 5862
 5863
 5864
 5865
 5866
 5867
 5868
 5869
 5870
 5871
 5872
 5873
 5874
 5875
 5876
 5877
 5878
 5879
 5880
 5881
 5882
 5883
 5884
 5885
 5886
 5887
 5888
 5889
 5890
 5891
 5892
 5893
 5894
 5895
 5896
 5897
 5898
 5899
 5900
 5901
 5902
 5903
 5904
 5905
 5906
 5907
 5908
 5909
 5910
 5911
 5912
 5913
 5914
 5915


 5869
 5870
 5871
 5872
 5873
 5874
 5875
 5876
 5877
 5878
 5879
 5880
 5881
 5882
 5883
 5884
 5885
 5886
 5887
 5888
 5889
 5890
 5891
 5892
 5893
 5894
 5895
 5896
 5897
 5898
 5899
 5900
 5901
 5902
 5903
 5904
 5905
 5906
 5907
 5909
 5910
 5911
 5912
 5913
 5914
 5915
 5916
 5917
 5918
 5919
 5920
 5921
 5922
 5923
 5924
 5925
 5926
 5927
 5928
 5929
 5930
 5931
 5932
 5933
 5934
 5935
 5936
 5937
 5938
 5939
 5940
 5941
 5942
 5948
 5949
 5950
 5951
 5952
 5955
 5956
 5957
 5958
 5959
 5960
 5961
 5962
 5963
 5964
 5965
 5966
 5967
 5968
 5969
 5970
 5973
 5974
 5975
 5976
 5977
 5978
 5979
 5980
 5981
 5984
 5985
 5986
 5987
 5988
 5989
 5990
 5991
 5992
 5993
 5994
 5995
 5996
 5997
 5998
 5999
 6000
 6001
 6002
 6003
 6004
 6005
 6006
 6007
 6008
 6009
 6010
 6011
 6012
 6013
 6014
 6015
 6016
 6017
 6018
 6019
 6020
 6021
 6022
 6023
 6024
 6025
 6026
 6027
 6028
 6029
 6030
 6031
 6032
 6033
 6034
 6035
 6036
 6037
 6038
 6039
 6040
 6041
 6042
 6043
 6044
 6045
 6046
 60

Validation R^2: -1.32311040387


Hidden size conv: 25
Kernel size: 3
Hidden size full: 50
Dropout full: 0.1
Hidden size combo: 50
Dropout combo: 0.1
Learning rate: 0.01
Weight decay: 0.0001

    0
    3
    6
  ⋮  
 6105
 6108
 6117
[torch.LongTensor of size 1842]


  585
  588
  591
  ⋮  
 6111
 6114
 6117
[torch.LongTensor of size 2212]


    0
    6
   12
  ⋮  
 6023
 6024
 6025
[torch.LongTensor of size 1547]


 5904
 5908
 5909
 5910
 5911
 5912
 5913
 5914
 5915
 5916
 5917
 5918
 5919
 5920
 5921
 5922
 5924
 5925
 5927
 5928
 5929
 5930
 5933
 5934
 5935
 5936
 5937
 5938
 5939
 5940
 5941
 5942
 5943
 5944
 5945
 5946
 5947
 5948
 5949
 5951
 5952
 5953
 5954
 5955
 5956
 5957
 5959
 5960
 5961
 5962
 5963
 5964
 5966
 5967
 5968
 5969
 5970
 5971
 5972
 5973
 5974
 5975
 5976
 5977
 5978
 5980
 5981
 5982
 5983
 5985
 5986
 5987
 5988
 5989
 5990
 5991
 5992
 5993
 5994
 5995
 5996
 5997
 5998
 5999
 6000
 6001
 6002
 6003
 6004
 6005
 6006
 6007
 6008
 6009
 6010
 6011
 6012


    0
    3
    6
  ⋮  
 6141
 6144
 6147
[torch.LongTensor of size 1177]


    0
    3
    6
  ⋮  
 6146
 6147
 6148
[torch.LongTensor of size 2069]


 4389
 4390
 4391
  ⋮  
 6145
 6146
 6148
[torch.LongTensor of size 1661]


    0
    3
    6
  ⋮  
 6141
 6144
 6147
[torch.LongTensor of size 1999]


 1917
 1920
 1923
 1926
 1929
 1932
 1935
 1938
 1941
 1944
 1947
 1950
 1953
 1956
 1959
 1962
 1965
 1968
 1971
 1974
 1977
 1980
 1983
 1989
 1992
 1995
 1998
 2001
 2004
 2007
 2010
 2013
 2016
 2022
 2025
 2028
 2031
 2034
 2037
 2040
 2043
 2046
 2049
 2052
 2055
 2058
 2061
 2064
 2067
 2070
 2073
 2076
 2079
 2082
 2085
 2088
 2091
 2094
 2097
 2100
 2103
 2106
 2109
 2112
 2115
 2118
 2121
 2124
 2127
 2130
 2133
 2136
 2139
 2142
 2145
 2148
 2151
 2154
 2157
 2160
 2163
 2166
 2169
 2172
 2175
 2178
 2181
 2184
 2187
 2190
 2193
 2196
 2199
 2202
 2205
 2208
 2211
 2214
 2217
 2220
 2226
 2229
 2232
 2235
 2238
 2241
 2244
 2247
 2250
 2253
 2256
 2259
 2262
 2265
 2268
 2271


    0
    3
    6
  ⋮  
 6081
 6084
 6087
[torch.LongTensor of size 1490]


 2103
 2106
 2109
  ⋮  
 6018
 6021
 6024
[torch.LongTensor of size 1228]


    0
    3
    6
  ⋮  
 6111
 6114
 6117
[torch.LongTensor of size 1930]


 1802
 1803
 1804
  ⋮  
 6018
 6021
 6024
[torch.LongTensor of size 2737]


    0
    3
    6
  ⋮  
 6176
 6177
 6178
[torch.LongTensor of size 2935]


    0
    3
    6
  ⋮  
 6115
 6116
 6117
[torch.LongTensor of size 3717]


    0
    1
    2
  ⋮  
 6023
 6024
 6025
[torch.LongTensor of size 5064]


 2550
 2553
 2556
 2559
 2562
 2565
 2568
 2574
 2577
 2580
 2583
 2586
 2592
 2595
 2598
 2601
 2604
 2607
 2610
 2613
 2616
 2619
 2622
 2625
 2628
 2631
 2634
 2637
 2640
 2643
 2646
 2652
 2658
 2664
 2670
 2679
 2682
 2688
 2694
 2700
 2706
 2712
 2718
 2724
 2730
 2736
 2742
 2748
 2754
 2760
 2766
 2772
 2778
 2784
 2790
 2796
 2808
 2820
 2826
 2832
 2835
 2838
 2844
 2847
 2850
 2853
 2856
 2859
 2868
 2871
 2874
 2877
 2880
 2883
 2886
 2889
 2892
 2895


    0
    3
    6
  ⋮  
 6081
 6084
 6087
[torch.LongTensor of size 1247]


 5869
 5870
 5871
 5872
 5873
 5874
 5875
 5876
 5877
 5878
 5879
 5880
 5881
 5882
 5883
 5884
 5885
 5886
 5887
 5888
 5889
 5890
 5891
 5892
 5893
 5894
 5895
 5896
 5897
 5898
 5899
 5900
 5901
 5902
 5903
 5904
 5905
 5906
 5907
 5909
 5910
 5911
 5912
 5913
 5914
 5915
 5916
 5917
 5918
 5919
 5920
 5921
 5922
 5923
 5924
 5925
 5926
 5927
 5928
 5929
 5930
 5931
 5932
 5933
 5934
 5935
 5936
 5937
 5938
 5939
 5940
 5941
 5942
 5948
 5949
 5950
 5951
 5952
 5955
 5956
 5957
 5958
 5959
 5960
 5961
 5962
 5963
 5964
 5965
 5966
 5967
 5968
 5969
 5970
 5973
 5974
 5975
 5976
 5977
 5978
 5979
 5980
 5981
 5984
 5985
 5986
 5987
 5988
 5989
 5990
 5991
 5992
 5993
 5994
 5995
 5996
 5997
 5998
 5999
 6000
 6001
 6002
 6003
 6004
 6005
 6006
 6007
 6008
 6009
 6010
 6011
 6012
 6013
 6014
 6015
 6016
 6017
 6018
 6019
 6020
 6021
 6022
 6023
 6024
 6025
 6026
 6027
 6028
 6029
 6030
 6031
 6032
 6033
 6034


 4394
 4395
 4396
 4397
 4398
 4399
 4401
 4404
 4407
 4410
 4413
 4416
 4419
 4422
 4425
 4428
 4431
 4434
 4437
 4440
 4443
 4446
 4449
 4452
 4455
 4458
 4461
 4464
 4467
 4470
 4473
 4476
 4479
 4485
 4488
 4491
 4494
 4497
 4500
 4503
 4506
 4509
 4512
 4515
 4518
 4521
 4530
 4533
 4536
 4539
 4542
 4545
 4548
 4551
 4554
 4557
 4560
 4563
 4566
 4569
 4572
 4575
 4578
 4581
 4584
 4587
 4590
 4593
 4602
 4605
 4608
 4611
 4617
 4620
 4623
 4626
 4629
 4632
 4635
 4638
 4641
 4644
 4647
 4650
 4653
 4656
 4659
 4662
 4665
 4668
 4671
 4674
 4677
 4680
 4683
 4686
 4689
 4692
 4695
 4698
 4701
 4704
 4707
 4710
 4713
 4716
 4719
 4722
 4725
 4728
 4731
 4734
 4737
 4740
 4743
 4746
 4749
 4752
 4755
 4758
 4761
 4764
 4767
 4770
 4773
 4776
 4779
 4782
 4785
 4788
 4791
 4794
 4797
 4800
 4803
 4806
 4809
 4812
 4815
 4818
 4821
 4824
 4827
 4830
 4833
 4836
 4839
 4842
 4845
 4848
 4851
 4854
 4857
 4860
 4863
 4866
 4869
 4872
 4875
 4878
 4881
 4884
 4887
 4890
 4893
 4896
 48


    0
    3
    6
  ⋮  
 6141
 6144
 6147
[torch.LongTensor of size 2108]


    0
    3
    6
  ⋮  
 6141
 6144
 6147
[torch.LongTensor of size 1879]


    0
    3
    6
  ⋮  
 6111
 6114
 6117
[torch.LongTensor of size 2325]


    0
    3
    6
  ⋮  
 6085
 6086
 6087
[torch.LongTensor of size 1887]


    0
    3
    6
  ⋮  
 6115
 6116
 6117
[torch.LongTensor of size 2094]


    0
    3
    6
  ⋮  
 6141
 6144
 6147
[torch.LongTensor of size 1532]


 1869
 1872
 1875
  ⋮  
 6111
 6114
 6117
[torch.LongTensor of size 1220]


    0
    3
    6
  ⋮  
 6146
 6147
 6148
[torch.LongTensor of size 2069]


 5752
 5757
 5758
 5759
 5760
 5761
 5762
 5763
 5764
 5765
 5766
 5767
 5768
 5769
 5770
 5771
 5772
 5778
 5779
 5780
 5781
 5782
 5784
 5785
 5786
 5787
 5788
 5789
 5790
 5791
 5792
 5793
 5794
 5795
 5796
 5797
 5798
 5799
 5800
 5801
 5802
 5803
 5804
 5805
 5806
 5807
 5808
 5809
 5810
 5811
 5812
 5813
 5814
 5815
 5816
 5817
 5818
 5819
 5820
 5821
 5822
 5823
 5824
 5825
 5826
 

Validation R^2: -1.45264034111


Hidden size conv: 25
Kernel size: 3
Hidden size full: 50
Dropout full: 0.1
Hidden size combo: 50
Dropout combo: 0.1
Learning rate: 0.01
Weight decay: 1e-05

 4443
 4444
 4445
  ⋮  
 6085
 6086
 6087
[torch.LongTensor of size 1440]


    0
    1
    2
  ⋮  
 6141
 6144
 6147
[torch.LongTensor of size 4933]


 3708
 3709
 3710
  ⋮  
 6054
 6055
 6056
[torch.LongTensor of size 2292]


    0
    1
    2
  ⋮  
 6111
 6114
 6117
[torch.LongTensor of size 3172]


    0
    3
   12
  ⋮  
 6115
 6116
 6117
[torch.LongTensor of size 2301]


    0
    6
   12
  ⋮  
 6111
 6114
 6117
[torch.LongTensor of size 1334]


    0
    3
    6
  ⋮  
 6015
 6018
 6021
[torch.LongTensor of size 1791]


    0
    3
    6
  ⋮  
 6115
 6116
 6117
[torch.LongTensor of size 2274]


    0
    3
    9
  ⋮  
 6111
 6114
 6117
[torch.LongTensor of size 1841]


    6
    9
   12
  ⋮  
 6141
 6144
 6147
[torch.LongTensor of size 2259]


    0
    3
    6
  ⋮  
 6081
 6084
 6087
[torch.L


    0
    3
    6
  ⋮  
 6111
 6114
 6117
[torch.LongTensor of size 3119]


 4850
 4980
 4981
 4982
 4983
 4984
 4985
 4988
 4989
 4993
 4995
 5008
 5009
 5010
 5011
 5137
 5138
 5139
 5140
 5141
 5142
 5143
 5144
 5145
 5146
 5147
 5148
 5149
 5150
 5151
 5152
 5153
 5154
 5155
 5156
 5157
 5158
 5159
 5160
 5161
 5162
 5163
 5164
 5165
 5166
 5167
 5168
 5169
 5170
 5171
 5172
 5173
 5174
 5175
 5176
 5177
 5178
 5179
 5180
 5181
 5182
 5183
 5184
 5185
 5186
 5187
 5188
 5189
 5190
 5191
 5192
 5193
 5194
 5195
 5196
 5197
 5198
 5199
 5200
 5201
 5202
 5203
 5204
 5205
 5206
 5207
 5208
 5209
 5210
 5211
 5212
 5213
 5214
 5215
 5216
 5217
 5218
 5219
 5220
 5221
 5222
 5223
 5224
 5225
 5226
 5227
 5228
 5229
 5230
 5231
 5232
 5233
 5234
 5235
 5236
 5237
 5238
 5239
 5240
 5241
 5242
 5243
 5244
 5245
 5246
 5247
 5248
 5249
 5250
 5251
 5252
 5253
 5254
 5255
 5256
 5257
 5258
 5259
 5260
 5261
 5262
 5263
 5264
 5265
 5266
 5267
 5268
 5269
 5270
 5271
 5272
 5273
 5274
 5275


    0
    3
    6
  ⋮  
 6176
 6177
 6178
[torch.LongTensor of size 2698]


    0
    3
    6
  ⋮  
 6146
 6147
 6148
[torch.LongTensor of size 3830]


 2196
 2199
 2202
  ⋮  
 6023
 6024
 6025
[torch.LongTensor of size 2920]


 5630
 5631
 5632
 5633
 5634
 5635
 5636
 5637
 5638
 5639
 5640
 5641
 5642
 5643
 5644
 5645
 5646
 5647
 5648
 5649
 5650
 5651
 5652
 5653
 5654
 5655
 5656
 5657
 5658
 5659
 5660
 5661
 5662
 5663
 5664
 5665
 5666
 5667
 5668
 5669
 5670
 5671
 5672
 5673
 5674
 5675
 5676
 5677
 5678
 5679
 5680
 5681
 5682
 5683
 5684
 5685
 5686
 5687
 5688
 5689
 5690
 5691
 5692
 5693
 5694
 5695
 5696
 5697
 5698
 5699
 5700
 5701
 5702
 5703
 5704
 5705
 5706
 5707
 5708
 5709
 5710
 5711
 5712
 5713
 5714
 5715
 5716
 5717
 5718
 5719
 5720
 5721
 5722
 5723
 5724
 5725
 5726
 5727
 5728
 5729
 5730
 5731
 5732
 5733
 5734
 5735
 5736
 5737
 5738
 5739
 5740
 5741
 5742
 5743
 5744
 5745
 5746
 5747
 5748
 5749
 5750
 5751
 5752
 5753
 5754
 5755
 5756
 5757
 57


    0
    3
    6
  ⋮  
 6115
 6116
 6117
[torch.LongTensor of size 2274]


    0
    5
    6
  ⋮  
 6111
 6114
 6117
[torch.LongTensor of size 1850]


 1827
 1830
 1833
  ⋮  
 6079
 6080
 6081
[torch.LongTensor of size 2891]


    0
    3
    6
  ⋮  
 6111
 6114
 6117
[torch.LongTensor of size 1618]


    0
    6
   12
   18
   24
   36
   42
   48
   54
   60
   66
   72
   78
   84
   90
   96
  102
  108
  114
  120
  126
  138
  144
  150
  180
  186
  192
  198
  204
  210
  216
  222
  228
  234
  240
  246
  252
  258
  264
  270
  276
  282
  288
  294
  300
  306
  312
  318
  325
  330
  336
  342
  348
  354
  360
  366
  372
  378
  384
  390
  396
  402
  408
  414
  420
  426
  432
  438
  444
  450
  456
  462
  468
  474
  480
  486
  492
  498
  504
  510
  516
  522
  528
  534
  540
  546
  552
  558
  564
  570
  576
  582
  588
  594
  600
  606
  612
  618
  624
  626
  630
  636
  642
  648
  654
  660
  666
  672
  678
  684
  689
  696
  702
  708
  714
  726

KeyboardInterrupt: 

In [19]:
len(torch.zeros(1).numpy())

1

In [4]:
num_epochs = 21
batch_size = 64
input_size_conv = train_x_std_nonConst.shape[1]
input_size_full = train_x_std_all.shape[1]

# CNN and optimizer hyper-parameters to test
hidden_size_conv_list = [25, 50, 75]
kernel_size_list = [3, 5]
padding_list = [1, 2]
hidden_size_full_list = [50, 100, 150]
dropout_full_list = [0.1, 0.3, 0.5]
hidden_size2_full_list = [50, 100, 150]
dropout2_full_list = [0.1, 0.3, 0.5]
lr_list = [0.1, 0.01]
weight_decay_list = [0.001, 0.0001, 0.00001]

# Loss function
mse_loss = torch.nn.MSELoss(size_average=True)

In [None]:
best_val_r2 = -np.inf
for hidden_size_conv in hidden_size_conv_list:
    for kernel_size, padding in zip(kernel_size_list, padding_list):
        for hidden_size_full in hidden_size_full_list:
            for dropout_full in dropout_full_list:
                for hidden_size2_full in hidden_size2_full_list:
                    for dropout2_full in dropout2_full_list:
                        for lr in lr_list:
                            for weight_decay in weight_decay_list:
                                
                                # instantiate CNN
                                cnn = CNN2(input_size_conv, hidden_size_conv, kernel_size, padding, input_size_full, hidden_size_full, 
                                          dropout_full, hidden_size2_full, dropout2_full)
                                
                                # instantiate optimizer
                                optimizer = torch.optim.Adam(cnn.parameters(), lr=lr, weight_decay=weight_decay)
                                
                                print('Hidden size conv: ' + str(hidden_size_conv))
                                print('Kernel size: ' + str(kernel_size))
                                print('Hidden size full: ' + str(hidden_size_full))
                                print('Dropout full: ' + str(dropout_full))
                                print('Hidden size 2 full: ' + str(hidden_size2_full))
                                print('Dropout 2 full: ' + str(dropout2_full))
                                print('Learning rate: ' + str(lr))
                                print('Weight decay: ' + str(weight_decay))

                                train_CNN(train_x_std_stack_nonConst, train_x_std_tuple, train_y_tuple, cnn, optimizer, mse_loss, num_epochs, batch_size)
                                
                                val_r2 = r2(cnn, batch_size, val_x_std_stack_nonConst, val_x_std_tuple, val_y_tuple)
                                print('Validation R^2: ' + str(val_r2))
                                print()
                                print()
                                
                                if val_r2 > best_val_r2:
                                    best_val_r2 = val_r2
                                    best_hidden_size_conv = hidden_size_conv
                                    best_kernel_size = kernel_size
                                    best_hidden_size_full = hidden_size_full
                                    best_dropout_full = dropout_full
                                    best_hidden_size2_full = hidden_size2_full
                                    best_dropout2_full = dropout2_full
                                    best_lr = lr
                                    best_weight_decay = weight_decay
                                    
print('Best validation R^2: ' + str(best_val_r2))
print('Best hidden size conv: ' + str(best_hidden_size_conv))
print('Best kernel size: ' + str(best_kernel_size))
print('Best hidden size full: ' + str(best_hidden_size_full))
print('Best dropout full: ' + str(best_dropout_full))
print('Best hidden size 2 full: ' + str(best_hidden_size2_full))
print('Best dropout 2 full: ' + str(best_dropout2_full))
print('Best learning rate: ' + str(best_lr))
print('Best weight decay: ' + str(best_weight_decay))                    

Hidden size conv: 25
Kernel size: 3
Hidden size full: 50
Dropout full: 0.1
Hidden size 2 full: 50
Dropout 2 full: 0.1
Learning rate: 0.1
Weight decay: 0.001
Validation R^2: 0.723241878917


Hidden size conv: 25
Kernel size: 3
Hidden size full: 50
Dropout full: 0.1
Hidden size 2 full: 50
Dropout 2 full: 0.1
Learning rate: 0.1
Weight decay: 0.0001
Validation R^2: 0.738982157628


Hidden size conv: 25
Kernel size: 3
Hidden size full: 50
Dropout full: 0.1
Hidden size 2 full: 50
Dropout 2 full: 0.1
Learning rate: 0.1
Weight decay: 1e-05
Validation R^2: 0.735909094481


Hidden size conv: 25
Kernel size: 3
Hidden size full: 50
Dropout full: 0.1
Hidden size 2 full: 50
Dropout 2 full: 0.1
Learning rate: 0.01
Weight decay: 0.001
Validation R^2: 0.715693529723


Hidden size conv: 25
Kernel size: 3
Hidden size full: 50
Dropout full: 0.1
Hidden size 2 full: 50
Dropout 2 full: 0.1
Learning rate: 0.01
Weight decay: 0.0001
Validation R^2: 0.713598209591


Hidden size conv: 25
Kernel size: 3
Hidden siz