<a href="https://colab.research.google.com/github/mixidota2/kaggle-indoor/blob/main/notebook/Indoor_001_baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Overview
- Baselineを構築するためのnotebook
- とりあえずデータ読んで最低限のsubをするだけを目的とする

In [2]:
!nvidia-smi

Tue Mar 30 10:07:17 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.56       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   40C    P0    27W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
import os
filename = "/root/.kaggle/kaggle.json"
os.makedirs(os.path.dirname(filename), exist_ok=True)
!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
os.chmod(filename, 600)

In [4]:
!kaggle datasets download -d kokitanisaka/indoorunifiedwifids
!unzip indoorunifiedwifids.zip > /dev/null

Downloading indoorunifiedwifids.zip to /content
 99% 458M/463M [00:04<00:00, 116MB/s] 
100% 463M/463M [00:04<00:00, 111MB/s]


In [5]:
import os
import gc
import glob 
import copy
import pickle
import random

import pandas as pd
import numpy as np

import yaml
from tqdm import tqdm
from joblib import Parallel, delayed

import seaborn as sns
import matplotlib.pyplot as plt

import cv2

import scipy.stats as stats

from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import mean_squared_error
from sklearn.decomposition import TruncatedSVD

import torch
import torch.nn as nn 
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

tqdm.pandas(position=0, leave=True)

  from pandas import Panel


In [6]:
# consts
N_SPLITS = 5

SEED = 42

NUM_FEATS = 20

In [7]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
def get_timestamp():
    import time
    timestamp = ''
    for i, d in enumerate(time.localtime()):
        if i == 3:
            d += 8
        timestamp += str(d) + '-'
        if i == 4:
            break
    return timestamp[:-1]
def comp_metric(xhat, yhat, fhat, x, y, f):
    intermediate = np.sqrt((xhat-x)**2 + (yhat-y)**2) + 15 * np.abs(fhat-f)
#     intermediate = np.sqrt((xhat-x)**2 + (yhat-y)**2)
    return intermediate.sum()/xhat.shape[0]

## Preprocess

In [8]:
with open(f'train_all.pkl', 'rb') as f:
  data = pickle.load( f)
with open(f'test_all.pkl', 'rb') as f:
  test_data = pickle.load(f)

In [9]:
# count n features
BSSID_FEATS = [f'bssid_{i}' for i in range(NUM_FEATS)]
RSSI_FEATS  = [f'rssi_{i}' for i in range(NUM_FEATS)]

In [10]:
# get unique wifi bssids
wifi_bssids = []
for i in range(100):
    wifi_bssids.extend(data.iloc[:,i].values.tolist())
wifi_bssids = list(set(wifi_bssids))

wifi_bssids_size = len(wifi_bssids)
print(f'BSSID TYPES: {wifi_bssids_size}')

wifi_bssids_test = []
for i in range(100):
    wifi_bssids_test.extend(test_data.iloc[:,i].values.tolist())
wifi_bssids_test = list(set(wifi_bssids_test))

wifi_bssids_size = len(wifi_bssids_test)
print(f'BSSID TYPES: {wifi_bssids_size}')

wifi_bssids.extend(wifi_bssids_test)
wifi_bssids_size = len(wifi_bssids)

BSSID TYPES: 61206
BSSID TYPES: 33042


In [11]:
# preprocess

le = LabelEncoder()
le.fit(wifi_bssids)
le_site = LabelEncoder()
le_site.fit(data['site_id'])

ss = StandardScaler()
ss.fit(data.loc[:,RSSI_FEATS])

StandardScaler(copy=True, with_mean=True, with_std=True)

In [12]:
# apply transforms

data.loc[:,RSSI_FEATS] = ss.transform(data.loc[:,RSSI_FEATS])
for i in BSSID_FEATS:
    data.loc[:,i] = le.transform(data.loc[:,i])
    data.loc[:,i] = data.loc[:,i] + 1
    
data.loc[:, 'site_id'] = le_site.transform(data.loc[:, 'site_id'])

data.loc[:,RSSI_FEATS] = ss.transform(data.loc[:,RSSI_FEATS])

In [13]:
test_data.loc[:,RSSI_FEATS] = ss.transform(test_data.loc[:,RSSI_FEATS])
for i in BSSID_FEATS:
    test_data.loc[:,i] = le.transform(test_data.loc[:,i])
    test_data.loc[:,i] = test_data.loc[:,i] + 1
    
test_data.loc[:, 'site_id'] = le_site.transform(test_data.loc[:, 'site_id'])

test_data.loc[:,RSSI_FEATS] = ss.transform(test_data.loc[:,RSSI_FEATS])

In [14]:
site_count = len(data['site_id'].unique())
data.reset_index(drop=True, inplace=True)

In [15]:
seed_everything(SEED)

## Some EDA

In [16]:
#plt.figure(figsize=(10,3))
#max_iter = 10
#for i, (name, group) in enumerate(data.groupby("path")):
#    sns.lineplot(data=group, y=RSSI_FEATS[0], x=range(group.shape[0]))
#    if i > max_iter:
#        break
#plt.figure(figsize=(10,3))
#for i, (name, group) in enumerate(data.groupby("path")):
#    sns.lineplot(data=group, y="x", x=range(group.shape[0]))
#    if i > max_iter:
#        break
#plt.figure(figsize=(10,3))
#for i, (name, group) in enumerate(data.groupby("path")):
#    sns.lineplot(data=group, y="y", x=range(group.shape[0]))
#    if i > max_iter:
#        break

In [17]:
#tmp = data.loc[:,RSSI_FEATS]
#tmp.head(10)

## Modeling

In [18]:
class IndoorDataset(Dataset):
    def __init__(self, data, flag='TRAIN'):
        self.data = data
        self.flag = flag
    def __len__(self):
        return self.data.shape[0]
    def __getitem__(self, index):
        tmp_data = self.data.iloc[index]
        if self.flag == 'TRAIN':
            return {
                'BSSID_FEATS':tmp_data[BSSID_FEATS].values.astype(float),
                'RSSI_FEATS':tmp_data[RSSI_FEATS].values.astype(float),
                'site_id':tmp_data['site_id'].astype(int),
                'x':tmp_data['x'],
                'y':tmp_data['y'],
                'floor':tmp_data['floor'],
            }
        else:
            return {
                'BSSID_FEATS':tmp_data[BSSID_FEATS].values.astype(float),
                'RSSI_FEATS':tmp_data[RSSI_FEATS].values.astype(float),
                'site_id':tmp_data['site_id'].astype(int)
            }

In [19]:
class SimpleLSTM(nn.Module):
    def __init__(self, embedding_dim = 64, seq_len=20):
        super(SimpleLSTM, self).__init__()
        self.emb_BSSID_FEATS = nn.Embedding(wifi_bssids_size, embedding_dim)
        self.emb_site_id = nn.Embedding(site_count, 2)
        self.lstm1 = nn.LSTM(input_size=256,hidden_size=128, dropout=0.3, bidirectional=False)
        self.lstm2 = nn.LSTM(input_size=128,hidden_size=16, dropout=0.1, bidirectional=False)
        self.lr = nn.Linear(NUM_FEATS, NUM_FEATS * embedding_dim)
        self.lr1 = nn.Linear(2562, 256)
        self.lr_xy = nn.Linear(16, 2)
        self.lr_floor = nn.Linear(16, 1)
        self.batch_norm1 = nn.BatchNorm1d(NUM_FEATS)
        self.batch_norm2 = nn.BatchNorm1d(2562)
        self.batch_norm3 = nn.BatchNorm1d(1)
        self.dropout = nn.Dropout(0.3)
    def forward(self, x):
        
        x_bssid = self.emb_BSSID_FEATS(x['BSSID_FEATS'])
        x_bssid = torch.flatten(x_bssid, start_dim=-2)
        
        x_site_id = self.emb_site_id(x['site_id'])
        x_site_id = torch.flatten(x_site_id, start_dim=-1)
        x_rssi = self.batch_norm1(x['RSSI_FEATS'])
        x_rssi = self.lr(x_rssi)
        x_rssi = torch.relu(x_rssi)
        
        x = torch.cat([x_bssid, x_site_id, x_rssi], dim=-1)
        x = self.batch_norm2(x)
        x = self.dropout(x)
        x = torch.relu(self.lr1(x))

        x = x.unsqueeze(-2)
        x = self.batch_norm3(x)
        x = x.transpose(0, 1)
        x, _ = self.lstm1(x)
        x = x.transpose(0, 1)
        x = torch.relu(x)
        x = x.transpose(0, 1)
        x, _ = self.lstm2(x)
        x = x.transpose(0, 1)
        x = torch.relu(x)
        xy = self.lr_xy(x)
        floor = self.lr_floor(x)
        floor = torch.relu(floor)
        return xy.squeeze(-2), floor.squeeze(-2)

In [20]:
def evaluate(net, data_loader,  device='cuda'):
    net.to(device)
    net.eval()
    x_list = []
    y_list = []
    floor_list = []
    prexs_list = []
    preys_list = []
    prefloors_list = []
    for d in tqdm(data_loader, position=0):
        data_dict = {}
        data_dict['BSSID_FEATS'] = d['BSSID_FEATS'].to(device).long()
        data_dict['RSSI_FEATS'] = d['RSSI_FEATS'].to(device).float()
        data_dict['site_id'] = d['site_id'].to(device).long()
        x = d['x'].to(device).float()
        y = d['y'].to(device).float()
        floor = d['floor'].to(device).long()
        x_list.append(x.cpu().detach().numpy())
        y_list.append(y.cpu().detach().numpy())
        floor_list.append(floor.cpu().detach().numpy())
        xy, floor = net(data_dict)
        prexs_list.append(xy[:, 0].cpu().detach().numpy())
        preys_list.append(xy[:, 1].cpu().detach().numpy())
        prefloors_list.append(floor.squeeze().cpu().detach().numpy())
    x = np.concatenate(x_list)
    y = np.concatenate(y_list)
    floor = np.concatenate(floor_list)
    prexs = np.concatenate(prexs_list)
    preys =np.concatenate(preys_list)
    prefloors = np.concatenate(prefloors_list)
    eval_score = comp_metric(x, y, floor, prexs, preys, prefloors)
    return eval_score
def get_result(net, data_loader, device='cuda'):
    net.eval()
    net.to(device)
    prexs_list = []
    preys_list = []
    prefloors_list = []
    data_dict = {}
    for d in tqdm(data_loader, position=0):
        data_dict['BSSID_FEATS'] = d['BSSID_FEATS'].to(device).long()
        data_dict['RSSI_FEATS'] = d['RSSI_FEATS'].to(device).float()
        data_dict['site_id'] = d['site_id'].to(device).long()
        xy, floor = net(data_dict)
        prexs_list.append(xy[:, 0].cpu().detach().numpy())
        preys_list.append(xy[:, 1].cpu().detach().numpy())
        prefloors_list.append(floor.squeeze(-1).cpu().detach().numpy())
    prexs = np.concatenate(prexs_list)
    preys =np.concatenate(preys_list)
    prefloors = np.concatenate(prefloors_list)
    return prexs, preys, prefloors

## Training

In [21]:
score_df = pd.DataFrame()
oof = list()
predictions = list()

oof_x, oof_y, oof_f = np.zeros(data.shape[0]), np.zeros(data.shape[0]), np.zeros(data.shape[0])
preds_x, preds_y = 0, 0
preds_f_arr = np.zeros((test_data.shape[0], N_SPLITS))

for fold, (trn_idx, val_idx) in enumerate(StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED).split(data.loc[:, 'path'], data.loc[:, 'path'])):

    train_data = data.loc[trn_idx]
    valid_data = data.loc[val_idx]
    train_dataset = IndoorDataset(train_data)
    train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
    valid_dataset = IndoorDataset(valid_data)
    valid_dataloader = DataLoader(valid_dataset, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
    test_dataset = IndoorDataset(test_data, 'TEST')
    test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    net = SimpleLSTM()
    net = net.to(device)

    mse = nn.MSELoss()
    mse = mse.to(device)
    optim = torch.optim.Adam(net.parameters(), lr=5e-3)

    best_loss = 1000
    num_epochs = 15
    best_epoch = 0
    for epoch in range(num_epochs):
        net.train()
        losses = []
        pbar = tqdm(train_dataloader, position=0)
        for d in pbar:
            data_dict = {}
            data_dict['BSSID_FEATS'] = d['BSSID_FEATS'].to(device).long()
            data_dict['RSSI_FEATS'] = d['RSSI_FEATS'].to(device).float()
            data_dict['site_id'] = d['site_id'].to(device).long()
            x = d['x'].to(device).float().unsqueeze(-1)
            y = d['y'].to(device).float().unsqueeze(-1)
            floor = d['floor'].to(device).long()
            xy, floor = net(data_dict)
            label = torch.cat([x, y], dim=-1)
            loss = mse(xy, label)
            loss.backward()
            optim.step()
            optim.zero_grad()
            losses.append(loss.cpu().detach().numpy())
            pbar.set_description(f'loss:{np.mean(losses)}')
            data_dict['BSSID_FEATS'] = data_dict['BSSID_FEATS'].detach()
            data_dict['RSSI_FEATS'] = data_dict['RSSI_FEATS'].detach()
            data_dict['site_id'] = data_dict['site_id'].detach()
            del x, y, xy, floor, label, data_dict, loss, d
            gc.collect()
            torch.cuda.empty_cache()
        score = evaluate(net, valid_dataloader, device)
        if score < best_loss:
            best_loss = score
            best_epoch = epoch
            best_model = copy.deepcopy(net)
        if best_epoch + 2 < epoch:
            break
        print("*="*50)
        print(f"fold {fold} EPOCH {epoch}: mean position error {score}")
        print("*="*50)
    test_x, test_y, test_floor = get_result(best_model, test_dataloader, device)
    preds_f_arr[:,fold] = test_floor
    preds_x += test_x
    preds_y += test_y

  "num_layers={}".format(dropout, num_layers))
  "num_layers={}".format(dropout, num_layers))
loss:10649.4931640625: 100%|██████████| 1614/1614 [03:48<00:00,  7.06it/s]
100%|██████████| 404/404 [00:48<00:00,  8.37it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 0 EPOCH 0: mean position error 122.13354976976294
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:5204.7900390625: 100%|██████████| 1614/1614 [03:53<00:00,  6.90it/s]
100%|██████████| 404/404 [00:48<00:00,  8.27it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 0 EPOCH 1: mean position error 105.48808150108502
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:3217.663330078125: 100%|██████████| 1614/1614 [03:54<00:00,  6.89it/s]
100%|██████████| 404/404 [00:48<00:00,  8.34it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 0 EPOCH 2: mean position error 80.82157851307761
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:2274.22216796875: 100%|██████████| 1614/1614 [03:51<00:00,  6.96it/s]
100%|██████████| 404/404 [00:47<00:00,  8.46it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 0 EPOCH 3: mean position error 75.83896452537705
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:2004.03271484375: 100%|██████████| 1614/1614 [03:50<00:00,  6.99it/s]
100%|██████████| 404/404 [00:47<00:00,  8.42it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 0 EPOCH 4: mean position error 73.24200950169418
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:1693.78271484375: 100%|██████████| 1614/1614 [03:50<00:00,  6.99it/s]
100%|██████████| 404/404 [00:48<00:00,  8.39it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 0 EPOCH 5: mean position error 59.62080650542593
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:869.453857421875: 100%|██████████| 1614/1614 [03:53<00:00,  6.91it/s]
100%|██████████| 404/404 [00:48<00:00,  8.33it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 0 EPOCH 6: mean position error 46.40861018524598
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:432.6205139160156: 100%|██████████| 1614/1614 [03:54<00:00,  6.88it/s]
100%|██████████| 404/404 [00:48<00:00,  8.35it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 0 EPOCH 7: mean position error 39.31835805404489
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:252.3743438720703: 100%|██████████| 1614/1614 [03:51<00:00,  6.98it/s]
100%|██████████| 404/404 [00:46<00:00,  8.62it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 0 EPOCH 8: mean position error 36.034661425334384
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:165.03794860839844: 100%|██████████| 1614/1614 [03:40<00:00,  7.33it/s]
100%|██████████| 404/404 [00:47<00:00,  8.46it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 0 EPOCH 9: mean position error 34.22622668644303
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:116.21392822265625: 100%|██████████| 1614/1614 [03:51<00:00,  6.98it/s]
100%|██████████| 404/404 [00:48<00:00,  8.40it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 0 EPOCH 10: mean position error 33.41744001000553
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:86.0011215209961: 100%|██████████| 1614/1614 [03:51<00:00,  6.98it/s]
100%|██████████| 404/404 [00:47<00:00,  8.50it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 0 EPOCH 11: mean position error 33.238851698537644
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:65.72671508789062: 100%|██████████| 1614/1614 [03:50<00:00,  6.99it/s]
100%|██████████| 404/404 [00:47<00:00,  8.49it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 0 EPOCH 12: mean position error 32.10782973381179
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:52.502498626708984: 100%|██████████| 1614/1614 [03:51<00:00,  6.97it/s]
100%|██████████| 404/404 [00:47<00:00,  8.46it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 0 EPOCH 13: mean position error 31.750645929402538
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:44.3199348449707: 100%|██████████| 1614/1614 [03:50<00:00,  6.99it/s]
100%|██████████| 404/404 [00:47<00:00,  8.42it/s]
  0%|          | 0/80 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 0 EPOCH 14: mean position error 31.57361996569369
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


100%|██████████| 80/80 [00:08<00:00,  9.46it/s]
loss:11824.9765625: 100%|██████████| 1614/1614 [03:50<00:00,  7.00it/s]
100%|██████████| 404/404 [00:47<00:00,  8.45it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 1 EPOCH 0: mean position error 135.56689764790846
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:6392.1591796875: 100%|██████████| 1614/1614 [03:51<00:00,  6.97it/s]
100%|██████████| 404/404 [00:47<00:00,  8.54it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 1 EPOCH 1: mean position error 110.59719232687708
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:4492.55078125: 100%|██████████| 1614/1614 [03:50<00:00,  7.01it/s]
100%|██████████| 404/404 [00:47<00:00,  8.48it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 1 EPOCH 2: mean position error 105.83741958878433
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:3314.332763671875: 100%|██████████| 1614/1614 [03:50<00:00,  7.01it/s]
100%|██████████| 404/404 [00:47<00:00,  8.48it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 1 EPOCH 3: mean position error 80.54844784583537
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:2187.04150390625: 100%|██████████| 1614/1614 [03:53<00:00,  6.90it/s]
100%|██████████| 404/404 [00:48<00:00,  8.27it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 1 EPOCH 4: mean position error 75.54922280667381
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:1967.072998046875: 100%|██████████| 1614/1614 [03:50<00:00,  6.99it/s]
100%|██████████| 404/404 [00:47<00:00,  8.50it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 1 EPOCH 5: mean position error 73.22001979477763
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:1660.419189453125: 100%|██████████| 1614/1614 [03:44<00:00,  7.20it/s]
100%|██████████| 404/404 [00:47<00:00,  8.48it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 1 EPOCH 6: mean position error 59.0660053281634
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:855.2752685546875: 100%|██████████| 1614/1614 [03:49<00:00,  7.02it/s]
100%|██████████| 404/404 [00:47<00:00,  8.43it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 1 EPOCH 7: mean position error 47.169015062472596
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:486.5459289550781: 100%|██████████| 1614/1614 [03:51<00:00,  6.98it/s]
100%|██████████| 404/404 [00:47<00:00,  8.46it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 1 EPOCH 8: mean position error 41.14009145557175
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:302.10565185546875: 100%|██████████| 1614/1614 [03:48<00:00,  7.08it/s]
100%|██████████| 404/404 [00:46<00:00,  8.61it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 1 EPOCH 9: mean position error 37.45223076672644
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:202.1479034423828: 100%|██████████| 1614/1614 [03:47<00:00,  7.10it/s]
100%|██████████| 404/404 [00:47<00:00,  8.55it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 1 EPOCH 10: mean position error 35.45362831687366
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:144.0344696044922: 100%|██████████| 1614/1614 [03:54<00:00,  6.87it/s]
100%|██████████| 404/404 [00:48<00:00,  8.32it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 1 EPOCH 11: mean position error 34.14314233164086
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:107.1541976928711: 100%|██████████| 1614/1614 [03:52<00:00,  6.94it/s]
100%|██████████| 404/404 [00:47<00:00,  8.46it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 1 EPOCH 12: mean position error 33.24376994088509
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:82.7732162475586: 100%|██████████| 1614/1614 [03:54<00:00,  6.88it/s]
100%|██████████| 404/404 [00:48<00:00,  8.34it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 1 EPOCH 13: mean position error 32.6818482392654
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:65.92532348632812: 100%|██████████| 1614/1614 [03:53<00:00,  6.92it/s]
100%|██████████| 404/404 [00:48<00:00,  8.37it/s]
  0%|          | 0/80 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 1 EPOCH 14: mean position error 32.57752543632187
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


100%|██████████| 80/80 [00:08<00:00,  9.29it/s]
loss:10350.080078125: 100%|██████████| 1614/1614 [03:54<00:00,  6.90it/s]
100%|██████████| 404/404 [00:48<00:00,  8.37it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 2 EPOCH 0: mean position error 120.16523332073037
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:4979.11572265625: 100%|██████████| 1614/1614 [03:53<00:00,  6.90it/s]
100%|██████████| 404/404 [00:48<00:00,  8.39it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 2 EPOCH 1: mean position error 106.01154941048519
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:3528.6865234375: 100%|██████████| 1614/1614 [03:44<00:00,  7.20it/s]
100%|██████████| 404/404 [00:47<00:00,  8.46it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 2 EPOCH 2: mean position error 81.90382374315126
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:2256.374755859375: 100%|██████████| 1614/1614 [03:52<00:00,  6.95it/s]
100%|██████████| 404/404 [00:48<00:00,  8.25it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 2 EPOCH 3: mean position error 75.5405467931543
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:1979.998779296875: 100%|██████████| 1614/1614 [03:54<00:00,  6.89it/s]
100%|██████████| 404/404 [00:49<00:00,  8.20it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 2 EPOCH 4: mean position error 72.5182867570912
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:1540.78564453125: 100%|██████████| 1614/1614 [03:58<00:00,  6.78it/s]
100%|██████████| 404/404 [00:49<00:00,  8.20it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 2 EPOCH 5: mean position error 56.78940869284181
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:717.82666015625: 100%|██████████| 1614/1614 [03:57<00:00,  6.81it/s]
100%|██████████| 404/404 [00:48<00:00,  8.26it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 2 EPOCH 6: mean position error 44.17589678666995
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:371.7783508300781: 100%|██████████| 1614/1614 [04:04<00:00,  6.61it/s]
100%|██████████| 404/404 [00:49<00:00,  8.08it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 2 EPOCH 7: mean position error 38.99205461159341
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:225.0804443359375: 100%|██████████| 1614/1614 [04:09<00:00,  6.48it/s]
100%|██████████| 404/404 [00:50<00:00,  8.07it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 2 EPOCH 8: mean position error 36.40538456139481
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:152.44215393066406: 100%|██████████| 1614/1614 [04:07<00:00,  6.53it/s]
100%|██████████| 404/404 [00:49<00:00,  8.23it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 2 EPOCH 9: mean position error 34.690769959420834
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:107.5768051147461: 100%|██████████| 1614/1614 [04:10<00:00,  6.44it/s]
100%|██████████| 404/404 [00:50<00:00,  8.02it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 2 EPOCH 10: mean position error 33.75066040891019
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:79.41162872314453: 100%|██████████| 1614/1614 [04:11<00:00,  6.43it/s]
100%|██████████| 404/404 [00:50<00:00,  8.05it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 2 EPOCH 11: mean position error 32.86444705869539
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:61.861541748046875: 100%|██████████| 1614/1614 [04:10<00:00,  6.43it/s]
100%|██████████| 404/404 [00:49<00:00,  8.08it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 2 EPOCH 12: mean position error 32.42684752720248
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:50.93075180053711: 100%|██████████| 1614/1614 [04:09<00:00,  6.47it/s]
100%|██████████| 404/404 [00:49<00:00,  8.13it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 2 EPOCH 13: mean position error 32.11856905150767
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:44.82150650024414: 100%|██████████| 1614/1614 [04:03<00:00,  6.62it/s]
100%|██████████| 404/404 [00:49<00:00,  8.19it/s]
  0%|          | 0/80 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 2 EPOCH 14: mean position error 31.92470166188767
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


100%|██████████| 80/80 [00:08<00:00,  8.96it/s]
loss:10699.3349609375: 100%|██████████| 1614/1614 [04:06<00:00,  6.56it/s]
100%|██████████| 404/404 [00:50<00:00,  8.08it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 3 EPOCH 0: mean position error 123.35757341018399
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:5224.74853515625: 100%|██████████| 1614/1614 [04:06<00:00,  6.56it/s]
100%|██████████| 404/404 [00:49<00:00,  8.09it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 3 EPOCH 1: mean position error 106.48153414531136
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:4216.7587890625: 100%|██████████| 1614/1614 [04:05<00:00,  6.57it/s]
100%|██████████| 404/404 [00:49<00:00,  8.16it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 3 EPOCH 2: mean position error 106.39971158074755
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:3598.45751953125: 100%|██████████| 1614/1614 [04:05<00:00,  6.57it/s]
100%|██████████| 404/404 [00:49<00:00,  8.15it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 3 EPOCH 3: mean position error 79.43776648424223
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:2047.4468994140625: 100%|██████████| 1614/1614 [04:02<00:00,  6.67it/s]
100%|██████████| 404/404 [00:49<00:00,  8.11it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 3 EPOCH 4: mean position error 72.635699953676
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:1466.3330078125: 100%|██████████| 1614/1614 [04:06<00:00,  6.54it/s]
100%|██████████| 404/404 [00:49<00:00,  8.13it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 3 EPOCH 5: mean position error 54.851090287474904
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:690.8379516601562: 100%|██████████| 1614/1614 [04:07<00:00,  6.53it/s]
100%|██████████| 404/404 [00:49<00:00,  8.11it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 3 EPOCH 6: mean position error 43.808218331580264
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:367.17108154296875: 100%|██████████| 1614/1614 [04:07<00:00,  6.52it/s]
100%|██████████| 404/404 [00:50<00:00,  8.08it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 3 EPOCH 7: mean position error 38.66201499452213
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:221.56056213378906: 100%|██████████| 1614/1614 [04:08<00:00,  6.49it/s]
100%|██████████| 404/404 [00:50<00:00,  8.04it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 3 EPOCH 8: mean position error 35.668366576125344
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:148.7327423095703: 100%|██████████| 1614/1614 [04:08<00:00,  6.50it/s]
100%|██████████| 404/404 [00:50<00:00,  8.07it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 3 EPOCH 9: mean position error 33.934665685207015
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:104.59144592285156: 100%|██████████| 1614/1614 [04:07<00:00,  6.52it/s]
100%|██████████| 404/404 [00:50<00:00,  8.07it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 3 EPOCH 10: mean position error 32.898486787868904
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:77.31887817382812: 100%|██████████| 1614/1614 [04:08<00:00,  6.50it/s]
100%|██████████| 404/404 [00:50<00:00,  8.04it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 3 EPOCH 11: mean position error 32.146625778205745
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:60.021995544433594: 100%|██████████| 1614/1614 [04:06<00:00,  6.54it/s]
100%|██████████| 404/404 [00:50<00:00,  8.07it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 3 EPOCH 12: mean position error 32.08572812100877
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:48.21564865112305: 100%|██████████| 1614/1614 [04:08<00:00,  6.50it/s]
100%|██████████| 404/404 [00:50<00:00,  8.04it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 3 EPOCH 13: mean position error 31.66145778187287
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:40.604530334472656: 100%|██████████| 1614/1614 [04:08<00:00,  6.49it/s]
100%|██████████| 404/404 [00:49<00:00,  8.14it/s]
  0%|          | 0/80 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 3 EPOCH 14: mean position error 31.285068559389895
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


100%|██████████| 80/80 [00:09<00:00,  8.83it/s]
loss:11042.525390625: 100%|██████████| 1614/1614 [04:08<00:00,  6.51it/s]
100%|██████████| 404/404 [00:50<00:00,  8.05it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 4 EPOCH 0: mean position error 124.79008050274763
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:5524.2578125: 100%|██████████| 1614/1614 [04:04<00:00,  6.60it/s]
100%|██████████| 404/404 [00:49<00:00,  8.12it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 4 EPOCH 1: mean position error 105.10459801298077
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:4265.3623046875: 100%|██████████| 1614/1614 [04:10<00:00,  6.43it/s]
100%|██████████| 404/404 [00:50<00:00,  8.02it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 4 EPOCH 2: mean position error 104.29229740910581
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:3292.40087890625: 100%|██████████| 1614/1614 [04:11<00:00,  6.43it/s]
100%|██████████| 404/404 [00:50<00:00,  8.05it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 4 EPOCH 3: mean position error 78.34838099329566
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:2133.102294921875: 100%|██████████| 1614/1614 [04:11<00:00,  6.42it/s]
100%|██████████| 404/404 [00:50<00:00,  8.02it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 4 EPOCH 4: mean position error 74.16114994128318
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:1925.05859375: 100%|██████████| 1614/1614 [04:11<00:00,  6.42it/s]
100%|██████████| 404/404 [00:49<00:00,  8.10it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 4 EPOCH 5: mean position error 71.88535744884949
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:1623.031005859375: 100%|██████████| 1614/1614 [04:12<00:00,  6.40it/s]
100%|██████████| 404/404 [00:50<00:00,  8.06it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 4 EPOCH 6: mean position error 57.61737476669212
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:793.8807373046875: 100%|██████████| 1614/1614 [04:11<00:00,  6.41it/s]
100%|██████████| 404/404 [00:50<00:00,  8.06it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 4 EPOCH 7: mean position error 45.109456845776435
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:423.9193420410156: 100%|██████████| 1614/1614 [04:11<00:00,  6.43it/s]
100%|██████████| 404/404 [00:50<00:00,  8.07it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 4 EPOCH 8: mean position error 39.189449994154764
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:263.799072265625: 100%|██████████| 1614/1614 [04:09<00:00,  6.46it/s]
100%|██████████| 404/404 [00:50<00:00,  8.04it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 4 EPOCH 9: mean position error 35.92887765573488
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:181.4266815185547: 100%|██████████| 1614/1614 [04:12<00:00,  6.40it/s]
100%|██████████| 404/404 [00:50<00:00,  8.05it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 4 EPOCH 10: mean position error 34.11004683437373
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:131.43060302734375: 100%|██████████| 1614/1614 [04:15<00:00,  6.31it/s]
100%|██████████| 404/404 [00:50<00:00,  8.01it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 4 EPOCH 11: mean position error 33.14419376877928
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:100.43680572509766: 100%|██████████| 1614/1614 [04:13<00:00,  6.37it/s]
100%|██████████| 404/404 [00:50<00:00,  8.05it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 4 EPOCH 12: mean position error 32.41727367187568
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:79.07003021240234: 100%|██████████| 1614/1614 [04:09<00:00,  6.48it/s]
100%|██████████| 404/404 [00:50<00:00,  8.00it/s]
  0%|          | 0/1614 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 4 EPOCH 13: mean position error 31.963135911421755
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


loss:64.64217376708984: 100%|██████████| 1614/1614 [04:13<00:00,  6.36it/s]
100%|██████████| 404/404 [00:50<00:00,  8.05it/s]
  0%|          | 0/80 [00:00<?, ?it/s]

*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=
fold 4 EPOCH 14: mean position error 31.56452147766356
*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=


100%|██████████| 80/80 [00:08<00:00,  9.03it/s]


In [24]:
subm = pd.read_csv('sample_submission.csv', index_col=0)

test_x = preds_x / (fold + 1) 
test_y = preds_y /(fold + 1)
print("*+"*40)
# as it breaks in the middle of cross-validation, the score is not accurate at all.
score = comp_metric(oof_x, oof_y, oof_f, data.iloc[:, -5].to_numpy(), data.iloc[:, -4].to_numpy(), data.iloc[:, -3].to_numpy())
oof.append(score)
print(f"mean position error {score}")
print("*+"*40)
preds_f_mode = stats.mode(preds_f_arr, axis=1)
preds_f = preds_f_mode[0].astype(int).reshape(-1)
test_preds = pd.DataFrame(np.stack((preds_f, test_x, test_y))).T
test_preds.columns = subm.columns
test_preds.index = test_data["site_path_timestamp"]
test_preds["floor"] = test_preds["floor"].astype(int)
predictions.append(test_preds)

*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+
mean position error 192.2107121781046
*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+*+


In [25]:
all_preds = pd.concat(predictions)
all_preds = all_preds.reindex(subm.index)

In [26]:
simple_accurate_99 = pd.read_csv('submission_99.csv')
all_preds['floor'] = simple_accurate_99['floor'].values

In [27]:
all_preds.to_csv('submission.csv')

## Post Processing

In [29]:
!git clone --depth 1 https://github.com/location-competition/indoor-location-competition-20 indoor_location_competition_20 > /dev/null
!rm -rf indoor_location_competition_20/data > /dev/null

Cloning into 'indoor_location_competition_20'...
remote: Enumerating objects: 1169, done.[K
remote: Counting objects: 100% (1169/1169), done.[K
remote: Compressing objects: 100% (1131/1131), done.[K
remote: Total 1169 (delta 38), reused 1167 (delta 38), pack-reused 0[K
Receiving objects: 100% (1169/1169), 411.37 MiB | 18.03 MiB/s, done.
Resolving deltas: 100% (38/38), done.
Checking out files: 100% (1145/1145), done.


In [36]:
!pip install --upgrade kaggle
!mkdir indoor_dataset
!kaggle competitions download -c indoor-location-navigation -p indoor_dataset > /dev/null
!unzip indoor_dataset/indoor-location-navigation.zip > /dev/null

Collecting kaggle
[?25l  Downloading https://files.pythonhosted.org/packages/3a/e7/3bac01547d2ed3d308ac92a0878fbdb0ed0f3d41fb1906c319ccbba1bfbc/kaggle-1.5.12.tar.gz (58kB)
[K     |█████▋                          | 10kB 19.2MB/s eta 0:00:01[K     |███████████▏                    | 20kB 9.6MB/s eta 0:00:01[K     |████████████████▊               | 30kB 7.4MB/s eta 0:00:01[K     |██████████████████████▎         | 40kB 6.5MB/s eta 0:00:01[K     |███████████████████████████▉    | 51kB 4.7MB/s eta 0:00:01[K     |████████████████████████████████| 61kB 3.1MB/s 
Building wheels for collected packages: kaggle
  Building wheel for kaggle (setup.py) ... [?25l[?25hdone
  Created wheel for kaggle: filename=kaggle-1.5.12-cp37-none-any.whl size=73053 sha256=aaab3af6a4db14168e6ee12031389216a43061c883485305a52d71a557860ed4
  Stored in directory: /root/.cache/pip/wheels/a1/6a/26/d30b7499ff85a4a4593377a87ecf55f7d08af42f0de9b60303
Successfully built kaggle
Installing collected packages: kaggl

In [30]:
import multiprocessing
import numpy as np
import pandas as pd
import scipy.interpolate
import scipy.sparse
from tqdm import tqdm

from indoor_location_competition_20.io_f import read_data_file
import indoor_location_competition_20.compute_f as compute_f

In [31]:
def compute_rel_positions(acce_datas, ahrs_datas):
    step_timestamps, step_indexs, step_acce_max_mins = compute_f.compute_steps(acce_datas)
    headings = compute_f.compute_headings(ahrs_datas)
    stride_lengths = compute_f.compute_stride_length(step_acce_max_mins)
    step_headings = compute_f.compute_step_heading(step_timestamps, headings)
    rel_positions = compute_f.compute_rel_positions(stride_lengths, step_headings)
    return rel_positions

def correct_path(args):
    path, path_df = args
    
    T_ref  = path_df['timestamp'].values
    xy_hat = path_df[['x', 'y']].values
    
    example = read_data_file(f'{INPUT_PATH}/test/{path}.txt')
    rel_positions = compute_rel_positions(example.acce, example.ahrs)
    if T_ref[-1] > rel_positions[-1, 0]:
        rel_positions = [np.array([[0, 0, 0]]), rel_positions, np.array([[T_ref[-1], 0, 0]])]
    else:
        rel_positions = [np.array([[0, 0, 0]]), rel_positions]
    rel_positions = np.concatenate(rel_positions)
    
    T_rel = rel_positions[:, 0]
    delta_xy_hat = np.diff(scipy.interpolate.interp1d(T_rel, np.cumsum(rel_positions[:, 1:3], axis=0), axis=0)(T_ref), axis=0)

    N = xy_hat.shape[0]
    delta_t = np.diff(T_ref)
    alpha = (8.1)**(-2) * np.ones(N)
    beta  = (0.3 + 0.3 * 1e-3 * delta_t)**(-2)
    A = scipy.sparse.spdiags(alpha, [0], N, N)
    B = scipy.sparse.spdiags( beta, [0], N-1, N-1)
    D = scipy.sparse.spdiags(np.stack([-np.ones(N), np.ones(N)]), [0, 1], N-1, N)

    Q = A + (D.T @ B @ D)
    c = (A @ xy_hat) + (D.T @ (B @ delta_xy_hat))
    xy_star = scipy.sparse.linalg.spsolve(Q, c)

    return pd.DataFrame({
        'site_path_timestamp' : path_df['site_path_timestamp'],
        'floor' : path_df['floor'],
        'x' : xy_star[:, 0],
        'y' : xy_star[:, 1],
    })

In [37]:
INPUT_PATH = './indoor_dataset'

sub = pd.read_csv('submission.csv')
tmp = sub['site_path_timestamp'].apply(lambda s : pd.Series(s.split('_')))
sub['site'] = tmp[0]
sub['path'] = tmp[1]
sub['timestamp'] = tmp[2].astype(float)

processes = multiprocessing.cpu_count()
with multiprocessing.Pool(processes=processes) as pool:
    dfs = pool.imap_unordered(correct_path, sub.groupby('path'))
    dfs = tqdm(dfs)
    dfs = list(dfs)
sub = pd.concat(dfs).sort_values('site_path_timestamp')
sub.to_csv('submission_post.csv', index=False)



0it [00:00, ?it/s][A[A

1it [00:01,  1.98s/it][A[A

3it [00:02,  1.41s/it][A[A

4it [00:02,  1.24s/it][A[A

5it [00:03,  1.07s/it][A[A

6it [00:03,  1.21it/s][A[A

7it [00:04,  1.27it/s][A[A

9it [00:05,  1.48it/s][A[A

10it [00:06,  1.50it/s][A[A

11it [00:06,  1.45it/s][A[A

12it [00:07,  1.60it/s][A[A

13it [00:09,  1.18s/it][A[A

14it [00:10,  1.11it/s][A[A

15it [00:10,  1.26it/s][A[A

16it [00:12,  1.06s/it][A[A

18it [00:13,  1.03it/s][A[A

19it [00:13,  1.39it/s][A[A

20it [00:14,  1.50it/s][A[A

21it [00:14,  2.00it/s][A[A

22it [00:15,  1.32it/s][A[A

23it [00:16,  1.66it/s][A[A

24it [00:17,  1.06it/s][A[A

26it [00:18,  1.34it/s][A[A

27it [00:20,  1.18s/it][A[A

28it [00:21,  1.08it/s][A[A

30it [00:22,  1.10it/s][A[A

32it [00:23,  1.41it/s][A[A

33it [00:23,  1.68it/s][A[A

34it [00:24,  1.44it/s][A[A

35it [00:25,  1.38it/s][A[A

37it [00:25,  1.70it/s][A[A

38it [00:26,  1.77it/s][A[A

39it [00:26,  1.92

## Check Tensor Shapes

In [2]:
tmp = data.loc[:9,BSSID_FEATS]
tmp_ids = []
for i in range(20):
    tmp_ids.extend(tmp.iloc[:,i].values.tolist())
tmp_ids = len(list(set(tmp_ids)))
_emb = nn.Embedding(wifi_bssids_size, 64)
_res = _emb(torch.tensor(tmp.values.astype(float)).long())
_res = torch.flatten(_res, start_dim=-2)
#torch.tensor(tmp.values).size()

tmp2 = data.loc[:9,RSSI_FEATS]
tmp2_ids = []
for i in range(20):
    tmp2_ids.extend(tmp2.iloc[:,i].values.tolist())
tmp2_ids = len(list(set(tmp2_ids)))
lr = nn.Linear(20, 1280)
_res2 = lr(torch.tensor(tmp2.values.astype(float)).float())

tmp3 = data.loc[:9,"site_id"]
_emb2 = nn.Embedding(site_count, 2)
_res3 = _emb2(torch.tensor(tmp3.values.astype(float)).long())
_res3 = torch.flatten(_res3, start_dim=-1)

NameError: ignored

In [None]:
_res.size(), _res2.size(), _res3.size()

In [None]:
_all = torch.cat([_res, _res2, _res3], dim=-1)
_all = nn.Linear(2562, 256)(_all)

In [None]:
_all.size()

In [None]:
_un = _all.unsqueeze(-2)
_un.size()

In [None]:
_tr = _un.transpose(0, 1)
_tr.size()

In [None]:
_lstm1 = nn.LSTM(input_size=256,hidden_size=128, dropout=0.3, bidirectional=False)
_ls1, _ = _lstm1(_tr)
_ls1.size()

In [None]:
_xy = _ls1.transpose(0, 1)
_xy = nn.Linear(128,2)(_xy)
_xy.size()

In [None]:
_xy.squeeze(-2).size()