# Example of MonteCarlo BatchNormalization model in Cifar10

In [2]:
import torch
import mcbn
import utils

In [None]:
N_ENS = 10 #number of samples
BATCH_SIZE_TRAIN = 50
BATCH_SIZE_TEST = 100

In [4]:
# Define GPU
print(torch.cuda.is_available())
dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

False


In [None]:
def predicted_class(y_pred):
    '''Returns the predicted class for a given softmax output.'''
    if y_pred.shape[-1] == 1:
        y_pred = y_pred.view(-1)
        y_pred = (y_pred>0.5).float()
        
    else:
        y_pred = torch.max(y_pred, 1)[1]
    return y_pred

def correct_class(y_pred,y_true):
    '''Returns a bool tensor indicating if each prediction is correct'''

    y_pred = predicted_class(y_pred)
    correct = (y_pred==y_true)
    
    return correct

def correct_total(y_pred,y_true):
    '''Returns the number of correct predictions in a batch where dk_mask=0'''
    correct = correct_class(y_pred,y_true)
    correct_total = torch.sum(correct).item()
    return correct_total

def model_acc(model,data):
    '''Returns the total accuracy of model in some dataset'''
    with torch.no_grad():
        dev = next(model.parameters()).device
        total = 0
        correct= 0
        for image,label in data:
            image,label = image.to(dev), label.to(dev)
            output = model(image)
            total += label.size(0)
            correct += correct_total(output,label)
    return (correct*100/total)

## Data

In [None]:
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=BATCH_SIZE_TRAIN, shuffle=False,
        num_workers=2, pin_memory=True)
train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]), download=True),
        batch_size=BATCH_SIZE_TEST, shuffle=True,
        num_workers=2, pin_memory=True)

## Pseudo model

Change it to a trained model

In [None]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'vgg16_bn', pretrained=True)
input_lastLayer = model.classifier[6].in_features
model.classifier[6] = torch.nn.Linear(input_lastLayer,10)
model.eval()

## MCBN model

In [None]:
model_mcbn = mcbn.MCBN_Ensemble(model,N_ENS,train_loader,return_average = False).to(dev)
#return_average = False so we can see output's shape

In [None]:
im,_ = next(iter(test_loader))
im = im.to(dev)
output = model_mcbn(im)
print(output.shape)

## Model as ensemble: Accuracy

In [None]:
model_mcbn.return_average = True
print(model_acc(model_mcbn,test_loader))