<h1>Two Spiral Classification Task <img src="https://raw.githubusercontent.com/jkoutsikakis/datasets/master/two_spiral_dataset/two_spiral_dataset.png" width="40px" height="40px" style='display:inline'/></h1>



In this example we will demonstrate how to train and evaluate a model on the TwoSpiral dataset using PyTorchWrapper.

#### Additional libraries

First of all we need to install the `requests` library in order to download the data.

In [None]:
! pip install requests


#### Import Statements

In [None]:
import torch
import numpy as np
import requests
import os

from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data.dataset import Dataset
from torch import nn

import pytorch_wrapper as pw


#### Dataset Definition

Next we create a class that derives from `torch.utils.data.Dataset`. PyTorchWrapper expects that each batch returned by a `torch.utils.data.DataLoader` is represented as a dictionary. This was done in order to be flexible about what information is contained inside a single batch. Since in this case we won't use a custom collate function it is enough to make sure that the `Dataset` object represents a single example as a dictionary. The `DataLoader` will automatically convert a batch of examples (dictionaries) into a single dictionary of examples as follows:

[{'input': x1, 'target':y1}, {'input': x2, 'target':y2}] -> DataLoader -> ['input': tensor([x1, x2]), 'target': tensor([y1, y2])]

The data will be converted automatically to tensors taking into consideration the type of the original data (`numpy.float32` will become `tensor.float32` is this case).

In [None]:
class TwoSpiralDataset(Dataset):
    def __init__(self):
        super(TwoSpiralDataset, self).__init__()
        
        raw_tsv_request = requests.get(
            'https://raw.githubusercontent.com/jkoutsikakis/datasets/master/two_spiral_dataset/two_spiral_dataset.tsv'
        )

        self.pos = []
        self.target = []

        for line in raw_tsv_request.text.split('\n')[1:-1]:
            pos_x, pos_y, cur_target = line.split('\t')
            self.pos.append([float(pos_x), float(pos_y)])
            self.target.append(float(cur_target))

        self.pos = np.array(self.pos, dtype='float32')
        self.target = np.array(self.target, dtype='float32')

    def __getitem__(self, item_index):
        return {
            'input': self.pos[item_index],
            'target': self.target[item_index]
        }

    def __len__(self):
        return self.target.shape[0]


#### Model Definition
Next we define our model. We do so by extending the `toch.nn.Module` class.
In this case we will be using a simple MLP with 3 hidden layers of size 128,
the ReLU activation function and batch normalization. `pytorch_wrapper.modules.MLP` is one of several ready to use modules provided by PyTorchWrapper.

PyTorchWrapper can also handle multi-input models. In such case the dictionary returned by the `torch.utils.data.Dataset`'s `__getitem__`
method must contain a list of values at key `'input'` that correspond (one to one) to the arguments of the model's `forward` method.

In [None]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.mlp = pw.modules.MLP(
            input_size=2,
            num_hidden_layers=3,
            hidden_layer_size=128,
            hidden_activation=nn.ReLU,
            hidden_dp=0,
            hidden_layer_post_activation_bn=True,
            output_size=1,
            output_activation=None
        )

    def forward(self, x):
        return self.mlp(x).squeeze()


#### Training

Next we create the dataset object along with three data loaders (for training, validation,  and testing). The dataset
contains 1000 examples of which 800 will be used for training while the rest subsets will contain 100 examples each.


In [None]:
dataset = TwoSpiralDataset()

train_data_loader = DataLoader(
    dataset,
    sampler=SubsetRandomSampler(list(range(0, 800))),
    batch_size=32
)

val_data_loader = DataLoader(
    dataset,
    sampler=pw.samplers.SubsetSequentialSampler(list(range(800, 900))),
    batch_size=32
)

test_data_loader = DataLoader(
    dataset,
    sampler=pw.samplers.SubsetSequentialSampler(list(range(900, 1000))),
    batch_size=32
)


Then we create the model and we wrap it with a `pytorch_wrapper.System` object. The `System` object provides methods
to train and evaluate the model it contains. 

In [None]:
model = Model()

# last_activation must point to the torch function that needs to be called at non training time.
# Some losses (as in this case) work with logits and as such the last activation might not be
# performed inside the model's forward method. If the last activation is performed inside the
# model then use None.
last_activation = torch.nn.Sigmoid()

if torch.cuda.is_available():
    system = pw.System(model, last_activation=last_activation, device=torch.device('cuda'))
else:
    system = pw.System(model, last_activation=last_activation, device=torch.device('cpu'))


Next we train the model.

In [None]:
# The GenericPointWiseLossWrapper object wraps a native pointwise loss. The batch_target_key
# is the key of the dictionary (batch) returned by the DataLoader where it contains the target values.
# We specified this key when we defined the dictionary returned by the Dataset's ```__getitem__``` method. For a custom loss
# you can implement a class that derives from AbstractLossWrapper. 
loss_wrapper = pw.loss_wrappers.GenericPointWiseLossWrapper(nn.BCEWithLogitsLoss(),
                                                            batch_target_key='target')

# Create the optimizer.
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()))

# Dictionary containing the dataloaders used for evaluation after each epoch.
evaluation_data_loaders = {'train': train_data_loader, 'val': val_data_loader}

# Dictionary containing the evaluators.
evaluators = {'acc': pw.evaluators.AccuracyEvaluator(batch_target_key='target')}

# Callback that stops the training process if accuracy does not improve for 20 epochs in the validation set.
os.makedirs('tmp', exist_ok=True)
es_callback = pw.training_callbacks.EarlyStoppingCriterionCallback(
    patience=20,
    evaluation_data_loader_key='val',
    evaluator_key='acc',
    tmp_best_state_filepath='tmp/ts_tmp_best.weights'
)

# The batch_input_key is the key of the dictionary (batch) returned by the dataloader where it contains the
# input of the model. We specified this key when we defined the dictionary returned by the Dataset's ```__getitem__``` method.
batch_input_key = 'input'

_ = system.train(
    loss_wrapper=loss_wrapper,
    optimizer=optimizer,
    train_data_loader=train_data_loader,
    evaluators=evaluators,
    evaluation_data_loaders=evaluation_data_loaders,
    batch_input_key=batch_input_key,
    callbacks=[es_callback]
)


We can use the `evaluate` method in order to evaluate the model.

In [None]:
test_results = system.evaluate(test_data_loader, evaluators)
print(test_results['acc'])


We can use the `predict` method in order to predict for all the examples returned by a data loder.

In [None]:
predictions = system.predict(test_data_loader, perform_last_activation=True)


In [None]:
ex_pred_pos = 0
ex_ds_pos = 900 + ex_pred_pos  # remember we used SubsetSequentialSampler
print(f'Prediction for ex {ex_pred_pos}: {predictions["outputs"][ex_pred_pos]}')
print(f'Label of ex {ex_pred_pos}: {dataset[ex_ds_pos]["target"]}')


We can use the `predict_batch` method in order to predict for a single batch.

In [None]:
system.last_activation(system.predict_batch(torch.tensor([[5., 1.]]))).item()


#### Saving & Loading

We can save and load the model's weights directly.

In [None]:
system.save_model_state('data/two_spiral_final.weights')
_ = system.load_model_state('data/two_spiral_final.weights')


But we can also save and load the whole system at once. 

In [None]:
system.save('data/two_spiral_final.system')
system = pw.System.load('data/two_spiral_final.system')
