### Imports

In [1]:
import json
import numpy as np
import pandas as pd
import random
import time

from pathlib import Path
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split, StratifiedShuffleSplit
from sklearn.preprocessing import LabelBinarizer
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader, Subset

from datasets import SurfaceDataset
from helpers import EarlyStopper
from models import CNNSurfaceClassifier

### Device

In [2]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device

'cuda:0'

### Seed

In [3]:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(0)

<torch._C.Generator at 0x16ab1f8a5f0>

### Constants

In [4]:
BATCH_SIZE = 32
INPUT_SIZE = 2
NUM_EPOCHS = 100
DATA_DIR = Path('../data/train_set/csv/')
HISTORY_DIR = Path('../results/tuning/')
LOOKBACK = 8/3
SAMPLING_FREQUENCY = 75.
DATASET_FREQUENCY = 150.
SUBSET = ('servo',)
CONFIGURATIONS = ('6W',)

### Load and split data

In [5]:
with open('../data/train_set/labels.json') as fp:
    labels = json.load(fp)

In [6]:
dataset = [(DATA_DIR.joinpath(key + '.csv'), values['surface']) for key, values in labels.items() if values['kinematics'] in CONFIGURATIONS and values['spacing'] == 'R1' and 'T1' in values['trajectory']]

In [7]:
X = pd.Series([run[0] for run in dataset], name='bag_name')
y_primary = [run[1] for run in dataset]

In [8]:
y_secondary = []
# y_secondary = ['slippery' if label in ('1_Panele', '5_Spienione_PCV', '6_Linoleum')
#                else 'grippy' if label in ('3_Wykladzina_jasna', '8_Pusta_plyta', '9_podklady')
#                else 'neutral' for label in y_primary]
# y_secondary = ['slippery' if label in ('3_Wykladzina_jasna', '4_Trawa')
#                else 'grippy' if label in ('5_Spienione_PCV', '8_Pusta_plyta', '9_podklady', '10_Mata_ukladana')
#                else 'neutral' for label in y_primary] # Pawel set
# y_secondary = ['slippery' if label in ('3_Wykladzina_jasna', '4_Trawa')
#                else 'grippy' if label in ('2_Wykladzina_czarna', '5_Spienione_PCV', '9_podklady', '10_Mata_ukladana')
#                else 'neutral' for label in y_primary] # Clustering set

In [9]:
lb = LabelBinarizer()
if y_secondary:
    lb.fit(y_secondary)
    y = lb.transform(y_secondary)
else:
    lb.fit(y_primary)
    y = lb.transform(y_primary)
classes = lb.classes_
num_classes = len(classes)
y = y.reshape(-1, num_classes)

### Custom datasets

In [10]:
cv_data = SurfaceDataset(X, y, sample_freq=SAMPLING_FREQUENCY, data_freq=DATASET_FREQUENCY, lookback=LOOKBACK, subset=SUBSET)

### Loss function

In [11]:
criterion = nn.CrossEntropyLoss()

### Training loop

In [12]:
history = {}

sss = StratifiedShuffleSplit(test_size=0.2)
for i, (training_index, test_index) in enumerate(sss.split(X, y)):
    # Initialize the model in each split
    cnn_model = CNNSurfaceClassifier(input_size=INPUT_SIZE, output_size=num_classes).to(device)
    # Initialize optimizer in each split
    optimizer = torch.optim.Adam(
        cnn_model.parameters(),
        lr=1e-3,
        eps=1e-6,
        weight_decay=1e-3,
        )
    # Initialize scheduler in each split
    scheduler = ExponentialLR(optimizer, gamma=0.9)
    # Initialize early stopping
    early_stopper = EarlyStopper()
    
    # Separate hold-out fold
    train_index, val_index = train_test_split(training_index, test_size=0.2, stratify=y[training_index])
    
    train_dataloader = DataLoader(
        Subset(cv_data, train_index),
        batch_size=BATCH_SIZE,
        worker_init_fn=seed_worker,
        generator=g,
        shuffle=True,
    )
    val_dataloader = DataLoader(
        Subset(cv_data, val_index),
        batch_size=BATCH_SIZE,
        worker_init_fn=seed_worker,
        generator=g,
    )
    test_dataloader = DataLoader(
        Subset(cv_data, test_index),
        batch_size=BATCH_SIZE,
        worker_init_fn=seed_worker,
        generator=g,
    )
    
    train_batches = len(train_dataloader)
    val_batches = len(val_dataloader)

    for epoch in range(NUM_EPOCHS):
        running_train_loss = 0.0
        running_val_loss = 0.0
        
        pbar = tqdm(train_dataloader, total=train_batches)
        cnn_model.train()
        for idx, (batch_x, batch_y) in enumerate(pbar):
            optimizer.zero_grad()
            
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            batch_x = batch_x.permute(0, 2, 1)
            train_outputs = cnn_model(batch_x)
            train_loss = criterion(train_outputs, batch_y)
            running_train_loss += train_loss
    
            # Backward pass
            train_loss.backward()
            optimizer.step()
    
            pbar.set_description(f"Fold {i + 1}/{sss.get_n_splits()}, Epoch {epoch + 1}/{NUM_EPOCHS}, Training loss: {running_train_loss / (idx + 1):.2E}")
        scheduler.step()
        
        pbar_val = tqdm(val_dataloader, total=val_batches)
        cnn_model.eval()
        with torch.no_grad():
            for idx, (batch_x_val, batch_y_val) in enumerate(pbar_val):
                batch_x_val, batch_y_val = batch_x_val.to(device), batch_y_val.to(device)
                batch_x_val = batch_x_val.permute(0, 2, 1)
                val_outputs = cnn_model(batch_x_val)
                val_loss = criterion(val_outputs, batch_y_val)
                running_val_loss += val_loss
                                
                pbar_val.set_description(f"Fold {i + 1}/{sss.get_n_splits()}, Epoch {epoch + 1}/{NUM_EPOCHS}, Validation loss: {running_val_loss / (idx + 1):.2E}")
                
        validation_loss = running_val_loss / (idx + 1)
        if early_stopper.early_stop(validation_loss):
            print(f"Split {i + 1} ended on epoch {epoch + 1 - early_stopper.patience}!")
            break
        if early_stopper.counter == 0:
            best_model = cnn_model.state_dict()
        
    cnn_model.load_state_dict(best_model)
    
    test_batches = len(test_dataloader)
    y_true, y_pred = [], []
    running_test_loss = 0.0
    
    pbar_test = tqdm(test_dataloader, total=test_batches)
    cnn_model.eval()
    with torch.no_grad():
        for idx, (batch_x_test, batch_y_test) in enumerate(pbar_test):
            batch_x_test, batch_y_test = batch_x_test.to(device), batch_y_test.to(device)
            batch_x_test = batch_x_test.permute(0, 2, 1)
            test_outputs = cnn_model(batch_x_test)
            test_loss = criterion(test_outputs, batch_y_test)
            running_test_loss += test_loss
            
            y_true.extend(torch.argmax(batch_y_test, dim=1).cpu().numpy())
            y_pred.extend(torch.argmax(test_outputs, dim=1).cpu().numpy())
            
            pbar_test.set_description(f"Fold {i + 1}/{sss.get_n_splits()}, Test loss: {running_test_loss / (idx + 1):.2E}")
    
    history[i + 1] = {'accuracy': accuracy_score(y_true, y_pred), 'f1_score': f1_score(y_true, y_pred, average='macro')}

history_filename = '_'.join((str(num_classes),) + CONFIGURATIONS + SUBSET) + '_' + time.strftime("%Y-%m-%d-%H-%M-%S")
json.dump(history, open(HISTORY_DIR / f'{history_filename}.json', 'w'))

  return F.conv1d(input, weight, bias, self.stride,
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  return F.conv1d(input, weight, bias, self.stride,
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Fold 1/10, Epoch 1/100, Training loss: 1.91E+00: 100%|██████████| 7/7 [00:04<00:00,  1.54it/s]
  return F.conv1d(input, weight, bias, self.stride,
Fold 1/10, Epoch 1/100, Validation loss: 2.32E+00: 100%|██████████| 2/2 [00:01<00:00,  1.66it/s]
Fold 1/10, Epoch 2/100, Training loss: 1.59E+00: 100%|██████████| 7/7 [00:04<00:00,  1.65it/s]
Fold 1/10, Epoch 2/100, Validation loss: 2.62E+00: 100%|██████████| 2/2 [00:01<00:00,  1.94it/s]
Fold 1/10, Epoch 3/100, Training loss: 1.57E+00: 100%|██████████| 7/7 [00:04<00:00,  1.51it/s]
Fold 1/10, Epoch 3/100, Validation loss: 3.26E+00: 100%|██████████| 2/2 [00:00<00:00,  2.04it/s]
Fold 1/10, Epoch 4/100, Training loss: 1.50E+00: 100%|██████

Split 1 ended on epoch 35!


Fold 1/10, Test loss: 5.02E-01: 100%|██████████| 2/2 [00:01<00:00,  1.61it/s]
Fold 2/10, Epoch 1/100, Training loss: 1.87E+00: 100%|██████████| 7/7 [00:04<00:00,  1.68it/s]
Fold 2/10, Epoch 1/100, Validation loss: 2.32E+00: 100%|██████████| 2/2 [00:01<00:00,  1.95it/s]
Fold 2/10, Epoch 2/100, Training loss: 1.47E+00: 100%|██████████| 7/7 [00:04<00:00,  1.53it/s]
Fold 2/10, Epoch 2/100, Validation loss: 2.52E+00: 100%|██████████| 2/2 [00:01<00:00,  1.98it/s]
Fold 2/10, Epoch 3/100, Training loss: 1.32E+00: 100%|██████████| 7/7 [00:04<00:00,  1.71it/s]
Fold 2/10, Epoch 3/100, Validation loss: 3.20E+00: 100%|██████████| 2/2 [00:01<00:00,  1.55it/s]
Fold 2/10, Epoch 4/100, Training loss: 1.16E+00: 100%|██████████| 7/7 [00:04<00:00,  1.73it/s]
Fold 2/10, Epoch 4/100, Validation loss: 3.56E+00: 100%|██████████| 2/2 [00:01<00:00,  1.96it/s]
Fold 2/10, Epoch 5/100, Training loss: 1.14E+00: 100%|██████████| 7/7 [00:04<00:00,  1.64it/s]
Fold 2/10, Epoch 5/100, Validation loss: 4.05E+00: 100%|███

Split 2 ended on epoch 32!


Fold 2/10, Test loss: 5.91E-01: 100%|██████████| 2/2 [00:01<00:00,  1.30it/s]
Fold 3/10, Epoch 1/100, Training loss: 1.94E+00: 100%|██████████| 7/7 [00:04<00:00,  1.67it/s]
Fold 3/10, Epoch 1/100, Validation loss: 2.31E+00: 100%|██████████| 2/2 [00:01<00:00,  1.74it/s]
Fold 3/10, Epoch 2/100, Training loss: 1.47E+00: 100%|██████████| 7/7 [00:04<00:00,  1.72it/s]
Fold 3/10, Epoch 2/100, Validation loss: 2.49E+00: 100%|██████████| 2/2 [00:00<00:00,  2.02it/s]
Fold 3/10, Epoch 3/100, Training loss: 1.42E+00: 100%|██████████| 7/7 [00:04<00:00,  1.46it/s]
Fold 3/10, Epoch 3/100, Validation loss: 3.19E+00: 100%|██████████| 2/2 [00:01<00:00,  1.93it/s]
Fold 3/10, Epoch 4/100, Training loss: 1.49E+00: 100%|██████████| 7/7 [00:04<00:00,  1.55it/s]
Fold 3/10, Epoch 4/100, Validation loss: 3.42E+00: 100%|██████████| 2/2 [00:00<00:00,  2.01it/s]
Fold 3/10, Epoch 5/100, Training loss: 1.34E+00: 100%|██████████| 7/7 [00:04<00:00,  1.53it/s]
Fold 3/10, Epoch 5/100, Validation loss: 3.35E+00: 100%|███

Split 3 ended on epoch 47!


Fold 3/10, Test loss: 4.84E-01: 100%|██████████| 2/2 [00:01<00:00,  1.60it/s]
Fold 4/10, Epoch 1/100, Training loss: 1.85E+00: 100%|██████████| 7/7 [00:04<00:00,  1.60it/s]
Fold 4/10, Epoch 1/100, Validation loss: 2.32E+00: 100%|██████████| 2/2 [00:00<00:00,  2.00it/s]
Fold 4/10, Epoch 2/100, Training loss: 1.46E+00: 100%|██████████| 7/7 [00:04<00:00,  1.66it/s]
Fold 4/10, Epoch 2/100, Validation loss: 2.58E+00: 100%|██████████| 2/2 [00:01<00:00,  1.79it/s]
Fold 4/10, Epoch 3/100, Training loss: 1.29E+00: 100%|██████████| 7/7 [00:04<00:00,  1.64it/s]
Fold 4/10, Epoch 3/100, Validation loss: 3.26E+00: 100%|██████████| 2/2 [00:01<00:00,  1.95it/s]
Fold 4/10, Epoch 4/100, Training loss: 1.41E+00: 100%|██████████| 7/7 [00:04<00:00,  1.57it/s]
Fold 4/10, Epoch 4/100, Validation loss: 4.19E+00: 100%|██████████| 2/2 [00:01<00:00,  1.96it/s]
Fold 4/10, Epoch 5/100, Training loss: 1.31E+00: 100%|██████████| 7/7 [00:04<00:00,  1.56it/s]
Fold 4/10, Epoch 5/100, Validation loss: 4.00E+00: 100%|███

Split 4 ended on epoch 15!


Fold 4/10, Test loss: 6.93E-01: 100%|██████████| 2/2 [00:01<00:00,  1.50it/s]
Fold 5/10, Epoch 1/100, Training loss: 1.85E+00: 100%|██████████| 7/7 [00:04<00:00,  1.72it/s]
Fold 5/10, Epoch 1/100, Validation loss: 2.32E+00: 100%|██████████| 2/2 [00:01<00:00,  1.85it/s]
Fold 5/10, Epoch 2/100, Training loss: 1.47E+00: 100%|██████████| 7/7 [00:04<00:00,  1.52it/s]
Fold 5/10, Epoch 2/100, Validation loss: 2.50E+00: 100%|██████████| 2/2 [00:00<00:00,  2.05it/s]
Fold 5/10, Epoch 3/100, Training loss: 1.45E+00: 100%|██████████| 7/7 [00:04<00:00,  1.53it/s]
Fold 5/10, Epoch 3/100, Validation loss: 3.02E+00: 100%|██████████| 2/2 [00:01<00:00,  1.61it/s]
Fold 5/10, Epoch 4/100, Training loss: 1.32E+00: 100%|██████████| 7/7 [00:04<00:00,  1.54it/s]
Fold 5/10, Epoch 4/100, Validation loss: 4.01E+00: 100%|██████████| 2/2 [00:00<00:00,  2.01it/s]
Fold 5/10, Epoch 5/100, Training loss: 1.38E+00: 100%|██████████| 7/7 [00:04<00:00,  1.65it/s]
Fold 5/10, Epoch 5/100, Validation loss: 4.17E+00: 100%|███

Split 5 ended on epoch 34!


Fold 5/10, Test loss: 6.20E-01: 100%|██████████| 2/2 [00:01<00:00,  1.35it/s]
Fold 6/10, Epoch 1/100, Training loss: 1.77E+00: 100%|██████████| 7/7 [00:04<00:00,  1.69it/s]
Fold 6/10, Epoch 1/100, Validation loss: 2.32E+00: 100%|██████████| 2/2 [00:01<00:00,  1.82it/s]
Fold 6/10, Epoch 2/100, Training loss: 1.43E+00: 100%|██████████| 7/7 [00:04<00:00,  1.52it/s]
Fold 6/10, Epoch 2/100, Validation loss: 2.58E+00: 100%|██████████| 2/2 [00:00<00:00,  2.03it/s]
Fold 6/10, Epoch 3/100, Training loss: 1.54E+00: 100%|██████████| 7/7 [00:04<00:00,  1.57it/s]
Fold 6/10, Epoch 3/100, Validation loss: 3.36E+00: 100%|██████████| 2/2 [00:01<00:00,  1.48it/s]
Fold 6/10, Epoch 4/100, Training loss: 1.34E+00: 100%|██████████| 7/7 [00:04<00:00,  1.61it/s]
Fold 6/10, Epoch 4/100, Validation loss: 3.67E+00: 100%|██████████| 2/2 [00:01<00:00,  1.97it/s]
Fold 6/10, Epoch 5/100, Training loss: 1.27E+00: 100%|██████████| 7/7 [00:04<00:00,  1.61it/s]
Fold 6/10, Epoch 5/100, Validation loss: 4.30E+00: 100%|███

Split 6 ended on epoch 31!


Fold 6/10, Test loss: 6.13E-01: 100%|██████████| 2/2 [00:01<00:00,  1.58it/s]
Fold 7/10, Epoch 1/100, Training loss: 1.91E+00: 100%|██████████| 7/7 [00:04<00:00,  1.55it/s]
Fold 7/10, Epoch 1/100, Validation loss: 2.30E+00: 100%|██████████| 2/2 [00:01<00:00,  2.00it/s]
Fold 7/10, Epoch 2/100, Training loss: 1.48E+00: 100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
Fold 7/10, Epoch 2/100, Validation loss: 2.51E+00: 100%|██████████| 2/2 [00:01<00:00,  1.95it/s]
Fold 7/10, Epoch 3/100, Training loss: 1.36E+00: 100%|██████████| 7/7 [00:04<00:00,  1.65it/s]
Fold 7/10, Epoch 3/100, Validation loss: 3.24E+00: 100%|██████████| 2/2 [00:01<00:00,  1.52it/s]
Fold 7/10, Epoch 4/100, Training loss: 1.37E+00: 100%|██████████| 7/7 [00:04<00:00,  1.70it/s]
Fold 7/10, Epoch 4/100, Validation loss: 4.30E+00: 100%|██████████| 2/2 [00:00<00:00,  2.01it/s]
Fold 7/10, Epoch 5/100, Training loss: 1.43E+00: 100%|██████████| 7/7 [00:04<00:00,  1.56it/s]
Fold 7/10, Epoch 5/100, Validation loss: 3.88E+00: 100%|███

Split 7 ended on epoch 23!


Fold 7/10, Test loss: 7.07E-01: 100%|██████████| 2/2 [00:01<00:00,  1.61it/s]
Fold 8/10, Epoch 1/100, Training loss: 1.80E+00: 100%|██████████| 7/7 [00:04<00:00,  1.57it/s]
Fold 8/10, Epoch 1/100, Validation loss: 2.32E+00: 100%|██████████| 2/2 [00:00<00:00,  2.01it/s]
Fold 8/10, Epoch 2/100, Training loss: 1.55E+00: 100%|██████████| 7/7 [00:04<00:00,  1.55it/s]
Fold 8/10, Epoch 2/100, Validation loss: 2.54E+00: 100%|██████████| 2/2 [00:01<00:00,  1.91it/s]
Fold 8/10, Epoch 3/100, Training loss: 1.37E+00: 100%|██████████| 7/7 [00:04<00:00,  1.62it/s]
Fold 8/10, Epoch 3/100, Validation loss: 2.99E+00: 100%|██████████| 2/2 [00:01<00:00,  1.66it/s]
Fold 8/10, Epoch 4/100, Training loss: 1.43E+00: 100%|██████████| 7/7 [00:04<00:00,  1.74it/s]
Fold 8/10, Epoch 4/100, Validation loss: 4.19E+00: 100%|██████████| 2/2 [00:01<00:00,  1.93it/s]
Fold 8/10, Epoch 5/100, Training loss: 1.30E+00: 100%|██████████| 7/7 [00:04<00:00,  1.61it/s]
Fold 8/10, Epoch 5/100, Validation loss: 4.37E+00: 100%|███

Split 8 ended on epoch 29!


Fold 8/10, Test loss: 5.76E-01: 100%|██████████| 2/2 [00:01<00:00,  1.57it/s]
Fold 9/10, Epoch 1/100, Training loss: 1.91E+00: 100%|██████████| 7/7 [00:04<00:00,  1.46it/s]
Fold 9/10, Epoch 1/100, Validation loss: 2.31E+00: 100%|██████████| 2/2 [00:00<00:00,  2.02it/s]
Fold 9/10, Epoch 2/100, Training loss: 1.62E+00: 100%|██████████| 7/7 [00:04<00:00,  1.61it/s]
Fold 9/10, Epoch 2/100, Validation loss: 2.45E+00: 100%|██████████| 2/2 [00:01<00:00,  1.95it/s]
Fold 9/10, Epoch 3/100, Training loss: 1.50E+00: 100%|██████████| 7/7 [00:04<00:00,  1.67it/s]
Fold 9/10, Epoch 3/100, Validation loss: 2.70E+00: 100%|██████████| 2/2 [00:00<00:00,  2.00it/s]
Fold 9/10, Epoch 4/100, Training loss: 1.33E+00: 100%|██████████| 7/7 [00:04<00:00,  1.66it/s]
Fold 9/10, Epoch 4/100, Validation loss: 3.31E+00: 100%|██████████| 2/2 [00:01<00:00,  1.65it/s]
Fold 9/10, Epoch 5/100, Training loss: 1.31E+00: 100%|██████████| 7/7 [00:04<00:00,  1.71it/s]
Fold 9/10, Epoch 5/100, Validation loss: 3.34E+00: 100%|███

Split 9 ended on epoch 38!


Fold 9/10, Test loss: 5.74E-01: 100%|██████████| 2/2 [00:01<00:00,  1.74it/s]
Fold 10/10, Epoch 1/100, Training loss: 1.82E+00: 100%|██████████| 7/7 [00:03<00:00,  1.85it/s]
Fold 10/10, Epoch 1/100, Validation loss: 2.34E+00: 100%|██████████| 2/2 [00:00<00:00,  2.11it/s]
Fold 10/10, Epoch 2/100, Training loss: 1.65E+00: 100%|██████████| 7/7 [00:03<00:00,  1.82it/s]
Fold 10/10, Epoch 2/100, Validation loss: 2.60E+00: 100%|██████████| 2/2 [00:00<00:00,  2.10it/s]
Fold 10/10, Epoch 3/100, Training loss: 1.37E+00: 100%|██████████| 7/7 [00:03<00:00,  1.90it/s]
Fold 10/10, Epoch 3/100, Validation loss: 3.70E+00: 100%|██████████| 2/2 [00:00<00:00,  2.12it/s]
Fold 10/10, Epoch 4/100, Training loss: 1.22E+00: 100%|██████████| 7/7 [00:03<00:00,  1.85it/s]
Fold 10/10, Epoch 4/100, Validation loss: 3.70E+00: 100%|██████████| 2/2 [00:00<00:00,  2.15it/s]
Fold 10/10, Epoch 5/100, Training loss: 1.06E+00: 100%|██████████| 7/7 [00:03<00:00,  1.80it/s]
Fold 10/10, Epoch 5/100, Validation loss: 4.26E+00

Split 10 ended on epoch 30!


Fold 10/10, Test loss: 5.81E-01: 100%|██████████| 2/2 [00:01<00:00,  1.52it/s]
