In [None]:
import copy
import torch
import time
import os
import random
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

import aux_funcs  as af
import network_architectures as arcs
import model_funcs as mf

In [15]:
class BasicBlockWOutput(nn.Module):
    expansion = 1

    def __init__(self, in_channels, channels, params, stride=1):
        super(BasicBlockWOutput, self).__init__()
        add_output = params[0]
        num_classes = params[1]
        input_size = params[2]
        self.output_id = params[3]

        self.depth = 2

        layers = nn.ModuleList()

        conv_layer = []
        conv_layer.append(nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False))
        conv_layer.append(nn.BatchNorm2d(channels))
        conv_layer.append(nn.ReLU())
        conv_layer.append(nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False))
        conv_layer.append(nn.BatchNorm2d(channels))

        layers.append(nn.Sequential(*conv_layer))

        shortcut = nn.Sequential()

        if stride != 1 or in_channels != self.expansion*channels:
            shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion*channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*channels)
            )

        layers.append(shortcut)
        layers.append(nn.ReLU())

        self.layers = layers

        if add_output:
            self.output = af.InternalClassifier(input_size, self.expansion*channels, num_classes)
            self.no_output = False

        else:
            self.output = None
            self.forward = self.only_forward
            self.no_output = True

    def forward(self, x):
        fwd = self.layers[0](x) # conv layers
        fwd = fwd + self.layers[1](x) # shortcut
        return self.layers[2](fwd), 1, self.output(fwd)         # output layers for this module

    def only_output(self, x):
        fwd = self.layers[0](x) # conv layers
        fwd = fwd + self.layers[1](x) # shortcut
        fwd = self.layers[2](fwd) # activation
        out = self.output(fwd)         # output layers for this module
        return out

    def only_forward(self, x):
        fwd = self.layers[0](x) # conv layers
        fwd = fwd + self.layers[1](x) # shortcut
        return self.layers[2](fwd), 0, None # activation

class ResNet_SDN(nn.Module):
    def __init__(self, params):
        super(ResNet_SDN, self).__init__()
        self.num_blocks = params['num_blocks']
        self.num_classes = int(params['num_classes'])
        self.augment_training = params['augment_training']
        self.input_size = int(params['input_size'])
        self.block_type = params['block_type']
        self.add_out_nonflat = params['add_ic']
        self.add_output = [item for sublist in self.add_out_nonflat for item in sublist]
        self.init_weights = params['init_weights']
        self.train_func = mf.sdn_train
        self.in_channels = 16
        self.num_output = sum(self.add_output) + 1
        self.test_func = mf.sdn_test

        self.init_depth = 1
        self.end_depth = 1
        self.cur_output_id = 0

        self.stop_ic_id = -1

        self.init_rpf_channels = params['init_rpf_pannel']
        self.use_rpf = params['use_rpf']

        if self.block_type == 'basic':
            self.block = BasicBlockWOutput

        init_conv = []

        if self.input_size ==  32: # cifar10
            self.cur_input_size = self.input_size
            if self.use_rpf:
                init_conv.append(nn.Conv2d(3, self.in_channels - self.init_rpf_channels, kernel_size=3, stride=1, padding=1, bias=False))
                init_conv.append(nn.Conv2d(3, self.init_rpf_channels, kernel_size=3, stride=1, padding=1, bias=False))
            else:
                init_conv.append(nn.Conv2d(3, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False))
        else: # tiny imagenet
            self.cur_input_size = int(self.input_size/2)
            if self.use_rpf:
                init_conv.append(nn.Conv2d(3, self.in_channels - self.init_rpf_channels, kernel_size=3, stride=1, padding=1, bias=False))
                init_conv.append(nn.Conv2d(3, self.init_rpf_channels, kernel_size=3, stride=1, padding=1, bias=False))
            else:
                init_conv.append(nn.Conv2d(3, self.in_channels, kernel_size=3, stride=2, padding=1, bias=False))

        init_conv.append(nn.BatchNorm2d(self.in_channels))
        init_conv.append(nn.ReLU())

        self.init_conv = nn.Sequential(*init_conv)

        self.layers = nn.ModuleList()
        self.layers.extend(self._make_layer(self.in_channels, block_id=0, stride=1))

        self.cur_input_size = int(self.cur_input_size/2)
        self.layers.extend(self._make_layer(32, block_id=1, stride=2))

        self.cur_input_size = int(self.cur_input_size/2)
        self.layers.extend(self._make_layer(64, block_id=2, stride=2))

        end_layers = []

        end_layers.append(nn.AvgPool2d(kernel_size=8))
        end_layers.append(af.Flatten())
        end_layers.append(nn.Linear(64*self.block.expansion, self.num_classes))
        self.end_layers = nn.Sequential(*end_layers)

        if self.init_weights:
            self.initialize_weights()

    def _make_layer(self, channels, block_id, stride):
        num_blocks = int(self.num_blocks[block_id])
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for cur_block_id, stride in enumerate(strides):
            add_output = self.add_out_nonflat[block_id][cur_block_id]
            params  = (add_output, self.num_classes, int(self.cur_input_size), self.cur_output_id)
            layers.append(self.block(self.in_channels, channels, params, stride))
            self.in_channels = channels * self.block.expansion
            self.cur_output_id += add_output

        return layers

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def random_rp_matrix(self):
        param = next(self.init_conv[0].parameters())
        kernel_size = param.data.size()[-1]
        param.data = torch.normal(mean=0.0, std=1/kernel_size, size=param.data.size()).to('cuda')

    def rp_forward(self, x, out, kernel):
        rp_out = kernel(x)
        if out is None:
            return rp_out
        else:
            out = torch.cat([out, rp_out], dim=1)
            return out

    def forward(self, x):
        outputs = []
        if self.use_rpf:
            fwd = self.init_conv[0](x)
            fwd = self.rp_forward(x, fwd, self.init_conv[1])
            fwd = self.init_conv[2](fwd)
        else:
            fwd = self.init_conv(x)
        for layer in self.layers:
            fwd, is_output, output = layer(fwd)
            if is_output:
                if len(outputs) == self.stop_ic_id:
                    return output
                outputs.append(output)
        fwd = self.end_layers(fwd)
        outputs.append(fwd)

        return outputs

    # takes a single input
    def early_exit(self, x):
        confidences = []
        outputs = []

        fwd = self.init_conv(x)
        output_id = 0
        for layer in self.layers:
            fwd, is_output, output = layer(fwd)

            if is_output:
                outputs.append(output)
                softmax = nn.functional.softmax(output[0], dim=0)

                confidence = torch.max(softmax)
                confidences.append(confidence)

                if confidence >= self.confidence_threshold:
                    is_early = True
                    return output, output_id, is_early

                output_id += is_output

        output = self.end_layers(fwd)
        outputs.append(output)

        softmax = nn.functional.softmax(output[0], dim=0)
        confidence = torch.max(softmax)
        confidences.append(confidence)
        max_confidence_output = np.argmax(confidences)
        is_early = False
        return outputs[max_confidence_output], max_confidence_output, is_early

In [16]:
model_params = {}
model_params['task'] = 'cifar10'
model_params['input_size'] = 32
model_params['num_classes'] = 10
model_params['block_type'] = 'basic'
model_params['num_blocks'] = [9,9,9]
model_params['add_ic'] = [[0, 0, 0, 1, 0, 0, 0, 1, 0], [0, 0, 1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 1, 0, 0, 0]] # 15, 30, 45, 60, 75, 90 percent of GFLOPs
model_params['network_type'] = 'resnet56'
model_params['augment_training'] = True
model_params['init_weights'] = True
model_params['architecture'] = 'sdn'
model_params['base_model'] = 'resnet56'
network_type = model_params['network_type']
model_params['momentum'] = 0.9
model_params['learning_rate'] = 0.1
model_params['epochs'] = 100
model_params['milestones'] = [35, 60, 85]
model_params['gammas'] = [0.1, 0.1, 0.1]

# SDN ic_only training params
model_params['ic_only'] = {}
model_params['ic_only']['learning_rate'] = 0.001 # lr for full network training after sdn modification
model_params['ic_only']['epochs'] = 25
model_params['ic_only']['milestones'] = [15]
model_params['ic_only']['gammas'] = [0.1]
model = ResNet_SDN(model_params)

KeyError: 'init_rpf_pannel'

In [17]:
model.stop_ic_id = 1
batch_shape = [128, 3, 32, 32]

In [None]:
model

In [None]:
def evaluate_attack(model,layer_id, test_loader, atk, atk_name):
    test_loss = 0
    test_acc = 0
    n = 0
    model.eval()
    model.stop_ic_id = layer_id
    for i, (X, y) in enumerate(test_loader):
        X, y = X.to('cuda'), y.to('cuda')

        #if args.rp:
            # random select a path to attack
        #    model.module.random_rp_matrix()

        X_adv = atk(X, y)  # advtorch

        #if args.rp:
         #   # random select a path to infer
          #  model.module.random_rp_matrix()

        with torch.no_grad():
            output = model(X_adv)
        loss = F.cross_entropy(output, y)
        test_loss += loss.item() * y.size(0)
        test_acc += (output.max(1)[1] == y).sum().item()
        n += y.size(0)

    pgd_acc = test_acc / n
    print('Attack_type: [{:s}] done, acc: {:.4f} \t'.format(atk_name, pgd_acc))

In [14]:
test_input = torch.rand(batch_shape)
model.confidence_threshold = 0.5
model.early_exit(test_input)

(tensor([[-1.0568, -3.5276,  8.4635,  ...,  0.9191, -2.0948,  5.1047],
         [-0.2937, -2.9282,  7.2088,  ...,  0.4159, -2.3300,  5.0619],
         [-1.0225, -3.4461,  7.7779,  ..., -0.2592, -1.9531,  4.8934],
         ...,
         [-0.3686, -3.8420,  7.8554,  ...,  0.3382, -1.8080,  5.1317],
         [-0.9766, -3.8278,  7.7708,  ...,  0.2753, -2.6839,  5.1647],
         [-0.8822, -3.5565,  8.2377,  ...,  0.3630, -2.7631,  4.6707]],
        grad_fn=<AddmmBackward0>),
 0,
 True)

In [None]:
cifar = af.get_dataset("cifar10")

In [None]:
import torchattacks
atk = torchattacks.FGSM(model, eps=8/255)
atk.set_normalization_used(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
evaluate_attack(model,1, cifar.test_loader, atk, 'fgsm')

python train_imagenet.py --pretrained --lr 0.02 --lr_schedule cosine --batch_size 1024 --epochs 90 --adv_train --rp --rp_block -1 -1 --rp_out_channel 48 --rp_weight_decay 1e-2 --save_dir resnet50_imagenet_RPF

In [127]:
class BasicBlockWOutput(nn.Module):
    expansion = 1

    def __init__(self, in_channels, channels, params, stride=1):
        super(BasicBlockWOutput, self).__init__()
        add_output = params[0]
        num_classes = params[1]
        input_size = params[2]
        self.output_id = params[3]

        self.depth = 2

        layers = nn.ModuleList()

        conv_layer = []
        conv_layer.append(nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1, bias=False))
        conv_layer.append(nn.BatchNorm2d(channels))
        conv_layer.append(nn.ReLU())
        conv_layer.append(nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False))
        conv_layer.append(nn.BatchNorm2d(channels))

        layers.append(nn.Sequential(*conv_layer))

        shortcut = nn.Sequential()

        if stride != 1 or in_channels != self.expansion*channels:
            shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion*channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*channels)
            )

        layers.append(shortcut)
        layers.append(nn.ReLU())

        self.layers = layers

        if add_output:
            self.output = af.InternalClassifier(input_size, self.expansion*channels, num_classes)
            self.no_output = False

        else:
            self.output = None
            self.forward = self.only_forward
            self.no_output = True

    def forward(self, x):
        fwd = self.layers[0](x) # conv layers
        fwd = fwd + self.layers[1](x) # shortcut
        return self.layers[2](fwd), 1, self.output(fwd)         # output layers for this module

    def only_output(self, x):
        fwd = self.layers[0](x) # conv layers
        fwd = fwd + self.layers[1](x) # shortcut
        fwd = self.layers[2](fwd) # activation
        out = self.output(fwd)         # output layers for this module
        return out

    def only_forward(self, x):
        fwd = self.layers[0](x) # conv layers
        fwd = fwd + self.layers[1](x) # shortcut
        return self.layers[2](fwd), 0, None # activation

class ResNet_SDN(nn.Module):
    def __init__(self, params):
        super(ResNet_SDN, self).__init__()
        self.num_blocks = params['num_blocks']
        self.num_classes = int(params['num_classes'])
        self.augment_training = params['augment_training']
        self.input_size = int(params['input_size'])
        self.block_type = params['block_type']
        self.add_out_nonflat = params['add_ic']
        self.add_output = [item for sublist in self.add_out_nonflat for item in sublist]
        self.init_weights = params['init_weights']
        self.train_func = mf.sdn_train
        self.in_channels = 16
        self.num_output = sum(self.add_output) + 1
        self.test_func = mf.sdn_test

        self.init_depth = 1
        self.end_depth = 1
        self.cur_output_id = 0

        self.stop_ic_id = -1

        self.init_rpf_channels = params['init_rpf_pannel']
        self.use_rpf = params['use_rpf']

        if self.block_type == 'basic':
            self.block = BasicBlockWOutput

        init_conv = []

        if self.input_size ==  32: # cifar10
            self.cur_input_size = self.input_size
            if self.use_rpf:
                init_conv.append(nn.Conv2d(3, self.in_channels - self.init_rpf_channels, kernel_size=3, stride=1, padding=1, bias=False))
                init_conv.append(nn.Conv2d(3, self.init_rpf_channels, kernel_size=3, stride=1, padding=1, bias=False))
            else:
                init_conv.append(nn.Conv2d(3, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False))
        else: # tiny imagenet
            self.cur_input_size = int(self.input_size/2)
            if self.use_rpf:
                init_conv.append(nn.Conv2d(3, self.in_channels - self.init_rpf_channels, kernel_size=3, stride=1, padding=1, bias=False))
                init_conv.append(nn.Conv2d(3, self.init_rpf_channels, kernel_size=3, stride=1, padding=1, bias=False))
            else:
                init_conv.append(nn.Conv2d(3, self.in_channels, kernel_size=3, stride=2, padding=1, bias=False))

        init_conv.append(nn.BatchNorm2d(self.in_channels))
        init_conv.append(nn.ReLU())

        self.init_conv = nn.Sequential(*init_conv)

        self.layers = nn.ModuleList()
        self.layers.extend(self._make_layer(self.in_channels, block_id=0, stride=1))

        self.cur_input_size = int(self.cur_input_size/2)
        self.layers.extend(self._make_layer(32, block_id=1, stride=2))

        self.cur_input_size = int(self.cur_input_size/2)
        self.layers.extend(self._make_layer(64, block_id=2, stride=2))

        end_layers = []

        end_layers.append(nn.AvgPool2d(kernel_size=8))
        end_layers.append(af.Flatten())
        end_layers.append(nn.Linear(64*self.block.expansion, self.num_classes))
        self.end_layers = nn.Sequential(*end_layers)

        if self.init_weights:
            self.initialize_weights()

    def _make_layer(self, channels, block_id, stride):
        num_blocks = int(self.num_blocks[block_id])
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for cur_block_id, stride in enumerate(strides):
            add_output = self.add_out_nonflat[block_id][cur_block_id]
            params  = (add_output, self.num_classes, int(self.cur_input_size), self.cur_output_id)
            layers.append(self.block(self.in_channels, channels, params, stride))
            self.in_channels = channels * self.block.expansion
            self.cur_output_id += add_output

        return layers

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def random_rp_matrix(self):
        param = next(self.init_conv[0].parameters())
        kernel_size = param.data.size()[-1]
        param.data = torch.normal(mean=0.0, std=1/kernel_size, size=param.data.size()).to('cuda')

    def rp_forward(self, x, out, kernel):
        rp_out = kernel(x)
        if out is None:
            return rp_out
        else:
            out = torch.cat([out, rp_out], dim=1)
            return out

    def forward(self, x):
        outputs = []
        if self.use_rpf:
            fwd = self.init_conv[0](x)
            fwd = self.rp_forward(x, fwd, self.init_conv[1])
            fwd = self.init_conv[2](fwd)
        else:
            fwd = self.init_conv(x)
        for layer in self.layers:
            fwd, is_output, output = layer(fwd)
            if is_output:
                if len(outputs) == self.stop_ic_id:
                    return output
                outputs.append(output)
        fwd = self.end_layers(fwd)
        outputs.append(fwd)

        return outputs

    # takes a single input
    def early_exit(self, x):
        device = next(self.parameters()).device
        confidences = []
        outputs = []
        batch_size = x.shape[0]
        result = torch.zeros(batch_size, dtype=torch.long).to(device)
        stop_at = torch.zeros(batch_size).to(device)
        stopped = torch.tensor([0]*batch_size, dtype=torch.bool).to(device)
        if self.use_rpf:
            fwd = self.init_conv[0](x)
            fwd = self.rp_forward(x, fwd, self.init_conv[1])
            fwd = self.init_conv[2](fwd)
        else:
            fwd = self.init_conv(x)
        output_id = 0
        for layer in self.layers:
            fwd, is_output, output = layer(fwd)

            if is_output:
                softmax = nn.functional.softmax(output, dim=1)
                softmax[stopped] = 0
                confidence = torch.max(softmax, dim=1)
                stop_index = (confidence.values > self.confidence_threshold).view(-1)
                result[stop_index] = confidence.indices[stop_index]
                stop_at[stop_index] = output_id
                stopped[stop_index] = True
                output_id += is_output

        output = self.end_layers(fwd)
        softmax = nn.functional.softmax(output, dim=1)
        confidence = torch.max(softmax, dim=1)
        result[~stopped] = confidence.indices[~stopped]
        stop_at[stop_index] = -1
        return result, stop_at


In [128]:
model_params = {}
model_params['task'] = 'cifar10'
model_params['input_size'] = 32
model_params['num_classes'] = 10
model_params['block_type'] = 'basic'
model_params['num_blocks'] = [9,9,9]
model_params['add_ic'] = [[0, 0, 0, 1, 0, 0, 0, 1, 0], [0, 0, 1, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 1, 0, 0, 0]] # 15, 30, 45, 60, 75, 90 percent of GFLOPs
model_params['network_type'] = 'resnet56'
model_params['augment_training'] = True
model_params['init_weights'] = True
model_params['architecture'] = 'sdn'
model_params['base_model'] = 'resnet56'
network_type = model_params['network_type']
model_params['momentum'] = 0.9
model_params['learning_rate'] = 0.1
model_params['epochs'] = 100
model_params['milestones'] = [35, 60, 85]
model_params['gammas'] = [0.1, 0.1, 0.1]
                          
model_params['init_rpf_pannel'] = 8
model_params['use_rpf'] = True

# SDN ic_only training params
model_params['ic_only'] = {}
model_params['ic_only']['learning_rate'] = 0.001 # lr for full network training after sdn modification
model_params['ic_only']['epochs'] = 25
model_params['ic_only']['milestones'] = [15]
model_params['ic_only']['gammas'] = [0.1]
model = ResNet_SDN(model_params).cuda()

In [129]:
model.random_rp_matrix()

In [130]:
batch_shape = [128,3, 32, 32]
test_input = torch.rand(batch_shape).cuda()
model.confidence_threshold = 0.5

In [131]:
model.early_exit(test_input)

(tensor([2, 5, 2, 1, 5, 2, 2, 2, 4, 2, 7, 7, 5, 2, 2, 6, 2, 2, 2, 2, 2, 2, 2, 5,
         7, 2, 5, 5, 6, 2, 4, 2, 2, 2, 2, 2, 2, 2, 7, 2, 2, 2, 2, 2, 2, 7, 2, 2,
         2, 6, 2, 5, 5, 6, 5, 6, 2, 6, 2, 2, 2, 5, 2, 2, 2, 2, 2, 2, 5, 5, 5, 5,
         2, 2, 2, 7, 6, 5, 6, 2, 7, 7, 2, 2, 2, 6, 2, 2, 6, 2, 2, 2, 2, 2, 5, 2,
         7, 5, 7, 7, 2, 2, 2, 2, 2, 6, 2, 2, 6, 2, 6, 2, 6, 2, 5, 2, 5, 6, 2, 2,
         2, 5, 5, 6, 2, 2, 4, 7], device='cuda:0'),
 tensor([0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 2., 0., 0.,
         0., 0., 0., 0., 0., 1., 1., 0., 1., 1., 2., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 2., 0., 1., 1., 2.,
         1., 2., 0., 2., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1.,
         0., 0., 0., 1., 2., 1., 2., 0., 1., 1., 0., 0., 0., 2., 0., 0., 2., 0.,
         0., 0., 0., 0., 1., 0., 1., 1., 1., 1., 0., 0., 0., 0., 0., 2., 0., 0.,
         2., 0., 2., 0., 2., 0., 1., 0., 1., 2., 0., 0., 

In [124]:
next(model.parameters()).device

device(type='cuda', index=0)

In [60]:
a = torch.rand(5,5)
a

tensor([[8.8099e-01, 3.5474e-01, 6.1766e-01, 1.2564e-01, 7.5395e-01],
        [3.3816e-01, 9.5851e-01, 5.5834e-01, 8.6521e-01, 3.3673e-01],
        [6.6410e-01, 1.6172e-01, 9.6324e-01, 8.6379e-01, 1.0377e-01],
        [8.6072e-01, 4.3179e-01, 8.1198e-01, 1.4278e-01, 5.9718e-01],
        [8.1437e-01, 6.6839e-01, 6.2485e-01, 2.5642e-01, 1.3471e-04]])

In [61]:
softmax = F.softmax(a, dim=1)
softmax

tensor([[0.1905, 0.1809, 0.1772, 0.1641, 0.2873],
        [0.1998, 0.1985, 0.2024, 0.1979, 0.2013],
        [0.1666, 0.2718, 0.1974, 0.2413, 0.1229],
        [0.2608, 0.1853, 0.2334, 0.1453, 0.1753],
        [0.2495, 0.2011, 0.2071, 0.2440, 0.0982]])

In [96]:
confidence = torch.max(softmax, dim=1)
stop_index = (confidence.values > 0.25).view(-1)
result = torch.zeros(5)
stop_layer = torch.zeros(5)
stopped = torch.tensor([0]*5, dtype=torch.bool)

In [99]:
result[stop_index] = confidence.indices[stop_index].float()
stop_layer[stop_index] = 1
stopped[stop_index] = True

In [101]:
softmax[stopped] = 0

In [102]:
softmax

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1998, 0.1985, 0.2024, 0.1979, 0.2013],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2495, 0.2011, 0.2071, 0.2440, 0.0982]])

In [84]:
confidence.indices

tensor([4, 2, 1, 0, 0])