In [1]:
import os
import sys
import glob
import yaml
import pickle
import random

import numpy as np
import pandas as pd
import scipy.stats as stats
from pathlib import Path

from sklearn import model_selection
from sklearn.preprocessing import StandardScaler, LabelEncoder

import wandb
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities.seed import seed_everything

## config

In [2]:
# config
with open('config_v2.yaml') as f:
    config = yaml.safe_load(f)

# globals variable
SEED = config['globals']['seed']
MAX_EPOCHS = config['globals']['max_epochs']
N_SPLITS = config['globals']['n_splits']
USE_FOLDS = config['globals']['use_folds']
DEBUG = config['globals']['debug']
EXP_MESSAGE = config['globals']['exp_message']
NOTES = config['globals']['notes']
MODEL_SAVE = config['globals']['model_save']
ONLY_PRED = config['globals']['only_pred']
PRETRAINED = config['globals']['pretrained']
PRETRAINED_PATH = config['globals']['pretrained_path']
EXP_NAME = str(Path().resolve()).split('/')[-1]

# seed
seed_everything(SEED)

Global seed set to 1996


1996

In [3]:
EXP_NAME

'exp007'

In [4]:
!wandb login 1bb2d0449c11d8b987e25c38b9d8dda176310fb6

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/kuzira/.netrc


## read data

In [5]:
# waypointを補正したdataset
root_dir = Path('../../input/')
with open(root_dir/'kuto_wifi_dataset_v2/train_all.pkl', 'rb') as f:
  train_df = pickle.load(f)

with open(root_dir/'kuto_wifi_dataset_v2/test_all.pkl', 'rb') as f:
  test_df = pickle.load(f)

sub = pd.read_csv(root_dir/'indoor-location-navigation/sample_submission.csv', index_col=0)

In [6]:
train_df

Unnamed: 0,bssid_0,bssid_1,bssid_2,bssid_3,bssid_4,bssid_5,bssid_6,bssid_7,bssid_8,bssid_9,...,timestamp,x,y,floor,floor_str,path,time_diff,wifi_x,wifi_y,site_id
0,ffe684dfd25a52b046e3108a3f70df46001425f0,97e4a381c3a02ed3151bbf41b8fc1fe5815f5387,5b71ef95e53358c558b78bf3fb152d793729bc8d,2f85d197aec7bfddfee3f53ae9e1b6ed1fc56e92,4328f33869766d0f77a9299441556338e4d8a2b9,df41c761b69993669d4eb875b4474ec44d2372ed,7dc49736770ee9073043134656c89a17529f882f,5a1a7a8496e5f8b88db082de0b412e447e01fd0b,463d0cfe3748eb70524138ed970f03375e8d1030,79179095e63e2b0431e85e3e33b02d95bb135c2e,...,1571625311855,68.064926,241.94000,0,F1,5dad1ca1dc3e2c0006606c3f,1952,66.823935,241.889369,5da958dd46f8266d0737457b
1,97e4a381c3a02ed3151bbf41b8fc1fe5815f5387,5b71ef95e53358c558b78bf3fb152d793729bc8d,cb8f53745c342e2bfd0bf77a5fd8cac6cf303945,ffe684dfd25a52b046e3108a3f70df46001425f0,6b769b9eeb24ff287e6a53736cc7c013d5902901,7dc49736770ee9073043134656c89a17529f882f,3fef087dd272ab07981a60c9cbf6f27460d1364e,5a1a7a8496e5f8b88db082de0b412e447e01fd0b,df41c761b69993669d4eb875b4474ec44d2372ed,d8b1ff62702e02106553be91dc22a0dcf0e780a7,...,1571625311855,68.064926,241.94000,0,F1,5dad1ca1dc3e2c0006606c3f,3900,65.582943,241.838738,5da958dd46f8266d0737457b
2,97e4a381c3a02ed3151bbf41b8fc1fe5815f5387,6b769b9eeb24ff287e6a53736cc7c013d5902901,cb8f53745c342e2bfd0bf77a5fd8cac6cf303945,4328f33869766d0f77a9299441556338e4d8a2b9,3fef087dd272ab07981a60c9cbf6f27460d1364e,5a1a7a8496e5f8b88db082de0b412e447e01fd0b,7dc49736770ee9073043134656c89a17529f882f,df41c761b69993669d4eb875b4474ec44d2372ed,5b71ef95e53358c558b78bf3fb152d793729bc8d,2f85d197aec7bfddfee3f53ae9e1b6ed1fc56e92,...,1571625320099,62.480465,241.71216,0,F1,5dad1ca1dc3e2c0006606c3f,-2385,64.341952,241.788107,5da958dd46f8266d0737457b
3,97e4a381c3a02ed3151bbf41b8fc1fe5815f5387,4328f33869766d0f77a9299441556338e4d8a2b9,3fef087dd272ab07981a60c9cbf6f27460d1364e,5a1a7a8496e5f8b88db082de0b412e447e01fd0b,5d82171d37c5296bcaed8c02745540b491d8a284,471740ef5065943b791f277ada358f9ffc011645,2f85d197aec7bfddfee3f53ae9e1b6ed1fc56e92,afe423c7bc0641d63c95e232ffd65cae3be95351,df41c761b69993669d4eb875b4474ec44d2372ed,5b71ef95e53358c558b78bf3fb152d793729bc8d,...,1571625320099,62.480465,241.71216,0,F1,5dad1ca1dc3e2c0006606c3f,-427,62.480465,241.712160,5da958dd46f8266d0737457b
4,97e4a381c3a02ed3151bbf41b8fc1fe5815f5387,4328f33869766d0f77a9299441556338e4d8a2b9,5a1a7a8496e5f8b88db082de0b412e447e01fd0b,cb8f53745c342e2bfd0bf77a5fd8cac6cf303945,6b769b9eeb24ff287e6a53736cc7c013d5902901,5d82171d37c5296bcaed8c02745540b491d8a284,5b71ef95e53358c558b78bf3fb152d793729bc8d,df41c761b69993669d4eb875b4474ec44d2372ed,2f85d197aec7bfddfee3f53ae9e1b6ed1fc56e92,afe423c7bc0641d63c95e232ffd65cae3be95351,...,1571625320099,62.480465,241.71216,0,F1,5dad1ca1dc3e2c0006606c3f,1528,62.893219,240.715162,5da958dd46f8266d0737457b
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
258120,993a56b32432fb19bfb4461a0e1a2ead9bcf192f,773cde25cb4e9fd90b11603abd5bf84d83b340e6,19647ac7bb55a673554aa08cafb3b096aac7f32c,a4a410696cb935d542d62afd8e8090dbbc341a16,fc27c0656fc13157bb2f58543d51e8ee972fdf66,c7d8359344120911f8550487f282c241d93c4750,26b22cce3b7694a7d765d9cf329b9065f3fb3a3c,827530050f580378b7aa53fb292dfb8a12b775e1,d64eeb8d997e8d87203479556bbb9efaf7e487fd,22a52f1717436ee378dc44b6d707a3816a65b5e4,...,1573822164854,12.662716,100.47756,1,F2,5dce9eea5516ad00065f04a7,447,12.662716,100.477560,5d27099f03f801723c32511d
258121,773cde25cb4e9fd90b11603abd5bf84d83b340e6,993a56b32432fb19bfb4461a0e1a2ead9bcf192f,19647ac7bb55a673554aa08cafb3b096aac7f32c,fc27c0656fc13157bb2f58543d51e8ee972fdf66,a4a410696cb935d542d62afd8e8090dbbc341a16,827530050f580378b7aa53fb292dfb8a12b775e1,c7d8359344120911f8550487f282c241d93c4750,d64eeb8d997e8d87203479556bbb9efaf7e487fd,b08a1d79d6d4bb2ca71336f7e995c7aa1342aa1f,26b22cce3b7694a7d765d9cf329b9065f3fb3a3c,...,1573822164854,12.662716,100.47756,1,F2,5dce9eea5516ad00065f04a7,2393,11.041773,102.698110,5d27099f03f801723c32511d
258122,993a56b32432fb19bfb4461a0e1a2ead9bcf192f,773cde25cb4e9fd90b11603abd5bf84d83b340e6,c7d8359344120911f8550487f282c241d93c4750,fc27c0656fc13157bb2f58543d51e8ee972fdf66,a4a410696cb935d542d62afd8e8090dbbc341a16,827530050f580378b7aa53fb292dfb8a12b775e1,d64eeb8d997e8d87203479556bbb9efaf7e487fd,18874cb574f0cae84582df367941ad94d877ccbb,22a52f1717436ee378dc44b6d707a3816a65b5e4,3c43188acbcc9704dd3987cf1ef14906f9dbe444,...,1573822173051,7.799886,107.13921,1,F2,5dce9eea5516ad00065f04a7,-3876,9.961144,104.178477,5d27099f03f801723c32511d
258123,993a56b32432fb19bfb4461a0e1a2ead9bcf192f,c7d8359344120911f8550487f282c241d93c4750,fc27c0656fc13157bb2f58543d51e8ee972fdf66,773cde25cb4e9fd90b11603abd5bf84d83b340e6,a4a410696cb935d542d62afd8e8090dbbc341a16,064419dd1c862bc6c960b365fed666a1a5ff36a9,b08a1d79d6d4bb2ca71336f7e995c7aa1342aa1f,3c43188acbcc9704dd3987cf1ef14906f9dbe444,19647ac7bb55a673554aa08cafb3b096aac7f32c,cdc456af06dec9e63340fdf06b976b04eaa3a4a8,...,1573822173051,7.799886,107.13921,1,F2,5dce9eea5516ad00065f04a7,-1946,8.880515,105.658843,5d27099f03f801723c32511d


In [7]:
train_df['x'].nunique(), train_df['wifi_x'].nunique()

(25336, 231094)

## time_diffの前処理
0~3000sのものはwifi_x,wifi_yを使用  
-1000s~0sのものはもともとのx.yを使用

In [8]:
# POSI_DIFF = 4000  # i番目のwaypointを基準に算出したwifi waypointのうち基準のtimestampの直近3sを信頼できるデータとして残す
# NEGA_DIFF = -100000  # i番目のwaypointを基準に算出したwifi waypointのうちi+1番目のwaypointに近いものにはi+1のwaypointを座標として与える

# train_df.loc[(NEGA_DIFF < train_df['time_diff']) & (train_df['time_diff'] <= 0), 'wifi_x'] = train_df.loc[(NEGA_DIFF< train_df['time_diff']) & (train_df['time_diff'] <= 0), 'x']
# train_df.loc[(NEGA_DIFF < train_df['time_diff']) & (train_df['time_diff'] <= 0), 'wifi_y'] = train_df.loc[(NEGA_DIFF < train_df['time_diff']) & (train_df['time_diff'] <= 0), 'y']
# train_df = train_df[(NEGA_DIFF < train_df['time_diff']) & (train_df['time_diff'] < POSI_DIFF)].reset_index(drop=True)
# train_df

BSSIDとRSSIは100ずつ存在しているけど全てが必要なわけではないみたい  
ここでは20だけ取り出している。

In [8]:
# training target features
NUM_FEATS = 80
BSSID_FEATS = [f'bssid_{i}' for i in range(NUM_FEATS)]
RSSI_FEATS  = [f'rssi_{i}' for i in range(NUM_FEATS)]

bssid_NはN個目のBSSIDを示しておりRSSI値が大きい順に番号が振られている。
100個しかない


In [9]:
# get numbers of bssids to embed them in a layer

# train
wifi_bssids = []
# bssidを列ごとにリストに入れていく
for i in range(100):
    wifi_bssids.extend(train_df.iloc[:,i].values.tolist())
wifi_bssids = list(set(wifi_bssids))

train_wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(train): {train_wifi_bssids_size}')

# test
wifi_bssids_test = []
for i in range(100):
    wifi_bssids_test.extend(test_df.iloc[:,i].values.tolist())
wifi_bssids_test = list(set(wifi_bssids_test))

test_wifi_bssids_size = len(wifi_bssids_test)
print(f'BSSID TYPES(test): {test_wifi_bssids_size}')


wifi_bssids.extend(wifi_bssids_test)
wifi_bssids = list(set(wifi_bssids))
wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES(all): {wifi_bssids_size}')


BSSID TYPES(train): 41286
BSSID TYPES(test): 28592
BSSID TYPES(all): 41300


## preprocessing

In [10]:
# preprocess

le = LabelEncoder()
le.fit(wifi_bssids)
le_site = LabelEncoder()
le_site.fit(train_df['site_id'])

ss = StandardScaler()
ss.fit(train_df.loc[:,RSSI_FEATS])


def preprocess(input_df, le=le, le_site=le_site, ss=ss):
    output_df = input_df.copy()
    # RSSIの正規化
    output_df.loc[:,RSSI_FEATS] = ss.transform(input_df.loc[:,RSSI_FEATS])

    # BSSIDのLE(1からふる)
    for i in BSSID_FEATS:
        output_df.loc[:,i] = le.transform(input_df.loc[:,i])
        output_df.loc[:,i] = output_df.loc[:,i] + 1  # 0からではなく1から番号を振りたいため なぜ？ embeddingのpadding用のダミー変数？

    # site_idのLE
    output_df['site_id_str'] = input_df['site_id'].copy()
    output_df.loc[:, 'site_id'] = le_site.transform(input_df.loc[:, 'site_id'])

    # なぜ２重でやる？
    # output_df.loc[:,RSSI_FEATS] = ss.transform(output_df.loc[:,RSSI_FEATS])
    return output_df

train = preprocess(train_df)
test = preprocess(test_df)

    

In [11]:
site_count = len(train['site_id'].unique())
site_count

24

In [21]:
train.iloc[:, -15:]

Unnamed: 0,rssi_96,rssi_97,rssi_98,rssi_99,timestamp,x,y,floor,floor_str,path,time_diff,wifi_x,wifi_y,site_id,site_id_str
0,-84,-84,-85,-85,1571625311855,68.064926,241.94000,0,F1,5dad1ca1dc3e2c0006606c3f,1952,66.823935,241.889369,21,5da958dd46f8266d0737457b
1,-84,-84,-85,-85,1571625311855,68.064926,241.94000,0,F1,5dad1ca1dc3e2c0006606c3f,3900,65.582943,241.838738,21,5da958dd46f8266d0737457b
2,-83,-83,-83,-83,1571625320099,62.480465,241.71216,0,F1,5dad1ca1dc3e2c0006606c3f,-2385,64.341952,241.788107,21,5da958dd46f8266d0737457b
3,-83,-83,-83,-83,1571625320099,62.480465,241.71216,0,F1,5dad1ca1dc3e2c0006606c3f,-427,62.480465,241.712160,21,5da958dd46f8266d0737457b
4,-83,-83,-83,-83,1571625320099,62.480465,241.71216,0,F1,5dad1ca1dc3e2c0006606c3f,1528,62.893219,240.715162,21,5da958dd46f8266d0737457b
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
258120,-79,-80,-80,-80,1573822164854,12.662716,100.47756,1,F2,5dce9eea5516ad00065f04a7,447,12.662716,100.477560,5,5d27099f03f801723c32511d
258121,-80,-80,-81,-81,1573822164854,12.662716,100.47756,1,F2,5dce9eea5516ad00065f04a7,2393,11.041773,102.698110,5,5d27099f03f801723c32511d
258122,-80,-80,-80,-80,1573822173051,7.799886,107.13921,1,F2,5dce9eea5516ad00065f04a7,-3876,9.961144,104.178477,5,5d27099f03f801723c32511d
258123,-81,-81,-81,-81,1573822173051,7.799886,107.13921,1,F2,5dce9eea5516ad00065f04a7,-1946,8.880515,105.658843,5,5d27099f03f801723c32511d


## PyTorch model
- embedding layerが重要  

In [13]:
# dataset
from torch.utils.data import Dataset, DataLoader
class IndoorDataset(Dataset):
    def __init__(self, df, phase='train'):
        self.df = df
        self.phase = phase
        self.bssid_feats = df[BSSID_FEATS].values.astype(int)
        self.rssi_feats = df[RSSI_FEATS].values.astype(np.float32)
        self.site_id = df['site_id'].values.astype(int)

        if phase in ['train', 'valid']:
            # self.xy = df[['x', 'y']].values.astype(np.float32)
            self.xy = df[['wifi_x', 'wifi_y']].values.astype(np.float32)  # wifiにより補正したx,yを使用
            self.floor = df['floor'].values.astype(np.float32)
        
    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        
        feature = {
            'BSSID_FEATS':self.bssid_feats[idx],
            'RSSI_FEATS':self.rssi_feats[idx],
            'site_id':self.site_id[idx]
        }
        if self.phase in ['train', 'valid']:
            target = {
                'xy':self.xy[idx],
                'floor':self.floor[idx]
            }
        else:
            target = {}
        return feature, target

In [14]:
import torch
from torch import nn

class LSTMModel(nn.Module):
    def __init__(self, wifi_bssids_size, site_count=24, embedding_dim=64):
        super(LSTMModel, self).__init__()
        # bssid
        # ->64次元に圧縮後sequence化にする
        # wifi_bssids_sizeが辞書の数を表す
        self.embedding_layer1 = nn.Sequential(
            nn.Embedding(wifi_bssids_size, embedding_dim),
            nn.Flatten(start_dim=-2)            
        )
        # site
        # ->2次元に圧縮後sequence化する
        # site_countが辞書の数を表す
        self.embedding_layer2 = nn.Sequential(
            nn.Embedding(site_count, 2),
            nn.Flatten(start_dim=-1)           
        )

        # rssi
        # 次元を64倍に線形変換
        self.linear_layer1 = nn.Sequential(
            nn.BatchNorm1d(NUM_FEATS),
            nn.Linear(NUM_FEATS, NUM_FEATS * embedding_dim),
            nn.ReLU()
        )
        
        # bssid, site, rssiの出力size
        feature_size = 2 + (2 * NUM_FEATS * embedding_dim)
        self.linear_layer2 = nn.Sequential(
            nn.BatchNorm1d(feature_size),
            nn.Dropout(0.3),
            nn.Linear(feature_size, 256),
            nn.ReLU()
        )

        self.batch_norm1 = nn.BatchNorm1d(1)
        self.lstm1 = nn.LSTM(input_size=256,hidden_size=128,dropout=0.3, batch_first=True)
        self.lstm2 = nn.LSTM(input_size=128,hidden_size=16,dropout=0.1, batch_first=True)

        self.fc_xy = nn.Linear(16, 2)
        # self.fc_x = nn.Linear(16, 1)
        # self.fc_y = nn.Linear(16, 1)
        self.fc_floor = nn.Linear(16, 1)

    
    def forward(self, x):
        # input embedding
        batch_size = x["site_id"].shape[0]
        x_bssid = self.embedding_layer1(x['BSSID_FEATS'])
        x_site_id = self.embedding_layer2(x['site_id'])
        x_rssi = self.linear_layer1(x['RSSI_FEATS'])
        x = torch.cat([x_bssid, x_site_id, x_rssi], dim=-1)
        x = self.linear_layer2(x)

        # lstm layer
        x = x.view(batch_size, 1, -1)  # [batch, 1]->[batch, 1, 1]
        x = self.batch_norm1(x)
        x, _ = self.lstm1(x)
        x = torch.relu(x)
        x, _ = self.lstm2(x)
        x = torch.relu(x)

        # output [batch, 1, 1] -> [batch]
        # x_ = self.fc_x(x).view(-1)
        # y_ = self.fc_y(x).view(-1)
        xy = self.fc_xy(x).squeeze(1)
        floor = torch.relu(self.fc_floor(x)).view(-1)
        # return {"x":x_, "y":y_, "floor":floor} 
        return {"xy": xy, "floor": floor}

In [15]:
def mean_position_error(xhat, yhat, fhat, x, y, f):
    intermediate = np.sqrt(np.power(xhat-x, 2) + np.power(yhat-y, 2)) + 15 * np.abs(fhat-f)
    return intermediate.sum()/xhat.shape[0]

def to_np(input):
    return input.detach().cpu().numpy()

In [16]:
def get_optimizer(model: nn.Module, config: dict):
    optimizer_config = config["optimizer"]
    optimizer_name = optimizer_config.get("name")
    base_optimizer_name = optimizer_config.get("base_name")
    optimizer_params = optimizer_config['params']

    if hasattr(optim, optimizer_name):
        optimizer = optim.__getattribute__(optimizer_name)(model.parameters(), **optimizer_params)
        return optimizer
    else:
        base_optimizer = optim.__getattribute__(base_optimizer_name)
        optimizer = globals().get(optimizer_name)(
            model.parameters(), 
            base_optimizer,
            **optimizer_config["params"])
        return  optimizer

def get_scheduler(optimizer, config: dict):
    scheduler_config = config["scheduler"]
    scheduler_name = scheduler_config.get("name")

    if scheduler_name is None:
        return
    else:
        return optim.lr_scheduler.__getattribute__(scheduler_name)(
            optimizer, **scheduler_config["params"])


def get_criterion(config: dict):
    loss_config = config["loss"]
    loss_name = loss_config["name"]
    loss_params = {} if loss_config.get("params") is None else loss_config.get("params")
    if hasattr(nn, loss_name):
        criterion = nn.__getattribute__(loss_name)(**loss_params)
    else:
        criterion = globals().get(loss_name)(**loss_params)

    return criterion

def worker_init_fn(worker_id):                                                          
    np.random.seed(np.random.get_state()[1][0] + worker_id)

In [17]:
# Learner class(pytorch-lighting)
class Learner(pl.LightningModule):
    def __init__(self, model, config):
        super().__init__()
        self.model = model
        self.config = config
        self.xy_criterion = get_criterion(config)
        self.f_criterion = get_criterion(config)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        loss = self.xy_criterion(output["xy"], y["xy"])
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        output = self.model(x)
        xy_loss = self.xy_criterion(output["xy"], y["xy"])
        f_loss = self.f_criterion(output["floor"], y["floor"])
        loss = xy_loss  # + f_loss
        mpe = mean_position_error(
            to_np(output['xy'][:, 0]), to_np(output['xy'][:, 1]), 0, 
            to_np(y['xy'][:, 0]), to_np(y['xy'][:, 1]), 0)
        
        # floor lossは現状は無視して良い
        self.log(f'Loss/val', loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Loss/xy', xy_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'Loss/floor', f_loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        self.log(f'MPE/val', mpe, on_step=False, on_epoch=True, prog_bar=False, logger=True)
        return loss

    def configure_optimizers(self):
        optimizer = get_optimizer(self.model, self.config)
        scheduler = get_scheduler(optimizer, self.config)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "Loss/val"}

In [18]:
# oof
def evaluate(model, loaders, phase):
    x_list = []
    y_list = []
    f_list = []
    with torch.no_grad():
        for batch in loaders[phase]:
            x, y = batch
            output = model(x)
            x_list.append(to_np(output['xy'][:, 0]))
            y_list.append(to_np(output['xy'][:, 1]))
            f_list.append(to_np(output['floor']))

    x_list = np.concatenate(x_list)
    y_list = np.concatenate(y_list)
    f_list = np.concatenate(f_list)
    return x_list, y_list, f_list

## train

In [22]:
oofs = []  # 全てのoofをdfで格納する
predictions = []  # 全ての予測値をdfで格納する
val_scores = []
# skf = model_selection.StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
gkf = model_selection.GroupKFold(n_splits=N_SPLITS)
# 今回はtargetを均等に分ける必要はなくpathが均等に分かれればいいのでskf.split()にpathを与えている。
for fold, (trn_idx, val_idx) in enumerate(gkf.split(train.loc[:, 'path'], groups=train.loc[:, 'path'])):
    # 指定したfoldのみループを回す
    if fold not in USE_FOLDS:
        continue

    print('=' * 20)
    print(f'Fold {fold}')
    print('=' * 20)

    # train/valid data
    trn_df = train.loc[trn_idx, BSSID_FEATS + RSSI_FEATS + ['site_id', 'wifi_x','wifi_y','floor']].reset_index(drop=True)
    val_df = train.loc[val_idx, BSSID_FEATS + RSSI_FEATS + ['site_id', 'wifi_x','wifi_y','floor']].reset_index(drop=True)

    # data loader
    loaders = {}
    loader_config = config["loader"]
    loaders["train"] = DataLoader(IndoorDataset(trn_df, phase="train"), **loader_config["train"], worker_init_fn=worker_init_fn) 
    loaders["valid"] = DataLoader(IndoorDataset(val_df, phase="valid"), **loader_config["valid"], worker_init_fn=worker_init_fn)
    loaders["test"] = DataLoader(IndoorDataset(test, phase="test"), **loader_config["test"], worker_init_fn=worker_init_fn)

    # model
    model = LSTMModel(wifi_bssids_size+1, site_count)  # +1としているのはLEを1スタートで始めているため
    model_name = model.__class__.__name__

    # callbacks
    callbacks = []
    checkpoint_callback = ModelCheckpoint(
        monitor=f'Loss/val',
        mode='min',
        dirpath=f"../../model/{EXP_NAME}",
        verbose=False,
        filename=f'{model_name}-{fold}')
    
    if MODEL_SAVE:
        callbacks.append(checkpoint_callback)

    early_stop_callback = EarlyStopping(
        monitor='Loss/val',
        min_delta=0.00,
        patience=10,
        verbose=False,
        mode='min')
    callbacks.append(early_stop_callback)

    # loggers
    RUN_NAME = EXP_NAME + "_" + EXP_MESSAGE
    wandb.init(project='indoor', notes=NOTES, entity='kuto5046', group=RUN_NAME)
    wandb.run.name = RUN_NAME + f'-fold-{fold}'
    wandb_config = wandb.config
    wandb_config.model_name = model_name
    wandb_config.LB = None
    wandb.watch(model)
    
    
    loggers = []
    loggers.append(WandbLogger())

    learner = Learner(model, config)
    # pretrained flag
    if PRETRAINED:
        ckpt = torch.load(PRETRAINED_PATH + f'{model_name}-{fold}.ckpt')
        learner.load_state_dict(ckpt['state_dict'])

    if not ONLY_PRED:
        trainer = pl.Trainer(
            logger=loggers, 
            callbacks=callbacks,
            max_epochs=MAX_EPOCHS,
            gpus=[0],
            fast_dev_run=DEBUG,
            deterministic=True,
            # precision=16,
            progress_bar_refresh_rate=0  # vscodeの時progress barの動作が遅いので表示しない
            )

        trainer.fit(learner, train_dataloader=loaders['train'], val_dataloaders=loaders['valid'])

    #############
    # validation (to make oof)
    #############
    model.eval()  
    oof_df = train.loc[val_idx, ['timestamp', 'x', 'y', 'site_id','site_id_str', 'wifi_x','wifi_y', 'floor', 'floor_str', 'path', 'time_diff']].reset_index(drop=True)
    oof_x, oof_y, oof_f = evaluate(model, loaders, phase="valid")
    oof_df["oof_x"] = oof_x
    oof_df["oof_y"] = oof_y
    oof_df["oof_floor"] = oof_f
    oofs.append(oof_df)
    
    val_score = mean_position_error(
        oof_df["oof_x"].values, oof_df["oof_y"].values, 0,
        oof_df['wifi_x'].values, oof_df['wifi_y'].values, 0)
    val_scores.append(val_score)
    print(f"fold {fold}: mean position error {val_score}")

    #############
    # inference
    #############n

    preds_x, preds_y, preds_f = evaluate(model, loaders, phase="test")
    test_preds = pd.DataFrame(np.stack((preds_f, preds_x, preds_y))).T
    test_preds.columns = sub.columns
    test_preds["site_path_timestamp"] = test["site_path_timestamp"]
    test_preds["floor"] = test_preds["floor"].astype(int)
    test_preds.to_csv(f'{RUN_NAME}_fold{fold}.csv', index=False)
    predictions.append(test_preds)
    wandb.finish()

Fold 0
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkuto5046[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.24 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


fold 0: mean position error 7.710968541326397


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

Fold 1
[34m[1mwandb[0m: wandb version 0.10.24 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


fold 1: mean position error 7.722838655325648


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

Fold 2
[34m[1mwandb[0m: wandb version 0.10.24 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


fold 2: mean position error 7.369896543694032


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

Fold 3
[34m[1mwandb[0m: wandb version 0.10.24 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


fold 3: mean position error 8.058767830670824


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

Fold 4
[34m[1mwandb[0m: wandb version 0.10.24 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


fold 4: mean position error 7.71283324302871


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

In [23]:
if len(USE_FOLDS) > 1:
    oofs_df = pd.concat(oofs)
else:
    oofs_df = oofs[0]
oofs_df.to_csv("oof.csv", index=False)
oofs_df

Unnamed: 0,timestamp,x,y,site_id,site_id_str,wifi_x,wifi_y,floor,floor_str,path,time_diff,oof_x,oof_y,oof_floor
0,1571625850645,82.922240,210.56656,21,5da958dd46f8266d0737457b,83.261745,211.637967,0,F1,5dad1caa18410e00067e734c,1878,88.874680,208.189270,0.035873
1,1571625850645,82.922240,210.56656,21,5da958dd46f8266d0737457b,83.601250,212.709373,0,F1,5dad1caa18410e00067e734c,3771,89.391998,208.756088,0.040552
2,1571625850645,82.922240,210.56656,21,5da958dd46f8266d0737457b,83.940755,213.780780,0,F1,5dad1caa18410e00067e734c,5657,89.672028,212.722687,0.040091
3,1571625861986,84.959270,216.99500,21,5da958dd46f8266d0737457b,84.280260,214.852187,0,F1,5dad1caa18410e00067e734c,-3788,90.491730,212.158722,0.041784
4,1571625861986,84.959270,216.99500,21,5da958dd46f8266d0737457b,84.619765,215.923593,0,F1,5dad1caa18410e00067e734c,-1905,86.897095,217.139526,0.036175
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
51620,1573822164854,12.662716,100.47756,5,5d27099f03f801723c32511d,12.662716,100.477560,1,F2,5dce9eea5516ad00065f04a7,447,15.012345,101.914925,0.269409
51621,1573822164854,12.662716,100.47756,5,5d27099f03f801723c32511d,11.041773,102.698110,1,F2,5dce9eea5516ad00065f04a7,2393,15.492247,98.461067,0.258506
51622,1573822173051,7.799886,107.13921,5,5d27099f03f801723c32511d,9.961144,104.178477,1,F2,5dce9eea5516ad00065f04a7,-3876,16.128004,100.353477,0.262384
51623,1573822173051,7.799886,107.13921,5,5d27099f03f801723c32511d,8.880515,105.658843,1,F2,5dce9eea5516ad00065f04a7,-1946,14.627602,100.902176,0.265500


In [50]:
if len(USE_FOLDS) > 1:
    # foldの結果を平均した後、reindexでsubmission fileにindexを合わせる
    sub = pd.concat(predictions).groupby('site_path_timestamp').mean().reindex(sub.index)
else:
    sub = predictions[0].reindex(sub.index)
sub

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,87.954796,102.605637
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,84.130959,102.008942
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,86.090317,104.619865
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,87.595718,105.741379
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,86.149513,105.548729
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,0,214.097397,90.171814
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,0,209.983994,99.197472
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,0,205.807724,105.314880
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,0,202.292267,111.906059


In [51]:
# floorの数値を置換
simple_accurate_99 = pd.read_csv(root_dir / 'simple-99-accurate-floor-model/submission.csv')
sub['floor'] = simple_accurate_99['floor'].values
sub

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,87.954796,102.605637
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,84.130959,102.008942
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,86.090317,104.619865
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,87.595718,105.741379
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,86.149513,105.548729
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,5,214.097397,90.171814
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,5,209.983994,99.197472
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,5,205.807724,105.314880
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,5,202.292267,111.906059


In [64]:
sub.to_csv(RUN_NAME + '_sub.csv')

In [27]:
print(f"CV:{np.mean(val_scores)}")

CV:7.715060962809122


## 後処理

In [28]:
import multiprocessing
import scipy.interpolate
import scipy.sparse
from tqdm import tqdm
import sys
sys.path.append('../../')
from src.io_f import read_data_file
from src import compute_f


In [29]:
def compute_rel_positions(acce_datas, ahrs_datas):
    step_timestamps, step_indexs, step_acce_max_mins = compute_f.compute_steps(acce_datas)
    headings = compute_f.compute_headings(ahrs_datas)
    stride_lengths = compute_f.compute_stride_length(step_acce_max_mins)
    step_headings = compute_f.compute_step_heading(step_timestamps, headings)
    rel_positions = compute_f.compute_rel_positions(stride_lengths, step_headings)
    return rel_positions
    
def correct_path(args):
    path, path_df = args
    
    T_ref  = path_df['timestamp'].values
    xy_hat = path_df[['x', 'y']].values
    
    example = read_data_file(f'{root_dir}/indoor-location-navigation/test/{path}.txt')
    rel_positions = compute_rel_positions(example.acce, example.ahrs)
    if T_ref[-1] > rel_positions[-1, 0]:
        rel_positions = [np.array([[0, 0, 0]]), rel_positions, np.array([[T_ref[-1], 0, 0]])]
    else:
        rel_positions = [np.array([[0, 0, 0]]), rel_positions]
    rel_positions = np.concatenate(rel_positions)
    
    T_rel = rel_positions[:, 0]
    delta_xy_hat = np.diff(scipy.interpolate.interp1d(T_rel, np.cumsum(rel_positions[:, 1:3], axis=0), axis=0)(T_ref), axis=0)

    N = xy_hat.shape[0]
    delta_t = np.diff(T_ref)
    alpha = (8.1)**(-2) * np.ones(N)
    beta  = (0.3 + 0.3 * 1e-3 * delta_t)**(-2)
    A = scipy.sparse.spdiags(alpha, [0], N, N)
    B = scipy.sparse.spdiags( beta, [0], N-1, N-1)
    D = scipy.sparse.spdiags(np.stack([-np.ones(N), np.ones(N)]), [0, 1], N-1, N)

    Q = A + (D.T @ B @ D)
    c = (A @ xy_hat) + (D.T @ (B @ delta_xy_hat))
    xy_star = scipy.sparse.linalg.spsolve(Q, c)

    return pd.DataFrame({
        'site_path_timestamp' : path_df['site_path_timestamp'],
        'floor' : path_df['floor'],
        'x' : xy_star[:, 0],
        'y' : xy_star[:, 1],
    })

def correct_path_train(args):
    #print(args)
    (site_id, path, floor), path_df = args
    
    T_ref  = path_df['timestamp'].values
    xy_hat = path_df[['x', 'y']].values
    
    example = read_data_file(f'{root_dir}/indoor-location-navigation/train/{site_id}/{floor}/{path}.txt')
    rel_positions = compute_rel_positions(example.acce, example.ahrs)
    if T_ref[-1] > rel_positions[-1, 0]:
        rel_positions = [np.array([[0, 0, 0]]), rel_positions, np.array([[T_ref[-1], 0, 0]])]
    else:
        rel_positions = [np.array([[0, 0, 0]]), rel_positions]
    rel_positions = np.concatenate(rel_positions)
    
    T_rel = rel_positions[:, 0]

    try:
        delta_xy_hat = np.diff(scipy.interpolate.interp1d(T_rel, np.cumsum(rel_positions[:, 1:3], axis=0), axis=0)(T_ref), axis=0)
    except:
        return pd.DataFrame({
            'site_path_timestamp' : path_df['site_path_timestamp'],
            'floor' : path_df['floor'],
            'x' : path_df['x'].to_numpy(),
            'y' : path_df['y'].to_numpy()
        })
    

    N = xy_hat.shape[0]
    delta_t = np.diff(T_ref)
    alpha = (8.1)**(-2) * np.ones(N)
    beta  = (0.3 + 0.3 * 1e-3 * delta_t)**(-2)
    A = scipy.sparse.spdiags(alpha, [0], N, N)
    B = scipy.sparse.spdiags( beta, [0], N-1, N-1)
    D = scipy.sparse.spdiags(np.stack([-np.ones(N), np.ones(N)]), [0, 1], N-1, N)

    Q = A + (D.T @ B @ D)
    c = (A @ xy_hat) + (D.T @ (B @ delta_xy_hat))
    xy_star = scipy.sparse.linalg.spsolve(Q, c)


    return pd.DataFrame({
        'site_path_timestamp' : path_df['site_path_timestamp'],
        'floor' : path_df['floor'],
        'x' : xy_star[:, 0],
        'y' : xy_star[:, 1],
    })

In [42]:
with open(root_dir/'kuto_wifi_dataset_v2/train_all.pkl', 'rb') as f:
    data_org = pickle.load(f)

# data_org = data_org[data_org['time_diff'].abs() <  time_diff_threshold].reset_index(drop=True)
data_org['site_path_timestamp'] = data_org['site_id'] + '_' + data_org['path'] + '_' + data_org['timestamp'].astype(str)
data_org

Unnamed: 0,bssid_0,bssid_1,bssid_2,bssid_3,bssid_4,bssid_5,bssid_6,bssid_7,bssid_8,bssid_9,...,x,y,floor,floor_str,path,time_diff,wifi_x,wifi_y,site_id,site_path_timestamp
0,ffe684dfd25a52b046e3108a3f70df46001425f0,97e4a381c3a02ed3151bbf41b8fc1fe5815f5387,5b71ef95e53358c558b78bf3fb152d793729bc8d,2f85d197aec7bfddfee3f53ae9e1b6ed1fc56e92,4328f33869766d0f77a9299441556338e4d8a2b9,df41c761b69993669d4eb875b4474ec44d2372ed,7dc49736770ee9073043134656c89a17529f882f,5a1a7a8496e5f8b88db082de0b412e447e01fd0b,463d0cfe3748eb70524138ed970f03375e8d1030,79179095e63e2b0431e85e3e33b02d95bb135c2e,...,68.064926,241.94000,0,F1,5dad1ca1dc3e2c0006606c3f,1952,66.823935,241.889369,5da958dd46f8266d0737457b,5da958dd46f8266d0737457b_5dad1ca1dc3e2c0006606...
1,97e4a381c3a02ed3151bbf41b8fc1fe5815f5387,5b71ef95e53358c558b78bf3fb152d793729bc8d,cb8f53745c342e2bfd0bf77a5fd8cac6cf303945,ffe684dfd25a52b046e3108a3f70df46001425f0,6b769b9eeb24ff287e6a53736cc7c013d5902901,7dc49736770ee9073043134656c89a17529f882f,3fef087dd272ab07981a60c9cbf6f27460d1364e,5a1a7a8496e5f8b88db082de0b412e447e01fd0b,df41c761b69993669d4eb875b4474ec44d2372ed,d8b1ff62702e02106553be91dc22a0dcf0e780a7,...,68.064926,241.94000,0,F1,5dad1ca1dc3e2c0006606c3f,3900,65.582943,241.838738,5da958dd46f8266d0737457b,5da958dd46f8266d0737457b_5dad1ca1dc3e2c0006606...
2,97e4a381c3a02ed3151bbf41b8fc1fe5815f5387,6b769b9eeb24ff287e6a53736cc7c013d5902901,cb8f53745c342e2bfd0bf77a5fd8cac6cf303945,4328f33869766d0f77a9299441556338e4d8a2b9,3fef087dd272ab07981a60c9cbf6f27460d1364e,5a1a7a8496e5f8b88db082de0b412e447e01fd0b,7dc49736770ee9073043134656c89a17529f882f,df41c761b69993669d4eb875b4474ec44d2372ed,5b71ef95e53358c558b78bf3fb152d793729bc8d,2f85d197aec7bfddfee3f53ae9e1b6ed1fc56e92,...,62.480465,241.71216,0,F1,5dad1ca1dc3e2c0006606c3f,-2385,64.341952,241.788107,5da958dd46f8266d0737457b,5da958dd46f8266d0737457b_5dad1ca1dc3e2c0006606...
3,97e4a381c3a02ed3151bbf41b8fc1fe5815f5387,4328f33869766d0f77a9299441556338e4d8a2b9,3fef087dd272ab07981a60c9cbf6f27460d1364e,5a1a7a8496e5f8b88db082de0b412e447e01fd0b,5d82171d37c5296bcaed8c02745540b491d8a284,471740ef5065943b791f277ada358f9ffc011645,2f85d197aec7bfddfee3f53ae9e1b6ed1fc56e92,afe423c7bc0641d63c95e232ffd65cae3be95351,df41c761b69993669d4eb875b4474ec44d2372ed,5b71ef95e53358c558b78bf3fb152d793729bc8d,...,62.480465,241.71216,0,F1,5dad1ca1dc3e2c0006606c3f,-427,62.480465,241.712160,5da958dd46f8266d0737457b,5da958dd46f8266d0737457b_5dad1ca1dc3e2c0006606...
4,97e4a381c3a02ed3151bbf41b8fc1fe5815f5387,4328f33869766d0f77a9299441556338e4d8a2b9,5a1a7a8496e5f8b88db082de0b412e447e01fd0b,cb8f53745c342e2bfd0bf77a5fd8cac6cf303945,6b769b9eeb24ff287e6a53736cc7c013d5902901,5d82171d37c5296bcaed8c02745540b491d8a284,5b71ef95e53358c558b78bf3fb152d793729bc8d,df41c761b69993669d4eb875b4474ec44d2372ed,2f85d197aec7bfddfee3f53ae9e1b6ed1fc56e92,afe423c7bc0641d63c95e232ffd65cae3be95351,...,62.480465,241.71216,0,F1,5dad1ca1dc3e2c0006606c3f,1528,62.893219,240.715162,5da958dd46f8266d0737457b,5da958dd46f8266d0737457b_5dad1ca1dc3e2c0006606...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
258120,993a56b32432fb19bfb4461a0e1a2ead9bcf192f,773cde25cb4e9fd90b11603abd5bf84d83b340e6,19647ac7bb55a673554aa08cafb3b096aac7f32c,a4a410696cb935d542d62afd8e8090dbbc341a16,fc27c0656fc13157bb2f58543d51e8ee972fdf66,c7d8359344120911f8550487f282c241d93c4750,26b22cce3b7694a7d765d9cf329b9065f3fb3a3c,827530050f580378b7aa53fb292dfb8a12b775e1,d64eeb8d997e8d87203479556bbb9efaf7e487fd,22a52f1717436ee378dc44b6d707a3816a65b5e4,...,12.662716,100.47756,1,F2,5dce9eea5516ad00065f04a7,447,12.662716,100.477560,5d27099f03f801723c32511d,5d27099f03f801723c32511d_5dce9eea5516ad00065f0...
258121,773cde25cb4e9fd90b11603abd5bf84d83b340e6,993a56b32432fb19bfb4461a0e1a2ead9bcf192f,19647ac7bb55a673554aa08cafb3b096aac7f32c,fc27c0656fc13157bb2f58543d51e8ee972fdf66,a4a410696cb935d542d62afd8e8090dbbc341a16,827530050f580378b7aa53fb292dfb8a12b775e1,c7d8359344120911f8550487f282c241d93c4750,d64eeb8d997e8d87203479556bbb9efaf7e487fd,b08a1d79d6d4bb2ca71336f7e995c7aa1342aa1f,26b22cce3b7694a7d765d9cf329b9065f3fb3a3c,...,12.662716,100.47756,1,F2,5dce9eea5516ad00065f04a7,2393,11.041773,102.698110,5d27099f03f801723c32511d,5d27099f03f801723c32511d_5dce9eea5516ad00065f0...
258122,993a56b32432fb19bfb4461a0e1a2ead9bcf192f,773cde25cb4e9fd90b11603abd5bf84d83b340e6,c7d8359344120911f8550487f282c241d93c4750,fc27c0656fc13157bb2f58543d51e8ee972fdf66,a4a410696cb935d542d62afd8e8090dbbc341a16,827530050f580378b7aa53fb292dfb8a12b775e1,d64eeb8d997e8d87203479556bbb9efaf7e487fd,18874cb574f0cae84582df367941ad94d877ccbb,22a52f1717436ee378dc44b6d707a3816a65b5e4,3c43188acbcc9704dd3987cf1ef14906f9dbe444,...,7.799886,107.13921,1,F2,5dce9eea5516ad00065f04a7,-3876,9.961144,104.178477,5d27099f03f801723c32511d,5d27099f03f801723c32511d_5dce9eea5516ad00065f0...
258123,993a56b32432fb19bfb4461a0e1a2ead9bcf192f,c7d8359344120911f8550487f282c241d93c4750,fc27c0656fc13157bb2f58543d51e8ee972fdf66,773cde25cb4e9fd90b11603abd5bf84d83b340e6,a4a410696cb935d542d62afd8e8090dbbc341a16,064419dd1c862bc6c960b365fed666a1a5ff36a9,b08a1d79d6d4bb2ca71336f7e995c7aa1342aa1f,3c43188acbcc9704dd3987cf1ef14906f9dbe444,19647ac7bb55a673554aa08cafb3b096aac7f32c,cdc456af06dec9e63340fdf06b976b04eaa3a4a8,...,7.799886,107.13921,1,F2,5dce9eea5516ad00065f04a7,-1946,8.880515,105.658843,5d27099f03f801723c32511d,5d27099f03f801723c32511d_5dce9eea5516ad00065f0...


In [43]:
oofs_df['site_path_timestamp'] = oofs_df['site_id_str'].astype(str) + '_' + oofs_df['path'] + '_' + oofs_df['timestamp'].astype(str)
oofs_df

Unnamed: 0,timestamp,x,y,site_id,site_id_str,wifi_x,wifi_y,floor,floor_str,path,time_diff,oof_x,oof_y,oof_floor,site_path_timestamp
0,1571625850645,82.922240,210.56656,21,5da958dd46f8266d0737457b,83.261745,211.637967,0,F1,5dad1caa18410e00067e734c,1878,88.874680,208.189270,0.035873,5da958dd46f8266d0737457b_5dad1caa18410e00067e7...
1,1571625850645,82.922240,210.56656,21,5da958dd46f8266d0737457b,83.601250,212.709373,0,F1,5dad1caa18410e00067e734c,3771,89.391998,208.756088,0.040552,5da958dd46f8266d0737457b_5dad1caa18410e00067e7...
2,1571625850645,82.922240,210.56656,21,5da958dd46f8266d0737457b,83.940755,213.780780,0,F1,5dad1caa18410e00067e734c,5657,89.672028,212.722687,0.040091,5da958dd46f8266d0737457b_5dad1caa18410e00067e7...
3,1571625861986,84.959270,216.99500,21,5da958dd46f8266d0737457b,84.280260,214.852187,0,F1,5dad1caa18410e00067e734c,-3788,90.491730,212.158722,0.041784,5da958dd46f8266d0737457b_5dad1caa18410e00067e7...
4,1571625861986,84.959270,216.99500,21,5da958dd46f8266d0737457b,84.619765,215.923593,0,F1,5dad1caa18410e00067e734c,-1905,86.897095,217.139526,0.036175,5da958dd46f8266d0737457b_5dad1caa18410e00067e7...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
51620,1573822164854,12.662716,100.47756,5,5d27099f03f801723c32511d,12.662716,100.477560,1,F2,5dce9eea5516ad00065f04a7,447,15.012345,101.914925,0.269409,5d27099f03f801723c32511d_5dce9eea5516ad00065f0...
51621,1573822164854,12.662716,100.47756,5,5d27099f03f801723c32511d,11.041773,102.698110,1,F2,5dce9eea5516ad00065f04a7,2393,15.492247,98.461067,0.258506,5d27099f03f801723c32511d_5dce9eea5516ad00065f0...
51622,1573822173051,7.799886,107.13921,5,5d27099f03f801723c32511d,9.961144,104.178477,1,F2,5dce9eea5516ad00065f04a7,-3876,16.128004,100.353477,0.262384,5d27099f03f801723c32511d_5dce9eea5516ad00065f0...
51623,1573822173051,7.799886,107.13921,5,5d27099f03f801723c32511d,8.880515,105.658843,1,F2,5dce9eea5516ad00065f04a7,-1946,14.627602,100.902176,0.265500,5d27099f03f801723c32511d_5dce9eea5516ad00065f0...


In [44]:
data_org = pd.merge(data_org, oofs_df[['site_path_timestamp', 'oof_x', 'oof_y']], on='site_path_timestamp', how='left')
data_org = data_org.rename(columns={'x':'target_x', 'y': 'target_y', 'oof_x':'x', 'oof_y':'y'})
data_org

Unnamed: 0,bssid_0,bssid_1,bssid_2,bssid_3,bssid_4,bssid_5,bssid_6,bssid_7,bssid_8,bssid_9,...,floor,floor_str,path,time_diff,wifi_x,wifi_y,site_id,site_path_timestamp,x,y
0,ffe684dfd25a52b046e3108a3f70df46001425f0,97e4a381c3a02ed3151bbf41b8fc1fe5815f5387,5b71ef95e53358c558b78bf3fb152d793729bc8d,2f85d197aec7bfddfee3f53ae9e1b6ed1fc56e92,4328f33869766d0f77a9299441556338e4d8a2b9,df41c761b69993669d4eb875b4474ec44d2372ed,7dc49736770ee9073043134656c89a17529f882f,5a1a7a8496e5f8b88db082de0b412e447e01fd0b,463d0cfe3748eb70524138ed970f03375e8d1030,79179095e63e2b0431e85e3e33b02d95bb135c2e,...,0,F1,5dad1ca1dc3e2c0006606c3f,1952,66.823935,241.889369,5da958dd46f8266d0737457b,5da958dd46f8266d0737457b_5dad1ca1dc3e2c0006606...,65.474991,235.977615
1,ffe684dfd25a52b046e3108a3f70df46001425f0,97e4a381c3a02ed3151bbf41b8fc1fe5815f5387,5b71ef95e53358c558b78bf3fb152d793729bc8d,2f85d197aec7bfddfee3f53ae9e1b6ed1fc56e92,4328f33869766d0f77a9299441556338e4d8a2b9,df41c761b69993669d4eb875b4474ec44d2372ed,7dc49736770ee9073043134656c89a17529f882f,5a1a7a8496e5f8b88db082de0b412e447e01fd0b,463d0cfe3748eb70524138ed970f03375e8d1030,79179095e63e2b0431e85e3e33b02d95bb135c2e,...,0,F1,5dad1ca1dc3e2c0006606c3f,1952,66.823935,241.889369,5da958dd46f8266d0737457b,5da958dd46f8266d0737457b_5dad1ca1dc3e2c0006606...,65.083534,232.990005
2,97e4a381c3a02ed3151bbf41b8fc1fe5815f5387,5b71ef95e53358c558b78bf3fb152d793729bc8d,cb8f53745c342e2bfd0bf77a5fd8cac6cf303945,ffe684dfd25a52b046e3108a3f70df46001425f0,6b769b9eeb24ff287e6a53736cc7c013d5902901,7dc49736770ee9073043134656c89a17529f882f,3fef087dd272ab07981a60c9cbf6f27460d1364e,5a1a7a8496e5f8b88db082de0b412e447e01fd0b,df41c761b69993669d4eb875b4474ec44d2372ed,d8b1ff62702e02106553be91dc22a0dcf0e780a7,...,0,F1,5dad1ca1dc3e2c0006606c3f,3900,65.582943,241.838738,5da958dd46f8266d0737457b,5da958dd46f8266d0737457b_5dad1ca1dc3e2c0006606...,65.474991,235.977615
3,97e4a381c3a02ed3151bbf41b8fc1fe5815f5387,5b71ef95e53358c558b78bf3fb152d793729bc8d,cb8f53745c342e2bfd0bf77a5fd8cac6cf303945,ffe684dfd25a52b046e3108a3f70df46001425f0,6b769b9eeb24ff287e6a53736cc7c013d5902901,7dc49736770ee9073043134656c89a17529f882f,3fef087dd272ab07981a60c9cbf6f27460d1364e,5a1a7a8496e5f8b88db082de0b412e447e01fd0b,df41c761b69993669d4eb875b4474ec44d2372ed,d8b1ff62702e02106553be91dc22a0dcf0e780a7,...,0,F1,5dad1ca1dc3e2c0006606c3f,3900,65.582943,241.838738,5da958dd46f8266d0737457b,5da958dd46f8266d0737457b_5dad1ca1dc3e2c0006606...,65.083534,232.990005
4,97e4a381c3a02ed3151bbf41b8fc1fe5815f5387,6b769b9eeb24ff287e6a53736cc7c013d5902901,cb8f53745c342e2bfd0bf77a5fd8cac6cf303945,4328f33869766d0f77a9299441556338e4d8a2b9,3fef087dd272ab07981a60c9cbf6f27460d1364e,5a1a7a8496e5f8b88db082de0b412e447e01fd0b,7dc49736770ee9073043134656c89a17529f882f,df41c761b69993669d4eb875b4474ec44d2372ed,5b71ef95e53358c558b78bf3fb152d793729bc8d,2f85d197aec7bfddfee3f53ae9e1b6ed1fc56e92,...,0,F1,5dad1ca1dc3e2c0006606c3f,-2385,64.341952,241.788107,5da958dd46f8266d0737457b,5da958dd46f8266d0737457b_5dad1ca1dc3e2c0006606...,65.774338,234.577438
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1278302,993a56b32432fb19bfb4461a0e1a2ead9bcf192f,c7d8359344120911f8550487f282c241d93c4750,fc27c0656fc13157bb2f58543d51e8ee972fdf66,773cde25cb4e9fd90b11603abd5bf84d83b340e6,a4a410696cb935d542d62afd8e8090dbbc341a16,064419dd1c862bc6c960b365fed666a1a5ff36a9,b08a1d79d6d4bb2ca71336f7e995c7aa1342aa1f,3c43188acbcc9704dd3987cf1ef14906f9dbe444,19647ac7bb55a673554aa08cafb3b096aac7f32c,cdc456af06dec9e63340fdf06b976b04eaa3a4a8,...,1,F2,5dce9eea5516ad00065f04a7,-1946,8.880515,105.658843,5d27099f03f801723c32511d,5d27099f03f801723c32511d_5dce9eea5516ad00065f0...,14.627602,100.902176
1278303,993a56b32432fb19bfb4461a0e1a2ead9bcf192f,c7d8359344120911f8550487f282c241d93c4750,fc27c0656fc13157bb2f58543d51e8ee972fdf66,773cde25cb4e9fd90b11603abd5bf84d83b340e6,a4a410696cb935d542d62afd8e8090dbbc341a16,064419dd1c862bc6c960b365fed666a1a5ff36a9,b08a1d79d6d4bb2ca71336f7e995c7aa1342aa1f,3c43188acbcc9704dd3987cf1ef14906f9dbe444,19647ac7bb55a673554aa08cafb3b096aac7f32c,cdc456af06dec9e63340fdf06b976b04eaa3a4a8,...,1,F2,5dce9eea5516ad00065f04a7,-1946,8.880515,105.658843,5d27099f03f801723c32511d,5d27099f03f801723c32511d_5dce9eea5516ad00065f0...,16.454372,101.799461
1278304,993a56b32432fb19bfb4461a0e1a2ead9bcf192f,c7d8359344120911f8550487f282c241d93c4750,773cde25cb4e9fd90b11603abd5bf84d83b340e6,fc27c0656fc13157bb2f58543d51e8ee972fdf66,a4a410696cb935d542d62afd8e8090dbbc341a16,827530050f580378b7aa53fb292dfb8a12b775e1,19647ac7bb55a673554aa08cafb3b096aac7f32c,064419dd1c862bc6c960b365fed666a1a5ff36a9,3c43188acbcc9704dd3987cf1ef14906f9dbe444,2ee4dca6f3705253eceaeecf2144083b14741d08,...,1,F2,5dce9eea5516ad00065f04a7,-17,7.799886,107.139210,5d27099f03f801723c32511d,5d27099f03f801723c32511d_5dce9eea5516ad00065f0...,16.128004,100.353477
1278305,993a56b32432fb19bfb4461a0e1a2ead9bcf192f,c7d8359344120911f8550487f282c241d93c4750,773cde25cb4e9fd90b11603abd5bf84d83b340e6,fc27c0656fc13157bb2f58543d51e8ee972fdf66,a4a410696cb935d542d62afd8e8090dbbc341a16,827530050f580378b7aa53fb292dfb8a12b775e1,19647ac7bb55a673554aa08cafb3b096aac7f32c,064419dd1c862bc6c960b365fed666a1a5ff36a9,3c43188acbcc9704dd3987cf1ef14906f9dbe444,2ee4dca6f3705253eceaeecf2144083b14741d08,...,1,F2,5dce9eea5516ad00065f04a7,-17,7.799886,107.139210,5d27099f03f801723c32511d,5d27099f03f801723c32511d_5dce9eea5516ad00065f0...,14.627602,100.902176


In [45]:
"""sub = pd.read_csv('../input/simple-99-accurate-floor-model/submission.csv')
tmp = sub['site_path_timestamp'].apply(lambda s : pd.Series(s.split('_')))
sub['site'] = tmp[0]
sub['path'] = tmp[1]
sub['timestamp'] = tmp[2].astype(float)"""

processes = multiprocessing.cpu_count()
with multiprocessing.Pool(processes=processes) as pool:
    dfs = pool.imap_unordered(correct_path_train, data_org.groupby(['site_id', 'path', 'floor_str']))
    dfs = tqdm(dfs)
    dfs = list(dfs)
oof_post_process = pd.concat(dfs).sort_values('site_path_timestamp')


10852it [04:15, 42.41it/s]


In [46]:
oof_post_process

Unnamed: 0,site_path_timestamp,floor,x,y
720328,5a0546857ecc773753327266_5d10a1669c50c70008fe8...,2,85.186721,35.440870
720331,5a0546857ecc773753327266_5d10a1669c50c70008fe8...,2,85.181121,35.424457
720329,5a0546857ecc773753327266_5d10a1669c50c70008fe8...,2,85.187162,35.435412
720330,5a0546857ecc773753327266_5d10a1669c50c70008fe8...,2,85.183923,35.432669
720388,5a0546857ecc773753327266_5d10a1669c50c70008fe8...,2,83.980955,26.415453
...,...,...,...,...
536457,5dc8cea7659e181adb076a3f_5dd7c119c5b77e0006b16...,-1,204.887168,104.321073
536458,5dc8cea7659e181adb076a3f_5dd7c119c5b77e0006b16...,-1,204.874542,104.318486
536459,5dc8cea7659e181adb076a3f_5dd7c119c5b77e0006b16...,-1,204.868634,104.316786
536452,5dc8cea7659e181adb076a3f_5dd7c119c5b77e0006b16...,-1,205.003519,104.349860


In [48]:
# waypoint補正前のx,yでの評価
oof_score_post_process = mean_position_error(
    oof_post_process.sort_values('site_path_timestamp')['x'].to_numpy(), oof_post_process.sort_values('site_path_timestamp')['y'].to_numpy(), 0, 
    data_org.sort_values('site_path_timestamp')['target_x'], data_org.sort_values('site_path_timestamp')['target_y'], 0
    )
oof_score_post_process

7.028372870985832

In [61]:
# waypoint補正後のx,yでの評価
oof_score_post_process = mean_position_error(
    oof_post_process.sort_values('site_path_timestamp')['x'].to_numpy(), oof_post_process.sort_values('site_path_timestamp')['y'].to_numpy(), 0, 
    data_org.sort_values('site_path_timestamp')['wifi_x'], data_org.sort_values('site_path_timestamp')['wifi_y'], 0
    )
oof_score_post_process

6.878435090529839

In [58]:
sub

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,87.954796,102.605637
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,84.130959,102.008942
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,86.090317,104.619865
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,87.595718,105.741379
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,86.149513,105.548729
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,5,214.097397,90.171814
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,5,209.983994,99.197472
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,5,205.807724,105.314880
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,5,202.292267,111.906059


In [59]:
sub = sub.reset_index()
sub_org = sub.copy()
tmp = sub['site_path_timestamp'].apply(lambda s : pd.Series(s.split('_')))
sub['site'] = tmp[0]
sub['path'] = tmp[1]
sub['timestamp'] = tmp[2].astype(float)
sub

Unnamed: 0,site_path_timestamp,floor,x,y,site,path,timestamp
0,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,87.954796,102.605637,5a0546857ecc773753327266,046cfa46be49fc10834815c6,9.0
1,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,84.130959,102.008942,5a0546857ecc773753327266,046cfa46be49fc10834815c6,9017.0
2,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,86.090317,104.619865,5a0546857ecc773753327266,046cfa46be49fc10834815c6,15326.0
3,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,87.595718,105.741379,5a0546857ecc773753327266,046cfa46be49fc10834815c6,18763.0
4,5a0546857ecc773753327266_046cfa46be49fc1083481...,0,86.149513,105.548729,5a0546857ecc773753327266,046cfa46be49fc10834815c6,22328.0
...,...,...,...,...,...,...,...
10128,5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f...,5,214.097397,90.171814,5dc8cea7659e181adb076a3f,fd64de8c4a2fc5ebb0e9f412,82589.0
10129,5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f...,5,209.983994,99.197472,5dc8cea7659e181adb076a3f,fd64de8c4a2fc5ebb0e9f412,85758.0
10130,5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f...,5,205.807724,105.314880,5dc8cea7659e181adb076a3f,fd64de8c4a2fc5ebb0e9f412,90895.0
10131,5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f...,5,202.292267,111.906059,5dc8cea7659e181adb076a3f,fd64de8c4a2fc5ebb0e9f412,96899.0


In [60]:
processes = multiprocessing.cpu_count()
with multiprocessing.Pool(processes=processes) as pool:
    dfs = pool.imap_unordered(correct_path, sub.groupby(['path']))
    dfs = tqdm(dfs)
    dfs = list(dfs)
new_sub = pd.concat(dfs).sort_values('site_path_timestamp')


626it [00:49, 12.55it/s]


In [68]:
new_sub = new_sub.set_index('site_path_timestamp')
new_sub

Unnamed: 0_level_0,floor,x,y
site_path_timestamp,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000000009,0,91.328964,95.522559
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000009017,0,81.869227,98.981264
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000015326,0,84.059223,103.985378
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000018763,0,85.353798,107.737236
5a0546857ecc773753327266_046cfa46be49fc10834815c6_0000000022328,0,85.799636,111.303132
...,...,...,...
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000082589,5,210.306021,98.235106
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000085758,5,208.502905,101.867297
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000090895,5,204.425589,108.775446
5dc8cea7659e181adb076a3f_fd64de8c4a2fc5ebb0e9f412_0000000096899,5,196.849895,114.257308


In [69]:
new_sub.to_csv(RUN_NAME + '_postprocess_sub.csv')