In [1]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


class TabularDataset(Dataset):
    def __init__(self, data, cat_cols=None, output_col=None):
 

        self.n = data.shape[0]

        if output_col:
            self.y = data[output_col].astype(np.float32).values.reshape(-1, 1)
        else:
            self.y =  np.zeros((self.n, 1))

        self.cat_cols = cat_cols if cat_cols else []
        self.cont_cols = [col for col in data.columns
                          if col not in self.cat_cols + [output_col]]

        if self.cont_cols:
            self.cont_X = data[self.cont_cols].astype(np.float32).values
        else:
            self.cont_X = np.zeros((self.n, 1))

        if self.cat_cols:
            self.cat_X = data[cat_cols].astype(np.int64).values
        else:
            self.cat_X =  np.zeros((self.n, 1))

    def __len__(self):
    
        return self.n

    def __getitem__(self, idx):
   
        return [self.y[idx], self.cont_X[idx], self.cat_X[idx]]


class FeedForwardNN(nn.Module):

    def __init__(self, emb_dims, no_of_cont, lin_layer_sizes,
               output_size, emb_dropout, lin_layer_dropouts):

    
        super().__init__()

        # Embedding layers
        self.emb_layers = nn.ModuleList([nn.Embedding(x, y)
                                         for x, y in emb_dims])

        no_of_embs = sum([y for x, y in emb_dims])
        self.no_of_embs = no_of_embs
        self.no_of_cont = no_of_cont

        # Linear Layers
        first_lin_layer = nn.Linear(self.no_of_embs + self.no_of_cont,
                                    lin_layer_sizes[0])

        self.lin_layers =\
         nn.ModuleList([first_lin_layer] +\
              [nn.Linear(lin_layer_sizes[i], lin_layer_sizes[i + 1])
               for i in range(len(lin_layer_sizes) - 1)])

        for lin_layer in self.lin_layers:
            nn.init.kaiming_normal_(lin_layer.weight.data)

        # Output Layer
        self.output_layer = nn.Linear(lin_layer_sizes[-1],
                                      output_size)
        nn.init.kaiming_normal_(self.output_layer.weight.data)

        # Batch Norm Layers
        self.first_bn_layer = nn.BatchNorm1d(self.no_of_cont)
        self.bn_layers = nn.ModuleList([nn.BatchNorm1d(size)
                                        for size in lin_layer_sizes])

        # Dropout Layers
        self.emb_dropout_layer = nn.Dropout(emb_dropout)
        self.droput_layers = nn.ModuleList([nn.Dropout(size)
                                      for size in lin_layer_dropouts])

    def forward(self, cont_data, cat_data):

        if self.no_of_embs != 0:
            x = [emb_layer(cat_data[:, i])
               for i,emb_layer in enumerate(self.emb_layers)]
            x = torch.cat(x, 1)
            x = self.emb_dropout_layer(x)

        if self.no_of_cont != 0:
            normalized_cont_data = self.first_bn_layer(cont_data)

            if self.no_of_embs != 0:
                x = torch.cat([x, normalized_cont_data], 1) 
            else:
                x = normalized_cont_data

        for lin_layer, dropout_layer, bn_layer in\
            zip(self.lin_layers, self.droput_layers, self.bn_layers):

            x = F.relu(lin_layer(x))
            x = bn_layer(x)
            x = dropout_layer(x)

        x = self.output_layer(x)

        return x

In [2]:
data = pd.read_csv('data.csv')

In [3]:
data['date_in'] = pd.to_datetime(data['date_in'])
data['year'] = data['date_in'].map(lambda x: x.strftime('%Y'))
data['month'] = data['date_in'].map(lambda x: x.strftime('%m'))
data['day'] = data['date_in'].map(lambda x: x.strftime('%d'))

In [4]:
dr = ['date_in','house_pk']

data = data.drop(dr,axis = 1)
data.head().T

Unnamed: 0,0,1,2,3,4
agency_id,90.0,90.0,90.0,90.0,90.0
price,532.0,588.0,588.0,588.0,588.0
dis_water_real,0.261,0.261,0.261,0.261,0.261
dis_shopping,3.0,3.0,3.0,3.0,3.0
no_bedrooms,3.0,3.0,3.0,3.0,3.0
max_persons,4.0,4.0,4.0,4.0,4.0
house_size,140.0,140.0,140.0,140.0,140.0
land_size,726.0,726.0,726.0,726.0,726.0
build_year,1953.0,1953.0,1953.0,1953.0,1953.0
renovation_year,2014.0,2014.0,2014.0,2014.0,2014.0


In [5]:
categorical_features = ['agency_id', 'apartment', 'indoor_pool', 'spa', 'internet', 'pets_allowed', 'water_view', 'fire_stove', 'year', 'month', 'day', 'build_year', 'renovation_year']
output_columns = 'price'

In [6]:
data.shape

(85195, 21)

In [7]:
label_encoders = {}
for cat_col in categorical_features:
        label_encoders[cat_col] = LabelEncoder()
        data[cat_col] = label_encoders[cat_col].fit_transform(data[cat_col])

In [8]:
print(data.head(), type(data))
dataset = TabularDataset(data=data, cat_cols=categorical_features, output_col=output_columns)
print('len of dataset: {}'.format(len(dataset)))

   agency_id  price  dis_water_real  dis_shopping  no_bedrooms  max_persons  \
0          0    532           0.261           3.0            3            4   
1          0    588           0.261           3.0            3            4   
2          0    588           0.261           3.0            3            4   
3          0    588           0.261           3.0            3            4   
4          0    588           0.261           3.0            3            4   

   house_size  land_size  build_year  renovation_year ...   indoor_pool  spa  \
0         140        726           4               21 ...             0    1   
1         140        726           4               21 ...             0    1   
2         140        726           4               21 ...             0    1   
3         140        726           4               21 ...             0    1   
4         140        726           4               21 ...             0    1   

   internet  pets_allowed  water_view  fire_

In [9]:
batchsize = 64
number_for_testing = int(len(dataset) * 0.05)
number_for_training = len(dataset) - number_for_testing
train, test = torch.utils.data.random_split(dataset,
    [number_for_training, number_for_testing])
trainloader = DataLoader(train, batchsize, shuffle=True)
testloader = DataLoader(test, batchsize, shuffle=True)

In [10]:
cat_dims = [int(data[col].nunique()) for col in categorical_features]
cat_dims

[4, 2, 2, 2, 2, 2, 2, 2, 4, 12, 31, 39, 26]

In [11]:
emb_dims = [(x, min(50, (x + 1) // 2)) for x in cat_dims]
emb_dims

[(4, 2),
 (2, 1),
 (2, 1),
 (2, 1),
 (2, 1),
 (2, 1),
 (2, 1),
 (2, 1),
 (4, 2),
 (12, 6),
 (31, 16),
 (39, 20),
 (26, 13)]

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [13]:
model = FeedForwardNN(emb_dims, no_of_cont=7, lin_layer_sizes=[50, 100],
                          output_size=1, emb_dropout=0.04,
                          lin_layer_dropouts=[0.001,0.01]).to(device)

In [14]:
print(model)

FeedForwardNN(
  (emb_layers): ModuleList(
    (0): Embedding(4, 2)
    (1): Embedding(2, 1)
    (2): Embedding(2, 1)
    (3): Embedding(2, 1)
    (4): Embedding(2, 1)
    (5): Embedding(2, 1)
    (6): Embedding(2, 1)
    (7): Embedding(2, 1)
    (8): Embedding(4, 2)
    (9): Embedding(12, 6)
    (10): Embedding(31, 16)
    (11): Embedding(39, 20)
    (12): Embedding(26, 13)
  )
  (lin_layers): ModuleList(
    (0): Linear(in_features=73, out_features=50, bias=True)
    (1): Linear(in_features=50, out_features=100, bias=True)
  )
  (output_layer): Linear(in_features=100, out_features=1, bias=True)
  (first_bn_layer): BatchNorm1d(7, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn_layers): ModuleList(
    (0): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (emb_dropout_layer): Dropout(p=0.04)
  (droput_layers): ModuleList(
    (0): Dropout(

In [15]:
epochs = 5

criterion = nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)

for epoch in range(epochs):
    
    for y, cont_x, cat_x in trainloader:
        
        optimizer.zero_grad()
        
        cont_x = cont_x.to(device)
        cat_x = cat_x.to(device)
        y = y.to(device)
        
        output = model(cont_x, cat_x)
        
        loss = criterion(output, y)
        
        print('loss: {:.3f}'.format(loss.item()))
        
        loss.backward()
        
        optimizer.step()

loss: 325896.469
loss: 593129.500
loss: 440189.531
loss: 384359.500
loss: 426547.719
loss: 378009.812
loss: 411272.750
loss: 334718.344
loss: 405524.656
loss: 469122.688
loss: 323537.750
loss: 454524.219
loss: 342258.719
loss: 371802.656
loss: 437236.094
loss: 347364.156
loss: 389390.656
loss: 337736.281
loss: 385670.406
loss: 379690.531
loss: 292707.156
loss: 487754.562
loss: 360938.594
loss: 437141.625
loss: 358319.594
loss: 327392.219
loss: 291466.781
loss: 359655.094
loss: 380830.062
loss: 432346.875
loss: 348570.344
loss: 330687.219
loss: 306329.906
loss: 408541.406
loss: 282245.000
loss: 257727.500
loss: 329498.938
loss: 329075.375
loss: 341508.656
loss: 372458.719
loss: 304280.031
loss: 331221.031
loss: 356233.656
loss: 383697.344
loss: 377724.531
loss: 468093.094
loss: 343449.219
loss: 453975.969
loss: 358818.938
loss: 328304.469
loss: 285896.594
loss: 290858.062
loss: 260610.875
loss: 304441.281
loss: 258916.859
loss: 308425.094
loss: 225116.891
loss: 375713.000
loss: 315961.8

loss: 17431.041
loss: 10800.265
loss: 13205.960
loss: 16065.410
loss: 10422.528
loss: 16962.084
loss: 11521.527
loss: 12602.115
loss: 21668.000
loss: 15951.779
loss: 10109.607
loss: 12744.494
loss: 15347.073
loss: 11375.688
loss: 15706.103
loss: 19921.928
loss: 17056.475
loss: 18256.223
loss: 14400.442
loss: 13098.826
loss: 16925.445
loss: 11556.138
loss: 16897.715
loss: 24533.191
loss: 10849.153
loss: 12321.685
loss: 20392.223
loss: 12556.147
loss: 12209.834
loss: 10088.658
loss: 16922.662
loss: 15823.423
loss: 9302.969
loss: 16952.707
loss: 13528.066
loss: 9970.514
loss: 7625.941
loss: 15209.735
loss: 12175.468
loss: 15627.758
loss: 9800.972
loss: 12686.157
loss: 9573.947
loss: 14049.424
loss: 5323.963
loss: 31786.088
loss: 17128.061
loss: 11829.659
loss: 13484.585
loss: 16762.053
loss: 8657.439
loss: 15663.567
loss: 12485.398
loss: 8752.924
loss: 11730.962
loss: 15763.769
loss: 7946.458
loss: 18376.400
loss: 15446.162
loss: 20258.869
loss: 13614.852
loss: 23669.277
loss: 11578.659
l

loss: 8370.118
loss: 11284.765
loss: 8915.829
loss: 10366.705
loss: 9478.569
loss: 11035.619
loss: 20265.180
loss: 12348.489
loss: 10983.706
loss: 19533.115
loss: 6465.584
loss: 8018.638
loss: 9416.486
loss: 11694.576
loss: 11796.307
loss: 12128.047
loss: 10633.890
loss: 12470.162
loss: 12897.620
loss: 7807.216
loss: 7976.199
loss: 14339.655
loss: 8552.097
loss: 25022.691
loss: 8954.209
loss: 13393.015
loss: 7532.121
loss: 7956.266
loss: 9836.014
loss: 16895.582
loss: 8783.432
loss: 9643.281
loss: 18530.910
loss: 7975.878
loss: 12176.380
loss: 4949.607
loss: 10804.452
loss: 11490.089
loss: 12175.232
loss: 14907.976
loss: 27925.420
loss: 13328.368
loss: 13426.212
loss: 16780.805
loss: 8934.236
loss: 14989.645
loss: 7157.640
loss: 9143.455
loss: 11548.494
loss: 6380.872
loss: 15873.207
loss: 8352.504
loss: 24772.797
loss: 15383.398
loss: 11161.276
loss: 11057.617
loss: 9885.290
loss: 9193.397
loss: 13263.056
loss: 6003.032
loss: 7041.496
loss: 8275.327
loss: 12093.141
loss: 23470.020
los

loss: 11412.111
loss: 8772.518
loss: 11497.693
loss: 16763.770
loss: 7834.281
loss: 11013.925
loss: 14309.729
loss: 9897.455
loss: 10928.750
loss: 6063.385
loss: 10161.385
loss: 11006.534
loss: 6673.449
loss: 6745.374
loss: 8892.748
loss: 10221.878
loss: 15659.881
loss: 18929.268
loss: 31710.740
loss: 9422.160
loss: 6045.182
loss: 13552.299
loss: 17295.791
loss: 7427.602
loss: 14691.400
loss: 10938.568
loss: 7088.988
loss: 20776.898
loss: 9083.184
loss: 6479.015
loss: 9222.321
loss: 6839.114
loss: 11113.229
loss: 8017.230
loss: 13489.843
loss: 16398.506
loss: 10788.607
loss: 8173.342
loss: 11637.468
loss: 11932.373
loss: 9541.759
loss: 18800.107
loss: 14209.454
loss: 6058.743
loss: 7473.281
loss: 19536.918
loss: 9338.600
loss: 9944.407
loss: 9997.711
loss: 8796.116
loss: 4594.399
loss: 8151.386
loss: 11145.241
loss: 10028.978
loss: 12767.395
loss: 6894.836
loss: 7428.405
loss: 30910.002
loss: 6966.304
loss: 5395.645
loss: 7844.190
loss: 12364.650
loss: 16788.078
loss: 11226.982
loss: 7

loss: 12727.047
loss: 19896.400
loss: 12037.592
loss: 3560.248
loss: 8583.609
loss: 6707.812
loss: 7012.828
loss: 4954.158
loss: 19555.527
loss: 11355.042
loss: 14459.240
loss: 8150.178
loss: 12889.040
loss: 9743.143
loss: 6501.411
loss: 10685.641
loss: 9346.009
loss: 7412.537
loss: 7670.573
loss: 9647.165
loss: 5124.712
loss: 7083.049
loss: 5451.249
loss: 8051.063
loss: 9387.886
loss: 15004.161
loss: 7382.655
loss: 5840.478
loss: 13570.117
loss: 4885.987
loss: 8233.933
loss: 8143.005
loss: 8228.493
loss: 7249.768
loss: 6337.333
loss: 20210.602
loss: 7424.698
loss: 6564.598
loss: 19278.561
loss: 7849.667
loss: 8481.526
loss: 4867.607
loss: 8977.118
loss: 12814.516
loss: 7072.550
loss: 6465.623
loss: 14441.606
loss: 8949.627
loss: 10527.380
loss: 10890.473
loss: 14544.270
loss: 5345.196
loss: 13414.959
loss: 7205.708
loss: 4484.760
loss: 6198.606
loss: 7477.999
loss: 6071.229
loss: 6316.984
loss: 15984.321
loss: 9586.090
loss: 11026.504
loss: 7726.312
loss: 11740.819
loss: 10412.951
los

loss: 12109.338
loss: 7358.900
loss: 8447.930
loss: 6910.001
loss: 40932.680
loss: 8255.425
loss: 6323.579
loss: 8082.444
loss: 9132.372
loss: 6458.892
loss: 8364.890
loss: 6149.714
loss: 5172.733
loss: 4150.924
loss: 7354.696
loss: 6728.846
loss: 6839.878
loss: 13341.576
loss: 9462.743
loss: 8335.212
loss: 5849.471
loss: 10278.885
loss: 7140.354
loss: 7115.993
loss: 8613.940
loss: 22143.143
loss: 4584.551
loss: 14237.215
loss: 8019.943
loss: 9289.431
loss: 17469.889
loss: 6844.592
loss: 3779.551
loss: 9118.328
loss: 6231.710
loss: 8939.221
loss: 10010.270
loss: 8751.314
loss: 6834.958
loss: 12379.433
loss: 15163.909
loss: 13215.622
loss: 8634.243
loss: 8341.478
loss: 7589.859
loss: 6969.887
loss: 16207.217
loss: 6473.120
loss: 8033.114
loss: 8897.348
loss: 7516.122
loss: 7589.596
loss: 15425.475
loss: 8623.304
loss: 14529.213
loss: 10405.820
loss: 9836.971
loss: 13647.979
loss: 8713.181
loss: 6144.639
loss: 5144.576
loss: 12142.898
loss: 4812.200
loss: 4743.069
loss: 9103.587
loss: 16

loss: 6745.350
loss: 7426.704
loss: 5420.000
loss: 6969.675
loss: 7913.140
loss: 8131.514
loss: 9874.508
loss: 9308.700
loss: 6398.550
loss: 9962.521
loss: 7025.230
loss: 7290.793
loss: 13608.515
loss: 12515.281
loss: 15519.157
loss: 11488.535
loss: 4104.260
loss: 3954.452
loss: 13990.423
loss: 26426.561
loss: 9086.631
loss: 9168.442
loss: 7353.444
loss: 9831.049
loss: 4646.783
loss: 8618.236
loss: 6811.720
loss: 11238.833
loss: 14437.143
loss: 8697.797
loss: 6506.135
loss: 10501.269
loss: 10948.692
loss: 7909.291
loss: 5351.629
loss: 6781.282
loss: 4289.850
loss: 9963.123
loss: 10749.078
loss: 11786.744
loss: 5958.898
loss: 7949.095
loss: 7702.564
loss: 6776.296
loss: 12323.731
loss: 4683.380
loss: 6783.904
loss: 10980.058
loss: 4897.008
loss: 6675.971
loss: 6859.653
loss: 10209.125
loss: 26264.660
loss: 10980.360
loss: 11740.076
loss: 6329.743
loss: 6144.827
loss: 14084.223
loss: 20897.486
loss: 5790.898
loss: 11592.662
loss: 7042.657
loss: 5278.264
loss: 6149.598
loss: 5654.597
loss

loss: 4271.690
loss: 7019.472
loss: 6986.353
loss: 9773.758
loss: 8094.416
loss: 9656.854
loss: 13555.487
loss: 8838.923
loss: 13045.233
loss: 9401.281
loss: 8938.893
loss: 21344.455
loss: 4734.150
loss: 13037.038
loss: 9310.480
loss: 14370.892
loss: 6367.405
loss: 17507.711
loss: 8745.639
loss: 8481.917
loss: 6956.246
loss: 5686.507
loss: 9244.249
loss: 16586.055
loss: 10304.577
loss: 9714.509
loss: 6803.627
loss: 10710.977
loss: 7194.207
loss: 4588.414
loss: 5047.452
loss: 7637.875
loss: 6318.770
loss: 8043.012
loss: 8901.624
loss: 14921.448
loss: 9196.314
loss: 10888.305
loss: 13179.806
loss: 8959.323
loss: 6031.401
loss: 8697.823
loss: 7794.238
loss: 6006.461
loss: 7570.482
loss: 6066.062
loss: 13299.791
loss: 5919.434
loss: 7391.006
loss: 9649.646
loss: 4162.333
loss: 3833.891
loss: 9928.989
loss: 19555.709
loss: 5239.082
loss: 11503.520
loss: 10314.581
loss: 8621.800
loss: 8324.214
loss: 10476.521
loss: 14746.071
loss: 6225.872
loss: 7228.367
loss: 8264.062
loss: 11290.803
loss: 

loss: 7485.770
loss: 8861.299
loss: 6852.436
loss: 12908.622
loss: 11168.109
loss: 6883.432
loss: 7083.597
loss: 4379.337
loss: 12269.495
loss: 5871.740
loss: 11397.769
loss: 7150.886
loss: 11815.214
loss: 9681.470
loss: 7250.599
loss: 6322.732
loss: 12017.379
loss: 8158.835
loss: 7591.317
loss: 5237.705
loss: 10911.222
loss: 16713.986
loss: 7648.151
loss: 6551.996
loss: 5779.417
loss: 13610.770
loss: 4573.027
loss: 7221.200
loss: 6962.932
loss: 7000.305
loss: 19548.568
loss: 5542.545
loss: 9650.104
loss: 12254.964
loss: 10469.400
loss: 4872.426
loss: 12842.757
loss: 12841.830
loss: 5684.136
loss: 4715.793
loss: 6923.569
loss: 4314.342
loss: 7094.976
loss: 5418.318
loss: 13749.940
loss: 6973.210
loss: 5635.433
loss: 14911.069
loss: 5475.199
loss: 8321.657
loss: 5184.490
loss: 5807.054
loss: 6192.936
loss: 6859.977
loss: 5822.756
loss: 7258.860
loss: 9365.266
loss: 7121.121
loss: 8556.903
loss: 11654.687
loss: 5414.093
loss: 10955.161
loss: 6114.094
loss: 4519.034
loss: 11762.507
loss: 

loss: 8161.064
loss: 11614.124
loss: 3752.148
loss: 5687.078
loss: 6362.249
loss: 8660.656
loss: 5898.086
loss: 18148.420
loss: 4880.589
loss: 6509.342
loss: 10115.787
loss: 11460.388
loss: 6436.218
loss: 5258.149
loss: 11923.024
loss: 8304.688
loss: 4624.285
loss: 4685.852
loss: 4072.214
loss: 8237.578
loss: 6246.566
loss: 7124.294
loss: 12623.114
loss: 8486.971
loss: 6611.530
loss: 15115.116
loss: 8467.008
loss: 16238.189
loss: 4316.835
loss: 7929.987
loss: 5683.186
loss: 2849.084
loss: 13221.644
loss: 9924.405
loss: 8138.777
loss: 8784.200
loss: 4613.417
loss: 11512.572
loss: 7550.947
loss: 5901.416
loss: 6363.946
loss: 4619.859
loss: 7482.637
loss: 4997.884
loss: 6714.421
loss: 12581.894
loss: 7802.875
loss: 13285.045
loss: 8340.742
loss: 4918.994
loss: 6372.737
loss: 6848.030
loss: 6083.520
loss: 20061.035
loss: 6308.574
loss: 4474.171
loss: 7486.635
loss: 3726.804
loss: 7478.096
loss: 6187.925
loss: 4996.361
loss: 17737.109
loss: 2968.999
loss: 10497.314
loss: 8160.675
loss: 6557

loss: 24277.385
loss: 17067.711
loss: 14522.146
loss: 3259.990
loss: 7032.248
loss: 6905.083
loss: 8539.096
loss: 6068.885
loss: 6079.324
loss: 9489.364
loss: 19255.633
loss: 10501.088
loss: 6826.443
loss: 10208.284
loss: 9533.265
loss: 11965.819
loss: 12190.773
loss: 7869.957
loss: 6783.787
loss: 7478.317
loss: 7871.649
loss: 12301.527
loss: 6122.675
loss: 6651.844
loss: 15850.614
loss: 5484.667
loss: 7406.113
loss: 10915.446
loss: 20905.334
loss: 9590.548
loss: 7964.728
loss: 5574.447
loss: 8605.259
loss: 9638.444
loss: 8537.952
loss: 5695.337
loss: 7899.475
loss: 7664.226
loss: 9863.562
loss: 7729.900
loss: 5120.550
loss: 9227.953
loss: 6089.350
loss: 5101.598
loss: 14894.325
loss: 5149.758
loss: 6877.221
loss: 8802.424
loss: 7515.525
loss: 11801.791
loss: 4047.703
loss: 9398.733
loss: 9707.592
loss: 10510.797
loss: 4255.195
loss: 6407.553
loss: 6266.445
loss: 9885.597
loss: 10455.386
loss: 5677.454
loss: 7022.216
loss: 4449.128
loss: 6401.161
loss: 9707.403
loss: 8161.048
loss: 643

loss: 8852.275
loss: 5422.803
loss: 15484.446
loss: 4060.789
loss: 6770.877
loss: 9917.266
loss: 31992.363
loss: 6492.957
loss: 19438.512
loss: 9637.884
loss: 9886.997
loss: 7593.569
loss: 6263.317
loss: 5087.285
loss: 5924.167
loss: 6994.068
loss: 14470.276
loss: 4632.972
loss: 4808.145
loss: 11696.383
loss: 5944.764
loss: 4441.711
loss: 7293.301
loss: 4651.045
loss: 3975.357
loss: 9743.499
loss: 7868.716
loss: 5373.298
loss: 9676.339
loss: 4169.382
loss: 13762.658
loss: 12779.224
loss: 5589.424
loss: 7007.402
loss: 9704.453
loss: 8982.880
loss: 6968.692
loss: 11713.716
loss: 5223.395
loss: 7822.999
loss: 5168.070
loss: 17851.342
loss: 3856.888
loss: 11996.506
loss: 18720.623
loss: 5366.893
loss: 10863.204
loss: 7861.355
loss: 12554.497
loss: 7064.797
loss: 7574.180
loss: 19913.107
loss: 7190.259
loss: 14550.796
loss: 11880.498
loss: 9095.348
loss: 6156.067
loss: 10404.225
loss: 9164.036
loss: 9799.594
loss: 6802.515
loss: 8480.547
loss: 10033.612
loss: 9143.972
loss: 10829.711
loss: 

In [18]:
import sklearn.metrics

for y, cont, cat in testloader:
    predicted = model(cont, cat).detach().numpy()
    actual = y.numpy()
    print(sklearn.metrics.r2_score(actual, predicted))

0.9221674274291969
0.8710208504221149
0.9534621455681875
0.9272809557520366
0.9560730024109886
0.8833352012375387
0.921501543742779
0.9567070323974445
0.8428761584077237
0.8787653681578018
0.9440694926869649
0.956624659035702
0.9067444999531428
0.8404393191643615
0.9032950096846757
0.878275958633035
0.9297182605449544
0.9026584444163017
0.8931712271118732
0.9223696006707522
0.9570201714236349
0.8631705973956865
0.9016090905013768
0.9150695936616138
0.9250129014436059
0.9063051153533932
0.9607654032632704
0.9694800464242554
0.9266053694500233
0.9067031377700434
0.967463842740057
0.885065689368728
0.9544356762164832
0.943354126884768
0.9450190928959836
0.8711691162884486
0.9300055310820357
0.9230894562115437
0.9040337976424071
0.9293064717370234
0.8483623114156672
0.9412070139358882
0.9194523448891542
0.9574325716546748
0.8923978615056479
0.9389274996429855
0.9218447728647766
0.9598544758088944
0.9198479000213144
0.9077552811283658
0.945549895903359
0.9172218831717286
0.8891283682308828
