In [1]:
from fbm_dropout.net import DenseNet, DenseNetFBM
import torch.optim as optim
import torch.nn as nn

hidden_sizes = [256, 64]
grid_sizes = [(16,16), (8,8)]
n_agents = [16, 8]
device='cuda'

model = DenseNet(hidden_sizes, 0, device=device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

model_dropout = DenseNet(hidden_sizes, .2, device=device)
optimizer_dropout = optim.Adam(model_dropout.parameters(), lr=0.001)

model_fbm_dropout = DenseNetFBM(hidden_sizes, n_agents, grid_sizes, show=False, device=device)
optimizer_fbm_dropout = optim.Adam(model_fbm_dropout.parameters(), lr=0.001)

criterion = nn.CrossEntropyLoss()

In [2]:
import torch
from fbm_dropout.dataset import get_MNIST_dataset

trainset, testset = get_MNIST_dataset()
batch_size = 64
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=True)

In [3]:
results = {}
results['train loss'] = {}
results['train loss']['no dropout'] = []
results['train loss']['dropout'] = []
results['train loss']['fbm dropout'] = []
results['test acc'] = {}
results['test acc']['no dropout'] = []
results['test acc']['dropout'] = []
results['test acc']['fbm dropout'] = []
for epoch in range(10):
    model.train()
    model_dropout.train()
    model_fbm_dropout.train()

    running_loss = {}
    running_loss['no dropout'] = running_loss['dropout'] = running_loss['fbm dropout'] = 0.0
    for images, labels in train_loader:

        images = images.view(images.shape[0], -1).to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        optimizer_dropout.zero_grad()
        optimizer_fbm_dropout.zero_grad()

        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        running_loss['no dropout'] += loss.item()

        output_dropout = model_dropout(images)
        loss_dropout = criterion(output_dropout, labels)
        loss_dropout.backward()
        optimizer_dropout.step()
        running_loss['dropout'] += loss_dropout.item()

        output_fbm_dropout = model_fbm_dropout(images)
        loss_fbm_dropout = criterion(output_fbm_dropout, labels)
        loss_fbm_dropout.backward()
        optimizer_fbm_dropout.step()
        running_loss['fbm dropout'] += loss_fbm_dropout.item()

    results['train loss']['no dropout'].append(running_loss['no dropout'] / len(train_loader))
    results['train loss']['dropout'].append(running_loss['dropout'] / len(train_loader))
    results['train loss']['fbm dropout'].append(running_loss['fbm dropout'] / len(train_loader))

    model.eval()
    model_dropout.eval()
    model_fbm_dropout.eval()
    acc = acc_dropout = acc_fbm_dropout = 0.0
    for images, labels in test_loader:
        images = images.view(-1,784).to(device)
        labels = labels.to(device)
        output = model(images)
        _, pred = torch.max(output, axis=1)
        acc += torch.sum(pred==labels).item()

        output_dropout = model_dropout(images)
        _, pred_dropout = torch.max(output_dropout, axis=1)
        acc_dropout += torch.sum(pred_dropout==labels).item()

        output_fbm_dropout = model_fbm_dropout(images)
        _, pred_fbm_dropout = torch.max(output_fbm_dropout, axis=1)
        acc_fbm_dropout += torch.sum(pred_fbm_dropout==labels).item()
    
    results['test acc']['no dropout'].append(acc / len(testset))
    results['test acc']['dropout'].append(acc_dropout / len(testset))
    results['test acc']['fbm dropout'].append(acc_fbm_dropout / len(testset))

    
    print('Epoch {}'.format(epoch))
    print('No Dropout : Train Loss {:.2f}, Test Acc {:.2f}'.format(results['train loss']['no dropout'][-1], results['test acc']['no dropout'][-1]))
    print('Dropout    : Train Loss {:.2f}, Test Acc {:.2f}'.format(results['train loss']['dropout'][-1], results['test acc']['dropout'][-1]))
    print('FBM Dropout: Train Loss {:.4f}, Test Acc {:.4f}, Drop Rate {:.4f} {:.4f}'.format(results['train loss']['fbm dropout'][-1], 
                                                                                        results['test acc']['fbm dropout'][-1], 
                                                                                        model_fbm_dropout.dropout_1.curr_dropout_rate, 
                                                                                        model_fbm_dropout.dropout_2.curr_dropout_rate))

Epoch 0
No Dropout : Train Loss 0.37, Test Acc 0.94
Dropout    : Train Loss 0.44, Test Acc 0.94
FBM Dropout: Train Loss 0.46, Test Acc 0.93
