In [None]:
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch
import numpy as np
import torch_utils as tu
import utils as ut

In [None]:
batch_size = 512
timesteps_per_example_in_100Hz = 200
num_examples = batch_size * 20

sensor1 = ut.Sensor(40, 0.0, 0.2)
sensor2 = ut.Sensor(80, 0.1, 0.1)
sensor3 = ut.Sensor(80, 0.0, 0.3)
sensor4 = ut.Sensor(20, 0.0, 0.1)
sensors = [sensor1, sensor2, sensor3, sensor4]

In [None]:
transform_options = [
    transforms.Compose([tu.AblateBlock(5,30), tu.ToTensor()]),
    transforms.Compose([tu.AddNoise((-0.1, 0.1), (0.0, 0.3)), tu.ToTensor()]),
    transforms.Compose([tu.RandomDownsample(), tu.ToTensor()])
]
trsfm = transforms.RandomChoice(transform_options)

dataset = tu.BadSensorsDataset(sensors,200, 20, jiggle_offsets=20, transform=trsfm, return_two_transforms=True)

In [None]:
dataset = tu.BadSensorsDataset(sensors,timesteps_per_example_in_100Hz,
                               num_examples, jiggle_offsets=None)

# if everything worked right we now have 4 * 20 samples a 200 timepoints in 100Hz resolution
assert len(dataset) == num_examples
assert len(dataset[0]) == timesteps_per_example_in_100Hz

In [None]:
valid_size = 0.2
test_size = 0.1
num_workers = 2

num_train = len(dataset)
indices = list(range(num_train))
train_idx, valid_idx, test_idx = ut.random_splits(indices, test_size, valid_size)

train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
test_sampeler = SubsetRandomSampler(test_idx)

# prepare data loaders (combine dataset and sampler)
train_loader = DataLoader(dataset, batch_size=batch_size,
    sampler=train_sampler, num_workers=num_workers)
valid_loader = DataLoader(dataset, batch_size=batch_size,
    sampler=valid_sampler, num_workers=num_workers)
test_loader = DataLoader(dataset, batch_size=batch_size,
    sampler=test_sampeler, num_workers=num_workers)

In [None]:
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...')

In [None]:
def train(n_epochs, model, projection, optimizer, train_loader, valid_loader):
    valid_loss_min = np.Inf # track change in validation loss

    for epoch in range(1, n_epochs+1):
        # keep track of training and validation loss
        train_loss = 0.0
        valid_loss = 0.0

        ###################
        # train the model #
        ###################
        model.train()
        for i_batch, (sk1, sk2) in enumerate(train_loader):
            # move tensors to GPU if CUDA is available
            if train_on_gpu:
                sk1, sk2 = sk1.cuda(), sk2.cuda()
            # clear the gradients of all optimized variables
            optimizer.zero_grad()
            # forward pass: compute predicted outputs by passing inputs to the model
            output, _ = model(data)
            # calculate the batch loss
            loss = criterion(output, target)
            # backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
            # perform a single optimization step (parameter update)
            optimizer.step()
            # update training loss
            train_loss += loss.item()*data.size(0)

        ######################
        # validate the model #
        ######################
        model.eval()
        for data, target in valid_loader:
            # move tensors to GPU if CUDA is available
            if train_on_gpu:
                data, target = data.cuda(), target.cuda()
            # forward pass: compute predicted outputs by passing inputs to the model
            output, _ = model(data)
            # calculate the batch loss
            loss = criterion(output, target)
            # update average validation loss
            valid_loss += loss.item()*data.size(0)

        # calculate average losses
        train_loss = train_loss/len(train_loader.sampler)
        valid_loss = valid_loss/len(valid_loader.sampler)

        # print training/validation statistics
        print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
            epoch, train_loss, valid_loss))

        # save model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
            valid_loss_min,
            valid_loss))
            torch.save(model.state_dict(), 'model_mnist.pt')
            valid_loss_min = valid_loss