### Ordering the bias improves the performance of deep ensemble.

In [None]:
import os, sys
from pathlib import Path
import types
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torch.utils.data import random_split, DataLoader
from torch.optim.lr_scheduler import LambdaLR
import numpy as np
import pandas as pd

import timm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark=True
torch.backends.cudnn.deterministic=False

In [2]:
num_classes = 10
train_batchsize = 128
eval_batchsize = 1000

cifar10_mean = (0.49139968, 0.48215827 ,0.44653124)
cifar10_std = (0.24703233, 0.24348505, 0.26158768)

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar10_mean, cifar10_std),
])

train_set, val_set = random_split(
    dsets.CIFAR10(
        root=os.environ['DATA'],
        train=True,
        download=True,
        transform=transform_train
    ),
    [45000, 5000]
)
train_loader = DataLoader(
    train_set, batch_size=train_batchsize, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(
    val_set, batch_size=eval_batchsize, shuffle=True, num_workers=4, pin_memory=True)

test_set = dsets.CIFAR10(
    root=os.environ['DATA'], train=False, download=True, transform=transform_test)
test_loader = DataLoader(
    test_set, batch_size=eval_batchsize, shuffle=False, num_workers=8, pin_memory=True)

Files already downloaded and verified
Files already downloaded and verified


In [3]:
## Detach the bias part of some nn.Module and manage it.
class bias_order_reg1(nn.Module):
    def __init__(self, module: nn.Module):
        super().__init__()
        self.module = module
        if self.module.bias is not None:
            self.bias = module.bias
            self.module.bias = None
        else:
            self.bias = None
            
    def forward(self, x):
        x = self.module(x)
        if self.bias is not None:
            # Inference on the log difference of consecutive bias values
            x + self.bias.exp().cumsum(-1)
        return x

def set_bias_order_reg1(model):
    for n, module in model.named_children():
        if len(list(module.children())) > 0:
            set_bias_order_reg1(module)
            
        if (isinstance(module, nn.Linear) or
            isinstance(module, nn.Conv2d)):
            #or
            #isinstance(module, nn.BatchNorm1d) or
            #isinstance(module, nn.BatchNorm2d)):
            setattr(model, n, bias_order_reg1(module))
    
def reset_parameters(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        m.reset_parameters()

In [None]:
model_base = timm.create_model("resnet18", pretrained=False)
model_base.fc = nn.Linear(512, num_classes, bias = True)
model_base = model_base.to(device)
model_reg1 = timm.create_model("resnet18", pretrained=False)
model_reg1.fc = nn.Linear(512, num_classes, bias = True)
set_bias_order_reg1(model_reg1)
model_reg1 = model_reg1.to(device)

coef_lambda = 1.0
num_ensemble = 50
num_epochs = 50
lr = 0.1
momentum = 0.9
weight_decay = 1e-4

criterion = nn.CrossEntropyLoss()


for ensemble_ind in tqdm(range(num_ensemble)): 
    torch.manual_seed(ensemble_ind)
    
    model_base.apply(reset_parameters)
    model_reg1.apply(reset_parameters)
    
    optimizer = optim.SGD(
        list(model_base.parameters())
        + list(model_reg1.parameters()),
        lr=lr,
        momentum=momentum,
        weight_decay=weight_decay, 
    )
 
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)

    
    train_set, val_set = random_split(
        dsets.CIFAR10(
            root=os.environ['DATA'],
            train=True,
            download=True,
            transform=transform_train
        ),
        [45000, 5000]
    )
    train_loader = DataLoader(
        train_set, batch_size=128, shuffle=True, num_workers=2)
    val_loader = DataLoader(
        val_set, batch_size=512, shuffle=True, num_workers=2)
    
    for epoch in range(num_epochs):
        train_acc_base = 0.0
        train_acc_reg1 = 0.0
        val_acc_base = 0.0
        val_acc_reg1 = 0.0
        best_val_acc_base = 0.0
        best_val_acc_reg1 = 0.0
        
        model_base.train()
        model_reg1.train()
        for images, labels in train_loader:  
            # Move tensors to the configured device
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs_base = model_base(images)
            outputs_reg1 = model_reg1(images)
            loss_base = criterion(outputs_base, labels)
            loss_reg1 = criterion(outputs_reg1, labels)
            train_acc_base += (labels == outputs_base.argmax(axis = 1)).sum().item()
            train_acc_reg1 += (labels == outputs_reg1.argmax(axis = 1)).sum().item()

            # Backprpagation and optimization
            optimizer.zero_grad()
            loss_base.backward()
            loss_reg1.backward()

            optimizer.step()
            
            scheduler.step()
            
        train_acc_base /= len(train_set)
        train_acc_reg1 /= len(train_set)
        
        model_base.eval()
        model_reg1.eval()
        with torch.no_grad():
            for images, labels in val_loader:
                val_input = images.to(device)
                val_acc_base += (labels.to(device) == model_base(val_input).argmax(axis = 1)).sum().item()
                val_acc_reg1 += (labels.to(device) == model_reg1(val_input).argmax(axis = 1)).sum().item()
            val_acc_base /= len(val_set)
            val_acc_reg1 /= len(val_set)
         
        if best_val_acc_base < val_acc_base:
            best_val_acc_base = val_acc_base
            torch.save(model_base.state_dict(), "./saved_model/cifar10_base_{:03d}".format(ensemble_ind))
        if best_val_acc_reg1 < val_acc_reg1:
            best_val_acc_reg1 = val_acc_reg1
            torch.save(model_reg1.state_dict(), "./saved_model/cifar10_reg1_{:03d}".format(ensemble_ind))

        """
        print ('Epoch [{}/{}], base: [{:.4f} / {:.4f}], reg1: [{:.4f} / {:.4f}], reg2: [{:.4f} / {:.4f}], reg3: [{:.4f} / {:.4f}]'.format(
            epoch+1, num_epochs, train_acc_base, val_acc_base, train_acc_reg, val_acc_reg
        ))
        """

  0%|                                                                               | 0/50 [00:00<?, ?it/s]

Files already downloaded and verified


  2%|█▎                                                                | 1/50 [19:03<15:34:04, 1143.76s/it]

Files already downloaded and verified


  4%|██▋                                                               | 2/50 [38:03<15:13:01, 1141.28s/it]

Files already downloaded and verified


  6%|███▉                                                              | 3/50 [57:05<14:54:12, 1141.55s/it]

Files already downloaded and verified


  8%|█████                                                           | 4/50 [1:16:34<14:43:38, 1152.58s/it]

Files already downloaded and verified


 10%|██████▍                                                         | 5/50 [1:36:48<14:41:06, 1174.81s/it]

Files already downloaded and verified


In [5]:
result_base = torch.zeros((num_ensemble, len(test_set), 10))
result_reg1 = torch.zeros((num_ensemble, len(test_set), 10))

true_label = torch.zeros((len(test_set)))
for ensemble_ind in range(num_ensemble):    
    model_base.load_state_dict(torch.load("./saved_model/cifar10_base_{:03d}".format(ensemble_ind)))
    model_reg1.load_state_dict(torch.load("./saved_model/cifar10_reg1_{:03d}".format(ensemble_ind)))
    model_base.eval()
    model_reg1.eval()
    
    with torch.no_grad():
        for ind, (images, labels) in enumerate(test_loader):
            test_input = images.to(device)
            result_base[ensemble_ind, ind * eval_batchsize:(ind+1)* eval_batchsize, :] = model_base(test_input).detach()
            result_reg1[ensemble_ind, ind * eval_batchsize:(ind+1)* eval_batchsize, :] = model_reg1(test_input).detach()
            true_label[ind * eval_batchsize:(ind+1)* eval_batchsize] = labels

In [8]:
model_base(test_input).detach().shape

torch.Size([1000, 1000])

In [6]:
(result_base.argmax(-1).mode(0).values == true_label).float().mean()

tensor(0.8798)

In [8]:
(result_reg1.argmax(-1).mode(0).values == true_label).float().mean()

tensor(0.8815)

In [9]:
(F.softmax(result_base, dim=-1).mean(0).argmax(-1) == true_label).float().mean()

tensor(0.8794)

In [10]:
(F.softmax(result_reg1, dim=-1).mean(0).argmax(-1) == true_label).float().mean()

tensor(0.8811)

#### Test Accuracy 