In [None]:
from data import series_data
from importlib import import_module
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler
from util import train, plot_confusion_matrix, plot_loss_accuracy, plot_dataloader_distribution

In [None]:
dataset_name = 'cinc2017'
model_name = 'SimpleLSTMNetwork'
rearrange_tensor = True
units = 50
layers = 1
batch = 150
epochs = 10
device = 'cpu'

In [None]:
dataset = import_module('datasets.' + dataset_name)
model = getattr(import_module('model'), model_name)
ts, labels = dataset.load_data()
if rearrange_tensor:
    ts = torch.transpose(ts, 1, 2)
n_labels = int(labels.max()) + 1
label_names = dataset.get_label_names()
data = series_data.Series(ts, labels)
train_size = int(0.8 * len(data))
valid_size = len(data) - train_size
train_data, valid_data = random_split(data, [train_size, valid_size])
train_loader = DataLoader(train_data, batch_size=batch, shuffle=True)
validation_loader = DataLoader(valid_data, batch_size=batch, shuffle=False)
net = model(num_classes=n_labels, input_len=ts.size(2), 
                            hidden_units=units, hidden_layers=layers).to(device)
print(f'The number of samples for training is {train_size}.')
class_samples = [int(sum(labels==l))/len(labels) for l in range(n_labels)]
print(f'Samples per classes: {class_samples}')
class_samples_train = [sum([int(t[1])==l for t in list(train_data)])/len(list(train_data)) for l in range(n_labels)]
print(f'Samples per classes (training): {class_samples_train}')
class_samples_valid = [sum([int(v[1])==l for v in list(valid_data)])/len(list(valid_data)) for l in range(n_labels)]
print(f'Samples per classes (validation): {class_samples_valid}')
print(f'The number of parameters is {sum(p.numel() for p in net.parameters())}.')

In [None]:
plot_dataloader_distribution(train_loader, label_names)

In [None]:
class_weights = [1/sum([int(t[1])==l for t in list(train_data)]) for l in range(n_labels)]
sample_weights = [class_weights[int(t[1])] for t in list(train_data)]
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(train_data), replacement=True)
rebalanced_loader = DataLoader(train_data, sampler=sampler, batch_size=batch)

In [None]:
plot_dataloader_distribution(rebalanced_loader, label_names)

In [None]:
train_loss, train_acc, validation_loss, validation_acc, predictions, targets = train(
    net, device, rebalanced_loader, validation_loader, epochs, lr=0.01)

In [None]:
plot_loss_accuracy(train_loss, train_acc, validation_loss, validation_acc)

In [None]:
plot_confusion_matrix(targets, predictions, label_names)