In [1]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import LabelEncoder

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


class TabularDataset(Dataset):
  def __init__(self, data, cat_cols=None, output_col=None):
    """
    Characterizes a Dataset for PyTorch
    Parameters
    ----------
    data: pandas data frame
      The data frame object for the input data. It must
      contain all the continuous, categorical and the
      output columns to be used.
    cat_cols: List of strings
      The names of the categorical columns in the data.
      These columns will be passed through the embedding
      layers in the model. These columns must be
      label encoded beforehand. 
    output_col: string
      The name of the output variable column in the data
      provided.
    """

    self.n = data.shape[0]

    if output_col:
      self.y = data[output_col].astype(np.float32).values.reshape(-1, 1)
    else:
      self.y =  np.zeros((self.n, 1))

    self.cat_cols = cat_cols if cat_cols else []
    self.cont_cols = [col for col in data.columns
                      if col not in self.cat_cols + [output_col]]

    if self.cont_cols:
      self.cont_X = data[self.cont_cols].astype(np.float32).values
    else:
      self.cont_X = np.zeros((self.n, 1))

    if self.cat_cols:
      self.cat_X = data[cat_cols].astype(np.int64).values
    else:
      self.cat_X =  np.zeros((self.n, 1))

  def __len__(self):
    """
    Denotes the total number of samples.
    """
    return self.n

  def __getitem__(self, idx):
    """
    Generates one sample of data.
    """
    return [self.y[idx], self.cont_X[idx], self.cat_X[idx]]


class FeedForwardNN(nn.Module):

  def __init__(self, emb_dims, no_of_cont, lin_layer_sizes,
               output_size, emb_dropout, lin_layer_dropouts):

    """
    Parameters
    ----------
    emb_dims: List of two element tuples
      This list will contain a two element tuple for each
      categorical feature. The first element of a tuple will
      denote the number of unique values of the categorical
      feature. The second element will denote the embedding
      dimension to be used for that feature.
    no_of_cont: Integer
      The number of continuous features in the data.
    lin_layer_sizes: List of integers.
      The size of each linear layer. The length will be equal
      to the total number
      of linear layers in the network.
    output_size: Integer
      The size of the final output.
    emb_dropout: Float
      The dropout to be used after the embedding layers.
    lin_layer_dropouts: List of floats
      The dropouts to be used after each linear layer.
    """

    super().__init__()

    # Embedding layers
    self.emb_layers = nn.ModuleList([nn.Embedding(x, y)
                                     for x, y in emb_dims])

    no_of_embs = sum([y for x, y in emb_dims])
    self.no_of_embs = no_of_embs
    self.no_of_cont = no_of_cont

    # Linear Layers
    first_lin_layer = nn.Linear(self.no_of_embs + self.no_of_cont,
                                lin_layer_sizes[0])

    self.lin_layers =\
     nn.ModuleList([first_lin_layer] +\
          [nn.Linear(lin_layer_sizes[i], lin_layer_sizes[i + 1])
           for i in range(len(lin_layer_sizes) - 1)])
    
    for lin_layer in self.lin_layers:
      nn.init.kaiming_normal_(lin_layer.weight.data)

    # Output Layer
    self.output_layer = nn.Linear(lin_layer_sizes[-1],
                                  output_size)
    nn.init.kaiming_normal_(self.output_layer.weight.data)

    # Batch Norm Layers
    self.first_bn_layer = nn.BatchNorm1d(self.no_of_cont)
    self.bn_layers = nn.ModuleList([nn.BatchNorm1d(size)
                                    for size in lin_layer_sizes])

    # Dropout Layers
    self.emb_dropout_layer = nn.Dropout(emb_dropout)
    self.droput_layers = nn.ModuleList([nn.Dropout(size)
                                  for size in lin_layer_dropouts])

  def forward(self, cont_data, cat_data):

    if self.no_of_embs != 0:
      x = [emb_layer(cat_data[:, i])
           for i,emb_layer in enumerate(self.emb_layers)]
      x = torch.cat(x, 1)
      x = self.emb_dropout_layer(x)

    if self.no_of_cont != 0:
      normalized_cont_data = self.first_bn_layer(cont_data)

      if self.no_of_embs != 0:
        x = torch.cat([x, normalized_cont_data], 1) 
      else:
        x = normalized_cont_data

    for lin_layer, dropout_layer, bn_layer in\
        zip(self.lin_layers, self.droput_layers, self.bn_layers):
      
      x = F.relu(lin_layer(x))
      x = bn_layer(x)
      x = dropout_layer(x)

    x = self.output_layer(x)

    return x

In [2]:
data = pd.read_csv('data.csv')

In [3]:
data['date_in'] = pd.to_datetime(data['date_in'])
data['year'] = data['date_in'].map(lambda x: x.strftime('%Y'))
data['month'] = data['date_in'].map(lambda x: x.strftime('%m'))
data['day'] = data['date_in'].map(lambda x: x.strftime('%d'))

In [4]:
dr = ['date_in','house_pk']

data = data.drop(dr,axis = 1)
data.head().T

Unnamed: 0,0,1,2,3,4
agency_id,90.0,90.0,90.0,90.0,90.0
price,532.0,588.0,588.0,588.0,588.0
dis_water_real,0.261,0.261,0.261,0.261,0.261
dis_shopping,3.0,3.0,3.0,3.0,3.0
no_bedrooms,3.0,3.0,3.0,3.0,3.0
max_persons,4.0,4.0,4.0,4.0,4.0
house_size,140.0,140.0,140.0,140.0,140.0
land_size,726.0,726.0,726.0,726.0,726.0
build_year,1953.0,1953.0,1953.0,1953.0,1953.0
renovation_year,2014.0,2014.0,2014.0,2014.0,2014.0


In [10]:
categorical_features = ['agency_id', 'apartment', 'indoor_pool', 'spa', 'internet', 'pets_allowed', 'water_view', 'fire_stove', 'year', 'month', 'day']
output_columns = 'price'

In [11]:
data.shape

(85195, 21)

In [12]:
label_encoders = {}
for cat_col in categorical_features:
        label_encoders[cat_col] = LabelEncoder()
        data[cat_col] = label_encoders[cat_col].fit_transform(data[cat_col])

In [13]:
dataset = TabularDataset(data=data, cat_cols=categorical_features, output_col=output_columns)

In [14]:
batchsize = 64
dataloader = DataLoader(dataset, batchsize, shuffle=True)

In [15]:
cat_dims = [int(data[col].nunique()) for col in categorical_features]
cat_dims

[4, 2, 2, 2, 2, 2, 2, 2, 4, 12, 31]

In [16]:
emb_dims = [(x, min(50, (x + 1) // 2)) for x in cat_dims]
emb_dims

[(4, 2),
 (2, 1),
 (2, 1),
 (2, 1),
 (2, 1),
 (2, 1),
 (2, 1),
 (2, 1),
 (4, 2),
 (12, 6),
 (31, 16)]

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [18]:
model = FeedForwardNN(emb_dims, no_of_cont=9, lin_layer_sizes=[50, 100],
                          output_size=1, emb_dropout=0.04,
                          lin_layer_dropouts=[0.001,0.01]).to(device)

In [19]:
print(model)

FeedForwardNN(
  (emb_layers): ModuleList(
    (0): Embedding(4, 2)
    (1): Embedding(2, 1)
    (2): Embedding(2, 1)
    (3): Embedding(2, 1)
    (4): Embedding(2, 1)
    (5): Embedding(2, 1)
    (6): Embedding(2, 1)
    (7): Embedding(2, 1)
    (8): Embedding(4, 2)
    (9): Embedding(12, 6)
    (10): Embedding(31, 16)
  )
  (lin_layers): ModuleList(
    (0): Linear(in_features=42, out_features=50, bias=True)
    (1): Linear(in_features=50, out_features=100, bias=True)
  )
  (output_layer): Linear(in_features=100, out_features=1, bias=True)
  (first_bn_layer): BatchNorm1d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn_layers): ModuleList(
    (0): BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): BatchNorm1d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (emb_dropout_layer): Dropout(p=0.04)
  (droput_layers): ModuleList(
    (0): Dropout(p=0.001)
    (1): Dropout(p=0.01)
  )
)


In [20]:
epochs = 5

criterion = nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), lr = 0.1)

for epoch in range(epochs):
    
    for y, cont_x, cat_x in dataloader:
        
        optimizer.zero_grad()
        
        cont_x = cont_x.to(device)
        cat_x = cat_x.to(device)
        y = y.to(device)
        
        output = model(cont_x, cat_x)
        
        loss = criterion(output, y)
        
        print('loss: {:.3f}'.format(loss.item()))
        
        loss.backward()
        
        optimizer.step()

loss: 416854.438
loss: 523421.219
loss: 287584.312
loss: 395829.750
loss: 456363.000
loss: 333489.812
loss: 439342.438
loss: 363844.656
loss: 295457.969
loss: 263772.812
loss: 361082.094
loss: 275852.938
loss: 234036.969
loss: 179036.703
loss: 163964.281
loss: 202244.062
loss: 176618.609
loss: 81841.117
loss: 119143.312
loss: 83646.258
loss: 61241.926
loss: 69151.516
loss: 36613.031
loss: 33725.961
loss: 36084.637
loss: 43632.996
loss: 27973.756
loss: 33030.262
loss: 29720.777
loss: 41068.914
loss: 51074.746
loss: 39600.691
loss: 48630.828
loss: 46995.758
loss: 45445.980
loss: 49229.949
loss: 60104.863
loss: 24711.416
loss: 20696.936
loss: 33244.793
loss: 32229.004
loss: 19708.168
loss: 22504.754
loss: 29445.252
loss: 41899.715
loss: 47722.352
loss: 22882.344
loss: 31127.914
loss: 17391.279
loss: 31750.375
loss: 15429.450
loss: 37866.328
loss: 18371.932
loss: 40367.316
loss: 21935.051
loss: 35214.410
loss: 20298.609
loss: 14842.490
loss: 34841.555
loss: 28405.678
loss: 23309.348
loss: 

loss: 11936.598
loss: 13853.999
loss: 9166.742
loss: 19783.619
loss: 17529.047
loss: 16276.135
loss: 9499.007
loss: 20907.426
loss: 19007.705
loss: 7327.719
loss: 18673.395
loss: 16781.941
loss: 24750.166
loss: 19024.975
loss: 15911.874
loss: 21674.207
loss: 17284.352
loss: 10513.167
loss: 6834.040
loss: 9338.033
loss: 13058.647
loss: 22220.174
loss: 10896.471
loss: 24338.086
loss: 20817.674
loss: 19079.549
loss: 15770.633
loss: 18146.395
loss: 21370.789
loss: 15561.690
loss: 13628.843
loss: 18426.635
loss: 24289.598
loss: 33493.863
loss: 26331.896
loss: 17541.809
loss: 16495.303
loss: 10678.614
loss: 17446.119
loss: 22233.785
loss: 43506.500
loss: 12870.308
loss: 16699.494
loss: 12435.603
loss: 17418.977
loss: 16267.722
loss: 18506.070
loss: 18840.312
loss: 8196.446
loss: 11701.642
loss: 17116.000
loss: 15220.720
loss: 15317.194
loss: 13238.052
loss: 10911.582
loss: 19044.562
loss: 17073.137
loss: 16729.033
loss: 4851.831
loss: 17110.271
loss: 13704.869
loss: 8940.118
loss: 13196.627


loss: 12550.074
loss: 8367.521
loss: 10813.107
loss: 12123.097
loss: 13114.134
loss: 7900.851
loss: 8952.268
loss: 10247.700
loss: 9706.075
loss: 10432.555
loss: 8876.635
loss: 8494.070
loss: 14558.873
loss: 7606.004
loss: 6914.565
loss: 23858.230
loss: 20385.631
loss: 8125.335
loss: 11535.518
loss: 10798.517
loss: 10773.902
loss: 18240.242
loss: 26684.006
loss: 13755.210
loss: 11587.394
loss: 17083.648
loss: 9288.322
loss: 13385.331
loss: 11619.389
loss: 12726.199
loss: 9535.024
loss: 7509.204
loss: 9976.345
loss: 10523.298
loss: 8906.546
loss: 14632.915
loss: 12632.153
loss: 9251.402
loss: 9056.527
loss: 24535.084
loss: 8780.908
loss: 11658.242
loss: 8384.514
loss: 13429.374
loss: 8923.059
loss: 11280.351
loss: 27179.828
loss: 11406.011
loss: 11261.264
loss: 10616.976
loss: 17492.135
loss: 10328.238
loss: 17282.357
loss: 12114.916
loss: 10636.437
loss: 11554.411
loss: 11683.469
loss: 7113.260
loss: 13639.473
loss: 11953.521
loss: 5298.244
loss: 11622.560
loss: 12292.180
loss: 13102.1

loss: 8491.586
loss: 10526.258
loss: 6229.564
loss: 16549.211
loss: 29999.543
loss: 15693.120
loss: 24671.623
loss: 12070.269
loss: 19123.811
loss: 11130.639
loss: 11534.384
loss: 13468.994
loss: 12434.079
loss: 10064.729
loss: 11900.912
loss: 17122.471
loss: 17448.658
loss: 13995.791
loss: 30578.418
loss: 10456.648
loss: 8476.193
loss: 9803.597
loss: 12506.493
loss: 11672.891
loss: 11194.303
loss: 10529.631
loss: 15740.380
loss: 11995.398
loss: 10308.249
loss: 11299.942
loss: 7763.214
loss: 9589.853
loss: 13822.090
loss: 6825.870
loss: 9013.098
loss: 12462.562
loss: 12368.338
loss: 8523.690
loss: 11910.081
loss: 6759.502
loss: 9423.238
loss: 9879.360
loss: 9288.531
loss: 8693.138
loss: 18136.111
loss: 7704.622
loss: 9795.285
loss: 15111.488
loss: 14313.971
loss: 19735.578
loss: 11855.675
loss: 18365.047
loss: 10398.751
loss: 25139.580
loss: 20278.508
loss: 17159.217
loss: 13329.096
loss: 12934.810
loss: 9699.061
loss: 9860.055
loss: 12140.400
loss: 12925.351
loss: 14105.479
loss: 9478

loss: 13132.456
loss: 9047.455
loss: 9688.868
loss: 12225.228
loss: 14394.947
loss: 8653.009
loss: 14108.597
loss: 22210.521
loss: 17890.490
loss: 20285.371
loss: 9254.951
loss: 9266.140
loss: 9258.171
loss: 13927.908
loss: 14293.869
loss: 6444.556
loss: 12690.363
loss: 20165.273
loss: 7040.963
loss: 6982.096
loss: 12862.607
loss: 5192.812
loss: 12126.546
loss: 12404.530
loss: 8793.006
loss: 10617.542
loss: 6828.200
loss: 12725.690
loss: 7039.394
loss: 11268.130
loss: 9684.365
loss: 8023.301
loss: 4497.845
loss: 6336.371
loss: 7467.970
loss: 17813.080
loss: 15583.849
loss: 6802.128
loss: 23187.900
loss: 7802.488
loss: 9882.544
loss: 8853.438
loss: 6714.980
loss: 15806.395
loss: 7928.168
loss: 10743.240
loss: 12699.509
loss: 5580.798
loss: 11380.671
loss: 6246.934
loss: 5816.381
loss: 8015.483
loss: 11023.779
loss: 6723.165
loss: 9314.028
loss: 9192.663
loss: 13126.233
loss: 6558.261
loss: 10152.095
loss: 9901.223
loss: 15696.090
loss: 12920.716
loss: 13743.734
loss: 9774.758
loss: 1766

loss: 11281.419
loss: 10614.427
loss: 7469.890
loss: 9220.164
loss: 9054.802
loss: 18308.908
loss: 29242.273
loss: 22667.807
loss: 7397.590
loss: 9344.586
loss: 11242.523
loss: 11503.321
loss: 11499.354
loss: 11460.174
loss: 15651.404
loss: 6749.617
loss: 9087.685
loss: 14592.008
loss: 13519.293
loss: 16814.154
loss: 15507.019
loss: 12897.615
loss: 31050.199
loss: 11504.962
loss: 6589.248
loss: 13530.911
loss: 10686.382
loss: 19567.633
loss: 9565.860
loss: 8445.280
loss: 12205.860
loss: 7598.921
loss: 11465.264
loss: 11347.952
loss: 16250.022
loss: 11217.549
loss: 10327.406
loss: 11929.195
loss: 16000.599
loss: 10474.456
loss: 10311.113
loss: 16313.301
loss: 7570.007
loss: 22403.203
loss: 12415.421
loss: 6630.941
loss: 8478.588
loss: 8873.964
loss: 15165.086
loss: 8684.635
loss: 4034.612
loss: 7231.755
loss: 10535.102
loss: 9276.681
loss: 15025.219
loss: 11213.751
loss: 16497.820
loss: 18784.775
loss: 9848.243
loss: 15555.086
loss: 9898.228
loss: 7784.754
loss: 12501.757
loss: 16131.27

loss: 7825.228
loss: 9027.308
loss: 11809.860
loss: 13865.280
loss: 15915.803
loss: 21201.676
loss: 5616.937
loss: 16086.062
loss: 14071.909
loss: 15753.884
loss: 7057.942
loss: 9552.236
loss: 10085.808
loss: 11993.687
loss: 6787.849
loss: 10598.609
loss: 15985.175
loss: 14373.052
loss: 21087.318
loss: 9476.037
loss: 14028.776
loss: 6099.116
loss: 10526.162
loss: 6001.936
loss: 21106.715
loss: 21264.291
loss: 12043.356
loss: 17515.893
loss: 7174.410
loss: 9490.240
loss: 11883.189
loss: 8901.770
loss: 13830.975
loss: 5663.126
loss: 10915.118
loss: 4908.554
loss: 13502.995
loss: 5151.656
loss: 9608.491
loss: 6214.186
loss: 15801.284
loss: 9529.037
loss: 14987.665
loss: 12650.564
loss: 9667.048
loss: 26327.314
loss: 11147.327
loss: 8238.355
loss: 8587.661
loss: 9295.708
loss: 8786.492
loss: 13319.609
loss: 13075.075
loss: 16680.211
loss: 5280.928
loss: 9156.396
loss: 7760.609
loss: 6461.667
loss: 6450.555
loss: 9919.120
loss: 16742.371
loss: 17439.826
loss: 12793.752
loss: 22270.479
loss:

loss: 7196.142
loss: 12636.556
loss: 21826.684
loss: 7080.076
loss: 10099.045
loss: 7729.569
loss: 7377.779
loss: 8146.364
loss: 7388.775
loss: 8818.557
loss: 10136.876
loss: 9178.282
loss: 14280.896
loss: 7291.739
loss: 6670.782
loss: 8084.439
loss: 12901.729
loss: 6882.994
loss: 10068.955
loss: 19129.951
loss: 6061.005
loss: 3994.288
loss: 5626.745
loss: 12431.749
loss: 10461.726
loss: 5978.096
loss: 29367.285
loss: 9486.389
loss: 8138.775
loss: 12693.228
loss: 10570.643
loss: 9587.390
loss: 15893.524
loss: 9421.765
loss: 5342.334
loss: 13496.524
loss: 5749.269
loss: 8302.550
loss: 10742.896
loss: 13071.231
loss: 10192.337
loss: 9816.496
loss: 10090.558
loss: 6372.603
loss: 10691.702
loss: 10596.015
loss: 11248.831
loss: 11358.843
loss: 17075.197
loss: 10117.444
loss: 13885.597
loss: 10745.205
loss: 8905.786
loss: 6566.942
loss: 9823.692
loss: 10380.800
loss: 6947.317
loss: 12835.078
loss: 10599.854
loss: 9468.413
loss: 11663.817
loss: 10897.534
loss: 9248.080
loss: 7251.874
loss: 61

loss: 6970.575
loss: 5324.929
loss: 7014.426
loss: 20574.121
loss: 18458.988
loss: 14355.061
loss: 16658.176
loss: 7632.724
loss: 11008.063
loss: 8201.536
loss: 22776.488
loss: 7175.386
loss: 10420.900
loss: 6157.996
loss: 5356.570
loss: 13014.786
loss: 7039.361
loss: 6831.808
loss: 10377.906
loss: 6819.354
loss: 10581.511
loss: 3623.209
loss: 4725.665
loss: 5020.787
loss: 4733.217
loss: 13278.717
loss: 29184.008
loss: 6824.318
loss: 7109.019
loss: 12674.274
loss: 6733.685
loss: 6934.032
loss: 10462.476
loss: 5971.314
loss: 8615.001
loss: 8555.274
loss: 6106.111
loss: 20527.129
loss: 6153.048
loss: 11603.288
loss: 14132.607
loss: 5282.478
loss: 13885.162
loss: 5684.182
loss: 6503.036
loss: 9222.873
loss: 7056.663
loss: 13370.906
loss: 10186.382
loss: 6216.137
loss: 8877.215
loss: 10280.502
loss: 9106.179
loss: 5974.416
loss: 7184.418
loss: 7640.088
loss: 6926.032
loss: 10538.644
loss: 8152.256
loss: 11038.940
loss: 13527.759
loss: 12814.993
loss: 9233.531
loss: 13149.317
loss: 13227.18

loss: 10178.616
loss: 18458.182
loss: 19558.367
loss: 5046.676
loss: 7249.390
loss: 7827.068
loss: 6655.898
loss: 12283.085
loss: 12805.173
loss: 5013.856
loss: 5909.931
loss: 7923.986
loss: 7601.243
loss: 7030.914
loss: 5921.928
loss: 7773.200
loss: 5951.256
loss: 17312.436
loss: 6023.236
loss: 8130.653
loss: 11486.685
loss: 4071.560
loss: 10340.889
loss: 7687.223
loss: 11358.538
loss: 11016.313
loss: 4460.456
loss: 7440.201
loss: 27964.039
loss: 10004.854
loss: 11805.540
loss: 8080.970
loss: 9468.896
loss: 11381.439
loss: 6354.629
loss: 10272.309
loss: 18291.787
loss: 13298.783
loss: 6042.806
loss: 6711.370
loss: 5823.505
loss: 7414.605
loss: 8312.501
loss: 11858.982
loss: 7421.775
loss: 9922.945
loss: 5606.386
loss: 8604.521
loss: 5242.466
loss: 9545.206
loss: 12668.172
loss: 10294.503
loss: 7251.791
loss: 14782.632
loss: 20530.336
loss: 4390.866
loss: 15155.935
loss: 6605.956
loss: 7464.694
loss: 11276.079
loss: 11496.285
loss: 7791.493
loss: 14274.278
loss: 14005.978
loss: 8199.55

KeyboardInterrupt: 