# HDB DGCNN Regression

Our objective is to compare the performance of a MLP price regressor against a DGCNN based price regressor. We expect the DGCNN to work better by taking into account geospatial relationships with 1) nearby recent sales, 2) nearby malls, 3) nearby MRTs. This should give us a more localised prior to do inference. Instead of getting MLP to learn every single feature, we extract relevant priors to make inference easier. We hope to see the DGCNN perform better than the MLP. 

In [76]:
import pickle
import os
import pandas as pd
from tqdm.notebook import tqdm

### Load Dataset

In [77]:
with open('housing_data.pickle', 'rb') as f:
    housing_data = pickle.load(f)

# sort by descending date (most recent first)

housing_data = housing_data.sort_values(by=['year','town'],ascending=[False,True])

housing_data = housing_data.reset_index(drop=True)
print(housing_data.info())
display(housing_data.head())

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 62214 entries, 0 to 62213
Data columns (total 14 columns):
 #   Column           Non-Null Count  Dtype  
---  ------           --------------  -----  
 0   town             62214 non-null  object 
 1   flat_type        62214 non-null  object 
 2   block            62214 non-null  object 
 3   street_name      62214 non-null  object 
 4   floor_area_sqm   62214 non-null  float64
 5   flat_model       62214 non-null  object 
 6   remaining_lease  62214 non-null  float64
 7   resale_price     62214 non-null  float64
 8   year             62214 non-null  int64  
 9   storey           62214 non-null  float64
 10  psm              62214 non-null  float64
 11  address          62214 non-null  object 
 12  latitude         62214 non-null  float64
 13  longitude        62214 non-null  float64
dtypes: float64(7), int64(1), object(6)
memory usage: 6.6+ MB
None


Unnamed: 0,town,flat_type,block,street_name,floor_area_sqm,flat_model,remaining_lease,resale_price,year,storey,psm,address,latitude,longitude
0,ANG MO KIO,2 ROOM,406,ANG MO KIO AVE 10,44.0,Improved,55.0,267000.0,2023,1.0,6068.181818,406 ANG MO KIO AVE 10,8.598078,0.314893
1,ANG MO KIO,2 ROOM,323,ANG MO KIO AVE 3,49.0,Improved,53.0,300000.0,2023,5.0,6122.44898,323 ANG MO KIO AVE 3,9.250907,-0.371289
2,ANG MO KIO,2 ROOM,314,ANG MO KIO AVE 3,44.0,Improved,54.0,280000.0,2023,5.0,6363.636364,314 ANG MO KIO AVE 3,9.064984,-0.10734
3,ANG MO KIO,2 ROOM,314,ANG MO KIO AVE 3,44.0,Improved,54.0,282000.0,2023,7.0,6409.090909,314 ANG MO KIO AVE 3,9.064984,-0.10734
4,ANG MO KIO,2 ROOM,170,ANG MO KIO AVE 4,45.0,Improved,62.0,289800.0,2023,1.0,6440.0,170 ANG MO KIO AVE 4,9.924554,-1.626898


### Load MRT/Mall Data

In [78]:
with open('mrt_data.pickle', 'rb') as f:
    mrt_data = pickle.load(f)
display(mrt_data.head())

with open('mall_data.pickle', 'rb') as f:
    mall_data = pickle.load(f)
display(mall_data.head())

# encode addresses
from sklearn.preprocessing import LabelEncoder
# mrt_encoder = LabelEncoder()
# df_mrt = mrt_data.apply(mrt_encoder.fit_transform).astype('int')
df_mrt = mrt_data.copy()
df_mrt[['latitude','longitude']] = mrt_data[['latitude','longitude']].astype('float')

# mall_encoder = LabelEncoder()
# df_mall = mall_data.apply(mall_encoder.fit_transform).astype('int')
df_mall = mall_data.copy()
df_mall[['latitude','longitude']] = mall_data[['latitude','longitude']].astype('float')

Unnamed: 0,address,latitude,longitude
0,Bukit Gombak,9.397626,-10.088609
1,Dakota,2.827397,4.205995
2,Marina Bay,0.045941,1.107969
3,Sembawang,14.968582,-3.333573
4,Tuas Link,-1.590055,-24.911934


Unnamed: 0,address,latitude,longitude
0,Beauty World Centre,6.431706,-8.292265
1,Anchorpoint,0.518399,-5.057376
2,600 @ Toa Payoh,5.505486,-0.008088
3,100 AM,-1.05751,-0.841594
4,Causeway Point,16.805408,-7.200186


In [79]:
# split numerical and categorical features

df_temp = housing_data.copy()
df_temp = df_temp.drop(columns=['block','address','street_name'])

#extract categorical features and encode them
house_encoder = LabelEncoder()
df_cat = df_temp.select_dtypes(include=['object'])
df_cat = df_cat.apply(house_encoder.fit_transform)
df_cat = df_cat.reset_index(drop=True).astype('int')
display(df_cat.head())

#extract numerical features
df_num = df_temp.select_dtypes(include=['float64'])
df_num = df_num.reset_index(drop=True)
df_num = df_num[['floor_area_sqm','remaining_lease','storey','resale_price','latitude','longitude']]
display(df_num.head())

Unnamed: 0,town,flat_type,flat_model
0,0,1,5
1,0,1,5
2,0,1,5
3,0,1,5
4,0,1,5


Unnamed: 0,floor_area_sqm,remaining_lease,storey,resale_price,latitude,longitude
0,44.0,55.0,1.0,267000.0,8.598078,0.314893
1,49.0,53.0,5.0,300000.0,9.250907,-0.371289
2,44.0,54.0,5.0,280000.0,9.064984,-0.10734
3,44.0,54.0,7.0,282000.0,9.064984,-0.10734
4,45.0,62.0,1.0,289800.0,9.924554,-1.626898


In [80]:
df_final = pd.concat([df_num, df_cat], axis=1)

df_Y = df_final[['resale_price']].copy() #.div(1000)
df_X = df_final.copy() #.drop(columns=['psm'])

display(df_Y.head())
display(df_X.head())
print(df_X.info())

Unnamed: 0,resale_price
0,267000.0
1,300000.0
2,280000.0
3,282000.0
4,289800.0


Unnamed: 0,floor_area_sqm,remaining_lease,storey,resale_price,latitude,longitude,town,flat_type,flat_model
0,44.0,55.0,1.0,267000.0,8.598078,0.314893,0,1,5
1,49.0,53.0,5.0,300000.0,9.250907,-0.371289,0,1,5
2,44.0,54.0,5.0,280000.0,9.064984,-0.10734,0,1,5
3,44.0,54.0,7.0,282000.0,9.064984,-0.10734,0,1,5
4,45.0,62.0,1.0,289800.0,9.924554,-1.626898,0,1,5


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 62214 entries, 0 to 62213
Data columns (total 9 columns):
 #   Column           Non-Null Count  Dtype  
---  ------           --------------  -----  
 0   floor_area_sqm   62214 non-null  float64
 1   remaining_lease  62214 non-null  float64
 2   storey           62214 non-null  float64
 3   resale_price     62214 non-null  float64
 4   latitude         62214 non-null  float64
 5   longitude        62214 non-null  float64
 6   town             62214 non-null  int64  
 7   flat_type        62214 non-null  int64  
 8   flat_model       62214 non-null  int64  
dtypes: float64(6), int64(3)
memory usage: 4.3 MB
None


### Train Test Split

Seperate numerical & categorical features, scale numerical features in training set

In [81]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import numpy as np

X_train, X_test, Y_train, Y_test = train_test_split(df_X,df_Y, test_size=0.05, shuffle=True)
# print(X_train.info())
# print(X_test.info())
# print(Y_train.info())
# print(Y_test.info())

#train set
X_train_num = X_train.select_dtypes(include=['float64'])
X_train_cat = X_train.select_dtypes(include=['int64'])
# X_train_cat = X_train.select_dtypes(include=['object'])
display(X_train_num,X_train_cat)

#test set
X_test_num = X_test.select_dtypes(include=['float64'])
X_test_cat = X_test.select_dtypes(include=['int64'])
# X_test_cat = X_test.select_dtypes(include=['object'])
display(X_test_num,X_test_cat)

#scaling for numerical features
for feature in X_train_num:
    if feature in ['longitude','latitude']:
        continue

    scaler = MinMaxScaler()
    scaler.fit(np.array(X_train_num[feature]).reshape(-1, 1))
    X_train_num[feature] = scaler.transform(np.array(X_train_num[feature]).reshape(-1, 1))
    X_test_num[feature] = scaler.transform(np.array(X_test_num[feature]).reshape(-1, 1))

    # scale psm results
    if feature == 'resale_price':
        Y_train[feature] = scaler.transform(np.array(Y_train[feature]).reshape(-1, 1))
        # Y_train[feature] = scaler.inverse_transform(np.array(Y_train[feature]).reshape(-1, 1))
        Y_test[feature] = scaler.transform(np.array(Y_test[feature]).reshape(-1, 1))
        # Y_test[feature] = scaler.inverse_transform(np.array(Y_test[feature]).reshape(-1, 1))
        price_scaler = scaler

Unnamed: 0,floor_area_sqm,remaining_lease,storey,resale_price,latitude,longitude
15663,91.0,61.0,5.0,478000.0,7.515076,3.240168
33996,99.0,56.0,8.0,478000.0,8.923331,0.333085
7306,67.0,61.0,2.0,342000.0,10.538971,-1.147900
38372,111.0,80.0,8.0,488000.0,10.903875,-9.909047
60893,74.0,65.0,5.0,313000.0,14.813871,-2.164379
...,...,...,...,...,...,...
51696,123.0,54.0,1.0,700000.0,2.815325,-6.894511
49223,93.0,93.0,5.0,472000.0,12.969300,7.125401
43734,120.0,75.0,10.0,580000.0,10.403890,3.951022
16059,138.0,61.0,1.0,690000.0,7.243946,3.979654


Unnamed: 0,town,flat_type,flat_model
15663,11,3,12
33996,0,3,12
7306,0,2,12
38372,5,4,5
60893,25,2,8
...,...,...,...
51696,18,4,17
49223,17,3,13
43734,11,4,5
16059,11,4,9


Unnamed: 0,floor_area_sqm,remaining_lease,storey,resale_price,latitude,longitude
42016,59.0,55.0,8.0,325000.0,4.367149,5.932859
31326,47.0,95.0,11.0,310000.0,14.872999,-1.075214
58046,92.0,87.0,8.0,750000.0,5.572333,-0.340559
2373,91.0,61.0,1.0,458000.0,9.929970,5.053452
47763,123.0,70.0,10.0,530000.0,9.200295,11.295701
...,...,...,...,...,...,...
48474,112.0,94.0,4.0,675000.0,13.830277,5.536642
33552,110.0,79.0,13.0,710000.0,8.805183,0.108866
19571,76.0,52.0,11.0,470000.0,2.177374,6.531166
48167,146.0,67.0,7.0,800000.0,9.704400,11.912120


Unnamed: 0,town,flat_type,flat_model
42016,10,2,5
31326,25,1,8
58046,23,3,8
2373,11,3,12
47763,16,4,5
...,...,...,...
48474,17,4,13
33552,0,4,5
19571,15,2,5
48167,16,5,7


### Baseline MLP Training & Evaluation
Baseline MLP regresses housing valuation based on apartment attributes only, without information on recent nearby transactions or proximity of amenities.

In [83]:
from torch.utils.data import Dataset, DataLoader
import torch

class TrainDataset(Dataset):
    def __init__(self,X_train_num, X_train_cat, Y_train,df_mrt,df_mall,psm_scaler):
        self.X_train_num = torch.Tensor(X_train_num.values) #(N,6)
        self.X_train_cat = torch.Tensor(X_train_cat.values) #(N,3)
        self.Y_train = torch.Tensor(Y_train.values) #(N,1)
        self.price_scaler = price_scaler
        self.mrt_data = torch.Tensor(df_mrt[['latitude','longitude']].values) #(M,2)
        self.mall_data = torch.Tensor(df_mall[['latitude','longitude']].values) #(K,2)
    
    def __len__(self):
        return len(self.X_train_num)
    
    def __getitem__(self,index):
        output = {'index': index,
                  'num_feat' : self.X_train_num[index], #(6,)
                  'cat_feat' : self.X_train_cat[index], #(3,)
                  'resale_price' : self.Y_train[index][0], #(1,)
        }
        return output


batch_size = 32
epochs = 100

dataset = TrainDataset(X_train_num, X_train_cat, Y_train, df_mrt, df_mall, price_scaler)
train_dataloader = DataLoader(dataset,batch_size=batch_size, shuffle=True,drop_last=True)

test_dataset = TrainDataset(X_test_num, X_test_cat, Y_test, df_mrt, df_mall, price_scaler)
test_dataloader = DataLoader(test_dataset,batch_size=batch_size, shuffle=False,drop_last=True)

# Define Model
import torch.nn as nn
class DGCNN(nn.Module):
    def __init__(self,dataset,k=8):
        super(DGCNN, self).__init__()
        self.k = k

        # cat embedders
        embed_dim = 8
        x_cat = dataset.X_train_cat
 
        num_towns = len(torch.unique(x_cat[:,0]))
        num_flat_type = len(torch.unique(x_cat[:,1]))
        num_flat_model = len(torch.unique(x_cat[:,2]))
        num_mrt = len(dataset.mrt_data[:,0])
        num_mall = len(dataset.mall_data[:,0])
  
        self.town_embedder = nn.Embedding(num_towns,embed_dim)
        self.flat_type_embedder = nn.Embedding(num_flat_type,embed_dim)
        self.flat_model_embedder = nn.Embedding(num_flat_model,embed_dim)
        self.mrt_embedder = nn.Embedding(num_mrt, 64)
        self.mall_embedder = nn.Embedding(num_mall, 64)

        # neighour numerical proj
        self.neighbour_num_proj = nn.Linear(4,24)

        # query numerical proj
        self.query_num_proj = nn.Linear(3,24)
        
        # neighbour mlp
        self.neighbour_mlp = nn.Sequential(
            nn.Linear(48, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
        )

        # query mlp
        self.query_mlp = nn.Sequential(
            nn.Linear(48, 64),
            nn.ReLU(),
            # nn.Linear(64, 64),
            # nn.ReLU(),
        )  
        self.query_mlp2 = nn.Sequential(
            nn.Linear(64, 128),
            nn.ReLU(),
            # nn.Linear(128, 128),
            # nn.ReLU(),
        )  

        # query and neighbour will create dim64 features from their own nodes
        # mrt and malls also create dim64 features from embedding

        self.unit_edgeconv = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
        )  

        self.mrt_edgeconv = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
        )  

        self.mall_edgeconv = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
        )

        self.decoder = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            # nn.Linear(256, 128),
            # nn.ReLU(),
            nn.Linear(128, 1),
            # nn.Sigmoid(), #(resale price)
        )  


    @staticmethod
    def knn(x,y,k=8,mask=None):
        # x has shape (N,2), y has shape (M,2)
        distances = torch.sum((x.unsqueeze(1) - y.unsqueeze(0)) ** 2, dim=-1)
        if mask is not None: 
            mask = torch.logical_not(mask) #invert mask
            distances[mask] += 1000.0

        _, indices = torch.topk(distances, k=k, dim=-1, largest=False,sorted=True)
        return indices, distances.gather(dim=-1, index=indices)

    def forward(self, self_num, self_cat, nn_num, nn_cat, nn_weights, mrt_idx, mrt_weights, mall_idx, mall_weights):

        # get embedding vectors for categorical variables
        self_cat_feats = self.cat_embedder(self_cat) #(B,1,24)
        nn_cat_feats = self.cat_embedder(nn_cat) #(B,N,24)

        # project numerical variables to 24 dim vectors
        self_num_feats = self.query_num_proj(self_num) #(B,1,24)
        nn_num_feats = self.neighbour_num_proj(nn_num) #(B,N,24)

        # stack cat and num features
        self_feats = torch.cat([self_num_feats, self_cat_feats],dim=-1) #(B,1,48)
        nn_feats = torch.cat([nn_num_feats, nn_cat_feats],dim=-1) #(B,N,48)

        # project self and nn feats with mlp
        self_feats = self.query_mlp(self_feats) #(B,1,64)
        nn_feats = self.neighbour_mlp(nn_feats) #(B,1,64)

        # get embedded vectors for mrt and mall
        mrt_feats = self.mrt_embedder(mrt_idx) #(B,4,64)
        mall_feats = self.mall_embedder(mall_idx) #(B,4,64)

        # create edge vectors (here we dont use feature diff unlike paper)
        query_nn_feats = torch.cat([self_feats.repeat(1,8,1),nn_feats],dim=-1) #(2,8,128)
        query_mrt_feats = torch.cat([self_feats.repeat(1,4,1),mrt_feats],dim=-1) #(2,4,128)
        query_mall_feats = torch.cat([self_feats.repeat(1,4,1),mall_feats],dim=-1) #(2,4,128)

        # run edgeconv on edge vectors
        unit_feats = self.unit_edgeconv(query_nn_feats) #(2,8,128)
        mrt_feats = self.mrt_edgeconv(query_mrt_feats) #(2,4,128)
        mall_feats = self.mall_edgeconv(query_mall_feats) #(2,4,128)

        # aggregate based on weights
        agg_unit_feats = torch.sum(nn_weights.unsqueeze(-1)*unit_feats,dim=1) #(2,128)
        agg_mrt_feats = torch.sum(mrt_weights.unsqueeze(-1)*mrt_feats,dim=1) #(2,128)
        agg_mall_feats = torch.sum(mall_weights.unsqueeze(-1)*mall_feats,dim=1) #(2,128)

        # project self query to 128 dim
        self_feats = self.query_mlp2(self_feats).squeeze(1) #(2,128)

        # combine all features
        out_feats = self_feats #+ agg_unit_feats # + agg_mrt_feats + agg_mall_feats #(2,128)

        # regress psm
        pred = self.decoder(out_feats).squeeze(-1)

        return pred

    def cat_embedder(self,cat_feat):
        town_feat = self.town_embedder(cat_feat[...,0].to(dtype=torch.int64))
        flat_type_feat = self.flat_type_embedder(cat_feat[...,1].to(dtype=torch.int64))
        flat_model_feat = self.flat_model_embedder(cat_feat[...,2].to(dtype=torch.int64))
        return torch.cat([town_feat, flat_type_feat, flat_model_feat],dim=-1) #(B,N,24)

k = 8 
mrt_locs = dataset.mrt_data
mall_locs = dataset.mall_data
unit_locs = dataset.X_train_num
unit_types = dataset.X_train_cat

if torch.cuda.is_available():
    device='cuda'
else:
    device='mps'
print(f"using device: {device}")

model = DGCNN(dataset).to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)

def idw(dists,thres=4.0):
    dist_mask = dists>thres
    dist_weights = 1/(dists+1)
    dist_weights[dist_mask] = 0.0
    num_neighbours = torch.logical_not(dist_mask).sum(dim=-1).unsqueeze(-1)+1 #(avoid division by 0)
    dist_weights /= num_neighbours# normalize weights by dist_mask
    return dist_weights, num_neighbours-1

for epoch_idx in range(epochs):

    # evaluation
    eval = True
    if eval:
        error_list = []
        preds_list = []
        gt_list = []
        pbar = tqdm(test_dataloader)
        for idx, data in enumerate(pbar):
            # get nearest mrt and malls
            mrt_idx, mrt_dists = DGCNN.knn(data['num_feat'][...,-2:],mrt_locs[...,-2:], k=4)
            mall_idx, mall_dists = DGCNN.knn(data['num_feat'][...,-2:],mall_locs[...,-2:], k=4)

            # restrict unit search to same housing type & town, create mask
            batch_size = data['cat_feat'].shape[0]
            unit_type_mask = unit_types.unsqueeze(0).repeat(batch_size,1,1)==data['cat_feat'].unsqueeze(1)
            unit_type_mask = torch.all(unit_type_mask,dim=-1)
            unit_idx, unit_dists = DGCNN.knn(data['num_feat'][...,-2:],unit_locs[...,-2:], k=k+1, mask=unit_type_mask)

            # remove self from topk retrievals
            batch_idx_list = []
            batch_dists_list = []
            for i in range(batch_size):
                if data['index'][i] in unit_idx[i]: # remove self
                    mask = unit_idx[i] != data['index'][i]
                    temp_idx = unit_idx[i][mask]
                    temp_dists = unit_dists[i][mask]

                else: # select first k
                    temp_idx = unit_idx[i][:-1]
                    temp_dists = unit_dists[i][:-1]
                batch_idx_list.append(temp_idx)
                batch_dists_list.append(temp_dists)
            unit_idx = torch.stack(batch_idx_list,dim=0)
            unit_dists = torch.stack(batch_dists_list,dim=0)

            # compute inverse distance weights
            mrt_weights, mrt_degree = idw(mrt_dists)
            mall_weights, mall_degree = idw(mall_dists)
            unit_weights, unit_degree = idw(unit_dists)

            # need to get node features for topk units and self unit
            nn_features_num = unit_locs.unsqueeze(0).repeat(batch_size,1,1).gather(dim=1,index=unit_idx.unsqueeze(-1).repeat(1,1,4)) # extract first 4 columns
            nn_features_cat = unit_types.unsqueeze(0).repeat(batch_size,1,1).gather(dim=1,index=unit_idx.unsqueeze(-1).repeat(1,1,3)) # extract first 3 columns

            # get self features - this is wrong for testing (unit locs dont contain test samples)
            self_features_num = data['num_feat'][...,:3].unsqueeze(1)
            self_features_cat = data['cat_feat'][...,:3].unsqueeze(1)
            
            # self_index = data['index'].unsqueeze(-1)
            # self_features_num = unit_locs.unsqueeze(0).repeat(batch_size,1,1).gather(dim=1,index=self_index.unsqueeze(-1).repeat(1,1,4)) # extract first 4 columns
            # self_features_cat = unit_types.unsqueeze(0).repeat(batch_size,1,1).gather(dim=1,index=self_index.unsqueeze(-1).repeat(1,1,3)) # extract first 3 columns
            # self_features_num = self_features_num[...,:3] # drop psm feature
            # print(self_features_num.shape)
            # print(self_features_cat.shape)
            # assert False

            self_features_num = self_features_num.to(device)
            self_features_cat = self_features_cat.to(device)
            nn_features_num = nn_features_num.to(device)
            nn_features_cat = nn_features_cat.to(device)
            nn_weights = unit_weights.to(device)
            mrt_idx = mrt_idx.to(device)
            mrt_weights = mrt_weights.to(device)
            mall_idx = mall_idx.to(device)
            mall_weights = mall_weights.to(device)

            pred = model(self_features_num,
                self_features_cat,
                nn_features_num,
                nn_features_cat,
                nn_weights,
                mrt_idx,
                mrt_weights,
                mall_idx,
                mall_weights)
            
            preds_list.append(pred)
            gt_list.append(data['resale_price'].to(device))

        preds = torch.cat(preds_list)
        gt = torch.cat(gt_list)
        preds_np = dataset.price_scaler.inverse_transform(preds.detach().cpu().numpy().reshape(-1,1)) 
        gt_np = dataset.price_scaler.inverse_transform(gt.detach().cpu().numpy().reshape(-1,1)) 
        errors_np = np.abs(preds_np-gt_np)
        print(f"price_error: {np.mean(errors_np)}")


    train = True
    if train:
        pbar = tqdm(train_dataloader)
        for idx, data in enumerate(pbar):

            optimizer.zero_grad()

            # get nearest mrt and malls
            mrt_idx, mrt_dists = DGCNN.knn(data['num_feat'][...,-2:],mrt_locs[...,-2:], k=4)
            mall_idx, mall_dists = DGCNN.knn(data['num_feat'][...,-2:],mall_locs[...,-2:], k=4)

            # restrict unit search to same housing type & town, create mask
            batch_size = data['cat_feat'].shape[0]
            unit_type_mask = unit_types.unsqueeze(0).repeat(batch_size,1,1)==data['cat_feat'].unsqueeze(1)
            unit_type_mask = torch.all(unit_type_mask,dim=-1)
            unit_idx, unit_dists = DGCNN.knn(data['num_feat'][...,-2:],unit_locs[...,-2:], k=k+1, mask=unit_type_mask)

            # remove self from topk retrievals
            batch_idx_list = []
            batch_dists_list = []
            for i in range(batch_size):
                if data['index'][i] in unit_idx[i]: # remove self
                    mask = unit_idx[i] != data['index'][i]
                    temp_idx = unit_idx[i][mask]
                    temp_dists = unit_dists[i][mask]

                else: # select first k
                    temp_idx = unit_idx[i][:-1]
                    temp_dists = unit_dists[i][:-1]
                batch_idx_list.append(temp_idx)
                batch_dists_list.append(temp_dists)
            unit_idx = torch.stack(batch_idx_list,dim=0)
            unit_dists = torch.stack(batch_dists_list,dim=0)


            # compute inverse distance weights
            mrt_weights, mrt_degree = idw(mrt_dists)
            mall_weights, mall_degree = idw(mall_dists)
            unit_weights, unit_degree = idw(unit_dists)

            # need to get node features for topk units and self unit
            nn_features_num = unit_locs.unsqueeze(0).repeat(batch_size,1,1).gather(dim=1,index=unit_idx.unsqueeze(-1).repeat(1,1,4)) # extract first 4 columns
            nn_features_cat = unit_types.unsqueeze(0).repeat(batch_size,1,1).gather(dim=1,index=unit_idx.unsqueeze(-1).repeat(1,1,3)) # extract first 3 columns

            # get self features
            self_index = data['index'].unsqueeze(-1)
            self_features_num = unit_locs.unsqueeze(0).repeat(batch_size,1,1).gather(dim=1,index=self_index.unsqueeze(-1).repeat(1,1,4)) # extract first 4 columns
            self_features_cat = unit_types.unsqueeze(0).repeat(batch_size,1,1).gather(dim=1,index=self_index.unsqueeze(-1).repeat(1,1,3)) # extract first 3 columns
            self_features_num = self_features_num[...,:3] # drop psm feature

            # place tensors on device
            self_features_num = self_features_num.to(device)
            self_features_cat = self_features_cat.to(device)
            nn_features_num = nn_features_num.to(device)
            nn_features_cat = nn_features_cat.to(device)
            nn_weights = unit_weights.to(device)
            mrt_idx = mrt_idx.to(device)
            mrt_weights = mrt_weights.to(device)
            mall_idx = mall_idx.to(device)
            mall_weights = mall_weights.to(device)

            # predict
            pred = model(self_features_num,
                self_features_cat,
                nn_features_num,
                nn_features_cat,
                nn_weights,
                mrt_idx,
                mrt_weights,
                mall_idx,
                mall_weights)
            
            # backward
            loss = torch.nn.MSELoss()(pred,data['resale_price'].to(device))
            pbar.set_postfix(epoch=epoch_idx, loss=loss.item())
            loss.backward()
            optimizer.step()



using device: mps


  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 391477.9375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 46969.44921875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 45181.61328125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 44413.48046875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 41564.26953125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 45605.30859375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 42970.18359375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 40787.1328125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 39937.70703125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 39495.54296875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 40292.1796875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 39732.09765625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 39106.3359375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 39172.76953125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 39698.55859375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 39060.76953125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 38747.1015625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 40586.46484375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 40388.48046875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 38641.96875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 38469.375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 38107.99609375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 38229.80859375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 38350.22265625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37866.63671875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37828.09765625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 38295.6171875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 38196.11328125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37728.65625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 39341.890625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37864.3203125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 38047.55859375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37368.8984375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37863.453125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37351.578125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 38271.14453125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37247.72265625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 39126.87109375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37191.5546875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 38777.421875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37495.3125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36921.1484375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37573.2578125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 38947.1796875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 38025.984375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37080.98046875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37375.46875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36611.26171875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 38711.96484375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36838.81640625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37596.390625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37590.56640625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37086.81640625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36776.98828125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36613.22265625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37248.390625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36727.69140625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37271.046875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36731.48046875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37115.03515625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36628.171875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36849.32421875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36248.16015625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36647.546875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37107.70703125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36961.05078125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36627.23828125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36611.51953125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37291.30078125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37831.875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37005.875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36477.6796875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36485.859375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36486.83984375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36792.3046875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36802.75390625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36536.6875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36373.40234375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37866.02734375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36282.8125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36377.046875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36875.06640625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36498.54296875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36138.17578125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36451.51171875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36600.40234375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 35565.9375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36263.59765625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36374.3984375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 35913.19140625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36265.4375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36589.90625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36570.5703125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36876.96484375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36502.90625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37702.359375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36020.3671875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36321.05859375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 37214.11328125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36089.5703125


  0%|          | 0/1846 [00:00<?, ?it/s]

### Geospatial DGCNN Training & Evaluation
Geospatial DGCNN constructs a unique graph per example. To create the geospatial graph, we search using K nearest neighbours to find up to 8 nearest recent transactions in the same town with the same flat type and model. We also find up to 4 nearest MRTs and malls within a 4km distance from the query unit. Nodes represent the apartment/mrt/mall, normalized inverse distance is used as the edge weights. Thus, nearer nodes have greater influence in the aggregated edge features. We apply EdgeConv to the graph structure, computing edge features as a function of adjacent nodes. The query node features are updated based on inverse distance weighted aggregation. A decoder MLP decodes the updated query node feature into a regressed valuation for the query apartment. 

In [87]:
from torch.utils.data import Dataset, DataLoader
import torch

class TrainDataset(Dataset):
    def __init__(self,X_train_num, X_train_cat, Y_train,df_mrt,df_mall,psm_scaler):
        self.X_train_num = torch.Tensor(X_train_num.values) #(N,6)
        self.X_train_cat = torch.Tensor(X_train_cat.values) #(N,3)
        self.Y_train = torch.Tensor(Y_train.values) #(N,1)
        self.price_scaler = price_scaler
        self.mrt_data = torch.Tensor(df_mrt[['latitude','longitude']].values) #(M,2)
        self.mall_data = torch.Tensor(df_mall[['latitude','longitude']].values) #(K,2)
    
    def __len__(self):
        return len(self.X_train_num)
    
    def __getitem__(self,index):
        output = {'index': index,
                  'num_feat' : self.X_train_num[index], #(6,)
                  'cat_feat' : self.X_train_cat[index], #(3,)
                  'resale_price' : self.Y_train[index][0], #(1,)
        }
        return output


batch_size = 32
epochs = 200

dataset = TrainDataset(X_train_num, X_train_cat, Y_train, df_mrt, df_mall, price_scaler)
train_dataloader = DataLoader(dataset,batch_size=batch_size, shuffle=True,drop_last=True)

test_dataset = TrainDataset(X_test_num, X_test_cat, Y_test, df_mrt, df_mall, price_scaler)
test_dataloader = DataLoader(test_dataset,batch_size=batch_size, shuffle=False,drop_last=True)

# Define Model
import torch.nn as nn
class DGCNN(nn.Module):
    def __init__(self,dataset,k=8):
        super(DGCNN, self).__init__()
        self.k = k

        # cat embedders
        embed_dim = 8
        x_cat = dataset.X_train_cat
 
        num_towns = len(torch.unique(x_cat[:,0]))
        num_flat_type = len(torch.unique(x_cat[:,1]))
        num_flat_model = len(torch.unique(x_cat[:,2]))
        num_mrt = len(dataset.mrt_data[:,0])
        num_mall = len(dataset.mall_data[:,0])
  
        self.town_embedder = nn.Embedding(num_towns,embed_dim)
        self.flat_type_embedder = nn.Embedding(num_flat_type,embed_dim)
        self.flat_model_embedder = nn.Embedding(num_flat_model,embed_dim)
        self.mrt_embedder = nn.Embedding(num_mrt, 64)
        self.mall_embedder = nn.Embedding(num_mall, 64)

        # neighour numerical proj
        self.neighbour_num_proj = nn.Linear(4,24)

        # query numerical proj
        self.query_num_proj = nn.Linear(3,24)
        
        # neighbour mlp
        self.neighbour_mlp = nn.Sequential(
            nn.Linear(48, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
        )

        # query mlp
        self.query_mlp = nn.Sequential(
            nn.Linear(48, 64),
            nn.ReLU(),
            # nn.Linear(64, 64),
            # nn.ReLU(),
        )  
        self.query_mlp2 = nn.Sequential(
            nn.Linear(64, 128),
            nn.ReLU(),
            # nn.Linear(128, 128),
            # nn.ReLU(),
        )  

        # query and neighbour will create dim64 features from their own nodes
        # mrt and malls also create dim64 features from embedding

        self.unit_edgeconv = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
        )  

        self.mrt_edgeconv = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
        )  

        self.mall_edgeconv = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
        )

        self.decoder = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            # nn.Linear(256, 128),
            # nn.ReLU(),
            nn.Linear(128, 1),
            # nn.Sigmoid(), #(resale price)
        )  


    @staticmethod
    def knn(x,y,k=8,mask=None):
        # x has shape (N,2), y has shape (M,2)
        distances = torch.sum((x.unsqueeze(1) - y.unsqueeze(0)) ** 2, dim=-1)
        if mask is not None: 
            mask = torch.logical_not(mask) #invert mask
            distances[mask] += 1000.0

        _, indices = torch.topk(distances, k=k, dim=-1, largest=False,sorted=True)
        return indices, distances.gather(dim=-1, index=indices)

    def forward(self, self_num, self_cat, nn_num, nn_cat, nn_weights, mrt_idx, mrt_weights, mall_idx, mall_weights):

        # get embedding vectors for categorical variables
        self_cat_feats = self.cat_embedder(self_cat) #(B,1,24)
        nn_cat_feats = self.cat_embedder(nn_cat) #(B,N,24)

        # project numerical variables to 24 dim vectors
        self_num_feats = self.query_num_proj(self_num) #(B,1,24)
        nn_num_feats = self.neighbour_num_proj(nn_num) #(B,N,24)

        # stack cat and num features
        self_feats = torch.cat([self_num_feats, self_cat_feats],dim=-1) #(B,1,48)
        nn_feats = torch.cat([nn_num_feats, nn_cat_feats],dim=-1) #(B,N,48)

        # project self and nn feats with mlp
        self_feats = self.query_mlp(self_feats) #(B,1,64)
        nn_feats = self.neighbour_mlp(nn_feats) #(B,1,64)

        # get embedded vectors for mrt and mall
        mrt_feats = self.mrt_embedder(mrt_idx) #(B,4,64)
        mall_feats = self.mall_embedder(mall_idx) #(B,4,64)

        # create edge vectors (here we dont use feature diff unlike paper)
        query_nn_feats = torch.cat([self_feats.repeat(1,8,1),nn_feats],dim=-1) #(2,8,128)
        query_mrt_feats = torch.cat([self_feats.repeat(1,4,1),mrt_feats],dim=-1) #(2,4,128)
        query_mall_feats = torch.cat([self_feats.repeat(1,4,1),mall_feats],dim=-1) #(2,4,128)

        # run edgeconv on edge vectors
        unit_feats = self.unit_edgeconv(query_nn_feats) #(2,8,128)
        mrt_feats = self.mrt_edgeconv(query_mrt_feats) #(2,4,128)
        mall_feats = self.mall_edgeconv(query_mall_feats) #(2,4,128)

        # aggregate based on weights
        agg_unit_feats = torch.sum(nn_weights.unsqueeze(-1)*unit_feats,dim=1) #(2,128)
        agg_mrt_feats = torch.sum(mrt_weights.unsqueeze(-1)*mrt_feats,dim=1) #(2,128)
        agg_mall_feats = torch.sum(mall_weights.unsqueeze(-1)*mall_feats,dim=1) #(2,128)

        # project self query to 128 dim
        self_feats = self.query_mlp2(self_feats).squeeze(1) #(2,128)

        # combine all features
        out_feats = self_feats + agg_unit_feats + agg_mrt_feats + agg_mall_feats #(2,128)

        # regress psm
        pred = self.decoder(out_feats).squeeze(-1)

        return pred

    def cat_embedder(self,cat_feat):
        town_feat = self.town_embedder(cat_feat[...,0].to(dtype=torch.int64))
        flat_type_feat = self.flat_type_embedder(cat_feat[...,1].to(dtype=torch.int64))
        flat_model_feat = self.flat_model_embedder(cat_feat[...,2].to(dtype=torch.int64))
        return torch.cat([town_feat, flat_type_feat, flat_model_feat],dim=-1) #(B,N,24)

k = 8 
mrt_locs = dataset.mrt_data
mall_locs = dataset.mall_data
unit_locs = dataset.X_train_num
unit_types = dataset.X_train_cat

if torch.cuda.is_available():
    device='cuda'
else:
    device='mps'
print(f"using device: {device}")

model = DGCNN(dataset).to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)

def idw(dists,thres=4.0):
    dist_mask = dists>thres
    dist_weights = 1/(dists+1)
    dist_weights[dist_mask] = 0.0
    num_neighbours = torch.logical_not(dist_mask).sum(dim=-1).unsqueeze(-1)+1 #(avoid division by 0)
    dist_weights /= num_neighbours# normalize weights by dist_mask
    return dist_weights, num_neighbours-1

for epoch_idx in range(epochs):

    # evaluation
    eval = True
    with torch.no_grad():
        error_list = []
        preds_list = []
        gt_list = []
        pbar = tqdm(test_dataloader)
        for idx, data in enumerate(pbar):
            # get nearest mrt and malls
            mrt_idx, mrt_dists = DGCNN.knn(data['num_feat'][...,-2:],mrt_locs[...,-2:], k=4)
            mall_idx, mall_dists = DGCNN.knn(data['num_feat'][...,-2:],mall_locs[...,-2:], k=4)

            # restrict unit search to same housing type & town, create mask
            batch_size = data['cat_feat'].shape[0]
            unit_type_mask = unit_types.unsqueeze(0).repeat(batch_size,1,1)==data['cat_feat'].unsqueeze(1)
            unit_type_mask = torch.all(unit_type_mask,dim=-1)
            unit_idx, unit_dists = DGCNN.knn(data['num_feat'][...,-2:],unit_locs[...,-2:], k=k+1, mask=unit_type_mask)

            # remove self from topk retrievals
            batch_idx_list = []
            batch_dists_list = []
            for i in range(batch_size):
                if data['index'][i] in unit_idx[i]: # remove self
                    mask = unit_idx[i] != data['index'][i]
                    temp_idx = unit_idx[i][mask]
                    temp_dists = unit_dists[i][mask]

                else: # select first k
                    temp_idx = unit_idx[i][:-1]
                    temp_dists = unit_dists[i][:-1]
                batch_idx_list.append(temp_idx)
                batch_dists_list.append(temp_dists)
            unit_idx = torch.stack(batch_idx_list,dim=0)
            unit_dists = torch.stack(batch_dists_list,dim=0)

            # compute inverse distance weights
            mrt_weights, mrt_degree = idw(mrt_dists)
            mall_weights, mall_degree = idw(mall_dists)
            unit_weights, unit_degree = idw(unit_dists)

            # need to get node features for topk units and self unit
            nn_features_num = unit_locs.unsqueeze(0).repeat(batch_size,1,1).gather(dim=1,index=unit_idx.unsqueeze(-1).repeat(1,1,4)) # extract first 4 columns
            nn_features_cat = unit_types.unsqueeze(0).repeat(batch_size,1,1).gather(dim=1,index=unit_idx.unsqueeze(-1).repeat(1,1,3)) # extract first 3 columns

            # get self features - this is wrong for testing (unit locs dont contain test samples)
            self_features_num = data['num_feat'][...,:3].unsqueeze(1)
            self_features_cat = data['cat_feat'][...,:3].unsqueeze(1)
            
            # self_index = data['index'].unsqueeze(-1)
            # self_features_num = unit_locs.unsqueeze(0).repeat(batch_size,1,1).gather(dim=1,index=self_index.unsqueeze(-1).repeat(1,1,4)) # extract first 4 columns
            # self_features_cat = unit_types.unsqueeze(0).repeat(batch_size,1,1).gather(dim=1,index=self_index.unsqueeze(-1).repeat(1,1,3)) # extract first 3 columns
            # self_features_num = self_features_num[...,:3] # drop psm feature
            # print(self_features_num.shape)
            # print(self_features_cat.shape)
            # assert False

            self_features_num = self_features_num.to(device)
            self_features_cat = self_features_cat.to(device)
            nn_features_num = nn_features_num.to(device)
            nn_features_cat = nn_features_cat.to(device)
            nn_weights = unit_weights.to(device)
            mrt_idx = mrt_idx.to(device)
            mrt_weights = mrt_weights.to(device)
            mall_idx = mall_idx.to(device)
            mall_weights = mall_weights.to(device)

            pred = model(self_features_num,
                self_features_cat,
                nn_features_num,
                nn_features_cat,
                nn_weights,
                mrt_idx,
                mrt_weights,
                mall_idx,
                mall_weights)
            
            preds_list.append(pred)
            gt_list.append(data['resale_price'].to(device))

        preds = torch.cat(preds_list)
        gt = torch.cat(gt_list)
        preds_np = dataset.price_scaler.inverse_transform(preds.detach().cpu().numpy().reshape(-1,1)) 
        gt_np = dataset.price_scaler.inverse_transform(gt.detach().cpu().numpy().reshape(-1,1)) 
        errors_np = np.abs(preds_np-gt_np)
        print(f"price_error: {np.mean(errors_np)}")


    train = True
    if train:
        pbar = tqdm(train_dataloader)
        for idx, data in enumerate(pbar):

            optimizer.zero_grad()

            # get nearest mrt and malls
            mrt_idx, mrt_dists = DGCNN.knn(data['num_feat'][...,-2:],mrt_locs[...,-2:], k=4)
            mall_idx, mall_dists = DGCNN.knn(data['num_feat'][...,-2:],mall_locs[...,-2:], k=4)

            # restrict unit search to same housing type & town, create mask
            batch_size = data['cat_feat'].shape[0]
            unit_type_mask = unit_types.unsqueeze(0).repeat(batch_size,1,1)==data['cat_feat'].unsqueeze(1)
            unit_type_mask = torch.all(unit_type_mask,dim=-1)
            unit_idx, unit_dists = DGCNN.knn(data['num_feat'][...,-2:],unit_locs[...,-2:], k=k+1, mask=unit_type_mask)

            # remove self from topk retrievals
            batch_idx_list = []
            batch_dists_list = []
            for i in range(batch_size):
                if data['index'][i] in unit_idx[i]: # remove self
                    mask = unit_idx[i] != data['index'][i]
                    temp_idx = unit_idx[i][mask]
                    temp_dists = unit_dists[i][mask]

                else: # select first k
                    temp_idx = unit_idx[i][:-1]
                    temp_dists = unit_dists[i][:-1]
                batch_idx_list.append(temp_idx)
                batch_dists_list.append(temp_dists)
            unit_idx = torch.stack(batch_idx_list,dim=0)
            unit_dists = torch.stack(batch_dists_list,dim=0)


            # compute inverse distance weights
            mrt_weights, mrt_degree = idw(mrt_dists)
            mall_weights, mall_degree = idw(mall_dists)
            unit_weights, unit_degree = idw(unit_dists)

            # need to get node features for topk units and self unit
            nn_features_num = unit_locs.unsqueeze(0).repeat(batch_size,1,1).gather(dim=1,index=unit_idx.unsqueeze(-1).repeat(1,1,4)) # extract first 4 columns
            nn_features_cat = unit_types.unsqueeze(0).repeat(batch_size,1,1).gather(dim=1,index=unit_idx.unsqueeze(-1).repeat(1,1,3)) # extract first 3 columns

            # get self features
            self_index = data['index'].unsqueeze(-1)
            self_features_num = unit_locs.unsqueeze(0).repeat(batch_size,1,1).gather(dim=1,index=self_index.unsqueeze(-1).repeat(1,1,4)) # extract first 4 columns
            self_features_cat = unit_types.unsqueeze(0).repeat(batch_size,1,1).gather(dim=1,index=self_index.unsqueeze(-1).repeat(1,1,3)) # extract first 3 columns
            self_features_num = self_features_num[...,:3] # drop psm feature

            # place tensors on device
            self_features_num = self_features_num.to(device)
            self_features_cat = self_features_cat.to(device)
            nn_features_num = nn_features_num.to(device)
            nn_features_cat = nn_features_cat.to(device)
            nn_weights = unit_weights.to(device)
            mrt_idx = mrt_idx.to(device)
            mrt_weights = mrt_weights.to(device)
            mall_idx = mall_idx.to(device)
            mall_weights = mall_weights.to(device)

            # predict
            pred = model(self_features_num,
                self_features_cat,
                nn_features_num,
                nn_features_cat,
                nn_weights,
                mrt_idx,
                mrt_weights,
                mall_idx,
                mall_weights)
            
            # backward
            loss = torch.nn.MSELoss()(pred,data['resale_price'].to(device))
            pbar.set_postfix(epoch=epoch_idx, loss=loss.item())
            loss.backward()
            optimizer.step()



using device: mps


  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 338457.90625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 38539.61328125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36220.78515625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 35775.23828125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 35073.3515625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 34753.953125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 33748.109375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36972.26171875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 33174.48828125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 34857.38671875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 33045.24609375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 36713.140625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 32655.314453125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 33448.2265625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 32304.0


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 33574.8125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 33299.01171875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 33111.71484375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 32777.0859375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 32781.16015625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 32051.2109375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 32169.025390625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 32112.150390625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 32001.888671875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 31436.80859375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 31373.578125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 31480.474609375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 31813.453125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 32035.64453125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 31345.103515625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 30278.169921875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 30451.263671875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 29077.2421875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 29491.162109375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 30280.087890625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 29243.876953125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 28686.83984375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 27107.43359375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 28074.4765625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 28221.267578125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25866.59375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 28316.150390625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25855.275390625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25194.8046875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 26883.65234375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 27209.041015625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 27619.427734375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 26093.91015625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25285.517578125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24692.91796875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 27362.662109375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25412.171875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 29232.759765625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25433.38671875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 28221.7265625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25653.65625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 26553.626953125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25037.96484375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 27200.30859375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24833.306640625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 28427.728515625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 29962.208984375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23862.66015625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 26212.5


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23931.17578125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24128.05859375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25781.220703125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 28231.5625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24543.490234375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24031.884765625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 27727.107421875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24455.154296875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 26730.994140625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25323.224609375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23875.396484375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25185.0078125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24201.53125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24684.515625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23559.021484375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25457.087890625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24668.765625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23433.314453125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24654.953125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24238.490234375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23479.236328125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 28396.078125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25680.80859375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24345.009765625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 29441.8046875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24856.171875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 26803.46484375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23831.6796875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23464.884765625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23692.990234375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24103.400390625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 27587.017578125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23383.267578125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24420.05078125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24216.984375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24119.865234375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23579.9921875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25217.5859375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 28031.443359375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24068.107421875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25482.83203125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25505.05078125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23819.912109375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24216.185546875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25107.177734375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23742.1796875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24730.072265625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23765.798828125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23900.072265625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 28211.220703125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25732.93359375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 27223.63671875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25546.3046875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24819.873046875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25230.623046875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25140.3359375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23795.046875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24447.6640625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24076.728515625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23988.421875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24671.453125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 27019.263671875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 31768.474609375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23748.623046875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 31637.79296875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24099.87890625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25071.384765625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23611.38671875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23901.638671875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25061.072265625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23362.736328125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24339.4921875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23986.5


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24002.013671875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24591.392578125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25955.515625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24634.4375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24428.779296875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24067.119140625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24916.1640625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24339.78125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25388.541015625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 26185.732421875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25073.935546875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24105.779296875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23438.072265625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23906.37890625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24009.046875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24453.533203125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23986.451171875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24319.40625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23453.283203125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23766.37109375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 26494.267578125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23527.376953125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 30670.90625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25088.03515625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24446.17578125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24103.041015625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25157.119140625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24640.90625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24025.2890625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24666.634765625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23768.34375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24305.921875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25515.314453125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25173.509765625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23359.634765625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 26216.05078125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23982.533203125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25124.8046875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 26249.869140625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23769.04296875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 29356.865234375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24809.986328125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 33750.984375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23398.31640625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24569.9453125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 26195.876953125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24185.201171875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25651.79296875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24826.8203125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25950.0390625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 26756.521484375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 27010.75


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 25688.521484375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24246.265625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 26615.2890625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24362.85546875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24939.92578125


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24276.32421875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24536.005859375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 28056.572265625


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 23862.35546875


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24685.55859375


  0%|          | 0/1846 [00:00<?, ?it/s]

  0%|          | 0/97 [00:00<?, ?it/s]

price_error: 24359.40234375


  0%|          | 0/1846 [00:00<?, ?it/s]