In [1]:
import numpy as np
import torch.utils.data
from torch.optim import Adam

from models import ConvNet
from utils import H5Dataset, robust_ce_loss

In [2]:
classes = (0, 1, 2)
ds = H5Dataset(path='data/kws10_16x32.hdf5', classes=classes)
num_samples = len(ds)
train_size = 0.8
idxs = np.arange(num_samples)
np.random.shuffle(idxs)
train_size = int(num_samples*train_size)
train_idx = idxs[:train_size]
test_idx = idxs[train_size:]

tr_sampler = torch.utils.data.SubsetRandomSampler(indices=train_idx)
ts_sampler = torch.utils.data.SubsetRandomSampler(indices=test_idx)
train_loader = torch.utils.data.DataLoader(ds, batch_size=20, sampler=tr_sampler)
test_loader = torch.utils.data.DataLoader(ds, batch_size=20, sampler=ts_sampler)

In [3]:
epochs = 50
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
lambda_consts = [0, 4]
for lambda_const in lambda_consts:
    print("Training model for lambda: {}".format(lambda_const))
    model = ConvNet(out_dim=len(classes))
    model = model.to(device)
    optimizer = Adam(model.parameters(), lr=0.001)

    for epoch in range(epochs):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            data.requires_grad = True
            optimizer.zero_grad()
            output = model(data)
            loss = robust_ce_loss(output, target, inputs=data, eps=lambda_const)
            loss.backward()
            optimizer.step()

        if epoch % 5 == 0:
            model.eval()
            # compute accuracy
            acc = 0.0
            num_samples = 0
            for batch_idx, (data, target) in enumerate(test_loader):
                data, target = data.to(device), target.to(device)
                output = model(data)
                preds = output.argmax(dim=1)
                acc += (preds.type(target.dtype) == target).float().sum()
                num_samples += data.shape[0]
            acc = acc/num_samples

            print("Epoch: {}/{} Accuracy on test set: {:.2%}".format(epoch+1, epochs, acc))
            torch.save(model.state_dict(), f'pretrained/lambda_{str(lambda_const).replace(".", "-")}.pt')

    torch.save(model.state_dict(), f'pretrained/lambda_{str(lambda_const).replace(".", "-")}.pt')
    del model, optimizer



Training model for lambda: 0
Epoch: 1/50 Accuracy on test set: 91.55%
Epoch: 6/50 Accuracy on test set: 97.52%
Epoch: 11/50 Accuracy on test set: 97.91%
Epoch: 16/50 Accuracy on test set: 96.75%
Epoch: 21/50 Accuracy on test set: 97.52%
Epoch: 26/50 Accuracy on test set: 98.12%
Epoch: 31/50 Accuracy on test set: 97.95%
Epoch: 36/50 Accuracy on test set: 98.21%
Epoch: 41/50 Accuracy on test set: 98.04%
Epoch: 46/50 Accuracy on test set: 98.12%
Training model for lambda: 4
Epoch: 1/50 Accuracy on test set: 63.28%
Epoch: 6/50 Accuracy on test set: 85.61%
Epoch: 11/50 Accuracy on test set: 69.09%
Epoch: 16/50 Accuracy on test set: 93.25%
Epoch: 21/50 Accuracy on test set: 90.14%
Epoch: 26/50 Accuracy on test set: 93.30%
Epoch: 31/50 Accuracy on test set: 95.39%
Epoch: 36/50 Accuracy on test set: 92.83%
Epoch: 41/50 Accuracy on test set: 95.77%
Epoch: 46/50 Accuracy on test set: 90.78%
