In [None]:
from google.colab import drive

#drive.mount('/content/gdrive')


In [None]:
import time
import copy
import numpy as np
import matplotlib.pyplot as plt
import os
import scipy

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import torch
torch.manual_seed(2)
np.random.seed(2)
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, SubsetRandomSampler, SequentialSampler
import torchvision.models as models
!pip install torchinfo
!pip install torchattacks
!pip install pip install grad-cam
!pip install ray
from torchinfo import summary
from torchvision.models.resnet import _resnet,BasicBlock
import torchattacks
import torchvision.utils
import torch.nn.functional as F
from pytorch_grad_cam import GradCAM
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler

Collecting torchinfo
  Downloading torchinfo-1.6.3-py3-none-any.whl (20 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.6.3
Collecting torchattacks
  Downloading torchattacks-3.2.4-py3-none-any.whl (102 kB)
[K     |████████████████████████████████| 102 kB 4.1 MB/s 
[?25hInstalling collected packages: torchattacks
Successfully installed torchattacks-3.2.4
Collecting install
  Downloading install-1.3.5-py3-none-any.whl (3.2 kB)
Collecting grad-cam
  Downloading grad-cam-1.3.7.tar.gz (4.5 MB)
[K     |████████████████████████████████| 4.5 MB 4.0 MB/s 
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Collecting ttach
  Downloading ttach-0.0.3-py3-none-any.whl (9.8 kB)
Building wheels for collected packages: grad-cam
  Building wheel for grad-cam (PEP 517) ... [?25l[?25hdone
  Created wheel for grad-cam: filename=grad_cam-1.3.7-py3-none-a

In [None]:
transform = transforms.Compose([
# transforms.Resize((224)),
transforms.ToTensor(),
])

trainset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform)

valset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform)


testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform)
classes = ('plane', 'car', 'bird', 'cat', 'deer',
        'dog', 'frog', 'horse', 'ship', 'truck')


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Files already downloaded and verified


In [None]:
from torch.nn import Conv2d,AvgPool2d,Linear,Sequential,Dropout,BatchNorm2d,ModuleList,BatchNorm1d
import torch.nn.functional as F
import numpy as np
import math
from torch.autograd import Variable
from functools import partial

class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

class ChannelGate(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
        super(ChannelGate, self).__init__()
        self.gate_channels = gate_channels
        self.mlp = nn.Sequential(
            Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
            )
        self.pool_types = pool_types
    def forward(self, x):
        channel_att_sum = None
        for pool_type in self.pool_types:
            if pool_type=='avg':
                avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( avg_pool )
            elif pool_type=='max':
                max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( max_pool )

            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw

        scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
        return x * scale

def logsumexp_2d(tensor):
    tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
    s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
    outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
    return outputs

class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )

class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = torch.sigmoid(x_out) # broadcasting
        return x * scale

class CBAM(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
        super(CBAM, self).__init__()
        self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
        self.no_spatial=no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()
    def forward(self, x):
        x_out = self.ChannelGate(x)
        if not self.no_spatial:
            x_out = self.SpatialGate(x_out)
        return x_out

class Base(nn.Module):
    def freeze(self):
        for param in self.base_model.parameters():
                param.requires_grad = False
    
    def unfreeze(self):
        for param in self.base_model.parameters():
                param.requires_grad = True
    
    def attach_fea_out(self,classname,input,output):
        self.features.append(output)

    def attach_fea_in(self,classname,input,output):
        self.features.append(input[0])

    def __init__(self,trainable = True,attention=False,base=18):
        super(Base,self).__init__()
        self.features = []
        self.channel_size = []
        print(base)
        if base == 9:
            self.base_model = _resnet('resnet', BasicBlock, [1, 1, 1, 1], False, True)
        elif base==18:
            self.base_model = models.resnet18(pretrained=False)
        else:
            self.base_model = models.resnet34(pretrained=False)

        used_blocks = ['layer1', 'layer2','layer3','layer4']
        unused_blocks = ['avgpool','fc']

        for block in used_blocks:
            getattr(self.base_model,block).register_forward_hook(self.attach_fea_out)

        for block in unused_blocks:
             setattr(self.base_model,block,nn.Identity())
        
        if not trainable:
            self.freeze()

        fake_img = torch.rand(1,3,256,256) ## pass fake img to the model to get the channel size of each inception block
        self.base_model(fake_img)
        self.channel_size = [block.size()[1] for block in self.features]
        self.features = []

    def forward(self,img):
        self.base_model(img)

    def get_MLSP(self,img,feature_type,resize = True):
        self.base_model(img)
        if resize:
            print(resize)
            if feature_type == 'narrow':
                MLSP = [F.adaptive_avg_pool2d(block, (1, 1)) for block in self.features]
                for i in range(len(MLSP)):
                    MLSP[i] = MLSP[i].squeeze(2).squeeze(2)

            if feature_type == 'wide':
                MLSP = [F.interpolate(block,mode = 'bilinear', size = 7) for block in self.features]
            
            MLSP = torch.cat(MLSP,dim = 1)
            self.features = []
        else:
            MLSP = self.features
            self.features = []
        return MLSP



class Head(nn.Module):
    dp = 0.5
    def conv_block(self,inc,outc,ker,padding = 1,avgpool = False):
        modules = []
        modules.append(nn.Dropout(Head.dp))
        if avgpool:
            modules.append(AvgPool2d(3,1,1))
        modules.append(Conv2d(inc,outc,ker,padding = padding))
        modules.append(nn.BatchNorm2d(outc))
        modules.append(nn.ReLU())
        return Sequential(*modules)

    def __init__(self,head_type,num_channel):
        super(Head, self).__init__()
        self.head_type = head_type
        self.num_ch = num_channel

        if head_type == 'mlsp_cnn_gap_attn':
            self.attn = []
            self.conv = []
            for i in range(4):
                if i!=3:
                    self.attn.append(CBAM(num_channel[i],reduction_ratio=16))
                else:
                    self.attn.append(CBAM(num_channel[i],reduction_ratio=16,no_spatial=True))
                self.conv.append(Sequential(
                                    self.conv_block(num_channel[i],num_channel[i],1,0),
                                    self.conv_block(num_channel[i],num_channel[i],3,1),
                          ))
            self.attn = ModuleList(self.attn)
            self.conv = ModuleList(self.conv)
        self.dense = Sequential(Linear(960,10))

    def forward(self,features):
        if self.head_type == 'mlsp_gap':
            x = torch.cat([F.adaptive_avg_pool2d(feature, (1, 1)) for feature in features],dim=1)
        else:
            x = torch.cat([F.adaptive_avg_pool2d(block2(block1(feature)+feature),(1,1)) for feature,block1,block2 in zip(features,self.attn,self.conv)],dim=1)
        x = torch.flatten(x, 1)
        x = self.dense(x)
        return x

class Fmodel(nn.Module):
    def __init__(self, head_type='mlsp_gap',base = 18,dp = 0.5):
        super(Fmodel,self).__init__()
        self.bmodel = Base(base=base)
        Head.dp = dp
        self.head = Head(head_type,self.bmodel.channel_size)
        self.feature_type = 'narrow'    
        self.resize = False
        self.fea = []
        self.gap_fea = []
        self.gradient = []
        self.handles = []
        
    def forward(self,img):
        x = self.bmodel.get_MLSP(img,self.feature_type,self.resize)
        x = self.head(x)
        return x

    def unfreeze(self):
        self.bmodel.unfreeze()
    
    def freeze(self):
        self.bmodel.freeze()

    def hook_gap(self):
        handle = self.head.dense.register_forward_hook(lambda layer, inl, _,: self.gap_fea.append(inl[0]))
        self.handles += [handle]
        return handle

    def hook_grad(self):
        handle = []
        handle.append(self.head.conv[0].register_full_backward_hook(lambda layer, inl, out,: self.gradient.append(out[0])))
        handle.append(self.head.conv[1].register_full_backward_hook(lambda layer, inl, out,: self.gradient.append(out[0])))
        handle.append(self.head.conv[2].register_full_backward_hook(lambda layer, inl, out,: self.gradient.append(out[0])))
        handle.append(self.head.conv[3].register_full_backward_hook(lambda layer, inl, out,: self.gradient.append(out[0])))
        self.handles += handle
        return handle

    def hook_fea(self):
        handle = []
        handle.append(self.head.conv[0].register_forward_hook(lambda layer, inl, out,: self.fea.append(out)))
        handle.append(self.head.conv[1].register_forward_hook(lambda layer, inl, out,: self.fea.append(out)))
        handle.append(self.head.conv[2].register_forward_hook(lambda layer, inl, out,: self.fea.append(out)))
        handle.append(self.head.conv[3].register_forward_hook(lambda layer, inl, out,: self.fea.append(out)))
        self.handles += handle
        return handle

    def hook(self):
        self.hook_gap()
        self.hook_grad()
        self.hook_fea()

    def unhook(self):
        for fea in self.gap_fea:
            fea.detach()
        for grad in self.gradient:
            grad.detach()
        for fea in self.fea:
            fea.detach()

        self.gap_fea = []
        self.gradient = []
        self.fea = []

        for h in self.handles:
            h.remove()
        self.handles = []



In [None]:
def grad_cam(model, preds, label,ind,ind_ad):
    masks_attn_ori = []
    masks_attn_adv = []


    preds[torch.arange(len(preds)),label].sum().backward(retain_graph=True)
    mse = torch.nn.MSELoss()
    loss = 0
    model.gradient = model.gradient[::-1]
    for i in range(4):
        weight = model.gradient[i].mean(dim=-1, keepdim=True).mean(dim=-2, keepdim=True)
        mask = weight * model.fea[i]
        mask_ori = F.relu(mask[ind].sum(dim=1))
        masks_attn_ori.append(mask_ori)
        
        mask_adv = F.relu(mask[ind_ad].sum(dim=1))      
        masks_attn_adv.append(mask_adv)

    model.zero_grad()
    for i in range(4):
        loss += mse(masks_attn_ori[i],masks_attn_adv[i])
    return loss

In [None]:
from tqdm.auto import tqdm, trange
import torchattacks

def train_cifar(config, checkpoint_dir=None, data_dir=None):
    trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform)

    valset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform)


    testset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform)

    test_loader = torch.utils.data.DataLoader(
        testset, batch_size=64, shuffle=False, num_workers=2, sampler = SequentialSampler(testset.data[5000:10000]),pin_memory=True)
    
    num_epochs = 15 #50
    adv = True
    dense_l2 = True
    model = Fmodel('mlsp_cnn_gap_attn',18,config['dp'])
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['lr'], weight_decay=config['l2'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,factor=0.1, patience=1, verbose=True, min_lr =1e-6)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2,pin_memory=True)
    val_loader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=False, num_workers=2, sampler = SequentialSampler(valset.data[5000:10000]),pin_memory=True)
    dataloaders = {"train": train_loader, "val": val_loader}

    if checkpoint_dir:
        model_state, optimizer_state = torch.load(
            os.path.join(checkpoint_dir, "checkpoint"))
        model.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)
    since = time.time()
    best_loss = 1000.0
    atk = torchattacks.PGD(model, eps=8/255, alpha=2/225, steps=7, random_start=True)
    atk_fgsm = torchattacks.FGSM(model, eps=8/255)
    
    print('Config: ', config)
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        for phase in ['train', 'val']: # Each epoch has a training and validation phase
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            size = 0

            iterator = iter(dataloaders[phase])
            for _ in range(len(iterator)): # Iterate over data
                inputs, labels = next(iterator)
                if adv:
                    if attn_pre or ada_reg:
                            inputs = torch.cat((inputs.to(device), atk(inputs, labels)),dim=0)
                            labels = torch.cat((labels,labels),dim=0)
                            ind = list(range(len(inputs)//2))
                            ind_ad = list(range(len(inputs)//2,len(inputs)))
                    else:
                        ind_ad = np.random.choice(len(inputs),len(inputs)//2,replace=False)
                        ind = np.delete(np.arange(len(inputs)),ind_ad)
                        inputs_adv = atk(inputs[ind_ad], labels[ind_ad])
                        inputs = inputs.to(device)
                        inputs[ind_ad] = inputs_adv

                else:
                    inputs = inputs.to(device)

                inputs = transforms.functional.normalize(inputs,(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                labels = labels.to(device)
                optimizer.zero_grad() # Zero the parameter gradients

                with torch.set_grad_enabled(phase == 'train'): # Forward. Track history if only in train
                    
                    if phase == 'train': # Backward + optimize only if in training phase
                        if dense_l2:
                            model.hook()
                        outputs = model(inputs)
                        loss = (config['adv_weight']*criterion(outputs[ind], labels[ind]) + (1-config['adv_weight'])*criterion(outputs[ind_ad], labels[ind_ad]))                           
                        _, preds = torch.max(outputs, 1)
                        loss_pres = grad_cam(model, outputs,labels,ind,ind_ad)
                        if not torch.isnan(loss_pres).any():
                            loss = config['scale']*loss + (1-config['scale'])*loss_pres
                        else:
                            print('nan')

                        if dense_l2:
                            fea = model.gap_fea[0]
                            out_ori = fea[ind]
                            out_adv = fea[ind_ad]
                            diff = (out_adv-out_ori).abs().mean(axis=0).reshape(out_ori.shape[-1])
                            diff = (diff - diff.min())/(diff.max()-diff.min()+1e-9)*config['ada_reg']
                            loss += (diff*((model.head.dense[0].weight)**2).sum(dim=0)).sum()

                        loss.backward()
                        optimizer.step()
                        if dense_l2:
                            model.unhook()
                    
                    if phase == 'val':
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                        _, preds = torch.max(outputs, 1)

                
                # Statistics
                
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                size+=len(preds)

            epoch_loss = running_loss / size
            
            
            if phase == 'train': # Adjust learning rate based on val loss
                scheduler.step(epoch_loss)
                
            epoch_acc = running_corrects.double() / size
            #print('{} Loss: {:.3f} Loss2: {:.3f} Acc: {:.3f}'.format(phase, epoch_loss, np.mean(loss2_his), epoch_acc))
            # deep copy the model
            if phase == 'val' and epoch_loss < best_loss:
                best_loss = epoch_loss
                print('Best loss: ', best_loss)
            if phase == 'val':
                with tune.checkpoint_dir(epoch) as checkpoint_dir:
                    path = os.path.join(checkpoint_dir, "checkpoint")
                    torch.save((model.state_dict(), optimizer.state_dict()), path)

                tune.report(loss=epoch_loss, accuracy=epoch_acc)


    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best_loss: {:4f}'.format(best_loss))



In [None]:
from functools import partial

config = {
    "dp": tune.choice([0.5,0.4,0.3, 0.2,0.1]),
    "l2": tune.loguniform(1e-5, 1e-2),
    "ada_reg": tune.loguniform(1e-5, 1e-2),
    "adv_weight": tune.uniform(0.4, 0.6),
    "scale": tune.uniform(0.4, 0.6),
    "lr": tune.loguniform(1e-4, 1e-2),
}
reporter = CLIReporter(
    # parameter_columns=["l1", "l2", "lr", "batch_size"],
    metric_columns=["loss", "accuracy",'training iteration'])

scheduler = ASHAScheduler(
        metric="loss",
        mode="min",
        max_t=20,
        grace_period=1,
        reduction_factor=2)

result = tune.run(
    partial(train_cifar),
    resources_per_trial={"cpu": 2, "gpu": 1},
    config=config,
    num_samples=30,
    verbose=0,
    scheduler=scheduler,
    progress_reporter=reporter,
)

best_trial = result.get_best_trial("loss", "min", 'all')
print("Best trial config: {}".format(best_trial.config))



2022-03-08 11:13:53,656	INFO registry.py:70 -- Detected unknown callable for trainable. Converting to class.


[2m[36m(func pid=298)[0m Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


[2m[36m(func pid=298)[0m   0%|          | 0/170498071 [00:00<?, ?it/s]
  0%|          | 1024/170498071 [00:00<7:49:41, 6049.89it/s]
  0%|          | 33792/170498071 [00:00<24:42, 115000.76it/s]
  0%|          | 82944/170498071 [00:00<14:43, 192969.96it/s]
  0%|          | 214016/170498071 [00:00<06:48, 417260.53it/s]
  0%|          | 443392/170498071 [00:00<03:47, 746970.57it/s]
  1%|          | 902144/170498071 [00:01<02:01, 1398121.19it/s]
  1%|          | 1819648/170498071 [00:01<01:02, 2677636.47it/s]
  2%|▏         | 3671040/170498071 [00:01<00:31, 5240091.31it/s]
  4%|▍         | 6800384/170498071 [00:01<00:17, 9269495.49it/s]
  6%|▌         | 9946112/170498071 [00:01<00:13, 11987326.38it/s]
  8%|▊         | 12960768/170498071 [00:01<00:11, 13718441.43it/s]
  9%|▉         | 16090112/170498071 [00:02<00:10, 15034888.86it/s]
 11%|█         | 18606080/170498071 [00:02<00:08, 17070111.68it/s]
 12%|█▏        | 20416512/170498071 [00:02<00:09, 16575538.69it/s]
 13%|█▎        | 2213

[2m[36m(func pid=298)[0m Extracting ./data/cifar-10-python.tar.gz to ./data
[2m[36m(func pid=298)[0m Files already downloaded and verified
[2m[36m(func pid=298)[0m Files already downloaded and verified
[2m[36m(func pid=298)[0m 18


[2m[36m(func pid=298)[0m E0308 11:14:22.415531481     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 11:14:22.438941128     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Config:  {'dp': 0.2, 'l2': 0.0002714992481279903, 'ada_reg': 0.002018363174302323, 'adv_weight': 0.6306034330529589, 'scale': 0.4530136484753278, 'lr': 0.0015223506791380115}
[2m[36m(func pid=298)[0m Epoch 0/14
[2m[36m(func pid=298)[0m ----------


[2m[36m(func pid=298)[0m E0308 11:22:24.606243623     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 11:22:24.625895980     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Best loss:  1.4819137649536134


[2m[36m(func pid=298)[0m E0308 11:22:42.696280471     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 11:22:42.717680508     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Epoch 1/14
[2m[36m(func pid=298)[0m ----------


[2m[36m(func pid=298)[0m E0308 11:30:41.875155584     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 11:30:41.896666914     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Best loss:  1.4715185634613037


[2m[36m(func pid=298)[0m E0308 11:30:57.216997416     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 11:30:57.240582520     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Epoch 2/14
[2m[36m(func pid=298)[0m ----------


[2m[36m(func pid=298)[0m E0308 11:38:53.152749559     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 11:38:53.172838331     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Best loss:  1.163822728729248


[2m[36m(func pid=298)[0m E0308 11:39:08.428880144     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 11:39:08.452136623     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Epoch 3/14
[2m[36m(func pid=298)[0m ----------


[2m[36m(func pid=298)[0m E0308 11:47:08.503680132     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 11:47:08.524536721     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Best loss:  1.1071286731719971


[2m[36m(func pid=298)[0m E0308 11:47:23.984585337     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 11:47:24.007980160     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Epoch 4/14
[2m[36m(func pid=298)[0m ----------


[2m[36m(func pid=298)[0m E0308 11:55:23.426851388     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 11:55:23.450341745     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Best loss:  1.0277284984588624


[2m[36m(func pid=298)[0m E0308 11:55:38.807181567     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 11:55:38.827636840     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Epoch 5/14
[2m[36m(func pid=298)[0m ----------


[2m[36m(func pid=298)[0m E0308 12:03:37.522767858     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 12:03:37.543669443     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Best loss:  0.9731098812103272


[2m[36m(func pid=298)[0m E0308 12:03:53.181850789     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 12:03:53.203681220     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Epoch 6/14
[2m[36m(func pid=298)[0m ----------


[2m[36m(func pid=298)[0m E0308 12:11:54.819172620     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 12:11:54.838955097     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Best loss:  0.916001215171814


[2m[36m(func pid=298)[0m E0308 12:12:10.328707610     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 12:12:10.354562137     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Epoch 7/14
[2m[36m(func pid=298)[0m ----------


[2m[36m(func pid=298)[0m E0308 12:20:15.179830923     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 12:20:15.201051957     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Best loss:  0.8516330413818359


[2m[36m(func pid=298)[0m E0308 12:20:30.768933983     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 12:20:30.790185686     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Epoch 8/14
[2m[36m(func pid=298)[0m ----------


[2m[36m(func pid=298)[0m E0308 12:28:33.536339545     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 12:28:33.563272150     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 12:28:49.149165049     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 12:28:49.170630232     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Epoch 9/14
[2m[36m(func pid=298)[0m ----------


[2m[36m(func pid=298)[0m E0308 12:36:51.250920500     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 12:36:51.272289602     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 12:37:06.793610857     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 12:37:06.817398868     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Epoch 10/14
[2m[36m(func pid=298)[0m ----------


[2m[36m(func pid=298)[0m E0308 12:45:08.487247211     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 12:45:08.509240310     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Best loss:  0.8246260431289673


[2m[36m(func pid=298)[0m E0308 12:45:24.043976592     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 12:45:24.072698373     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Epoch 11/14
[2m[36m(func pid=298)[0m ----------


[2m[36m(func pid=298)[0m E0308 12:53:24.982446504     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 12:53:25.003284931     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Best loss:  0.820308931350708


[2m[36m(func pid=298)[0m E0308 12:53:40.413953979     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 12:53:40.434899201     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Epoch 12/14
[2m[36m(func pid=298)[0m ----------


[2m[36m(func pid=298)[0m E0308 13:01:39.407699503     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 13:01:39.427447454     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 13:01:54.776443588     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 13:01:54.798854309     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Epoch 13/14
[2m[36m(func pid=298)[0m ----------


[2m[36m(func pid=298)[0m E0308 13:09:55.420821717     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 13:09:55.442393139     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Best loss:  0.7734364852905273


[2m[36m(func pid=298)[0m E0308 13:10:10.942693588     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 13:10:10.967407073     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=298)[0m Epoch 14/14
[2m[36m(func pid=298)[0m ----------


[2m[36m(func pid=298)[0m E0308 13:18:15.580865361     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=298)[0m E0308 13:18:15.603491792     380 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=701)[0m Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


[2m[36m(func pid=701)[0m   0%|          | 0/170498071 [00:00<?, ?it/s]
  0%|          | 1024/170498071 [00:00<7:49:50, 6048.07it/s]
  0%|          | 33792/170498071 [00:00<24:41, 115034.29it/s]
  0%|          | 82944/170498071 [00:00<14:43, 192919.72it/s]
  0%|          | 214016/170498071 [00:00<06:48, 416678.78it/s]
  0%|          | 443392/170498071 [00:00<03:48, 745308.38it/s]
  1%|          | 902144/170498071 [00:01<02:01, 1396418.27it/s]
  1%|          | 1819648/170498071 [00:01<01:02, 2677603.52it/s]
  2%|▏         | 3671040/170498071 [00:01<00:31, 5240842.53it/s]
  4%|▍         | 6439936/170498071 [00:01<00:19, 8593506.13it/s]
  6%|▌         | 9552896/170498071 [00:01<00:13, 11503628.30it/s]
  7%|▋         | 12583936/170498071 [00:01<00:11, 13354794.94it/s]
  9%|▉         | 15582208/170498071 [00:02<00:10, 14563966.81it/s]
 11%|█         | 18547712/170498071 [00:02<00:09, 15345491.76it/s]
 13%|█▎        | 21529600/170498071 [00:02<00:09, 15912294.77it/s]
 14%|█▍        | 2446

[2m[36m(func pid=701)[0m Extracting ./data/cifar-10-python.tar.gz to ./data
[2m[36m(func pid=701)[0m Files already downloaded and verified
[2m[36m(func pid=701)[0m Files already downloaded and verified
[2m[36m(func pid=701)[0m 18
[2m[36m(func pid=701)[0m Config:  {'dp': 0.2, 'l2': 3.881646497434185e-05, 'ada_reg': 1.2875285396276522e-05, 'adv_weight': 0.7713116193859444, 'scale': 0.5391227821497242, 'lr': 0.005064982066547872}
[2m[36m(func pid=701)[0m Epoch 0/14
[2m[36m(func pid=701)[0m ----------


[2m[36m(func pid=701)[0m E0308 13:18:59.167247433     729 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=701)[0m E0308 13:18:59.197425859     729 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=701)[0m E0308 13:27:07.345481251     729 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=701)[0m E0308 13:27:07.371418639     729 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=701)[0m Best loss:  1.5946197027206421
[2m[36m(func pid=757)[0m Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]
  0%|          | 1024/170498071 [00:00<7:47:08, 6082.93it/s]
  0%|          | 33792/170498071 [00:00<24:39, 115245.19it/s]
  0%|          | 82944/170498071 [00:00<14:42, 193198.71it/s]
  0%|          | 214016/170498071 [00:00<06:48, 416901.10it/s]
  0%|          | 443392/170498071 [00:00<03:47, 746630.44it/s]
  1%|          | 902144/170498071 [00:01<02:01, 1397321.25it/s]
  1%|          | 1819648/170498071 [00:01<01:02, 2680450.49it/s]
  2%|▏         | 3671040/170498071 [00:01<00:31, 5243928.89it/s]
  4%|▎         | 6259712/170498071 [00:01<00:19, 8294537.84it/s]
  5%|▌         | 8553472/170498071 [00:01<00:16, 9834450.54it/s]
  7%|▋         | 11256832/170498071 [00:01<00:13, 11622428.75it/s]
  8%|▊         | 13714432/170498071 [00:02<00:12, 12417168.08it/s]
 10%|▉         | 16499712/170498071 [00:02<00:11, 13541212.99it/s]
 11%|█         | 19006464/170498071 [00:02<00:10, 13842630.01it/s]
 13%|█▎        | 21742592/170498071 [00:02<00:10, 1

[2m[36m(func pid=757)[0m Extracting ./data/cifar-10-python.tar.gz to ./data
[2m[36m(func pid=757)[0m Files already downloaded and verified
[2m[36m(func pid=757)[0m Files already downloaded and verified
[2m[36m(func pid=757)[0m 18
[2m[36m(func pid=757)[0m Config:  {'dp': 0.1, 'l2': 0.0003291841312537266, 'ada_reg': 1.515965161568824e-05, 'adv_weight': 0.43427894511585396, 'scale': 0.42858641567824113, 'lr': 0.00036941252392986055}
[2m[36m(func pid=757)[0m Epoch 0/14
[2m[36m(func pid=757)[0m ----------


[2m[36m(func pid=757)[0m E0308 13:27:49.631570020     785 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=757)[0m E0308 13:27:49.665824955     785 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=757)[0m E0308 13:35:52.892487232     785 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=757)[0m E0308 13:35:52.920525165     785 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=757)[0m Best loss:  1.321039206123352
[2m[36m(func pid=757)[0m Epoch 1/14
[2m[36m(func pid=757)[0m ----------


[2m[36m(func pid=757)[0m E0308 13:36:08.500198098     785 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=757)[0m E0308 13:36:08.526392451     785 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=757)[0m E0308 13:44:13.132871875     785 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=757)[0m E0308 13:44:13.163267656     785 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=757)[0m Best loss:  1.222826982307434
[2m[36m(func pid=757)[0m Epoch 2/14
[2m[36m(func pid=757)[0m ----------


[2m[36m(func pid=757)[0m E0308 13:44:28.751799099     785 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=757)[0m E0308 13:44:28.783922054     785 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=757)[0m E0308 13:52:32.285771529     785 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=757)[0m E0308 13:52:32.312447723     785 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=757)[0m Best loss:  1.1216665241241455
[2m[36m(func pid=757)[0m Epoch 3/14
[2m[36m(func pid=757)[0m ----------


[2m[36m(func pid=757)[0m E0308 13:52:47.905639049     785 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=757)[0m E0308 13:52:47.931556566     785 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=757)[0m E0308 14:00:50.648202139     785 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=757)[0m E0308 14:00:50.673751633     785 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=757)[0m Best loss:  1.0370936120986938
[2m[36m(func pid=757)[0m Epoch 4/14
[2m[36m(func pid=757)[0m ----------


[2m[36m(func pid=757)[0m E0308 14:01:06.174220559     785 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=757)[0m E0308 14:01:06.203371350     785 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=757)[0m E0308 14:09:07.014249444     785 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=757)[0m E0308 14:09:07.039616833     785 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


[2m[36m(func pid=757)[0m Best loss:  1.0200728635787963
[2m[36m(func pid=757)[0m Epoch 5/14
[2m[36m(func pid=757)[0m ----------


[2m[36m(func pid=757)[0m E0308 14:09:22.572624602     785 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=757)[0m E0308 14:09:22.606160108     785 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=757)[0m E0308 14:17:23.044836949     785 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies
[2m[36m(func pid=757)[0m E0308 14:17:23.070174068     785 fork_posix.cc:70]           Fork support is only compatible with the epoll1 and poll polling strategies


In [None]:
best_trial = result.get_best_trial("loss", "min", 'all')
print("Best trial config: {}".format(best_trial.config))

Best trial config: {'dp': 0.1, 'l2': 5.467130776888194e-06, 'ada_reg': 0.008446479631052932, 'adv_weight': 0.73428888168777, 'scale': 0.5988473674763856, 'lr': 0.004933134661531971}
Best trial final validation loss: 0.9047795356750489
Best trial final validation accuracy: 0.7273000000000001
