
### Imports

In [1]:
import json
import pandas as pd
import time

from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelBinarizer
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader

from datasets import SurfaceDataset
from helpers import EarlyStopper
from models import CNNSurfaceClassifier

### Device

In [2]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device

'cuda:0'

### Constants

In [3]:
BATCH_SIZE = 32
INPUT_SIZE = 6
NUM_EPOCHS = 100
DATA_DIR = Path('../data/train_set/csv/')
LOOKBACK = 2.
SAMPLING_FREQUENCY = 100.
DATASET_FREQUENCY = 200.
SUBSET = ('imu',)
CONFIGURATIONS = ('4W',)

### Load and split data

In [4]:
with open('../data/labels.json') as fp:
    labels = json.load(fp)

In [5]:
dataset = [(DATA_DIR.joinpath(key + '.csv'), values['surface']) for key, values in labels.items() if values['kinematics'] in CONFIGURATIONS and values['spacing'] == 'R1' and 'T1' in values['trajectory']]

In [6]:
X = pd.Series([run[0] for run in dataset], name='bag_name')
y_primary = [run[1] for run in dataset]

In [7]:
# y_secondary = ['slippery' if label in ('1_Panele', '5_Spienione_PCV', '6_Linoleum')
#                else 'grippy' if label in ('3_Wykladzina_jasna', '8_Pusta_plyta', '9_podklady')
#                else 'neutral' for label in y_primary]
y_secondary = ['slippery' if label in ('3_Wykladzina_jasna', '4_Trawa')
               else 'grippy' if label in ('5_Spienione_PCV', '8_Pusta_plyta', '9_podklady', '10_Mata_ukladana')
               else 'neutral' for label in y_primary] # Pawel set
# y_secondary = ['slippery' if label in ('3_Wykladzina_jasna', '4_Trawa')
#                else 'grippy' if label in ('2_Wykladzina_czarna', '5_Spienione_PCV', '9_podklady', '10_Mata_ukladana')
#                else 'neutral' for label in y_primary] # Clustering set

In [8]:
lb = LabelBinarizer()
lb.fit(y_primary)
classes = lb.classes_
num_classes = len(classes)
y = lb.transform(y_primary)
y = y.reshape(-1, num_classes)

In [9]:
print(classes)

['10_Mata_ukladana' '1_Panele' '2_Wykladzina_czarna' '3_Wykladzina_jasna'
 '4_Trawa' '5_Spienione_PCV' '6_Linoleum' '7_Plytki_w_sali'
 '8_Pusta_plyta' '9_podklady']


In [10]:
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, stratify=y)
X_train.reset_index(drop=True, inplace=True)
X_val.reset_index(drop=True, inplace=True)

### Custom datasets

In [11]:
train_dataloader = DataLoader(SurfaceDataset(X_train, y_train, sample_freq=SAMPLING_FREQUENCY, data_freq=DATASET_FREQUENCY, lookback=LOOKBACK, subset=SUBSET), batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(SurfaceDataset(X_val, y_val, sample_freq=SAMPLING_FREQUENCY, data_freq=DATASET_FREQUENCY, lookback=LOOKBACK, subset=SUBSET), batch_size=BATCH_SIZE)

### Models

In [12]:
cnn_model = CNNSurfaceClassifier(input_size=INPUT_SIZE, output_size=num_classes).to(device)

### Optimizer

In [13]:
optimizer = torch.optim.Adam(
    cnn_model.parameters(),
    lr=1e-3,
    eps=1e-6,
    weight_decay=1e-3,
    )

### Scheduler

In [14]:
scheduler = ExponentialLR(optimizer, gamma=0.9)

### Early stopping

In [15]:
early_stopper = EarlyStopper()

### Loss function

In [16]:
criterion = nn.CrossEntropyLoss()

### Training loop

In [17]:
train_batches = len(train_dataloader)
val_batches = len(val_dataloader)

for epoch in range(NUM_EPOCHS):
    running_train_loss = 0.0
    running_val_loss = 0.0
    
    pbar = tqdm(train_dataloader, total=train_batches)
    cnn_model.train()
    for idx, (batch_x, batch_y) in enumerate(pbar):
        optimizer.zero_grad()
        
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        batch_x = batch_x.permute(0, 2, 1)
        train_outputs = cnn_model(batch_x)
        train_loss = criterion(train_outputs, batch_y)
        running_train_loss += train_loss

        # Backward pass
        train_loss.backward()
        optimizer.step()

        pbar.set_description(f"Epoch {epoch + 1}/{NUM_EPOCHS}, Training loss: {running_train_loss / (idx + 1):.2E}")
    scheduler.step()

    pbar_val = tqdm(val_dataloader, total=val_batches)
    cnn_model.eval()
    with torch.no_grad():
        for idx, (batch_x_val, batch_y_val) in enumerate(pbar_val):
            batch_x_val, batch_y_val = batch_x_val.to(device), batch_y_val.to(device)
            batch_x_val = batch_x_val.permute(0, 2, 1)
            val_outputs = cnn_model(batch_x_val)
            val_loss = criterion(val_outputs, batch_y_val)
            running_val_loss += val_loss
            
            pbar_val.set_description(f"Epoch {epoch + 1}/{NUM_EPOCHS}, Validation loss: {running_val_loss / (idx + 1):.2E}")
            
    validation_loss = running_val_loss / (idx + 1)
    if early_stopper.early_stop(validation_loss):
        print(f"Split ended on epoch {epoch + 1 - early_stopper.patience}!")
        break
    if early_stopper.counter == 0:
        print(True)
        best_model = cnn_model.state_dict()

config =  '_'.join(SUBSET + CONFIGURATIONS) + '_' + time.strftime("%Y-%m-%d-%H-%M-%S")
torch.save(best_model, f"../data/checkpoints/cnn_classifier_{config}.pt")

Epoch 1/100, Training loss: 1.93E+00: 100%|██████████| 9/9 [00:08<00:00,  1.09it/s]
Epoch 1/100, Validation loss: 2.25E+00: 100%|██████████| 3/3 [00:01<00:00,  1.97it/s]


True


Epoch 2/100, Training loss: 1.52E+00: 100%|██████████| 9/9 [00:05<00:00,  1.63it/s]
Epoch 2/100, Validation loss: 2.03E+00: 100%|██████████| 3/3 [00:01<00:00,  2.23it/s]


True


Epoch 3/100, Training loss: 1.34E+00: 100%|██████████| 9/9 [00:05<00:00,  1.62it/s]
Epoch 3/100, Validation loss: 1.66E+00: 100%|██████████| 3/3 [00:01<00:00,  2.22it/s]


True


Epoch 4/100, Training loss: 1.20E+00: 100%|██████████| 9/9 [00:05<00:00,  1.58it/s]
Epoch 4/100, Validation loss: 1.46E+00: 100%|██████████| 3/3 [00:01<00:00,  2.21it/s]


True


Epoch 5/100, Training loss: 1.18E+00: 100%|██████████| 9/9 [00:05<00:00,  1.62it/s]
Epoch 5/100, Validation loss: 1.26E+00: 100%|██████████| 3/3 [00:01<00:00,  2.22it/s]


True


Epoch 6/100, Training loss: 1.09E+00: 100%|██████████| 9/9 [00:05<00:00,  1.66it/s]
Epoch 6/100, Validation loss: 9.69E-01: 100%|██████████| 3/3 [00:01<00:00,  2.19it/s]


True


Epoch 7/100, Training loss: 1.01E+00: 100%|██████████| 9/9 [00:05<00:00,  1.63it/s]
Epoch 7/100, Validation loss: 8.50E-01: 100%|██████████| 3/3 [00:01<00:00,  2.22it/s]


True


Epoch 8/100, Training loss: 8.58E-01: 100%|██████████| 9/9 [00:05<00:00,  1.63it/s]
Epoch 8/100, Validation loss: 7.91E-01: 100%|██████████| 3/3 [00:01<00:00,  2.22it/s]


True


Epoch 9/100, Training loss: 8.08E-01: 100%|██████████| 9/9 [00:05<00:00,  1.62it/s]
Epoch 9/100, Validation loss: 1.06E+00: 100%|██████████| 3/3 [00:01<00:00,  2.24it/s]
Epoch 10/100, Training loss: 8.65E-01: 100%|██████████| 9/9 [00:05<00:00,  1.63it/s]
Epoch 10/100, Validation loss: 5.98E-01: 100%|██████████| 3/3 [00:01<00:00,  2.21it/s]


True


Epoch 11/100, Training loss: 8.04E-01: 100%|██████████| 9/9 [00:05<00:00,  1.67it/s]
Epoch 11/100, Validation loss: 1.08E+00: 100%|██████████| 3/3 [00:01<00:00,  2.24it/s]
Epoch 12/100, Training loss: 7.83E-01: 100%|██████████| 9/9 [00:05<00:00,  1.67it/s]
Epoch 12/100, Validation loss: 6.17E-01: 100%|██████████| 3/3 [00:01<00:00,  2.25it/s]
Epoch 13/100, Training loss: 7.69E-01: 100%|██████████| 9/9 [00:05<00:00,  1.62it/s]
Epoch 13/100, Validation loss: 6.28E-01: 100%|██████████| 3/3 [00:01<00:00,  2.25it/s]
Epoch 14/100, Training loss: 7.28E-01: 100%|██████████| 9/9 [00:05<00:00,  1.66it/s]
Epoch 14/100, Validation loss: 4.84E-01: 100%|██████████| 3/3 [00:01<00:00,  2.22it/s]


True


Epoch 15/100, Training loss: 6.24E-01: 100%|██████████| 9/9 [00:05<00:00,  1.67it/s]
Epoch 15/100, Validation loss: 1.24E+00: 100%|██████████| 3/3 [00:01<00:00,  2.28it/s]
Epoch 16/100, Training loss: 7.02E-01: 100%|██████████| 9/9 [00:05<00:00,  1.67it/s]
Epoch 16/100, Validation loss: 7.07E-01: 100%|██████████| 3/3 [00:01<00:00,  2.23it/s]
Epoch 17/100, Training loss: 7.27E-01: 100%|██████████| 9/9 [00:05<00:00,  1.67it/s]
Epoch 17/100, Validation loss: 8.28E-01: 100%|██████████| 3/3 [00:01<00:00,  2.21it/s]
Epoch 18/100, Training loss: 6.11E-01: 100%|██████████| 9/9 [00:05<00:00,  1.68it/s]
Epoch 18/100, Validation loss: 6.99E-01: 100%|██████████| 3/3 [00:01<00:00,  2.20it/s]
Epoch 19/100, Training loss: 5.94E-01: 100%|██████████| 9/9 [00:05<00:00,  1.66it/s]
Epoch 19/100, Validation loss: 7.35E-01: 100%|██████████| 3/3 [00:01<00:00,  2.26it/s]
Epoch 20/100, Training loss: 6.76E-01: 100%|██████████| 9/9 [00:05<00:00,  1.68it/s]
Epoch 20/100, Validation loss: 5.98E-01: 100%|█████████

True


Epoch 25/100, Training loss: 5.85E-01: 100%|██████████| 9/9 [00:06<00:00,  1.50it/s]
Epoch 25/100, Validation loss: 6.55E-01: 100%|██████████| 3/3 [00:01<00:00,  1.82it/s]
Epoch 26/100, Training loss: 6.32E-01: 100%|██████████| 9/9 [00:07<00:00,  1.28it/s]
Epoch 26/100, Validation loss: 5.00E-01: 100%|██████████| 3/3 [00:01<00:00,  1.72it/s]
Epoch 27/100, Training loss: 5.52E-01: 100%|██████████| 9/9 [00:05<00:00,  1.58it/s]
Epoch 27/100, Validation loss: 1.16E+00: 100%|██████████| 3/3 [00:01<00:00,  2.22it/s]
Epoch 28/100, Training loss: 4.96E-01: 100%|██████████| 9/9 [00:05<00:00,  1.67it/s]
Epoch 28/100, Validation loss: 4.91E-01: 100%|██████████| 3/3 [00:01<00:00,  2.23it/s]
Epoch 29/100, Training loss: 5.55E-01: 100%|██████████| 9/9 [00:05<00:00,  1.66it/s]
Epoch 29/100, Validation loss: 5.15E-01: 100%|██████████| 3/3 [00:01<00:00,  2.25it/s]
Epoch 30/100, Training loss: 5.20E-01: 100%|██████████| 9/9 [00:05<00:00,  1.63it/s]
Epoch 30/100, Validation loss: 4.43E-01: 100%|█████████

True


Epoch 34/100, Training loss: 5.76E-01: 100%|██████████| 9/9 [00:05<00:00,  1.66it/s]
Epoch 34/100, Validation loss: 4.69E-01: 100%|██████████| 3/3 [00:01<00:00,  2.24it/s]
Epoch 35/100, Training loss: 5.24E-01: 100%|██████████| 9/9 [00:05<00:00,  1.64it/s]
Epoch 35/100, Validation loss: 7.53E-01: 100%|██████████| 3/3 [00:01<00:00,  2.26it/s]
Epoch 36/100, Training loss: 4.90E-01: 100%|██████████| 9/9 [00:05<00:00,  1.66it/s]
Epoch 36/100, Validation loss: 4.96E-01: 100%|██████████| 3/3 [00:01<00:00,  2.20it/s]
Epoch 37/100, Training loss: 5.84E-01: 100%|██████████| 9/9 [00:05<00:00,  1.68it/s]
Epoch 37/100, Validation loss: 6.53E-01: 100%|██████████| 3/3 [00:01<00:00,  2.22it/s]
Epoch 38/100, Training loss: 5.21E-01: 100%|██████████| 9/9 [00:05<00:00,  1.67it/s]
Epoch 38/100, Validation loss: 1.16E+00: 100%|██████████| 3/3 [00:01<00:00,  2.25it/s]
Epoch 39/100, Training loss: 5.50E-01: 100%|██████████| 9/9 [00:05<00:00,  1.68it/s]
Epoch 39/100, Validation loss: 5.50E-01: 100%|█████████

Split ended on epoch 33!



