In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math

from tensorboardX import SummaryWriter

In [2]:
class InitLayer(nn.Module):
    """The initial layer of the network: one convolution layer followed by one
        batch normalization layer
    """
    def __init__(self, in_channels, out_channels, kernel_size, trs=False, bias=False):
        super(InitLayer, self).__init__()
        padding = (kernel_size - 1) // 2 # keep the same size
        self.input_conv = nn.Sequential(
            nn.Conv2d(in_channels, 
                      out_channels, 
                      kernel_size = kernel_size, 
                      padding = padding, 
                      bias = bias), 
            nn.BatchNorm2d(out_channels, 
                           track_running_stats=trs))
        
    def forward(self, x):
        out = self.input_conv(x)
        return out

In [3]:
# test
conv = InitLayer(3, 6, 3)
dumb_input = torch.zeros([1, 3, 224, 224], dtype=torch.float32)
print(conv(dumb_input).shape)
assert conv(dumb_input).shape == (1, 6, 224, 224)

torch.Size([1, 6, 224, 224])


In [4]:
class FactorizedReduction(nn.Module):
    """Reduces the size of feature map (W and H) by a factor of 2"""
    
    def __init__(self, in_channels, out_channels, trs=False, bias=False):
        super(FactorizedReduction, self).__init__()
        assert out_channels % 2 == 0, "Output channel number must be even :/"
        self.skip_path_1 = nn.Sequential(
            nn.AvgPool2d(1, stride = 2), 
            nn.Conv2d(in_channels, 
                      out_channels // 2, 
                      kernel_size = 1, 
                      bias = bias))
        
        self.skip_path_2 = copy.deepcopy(self.skip_path_1)
        self.padder = nn.ConstantPad2d((0, 1, 0, 1), 0)
        self.bn = nn.BatchNorm2d(out_channels, 
                                 track_running_stats=trs)
        
    def forward(self, x):
        """
        Args:
            x: input feature map with shape [N, C_in, H, W]
            
        Returns:
            out: reudced feature map with shape [N, C_out, H // 2, W // 2]
        """
        out_1 = self.skip_path_1(x) # skip path 1
        out_2 = self.skip_path_2(self.padder(x)[:, :, 1:, 1:]) # skip path 2
        assert out_1.shape == out_2.shape, "Out1's shape {} and out2's shape {} does noe equal :/".format(out_1.shape, out_2.shape)
        out = torch.cat([out_1, out_2], dim=1)
        
        return self.bn(out)

In [5]:
# test
conv = FactorizedReduction(6, 12)
dumb_input = torch.zeros([1, 6, 224, 224], dtype=torch.float32)
print(conv(dumb_input).shape)
assert conv(dumb_input).shape == (1, 12, 112, 112)

torch.Size([1, 12, 112, 112])


In [6]:
class ReluConvBN(nn.Module):
    """A combination of RELU -> CONV -> BATCH NORM.
    """
    def __init__(self, in_channels, out_channels, trs=False):
        super(ReluConvBN, self).__init__()
        
        self.rcb = nn.Sequential(
          nn.ReLU(),
          nn.Conv2d(in_channels = in_channels, 
                    out_channels = out_channels, 
                    kernel_size = 1),
          nn.BatchNorm2d(out_channels, 
                         track_running_stats = trs))
        
    def forward(self, inputs):
        return self.rcb(inputs)

In [7]:
# test
conv = ReluConvBN(6, 12)
dumb_input = torch.zeros([1, 6, 224, 224], dtype=torch.float32)
print(conv(dumb_input).shape)
assert conv(dumb_input).shape == (1, 12, 224, 224)

torch.Size([1, 12, 224, 224])


In [8]:
class IdentityBranch(nn.Module):
    """The identity branch.
    """
    def __init__(self):
        super(IdentityBranch, self).__init__()
        
    def forward(self, x):
        return x

In [9]:
class SeparableConv(nn.Module):
    """Implement the depthwise-separable convolution cell.
    """
    def __init__(self, in_channels, out_channels, kernel_size, bias=True):
        super(SeparableConv, self).__init__()
        
        padding = (kernel_size - 1) // 2 # keep the size unchanged
        self.depthwise = nn.Conv2d(in_channels, in_channels, 
                                   kernel_size=kernel_size, 
                                   padding=padding, 
                                   groups=in_channels, 
                                   bias=bias)
        self.pointwise = nn.Conv2d(in_channels, out_channels, 
                                   kernel_size=1, bias=bias)

    def forward(self, x):
        """
        Args:
            x: [N, C_in, H, W]
        
        Return:
            out: [N, C_out, H, W]
        """
        out = self.pointwise(self.depthwise(x))
        return out

In [10]:
class ENASCell(nn.Module):
    """Implement one ENAS cell (or node), each cell can have 5 different operations:
       avg_pool, max_pool, 3*3 conv, 5*5 conv, identity.
    """
    def __init__(self, in_channels, out_channels, node_id):
        super(ENASCell, self).__init__()
        
        self.node_id = node_id
        in_c, out_c = in_channels, out_channels
        self.choices = nn.ModuleDict({
                'conv3': nn.ModuleList([SeparableConv(in_c, out_c, 3)
                                       for i in range(node_id)]), 
                'conv5': nn.ModuleList([SeparableConv(in_c, out_c, 5)
                                       for i in range(node_id)]), 
                'avg_pool': nn.AvgPool2d(3, padding=1), 
                'max_pool': nn.MaxPool2d(3, padding=1), 
                'identity': IdentityBranch()
        })
        
    def forward(self, x, prev_cell, op_id):
        """
        Args:
            x: input from previous cell.
            prev_cell: integer, the previous cell's ID.
            op_id: integer, indicate which operation to use.
        """
        assert 0 <= op_id <= 4, "Operation ID out of range!"
        assert prev_cell < self.node_id, "Previous cell ID out of range :/"
         
        out = {
          0: lambda x: self.choices['conv3'][prev_cell](x), 
          1: lambda x: self.choices['conv5'][prev_cell](x), 
          2: lambda x: self.choices['avg_pool'](x), 
          3: lambda x: self.choices['max_pool'](x), 
          4: lambda x: self.choices['identity'](x)
        }[op_id](x)
        
        return out

In [11]:
cell = ENASCell(6, 12, 3)
# for name, param in cell.named_parameters():
#     print("Name: {}; Parameter size: {}".format(name, param.size()))
dumb_input = torch.zeros([1, 6, 224, 224], dtype=torch.float32)
output = cell(dumb_input, 0, 1)
print(output.shape)
assert output.shape == (1, 12, 224, 224)

torch.Size([1, 12, 224, 224])


In [12]:
class ENASCellFixed(nn.Module):
    """Implement one ENAS cell (or node), each cell can have 5 different operations:
       avg_pool, max_pool, 3*3 conv, 5*5 conv, identity.
    """
    def __init__(self, in_channels, out_channels, node_id):
        super(ENASCellFixed, self).__init__()
        
        in_c, out_c = in_channels, out_channels
        self.choices = nn.ModuleDict({
                'conv3': SeparableConv(in_c, out_c, 3), 
                'conv5': SeparableConv(in_c, out_c, 5), 
                'avg_pool': nn.AvgPool2d(3, padding=1), 
                'max_pool': nn.MaxPool2d(3, padding=1), 
                'identity': IdentityBranch()
        })
        
    def forward(self, x, prev_cell, op_id):
        """
        Args:
            x: input from previous cell.
            prev_cell: integer, the previous cell's ID.
            op_id: integer, indicate which operation to use.
        """
        assert 0 <= op_id <= 4, "Operation ID out of range!"
        
        out = {
          0: lambda x: self.choices['conv3'](x), 
          1: lambda x: self.choices['conv5'](x), 
          2: lambda x: self.choices['avg_pool'](x), 
          3: lambda x: self.choices['max_pool'](x), 
          4: lambda x: self.choices['identity'](x)
        }[op_id](x)
        
        return out

In [13]:
cell = ENASCellFixed(6, 12, 3)
# for name, param in cell.named_parameters():
#     print("Name: {}; Parameter size: {}".format(name, param.size()))
dumb_input = torch.zeros([1, 6, 224, 224], dtype=torch.float32)
output = cell(dumb_input, 0, 1)
print(output.shape)
assert output.shape == (1, 12, 224, 224)

torch.Size([1, 12, 224, 224])


In [14]:
class ENASLayer(nn.Module):
    """Implement ENAS layer class, each layer composes B nodes: 2 input nodes and 
       (B - 2) operation nodes. Parameters in one ENAS layers are shared.
       
       One EnasLayer is equivalent to the Convolution or Reduction Cell in the
       paper.
       
       The two input nodes have ID 0 and 1, the first operation node has ID 2 and 
       so on.
    """
    def __init__(self, node_num, out_channels, fixed=False):
        super(ENASLayer, self).__init__()
        self.node_num = node_num
        self.out_channels = out_channels
        
        self.frs = nn.ModuleList([FactorizedReduction(out_channels // 2, out_channels)
                                    for i in range(2)])
        self.final_rcb = ReluConvBN(out_channels * (node_num + 2), out_channels)
        
        # build cells (nodes), node_num does NOT include two input nodes
        Cell = ENASCellFixed if fixed else ENASCell
        self.nodes = nn.ModuleList()
        for cell_id in range(2, node_num + 2):
            self.nodes.append(Cell(out_channels, out_channels, cell_id))
        
        
    def forward(self, inputs, arc):
        """Forward two inputs through one ENAS layer, the unused inputs are concatenated.
        
        Args:
            inputs: (h[i], h[i-1]), 
            arc: list of integers, representing the architecture of the cell.
        """
        assert len(arc) == self.node_num * 4, "Oops, the length of arc is {}, which should be a multiple of 4 ({}).".format(len(arc), self.node_num * 4)
        assert len(inputs) == 2, "Require exactly 2 inputs."
        
        node_inps = []
        node_inps.extend(self._celibrate_size(inputs)) # assure two inputs have same shape
        
        for i in range(self.node_num):
            # the first operation
            x_id, x_op = arc[4 * i], arc[4 * i + 2]
            x = node_inps[x_id]
            x_out = self.nodes[i](x, x_id, x_op)
            
            # the second operation
            y_id, y_op = arc[4 * i + 1], arc[4 * i + 3]
            y = node_inps[y_id]
            y_out = self.nodes[i](y, y_id, y_op)
            
            # add two op's outputs
            out = x_out + y_out
            node_inps.append(out)
            
        # concatenate all outputs and project
        # NOTE: in the paper this is done for all unused nodes, here we make it sample
        final_output = self.final_rcb(torch.cat(node_inps, dim=1))
        assert final_output.shape == node_inps[0].shape, "Oops, seems like the final output shape is wrong: {}. Should be equal to {}.".format(final_output.shape, node_inps[0].shape)
        
        return final_output
        
        
    def _celibrate_size(self, inputs):
        """Because of the reduction cell, the second input might have half WH size 
           and double depth size. This function is to make sure two inputs have the 
           same W and H, and the depth equals to out_channels.
        """
        outs = []
        for i, inp in enumerate(inputs):
            if self._get_C(inp) != self.out_channels:
                outs.append(self.frs[i](inp))
            else:
                outs.append(inp)
                
        assert outs[0].shape == outs[1].shape
        return outs
    
    
    def _get_C(self, x):
        """Get channel size of a given feature map.
        """
        return x.shape[1]
        
        
    def _get_HW(self, x):
        """Get H and W of a given feature map.
        """
        return x.shape[-2], x.shape[-1]

In [15]:
layer = ENASLayer(2, 12, False)
# for name, param in layer.named_parameters():
#     print("Name: {}; Parameter size: {}".format(name, param.size()))

dumb_input1 = torch.zeros([1, 6, 64, 64], dtype=torch.float32)
dumb_input2 = torch.zeros([1, 12, 32, 32], dtype=torch.float32)
arc = [1, 1, 1, 4, 2, 0, 0, 0]
output = layer([dumb_input1, dumb_input2], arc)

print(output.shape)
assert output.shape == (1, 12, 32, 32)

torch.Size([1, 12, 32, 32])


In [16]:
# class GlobalAvgPool(nn.Module)

In [17]:
class AuxHeadLayer(nn.Module):
    """Auxiliary head for micro child training."""
    
    def __init__(self, in_channels, side_length, trs=False, layer_sizes=[128, 768, 10]):
        super(AuxHeadLayer, self).__init__()
        assert len(layer_sizes) == 3, "Should have exactly 3 layers in the auxiliary head."
        
        self.side_len = math.floor((side_length - 5) / 3 + 1)
        self.aux = nn.Sequential(
            nn.ReLU(),
            nn.AvgPool2d(5, stride=3), 
            ReluConvBN(in_channels, layer_sizes[0]), 
            nn.Conv2d(layer_sizes[0], layer_sizes[1], self.side_len), 
            nn.ReLU(), 
        )
        self.proj = nn.Linear(layer_sizes[1], layer_sizes[-1])
        
    def forward(self, x):
        out = self.aux(x)
        logits = self.proj(torch.squeeze(torch.squeeze(out, dim=-1), dim=-1))
        return logits

In [18]:
layer = AuxHeadLayer(80, 8)
# for name, param in layer.named_parameters():
#     print("Name: {}; Parameter size: {}".format(name, param.size()))

dumb_input1 = torch.zeros([1, 80, 8, 8], dtype=torch.float32)
output = layer(dumb_input1)

print(output.shape)
assert output.shape == (1, 10)

torch.Size([1, 10])


In [19]:
class GlobalAvgPool(nn.Module):
    """Average all points of an image at each channel."""
    
    def __init__(self):
        super(GlobalAvgPool, self).__init__()
        
    def forward(self, x):
        """
        Args:
            x: [N, C, H, W]
        """
        return torch.mean(torch.mean(x, -1), -1)

In [20]:
layer = GlobalAvgPool()

dumb_input1 = torch.zeros([1, 80, 8, 8], dtype=torch.float32)
output = layer(dumb_input1)

print(output.shape)
assert output.shape == (1, 80)

torch.Size([1, 80])


In [21]:
class MicroChild(nn.Module):
    """A shared CNN graph"""
    
    def __init__(self, config):
        super(MicroChild, self).__init__()
        self.config = config
        self.init_layer = InitLayer(3, config.out_channels * 3, 3)
        self.rcbs = nn.ModuleList([ReluConvBN(config.out_channels * 3, 
                                              config.out_channels) 
                                       for i in range(2)])
        self.pool_layers_indices = self._specify_pool_layers()
        self.layers = self._build_enas_layers()
        if config.use_aux_heads:
            self.aux, self.aux_head_indices = self._build_aux_heads()
        self.proj = self._build_proj()
        
        
    def _specify_pool_layers(self):
        """Specify which layers are pool layers (with reduction cell).
        """
        pool_distance = self.config.num_layers // 3
        return [pool_distance, pool_distance * 2 + 1]
        
        
    def _build_enas_layers(self):
        """Build ENAS layers. In every pool layer, the channels are doubled."""
        node_num, out_c = self.config.node_num, self.config.out_channels
        
        layers = nn.ModuleList()
        for i in range(config.num_layers):
            if i in self.pool_layers_indices:
                out_c = out_c * 2
            layers.append(ENASLayer(node_num, out_c, self.config.fixed))
        return layers
    
    
    def _build_aux_heads(self):
        """Build auxiliary head for training."""
        pool_layer_num = len(self.pool_layers_indices)
        channels = self.config.out_channels * pool_layer_num**2
        side_length = self.config.image_size // pool_layer_num**2
        
        aux = AuxHeadLayer(channels, side_length)
        aux_head_indices = [self.pool_layers_indices[-1] + 1]
        
        return aux, aux_head_indices

        
    def _build_proj(self):
        """The final projection layer for logits compution"""
        channels = self.config.out_channels * len(self.pool_layers_indices)**2
        proj = nn.Sequential(
            nn.ReLU(), 
            GlobalAvgPool(),
            nn.Linear(channels, self.config.class_num)
        )
        return proj
        
        
    def forward(self, images, arcs):
        """Compute the logits given images.
        
        Args:
            images: input images, [N, C_in, H, W], N is batch size.
            arcs: a tuple of two lists that contain integers represents the architecture 
                of a normal cell and a reduce cell, four integers together in the list as
                a node: (index_1, index_2, op_1, op_2).
        """
        normal_arc, reduce_arc = arcs
        x = self.init_layer(images)
        
        # NOTE: here the implementation is a litte different from Melody's
        inputs = []
        for i in range(len(self.rcbs)):
            inputs.append(self.rcbs[i](x))
        
        # ENAS layers
        aux_logits = None
        for layer_id in range(self.config.num_layers):
            x = self.layers[layer_id](inputs, normal_arc)
            inputs = [inputs[-1], x]
            
            if self.config.use_aux_heads and layer_id in self.aux_head_indices \
                and self.training:
                aux_logits = self.aux(x) # auxiliary head
        
        logits = self.proj(x)
        
        return logits, aux_logits

In [22]:
class Config:
    out_channels = 20
    num_layers = 15
    node_num = 6
    class_num = 10
    image_size = 32
    use_aux_heads = True
    fixed = True

In [23]:
config = Config()
module = MicroChild(config)

fake_iamge = torch.zeros([1, 3, 32, 32], dtype=torch.float32)
arc = [1, 1, 1, 4, 2, 0, 0, 0] * 3
logits, aux_logits = module(fake_iamge, (arc, arc))

print("Logits shape: {}".format(logits.shape))
print("Auxiliary logits shape: {}".format(aux_logits.shape))
assert logits.shape == (1, 10)
assert aux_logits.shape == (1, 10)

Logits shape: torch.Size([1, 10])
Auxiliary logits shape: torch.Size([1, 10])


In [24]:
class CELossWithAuxHead(nn.Module):
    """Module for computing cross entropy loss with possible auxiliary loss."""
    
    def __init__(self):
        super(CELossWithAuxHead, self).__init__()
        self.criterion = nn.CrossEntropyLoss()
        
        
    def forward(self, logits, target, aux_logits=None):
        """
        Args:
            logits: [batch_size, class_um]
            target: [batch_size]
            aux_logits: [batch_size, class_um]
        """
        loss = self.criterion(logits, target)
        
        if aux_logits is not None:
            aux_loss = self.criterion(aux_logits, target)
            loss = loss + aux_loss * 0.4
        
        return loss

In [25]:
class ChildModel:
    """The class for child model training, validating and testing."""
    
    def __init__(self, config, device, write_summary=True):
        """Initialize model.
        """
        self.config = config
        self.logger = self.config.logger
        
        # find device
        self.device = device

        # build and initialize model
        self.logger.info("- Building and initializing model...")
        self.model = self._build_model(config).to(device)
        self._initialize_model(self.model)
        
        # create optimizer and criterion
        self.logger.info("- Creating optimizer and criterion...")
        self.optimizer = self._get_optimizer(config, self.model)
        self.criterion = self._get_criterion(config).to(device)

        # create summary for tensorboard visualization
        if write_summary:
            self.writer = SummaryWriter(self.config.path_summary)
        else:
            self.writer = None        
        
        self.arcs = None # architecture for normal and reduction cell
            
        
    def _build_model(self, config):
        """Build a model.
        """
        return MicroChild(config)
    
        
    def _initialize_model(self, model):
        """Model initialization.
        """
        for p in model.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        return model
    
    
    def _get_optimizer(self, config, model):
        """Get optimizer.
        """
        return torch.optim.Adam(model.parameters(), lr=0, 
                                betas=(0.9, 0.98), eps=1e-9)
    
    
    def _get_criterion(self, config):
        """No need explaintion. 
        """
        return CELossWithAuxHead()
    
    
    def load_weights(self, path):
        """Load pre-trained weights.
        """
        self.model.load_state_dict(torch.load(path))
        
        
    def set_arc(self, arcs):
        """Set architectures. Must do this before calling the forward method.
        
        Args:
            arcs: tuple of two lists that contain integers represents the 
                architecture of a normal cell and a reduce cell.
        """
        self.arcs = arcs
        
        
    def loss_batch(self, loss_func, outputs, target, norm=None, optimizer=None):
        """Compute loss and update model weights on a batch of data.

        Args:
            outputs: [batch_size, class_num]
            target: [batch_size]
        """
        loss = loss_func(outputs, target)
        if norm is not None:
            loss /= norm
        
        if optimizer is not None:
            with torch.set_grad_enabled(True):
                loss.backward() # compute gradients
                optimizer.step() # update weights
                optimizer.zero_grad()
                
        return loss.item()
    
    
    def train_epoch(self, model, dataset, criterion, optimizer, epoch):
        """Train the model for one single epoch.
        """
        model.train() # set the model to train mode
        prog = Progbar(target=len(dataset)) # progress bar for visualization
        
        train_loss = 0.0
        for i, (images, labels) in enumerate(dataset):
            if self.arcs is None:
                raise Exception("Did you forget to set model arcs?")
            outputs = model(images, self.arcs)
            
            # compute loss and update model parameters on a batch of data
            batch_loss = self.loss_batch(criterion, outputs, labels, optimizer=optimizer)
            train_loss += batch_loss
            prog.update(i + 1, [("batch loss", batch_loss)])
            
            if self.writer is not None: # write summary to tensorboard
                self.writer.add_scalar('batch_loss', batch_loss, epoch*len(dataset) + i + 1)
            
                # draw the diagram for the first batch of data
                # if i == 0 and epoch == 0:
                #     self.writer.add_graph(self.model, (en_input, en_mask, de_input, de_mask), verbose=False)

        # compute the average loss
        epoch_loss = train_loss / len(dataset)
        return epoch_loss
    
    
    def evaluate(self, model, dataset, criterion):
        """Evaluate the model, return average loss and accuracy.
        """
        model.eval()
        with torch.no_grad():
            eval_loss, eval_corrects = 0., 0.
            for i, (images, labels) in enumerate(dataset):
                if self.arcs is None:
                    raise Exception("Did you forget to set model arcs?")
                outputs = model(images, self.arcs) # logits: [N, class_num]
                
                # compute loss and update model parameters on a batch of data
                batch_loss = self.loss_batch(criterion, outputs, labels, optimizer=None)
                eval_loss += batch_loss
                
                pred_labels = torch.argmax(ouputs, dim=-1)

                assert pred_labels.shape == labels.shape, "Predition output shape {} and actual labels shape {} does Not match.".format(pred_labels.shape, labels.shape)
                matching = (pred_labels == target)
                eval_corrects += torch.sum(matching).double()

            avg_loss = eval_loss / len(dataset)
            avg_acc  = eval_corrects / len(dataset)

        return avg_loss, avg_acc
    
    
    def fit(self, train_set, development_set, samples=None):
        """Model training.
        """
        num_epochs = self.config.num_epochs
        best_acc = 0.

        for epoch in range(num_epochs):
            self.logger.info('Epoch {}/{}'.format(epoch, num_epochs - 1))
            # print('-' * 10)
            # train
            train_loss = self.train_epoch(self.model, train_set, self.criterion, self.optimizer, epoch)
            self.logger.info("Traing Loss: {}".format(train_loss))

            # eval
            eval_loss, eval_acc = self.evaluate(self.model, development_set, self.criterion)
            self.logger.info("Evaluation:")
            self.logger.info("- loss: {}".format(eval_loss))
            self.logger.info("- acc: {}".format(eval_acc))

            # print samples
            decode_seq = self._greedy_decode(self.model, samples[0], samples[1], 
                    self.config.max_steps, self.config.start_symbol)
            for i, (s, r) in enumerate(zip(samples[0], decode_seq)):
                print("# ", i+1)
                print('- samples: {}'.format(s.tolist()))
                print('- results: {}'.format(r.tolist()))

            # monitor loss and accuracy
            if self.writer is not None:
                self.writer.add_scalar('epoch_loss', train_loss, epoch)
                self.writer.add_scalar('eval_loss', eval_loss, epoch)
                self.writer.add_scalar('eval_acc', eval_acc, epoch)

            # save the model
            if eval_acc >= best_acc:
                best_acc = eval_acc
                self.logger.info("New best score!")
                torch.save(self.model.state_dict(), self.config.dir_model + "model.pickle")
                self.logger.info("model is saved at: {}".format(self.config.dir_model))

    
    def predict(self, inputs):
        """Prediction.
        
        Return:
            outputs: logits [N, class_num]
            pred_labels: [N]
        """
        self.model.eval()
        with torch.no_grad():
            if self.arcs is None:
                raise Exception("Did you forget to set model arcs?")
            outputs = self.model(inputs, self.arcs)    # outputs: [N, class_num]
            pred_labels = torch.argmax(outputs, dim=-1) # [N]
            
        return outputs, pred_labels
    
    
    def test(self, dataset):
        """Test the model and print out a report.
        """
        self.model.eval() 
        with torch.no_grad():
            total_samples, corrects = 0, 0
            pred_class, label_class = [], []
            for images, labels in dataset:
                _, pred_labels = self.predict(images)

                corrects += torch.sum(labels == pred_labels).double()
                total_samples += labels.shape[0]
                
                for p, l in zip(pred_labels, labels):
                    pred_class.append(p)
                    label_class.append(l)
 
            accuracy = corrects / total_samples
            self.logger.info('\n')
            self.logger.info('Accuracy: {:.3}\n\n'.format(accuracy))
            self.logger.info(classification_report(label_class, pred_class))
        
        return label_class, pred_class

In [26]:
logits = torch.ones([10, 10], dtype=torch.float32)
target = torch.ones([10], dtype=torch.int64)
aux_logits = torch.zeros([10, 10], dtype=torch.int64)

criterion = nn.CrossEntropyLoss()
loss = criterion(logits, target)
print(loss)
print(loss.shape)

tensor(2.3026)
torch.Size([])
