#HDB DGCNN Regression

Our objective is to compare the performance of a MLP price regressor against a DGCNN based price regressor. We expect the DGCNN to work better by taking into account geospatial relationships with 1) nearby recent sales, 2) nearby malls, 3) nearby MRTs. This should give us a more localised prior to do inference. Instead of getting MLP to learn every single feature, we extract relevant priors to make inference easier. We hope to see the DGCNN perform better than the MLP. 

In [12]:
import pickle
import os
from google.colab import drive
import pandas as pd
from tqdm.notebook import tqdm

In [13]:
try:
    drive.mount('/content/gdrive')
except:
    pass

drive_path = 'gdrive/My Drive/Projects/HDBnet/'

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [14]:
with open(drive_path+'housing_data.pickle', 'rb') as f:
    housing_data = pickle.load(f)

# sort by descending date (most recent first)

housing_data = housing_data.sort_values(by=['year','town'],ascending=[False,True])

housing_data = housing_data.reset_index(drop=True)
print(housing_data.info())
display(housing_data.head())

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 62214 entries, 0 to 62213
Data columns (total 14 columns):
 #   Column           Non-Null Count  Dtype  
---  ------           --------------  -----  
 0   town             62214 non-null  object 
 1   flat_type        62214 non-null  object 
 2   block            62214 non-null  object 
 3   street_name      62214 non-null  object 
 4   floor_area_sqm   62214 non-null  float64
 5   flat_model       62214 non-null  object 
 6   remaining_lease  62214 non-null  float64
 7   resale_price     62214 non-null  float64
 8   year             62214 non-null  int64  
 9   storey           62214 non-null  float64
 10  psm              62214 non-null  float64
 11  address          62214 non-null  object 
 12  latitude         62214 non-null  float64
 13  longitude        62214 non-null  float64
dtypes: float64(7), int64(1), object(6)
memory usage: 6.6+ MB
None


Unnamed: 0,town,flat_type,block,street_name,floor_area_sqm,flat_model,remaining_lease,resale_price,year,storey,psm,address,latitude,longitude
0,ANG MO KIO,2 ROOM,406,ANG MO KIO AVE 10,44.0,Improved,55.0,267000.0,2023,2.0,6068.181818,406 ANG MO KIO AVE 10,8.598078,0.314893
1,ANG MO KIO,2 ROOM,323,ANG MO KIO AVE 3,49.0,Improved,53.0,300000.0,2023,5.0,6122.44898,323 ANG MO KIO AVE 3,9.250907,-0.371289
2,ANG MO KIO,2 ROOM,314,ANG MO KIO AVE 3,44.0,Improved,54.0,280000.0,2023,4.0,6363.636364,314 ANG MO KIO AVE 3,9.064984,-0.10734
3,ANG MO KIO,2 ROOM,314,ANG MO KIO AVE 3,44.0,Improved,54.0,282000.0,2023,7.0,6409.090909,314 ANG MO KIO AVE 3,9.064984,-0.10734
4,ANG MO KIO,2 ROOM,170,ANG MO KIO AVE 4,45.0,Improved,62.0,289800.0,2023,2.0,6440.0,170 ANG MO KIO AVE 4,9.924554,-1.626898


In [15]:
with open(drive_path+'mrt_data.pickle', 'rb') as f:
    mrt_data = pickle.load(f)
display(mrt_data.head())

with open(drive_path+'mall_data.pickle', 'rb') as f:
    mall_data = pickle.load(f)
display(mall_data.head())

# encode addresses
from sklearn.preprocessing import LabelEncoder
# mrt_encoder = LabelEncoder()
# df_mrt = mrt_data.apply(mrt_encoder.fit_transform).astype('int')
df_mrt = mrt_data.copy()
df_mrt[['latitude','longitude']] = mrt_data[['latitude','longitude']].astype('float')

# mall_encoder = LabelEncoder()
# df_mall = mall_data.apply(mall_encoder.fit_transform).astype('int')
df_mall = mall_data.copy()
df_mall[['latitude','longitude']] = mall_data[['latitude','longitude']].astype('float')

Unnamed: 0,address,latitude,longitude
0,Dhoby Ghaut,1.593235,-0.540454
1,Kranji,12.535859,-12.561699
2,Sengkang,11.880305,5.016745
3,Toa Payoh,5.831895,-0.836442
4,Tuas Crescent,0.8394,-24.598463


Unnamed: 0,address,latitude,longitude
0,Anchorpoint,0.518399,-5.057376
1,ERA APAC Centre,5.244562,-0.185457
2,The Star Vista,2.495502,-6.965425
3,100 AM,-1.05751,-0.841594
4,Leisure Park Kallang,1.987606,2.793656


In [16]:
# split numerical and categorical features

df_temp = housing_data.copy()
df_temp = df_temp.drop(columns=['block','address','street_name'])

#extract categorical features and encode them
house_encoder = LabelEncoder()
df_cat = df_temp.select_dtypes(include=['object'])
df_cat = df_cat.apply(house_encoder.fit_transform)
df_cat = df_cat.reset_index(drop=True) #.astype('int')
display(df_cat.head())

#extract numerical features
df_num = df_temp.select_dtypes(include=['float64'])
df_num = df_num.reset_index(drop=True)
display(df_num.head())

Unnamed: 0,town,flat_type,flat_model
0,0,1,5
1,0,1,5
2,0,1,5
3,0,1,5
4,0,1,5


Unnamed: 0,floor_area_sqm,remaining_lease,resale_price,storey,psm,latitude,longitude
0,44.0,55.0,267000.0,2.0,6068.181818,8.598078,0.314893
1,49.0,53.0,300000.0,5.0,6122.44898,9.250907,-0.371289
2,44.0,54.0,280000.0,4.0,6363.636364,9.064984,-0.10734
3,44.0,54.0,282000.0,7.0,6409.090909,9.064984,-0.10734
4,45.0,62.0,289800.0,2.0,6440.0,9.924554,-1.626898


In [17]:
df_final = pd.concat([df_num, df_cat], axis=1)

df_Y = df_final[['psm','floor_area_sqm']].copy() #.div(1000)
df_X = df_final.drop(columns=['resale_price'])

display(df_Y.head())
display(df_X.head())
print(df_X.info())

Unnamed: 0,psm,floor_area_sqm
0,6068.181818,44.0
1,6122.44898,49.0
2,6363.636364,44.0
3,6409.090909,44.0
4,6440.0,45.0


Unnamed: 0,floor_area_sqm,remaining_lease,storey,psm,latitude,longitude,town,flat_type,flat_model
0,44.0,55.0,2.0,6068.181818,8.598078,0.314893,0,1,5
1,49.0,53.0,5.0,6122.44898,9.250907,-0.371289,0,1,5
2,44.0,54.0,4.0,6363.636364,9.064984,-0.10734,0,1,5
3,44.0,54.0,7.0,6409.090909,9.064984,-0.10734,0,1,5
4,45.0,62.0,2.0,6440.0,9.924554,-1.626898,0,1,5


<class 'pandas.core.frame.DataFrame'>
RangeIndex: 62214 entries, 0 to 62213
Data columns (total 9 columns):
 #   Column           Non-Null Count  Dtype  
---  ------           --------------  -----  
 0   floor_area_sqm   62214 non-null  float64
 1   remaining_lease  62214 non-null  float64
 2   storey           62214 non-null  float64
 3   psm              62214 non-null  float64
 4   latitude         62214 non-null  float64
 5   longitude        62214 non-null  float64
 6   town             62214 non-null  int64  
 7   flat_type        62214 non-null  int64  
 8   flat_model       62214 non-null  int64  
dtypes: float64(6), int64(3)
memory usage: 4.3 MB
None


Train Test Split

Scale numerical features in training set

In [18]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler
import numpy as np

X_train, X_test, Y_train, Y_test = train_test_split(df_X,df_Y, test_size=0.1, shuffle=True)
# print(x_train.info())
# print(x_test.info())
# print(y_train.info())
# print(y_test.info())

#train set
X_train_num = X_train.select_dtypes(include=['float64'])
X_train_cat = X_train.select_dtypes(include=['int64'])
# X_train_cat = X_train.select_dtypes(include=['object'])

#test set
X_test_num = X_test.select_dtypes(include=['float64'])
X_test_cat = X_test.select_dtypes(include=['int64'])
# X_test_cat = X_test.select_dtypes(include=['object'])

#scaling for numerical features
for feature in X_train_num:
    if feature in ['longitude','latitude']:
        continue

    scaler = MinMaxScaler()
    scaler.fit(np.array(X_train_num[feature]).reshape(-1, 1))
    X_train_num[feature] = scaler.transform(np.array(X_train_num[feature]).reshape(-1, 1))
    X_test_num[feature] = scaler.transform(np.array(X_test_num[feature]).reshape(-1, 1))

    # scale psm results
    if feature == 'psm':
        Y_train[feature] = scaler.transform(np.array(Y_train[feature]).reshape(-1, 1))
        # Y_train[feature] = scaler.inverse_transform(np.array(Y_train[feature]).reshape(-1, 1))
        Y_test[feature] = scaler.transform(np.array(Y_test[feature]).reshape(-1, 1))
        # Y_test[feature] = scaler.inverse_transform(np.array(Y_test[feature]).reshape(-1, 1))
        psm_scaler = scaler

In [19]:
display(X_train_num.head())
display(X_test_num.head())
display(X_train_cat.head())
display(X_test_cat.head())
display(Y_train.head())
display(Y_test.head())

Unnamed: 0,floor_area_sqm,remaining_lease,storey,psm,latitude,longitude
36494,0.169811,0.381818,0.122449,0.136808,7.376775,-10.552062
33284,0.287736,0.254545,0.142857,0.156461,9.424728,0.251445
40375,0.542453,0.563636,0.081633,0.134568,12.459706,-11.812988
46735,0.334906,0.618182,0.244898,0.29992,4.188519,1.885827
28337,0.358491,0.545455,0.0,0.165396,7.715138,12.302897


Unnamed: 0,floor_area_sqm,remaining_lease,storey,psm,latitude,longitude
55750,0.386792,0.436364,0.061224,0.269641,7.100092,2.54812
37613,0.268868,0.2,0.306122,0.23497,-0.342067,-2.470033
52372,0.292453,0.672727,0.204082,0.176478,18.39943,-3.650218
54107,0.292453,0.945455,0.122449,0.181832,12.576459,3.05021
108,0.481132,0.545455,0.244898,0.312948,10.361516,-1.402614


Unnamed: 0,town,flat_type,flat_model
36494,3,2,12
33284,0,3,12
40375,8,5,7
46735,14,3,8
28337,22,3,8


Unnamed: 0,town,flat_type,flat_model
55750,21,3,8
37613,4,3,5
52372,19,3,8
54107,20,3,8
108,0,4,8


Unnamed: 0,psm,floor_area_sqm
36494,0.136808,67.0
33284,0.156461,92.0
40375,0.134568,146.0
46735,0.29992,102.0
28337,0.165396,107.0


Unnamed: 0,psm,floor_area_sqm
55750,0.269641,113.0
37613,0.23497,88.0
52372,0.176478,93.0
54107,0.181832,93.0
108,0.312948,133.0


Create PyTorch Dataset

In [None]:
from torch.utils.data import Dataset, DataLoader
import torch

class TrainDataset(Dataset):
    def __init__(self,X_train_num, X_train_cat, Y_train,df_mrt,df_mall,psm_scaler):
        self.X_train_num = torch.Tensor(X_train_num.values) #(N,6)
        self.X_train_cat = torch.Tensor(X_train_cat.values) #(N,3)
        self.Y_train = torch.Tensor(Y_train.values) #(N,2)
        self.psm_scaler = psm_scaler
        self.mrt_data = torch.Tensor(df_mrt[['latitude','longitude']].values) #(M,2)
        self.mall_data = torch.Tensor(df_mall[['latitude','longitude']].values) #(K,2)
    
    def __len__(self):
        return len(self.X_train_num)
    
    def __getitem__(self,index):
        output = {'index': index,
                  'num_feat' : self.X_train_num[index], #(6,)
                  'cat_feat' : self.X_train_cat[index], #(3,)
                  'psm' : self.Y_train[index,0], #(1,)
                  'area' : self.Y_train[index,1], #(1,)
        }
        return output


batch_size = 32
epochs = 10

dataset = TrainDataset(X_train_num, X_train_cat, Y_train, df_mrt, df_mall, psm_scaler)
train_dataloader = DataLoader(dataset,batch_size=batch_size, shuffle=True,drop_last=True)

test_dataset = TrainDataset(X_test_num, X_test_cat, Y_test, df_mrt, df_mall, psm_scaler)
test_dataloader = DataLoader(test_dataset,batch_size=batch_size, shuffle=False,drop_last=True)

# Define Model
import torch.nn as nn
class DGCNN(nn.Module):
    def __init__(self,dataset,k=8):
        super(DGCNN, self).__init__()
        self.k = k

        # cat embedders
        embed_dim = 8
        x_cat = dataset.X_train_cat
 
        num_towns = len(torch.unique(x_cat[:,0]))
        num_flat_type = len(torch.unique(x_cat[:,1]))
        num_flat_model = len(torch.unique(x_cat[:,2]))
        num_mrt = len(dataset.mrt_data[:,0])
        num_mall = len(dataset.mall_data[:,0])
  
        self.town_embedder = nn.Embedding(num_towns,embed_dim)
        self.flat_type_embedder = nn.Embedding(num_flat_type,embed_dim)
        self.flat_model_embedder = nn.Embedding(num_flat_model,embed_dim)
        self.mrt_embedder = nn.Embedding(num_mrt, 64)
        self.mall_embedder = nn.Embedding(num_mall, 64)

        # neighour numerical proj
        self.neighbour_num_proj = nn.Linear(4,24)

        # query numerical proj
        self.query_num_proj = nn.Linear(3,24)
        
        # neighbour mlp
        self.neighbour_mlp = nn.Sequential(
            nn.Linear(48, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
        )

        # query mlp
        self.query_mlp = nn.Sequential(
            nn.Linear(48, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
        )  
        self.query_mlp2 = nn.Sequential(
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
        )  

        # query and neighbour will create dim64 features from their own nodes
        # mrt and malls also create dim64 features from embedding

        self.unit_edgeconv = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
        )  

        self.mrt_edgeconv = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
        )  

        self.mall_edgeconv = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
        )

        self.decoder = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid(), #(psm price)
        )  


    @staticmethod
    def knn(x,y,k=8,mask=None):
        # x has shape (N,2), y has shape (M,2)
        distances = torch.sum((x.unsqueeze(1) - y.unsqueeze(0)) ** 2, dim=-1)
        if mask is not None: 
            mask = torch.logical_not(mask) #invert mask
            distances[mask] += 1000.0

        _, indices = torch.topk(distances, k=k, dim=-1, largest=False,sorted=True)
        return indices, distances.gather(dim=-1, index=indices)

    def forward(self, self_num, self_cat, nn_num, nn_cat, nn_weights, mrt_idx, mrt_weights, mall_idx, mall_weights):

        # get embedding vectors for categorical variables
        self_cat_feats = self.cat_embedder(self_cat) #(B,1,24)
        nn_cat_feats = self.cat_embedder(nn_cat) #(B,N,24)

        # project numerical variables to 24 dim vectors
        self_num_feats = self.query_num_proj(self_num) #(B,1,24)
        nn_num_feats = self.neighbour_num_proj(nn_num) #(B,N,24)

        # stack cat and num features
        self_feats = torch.cat([self_num_feats, self_cat_feats],dim=-1) #(B,1,48)
        nn_feats = torch.cat([nn_num_feats, nn_cat_feats],dim=-1) #(B,N,48)

        # project self and nn feats with mlp
        self_feats = self.query_mlp(self_feats) #(B,1,64)
        nn_feats = self.neighbour_mlp(nn_feats) #(B,1,64)

        # get embedded vectors for mrt and mall
        mrt_feats = self.mrt_embedder(mrt_idx) #(B,4,64)
        mall_feats = self.mall_embedder(mall_idx) #(B,4,64)

        # create edge vectors (here we dont use feature diff unlike paper)
        query_nn_feats = torch.cat([self_feats.repeat(1,8,1),nn_feats],dim=-1) #(2,8,128)
        query_mrt_feats = torch.cat([self_feats.repeat(1,4,1),mrt_feats],dim=-1) #(2,4,128)
        query_mall_feats = torch.cat([self_feats.repeat(1,4,1),mall_feats],dim=-1) #(2,4,128)

        # run edgeconv on edge vectors
        unit_feats = self.unit_edgeconv(query_nn_feats) #(2,8,128)
        mrt_feats = self.mrt_edgeconv(query_mrt_feats) #(2,4,128)
        mall_feats = self.mall_edgeconv(query_mall_feats) #(2,4,128)

        # aggregate based on weights
        agg_unit_feats = torch.sum(nn_weights.unsqueeze(-1)*unit_feats,dim=1) #(2,128)
        agg_mrt_feats = torch.sum(mrt_weights.unsqueeze(-1)*mrt_feats,dim=1) #(2,128)
        agg_mall_feats = torch.sum(mall_weights.unsqueeze(-1)*mall_feats,dim=1) #(2,128)

        # project self query to 128 dim
        self_feats = self.query_mlp2(self_feats).squeeze(1) #(2,128)

        # combine all features
        out_feats = self_feats + agg_unit_feats # + agg_mrt_feats + agg_mall_feats #(2,128)

        # regress psm
        pred = self.decoder(out_feats).squeeze(-1)

        return pred

    def cat_embedder(self,cat_feat):
        town_feat = self.town_embedder(cat_feat[...,0].to(dtype=torch.int64))
        flat_type_feat = self.flat_type_embedder(cat_feat[...,1].to(dtype=torch.int64))
        flat_model_feat = self.flat_model_embedder(cat_feat[...,2].to(dtype=torch.int64))
        return torch.cat([town_feat, flat_type_feat, flat_model_feat],dim=-1) #(B,N,24)

k = 8 
mrt_locs = dataset.mrt_data
mall_locs = dataset.mall_data
unit_locs = dataset.X_train_num
unit_types = dataset.X_train_cat

if torch.cuda.is_available():
    device='cuda'
else:
    device='cpu'
print(f"using device: {device}")

model = DGCNN(dataset).to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)


for epoch_idx in range(epochs):

    # evaluation
    eval = True
    if eval:
        error_list = []
        pbar = tqdm(test_dataloader)
        for idx, data in enumerate(pbar):
            # get nearest mrt and malls
            mrt_idx, mrt_dists = DGCNN.knn(data['num_feat'][...,-2:],mrt_locs[...,-2:], k=4)
            mall_idx, mall_dists = DGCNN.knn(data['num_feat'][...,-2:],mall_locs[...,-2:], k=4)

            # restrict unit search to same housing type & town, create mask
            batch_size = data['cat_feat'].shape[0]
            unit_type_mask = unit_types.unsqueeze(0).repeat(batch_size,1,1)==data['cat_feat'].unsqueeze(1)
            unit_type_mask = torch.all(unit_type_mask,dim=-1)
            unit_idx, unit_dists = DGCNN.knn(data['num_feat'][...,-2:],unit_locs[...,-2:], k=k+1, mask=unit_type_mask)

            # remove self from topk retrievals
            batch_idx_list = []
            batch_dists_list = []
            for i in range(batch_size):
                if data['index'][i] in unit_idx[i]: # remove self
                    mask = unit_idx[i] != data['index'][i]
                    temp_idx = unit_idx[i][mask]
                    temp_dists = unit_dists[i][mask]

                else: # select first k
                    temp_idx = unit_idx[i][:-1]
                    temp_dists = unit_dists[i][:-1]
                batch_idx_list.append(temp_idx)
                batch_dists_list.append(temp_dists)
            unit_idx = torch.stack(batch_idx_list,dim=0)
            unit_dists = torch.stack(batch_dists_list,dim=0)

            # compute inverse distance weights
            mrt_weights, mrt_degree = idw(mrt_dists)
            mall_weights, mall_degree = idw(mall_dists)
            unit_weights, unit_degree = idw(unit_dists)

            # need to get node features for topk units and self unit
            nn_features_num = unit_locs.unsqueeze(0).repeat(batch_size,1,1).gather(dim=1,index=unit_idx.unsqueeze(-1).repeat(1,1,4)) # extract first 4 columns
            nn_features_cat = unit_types.unsqueeze(0).repeat(batch_size,1,1).gather(dim=1,index=unit_idx.unsqueeze(-1).repeat(1,1,3)) # extract first 3 columns

            # get self features
            self_index = data['index'].unsqueeze(-1)
            self_features_num = unit_locs.unsqueeze(0).repeat(batch_size,1,1).gather(dim=1,index=self_index.unsqueeze(-1).repeat(1,1,4)) # extract first 4 columns
            self_features_cat = unit_types.unsqueeze(0).repeat(batch_size,1,1).gather(dim=1,index=self_index.unsqueeze(-1).repeat(1,1,3)) # extract first 3 columns
            self_features_num = self_features_num[...,:3] # drop psm feature

            self_features_num = self_features_num.to(device)
            self_features_cat = self_features_cat.to(device)
            nn_features_num = nn_features_num.to(device)
            nn_features_cat = nn_features_cat.to(device)
            nn_weights = unit_weights.to(device)
            mrt_idx = mrt_idx.to(device)
            mrt_weights = mrt_weights.to(device)
            mall_idx = mall_idx.to(device)
            mall_weights = mall_weights.to(device)

            pred = model(self_features_num,
                self_features_cat,
                nn_features_num,
                nn_features_cat,
                nn_weights,
                mrt_idx,
                mrt_weights,
                mall_idx,
                mall_weights)

            error = torch.abs(pred-data['psm'].to(device)) #(B,)
            error_list.append(error)
        errors = torch.cat(error_list) #(N,)
        errors_np = dataset.psm_scaler.inverse_transform(errors.detach().cpu().numpy().reshape(-1,1)) 
        print(f"error psm: {np.mean(errors_np)}")


    train = True
    if train:
        pbar = tqdm(train_dataloader)
        for idx, data in enumerate(pbar):

            optimizer.zero_grad()

            # get nearest mrt and malls
            mrt_idx, mrt_dists = DGCNN.knn(data['num_feat'][...,-2:],mrt_locs[...,-2:], k=4)
            mall_idx, mall_dists = DGCNN.knn(data['num_feat'][...,-2:],mall_locs[...,-2:], k=4)

            # restrict unit search to same housing type & town, create mask
            batch_size = data['cat_feat'].shape[0]
            unit_type_mask = unit_types.unsqueeze(0).repeat(batch_size,1,1)==data['cat_feat'].unsqueeze(1)
            unit_type_mask = torch.all(unit_type_mask,dim=-1)
            unit_idx, unit_dists = DGCNN.knn(data['num_feat'][...,-2:],unit_locs[...,-2:], k=k+1, mask=unit_type_mask)

            # remove self from topk retrievals
            batch_idx_list = []
            batch_dists_list = []
            for i in range(batch_size):
                if data['index'][i] in unit_idx[i]: # remove self
                    mask = unit_idx[i] != data['index'][i]
                    temp_idx = unit_idx[i][mask]
                    temp_dists = unit_dists[i][mask]

                else: # select first k
                    temp_idx = unit_idx[i][:-1]
                    temp_dists = unit_dists[i][:-1]
                batch_idx_list.append(temp_idx)
                batch_dists_list.append(temp_dists)
            unit_idx = torch.stack(batch_idx_list,dim=0)
            unit_dists = torch.stack(batch_dists_list,dim=0)

            def idw(dists,thres=4.0):
                dist_mask = dists>thres
                dist_weights = 1/(dists+1)
                dist_weights[dist_mask] = 0.0
                num_neighbours = torch.logical_not(dist_mask).sum(dim=-1).unsqueeze(-1)+1 #(avoid division by 0)
                dist_weights /= num_neighbours# normalize weights by dist_mask
                return dist_weights, num_neighbours-1

            # compute inverse distance weights
            mrt_weights, mrt_degree = idw(mrt_dists)
            mall_weights, mall_degree = idw(mall_dists)
            unit_weights, unit_degree = idw(unit_dists)

            # need to get node features for topk units and self unit
            nn_features_num = unit_locs.unsqueeze(0).repeat(batch_size,1,1).gather(dim=1,index=unit_idx.unsqueeze(-1).repeat(1,1,4)) # extract first 4 columns
            nn_features_cat = unit_types.unsqueeze(0).repeat(batch_size,1,1).gather(dim=1,index=unit_idx.unsqueeze(-1).repeat(1,1,3)) # extract first 3 columns

            # get self features
            self_index = data['index'].unsqueeze(-1)
            self_features_num = unit_locs.unsqueeze(0).repeat(batch_size,1,1).gather(dim=1,index=self_index.unsqueeze(-1).repeat(1,1,4)) # extract first 4 columns
            self_features_cat = unit_types.unsqueeze(0).repeat(batch_size,1,1).gather(dim=1,index=self_index.unsqueeze(-1).repeat(1,1,3)) # extract first 3 columns
            self_features_num = self_features_num[...,:3] # drop psm feature

            # place tensors on device
            self_features_num = self_features_num.to(device)
            self_features_cat = self_features_cat.to(device)
            nn_features_num = nn_features_num.to(device)
            nn_features_cat = nn_features_cat.to(device)
            nn_weights = unit_weights.to(device)
            mrt_idx = mrt_idx.to(device)
            mrt_weights = mrt_weights.to(device)
            mall_idx = mall_idx.to(device)
            mall_weights = mall_weights.to(device)

            # predict
            pred = model(self_features_num,
                self_features_cat,
                nn_features_num,
                nn_features_cat,
                nn_weights,
                mrt_idx,
                mrt_weights,
                mall_idx,
                mall_weights)
            
            # backward
            loss = torch.nn.MSELoss()(pred,data['psm'].to(device))
            pbar.set_postfix(epoch=epoch_idx, loss=loss.item())
            loss.backward()
            optimizer.step()



using device: cuda


  0%|          | 0/194 [00:00<?, ?it/s]