# Lenet classification in Mindspore for a custom dataset

In this notebook we will implement a classification metodology using the mindspore framework.

1. Define the required dataset. We will use a custom dataset.

2. Define a network. The LeNet network is used in this example.

3. Define the loss function and optimizer.

4. Load dataset, perform training. After the training is complete, check the result and save the model file.

5. Load the saved model for inference.

6. Validate the model, load the test dataset and trained model, and validate the result accuracy.

In [1]:
#import libraries
import os
import mindspore
import argparse
import mindspore.dataset as ds
import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train import Model
from mindspore.common.initializer import TruncatedNormal
import mindspore.dataset.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.vision import Inter
from mindspore.nn.metrics import Accuracy
from mindspore.common import dtype as mstype
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits

## Define a dataset

Mindspore implements *ImageFolderDataset*. A source dataset that reads images from a tree of directories. All images within one folder have the same label.

The generated dataset has two columns: [image, label]. The tensor of column image is of the uint8 type. The tensor of column label is of a scalar of uint32 type.

.
└── image_folder_dataset_directory

     ├── class1
     │    ├── 000000000001.jpg
     │    ├── 000000000002.jpg
     │    ├── ...
     ├── class2
     │    ├── 000000000001.jpg
     │    ├── 000000000002.jpg
     │    ├── ...
     ├── class3
     │    ├── 000000000001.jpg
     │    ├── 000000000002.jpg
     │    ├── ...
     ├── classN
     ├── ...



In [2]:
dataset_path = "D:\Datasets\cactus_dataset"

Defining the Dataset and Data Operations
Define the create_dataset() function to create a dataset. In this function, define the data augmentation and processing operations to be performed.

Define the dataset.

Define parameters required for data augmentation and processing.

Generate corresponding data augmentation operations according to the parameters.

Use the map() mapping function to apply data operations to the dataset.

Process the generated dataset.

In [3]:
def create_dataset(data_path, batch_size=32, repeat_size=1,
                   num_parallel_workers=1):
    """ create dataset for train or test
    Args:
        data_path: Data path
        batch_size: The number of data records in each group
        repeat_size: The number of replicated data records
        num_parallel_workers: The number of parallel workers
    """
    # define dataset
    cactus_ds = ds.ImageFolderDataset(data_path, num_parallel_workers=num_parallel_workers, decode=True)

    # define operation parameters
    resize_height, resize_width = 32, 32
    rescale = 1.0 / 255.0
    shift = 0.0
    rescale_nml = 1 / 0.3081
    shift_nml = -1 * 0.1307 / 0.3081

    # define map operations 
    # mindspore transformations
    resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR)  # Resize images to (32, 32)
    rescale_nml_op = CV.Rescale(rescale_nml, shift_nml) # normalize images
    rescale_op = CV.Rescale(rescale, shift) # rescale images
    hwc2chw_op = CV.HWC2CHW() # change shape from (height, width, channel) to (channel, height, width) to fit network.
    type_cast_op = C.TypeCast(mstype.int32) # change data type of label to int32 to fit network
    
    # apply map operations on images
    cactus_ds = cactus_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers)
    cactus_ds = cactus_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers)
    cactus_ds = cactus_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers)
    cactus_ds = cactus_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)
    cactus_ds = cactus_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)
    
    # apply DatasetOps
    buffer_size = 1000
    cactus_ds = cactus_ds.shuffle(buffer_size=buffer_size)
    cactus_ds = cactus_ds.batch(batch_size, drop_remainder=True)
    cactus_ds = cactus_ds.repeat(repeat_size)
    
    #NOTE, there is no "Tensor" definitions
    return cactus_ds

## Implement neural network

We will implement from scratch a Lenet5 architecture.

![Lenet](./lenet5_cactus.PNG)

Figure from: López-Jiménez, E., Vasquez-Gomez, J. I., Sanchez-Acevedo, M. A., Herrera-Lozada, J. C., & Uriarte-Arcia, A. V. (2019). Columnar cactus recognition in aerial images using a deep learning approach. Ecological Informatics, 52, 131-138.

In [4]:
# Initialize variable
def weight_variable():
    """
    weight initial
    """
    return TruncatedNormal(0.02)

# Simplify the definition of the convolutions
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
    """
    conv layer weight initial
    """
    weight = weight_variable()
    return nn.Conv2d(in_channels, out_channels,
                     kernel_size=kernel_size, stride=stride, padding=padding,
                     weight_init=weight, has_bias=False, pad_mode="valid")

# Define a fully connected layer
def fc_with_initialize(input_channels, out_channels):
    """
    fc layer weight initial
    """
    weight = weight_variable()
    bias = weight_variable()
    return nn.Dense(input_channels, out_channels, weight, bias)

In [5]:
import mindspore.ops.operations as P

# Implementation of Lenet5
class LeNet5(nn.Cell):
    # nn.Cell. Base class for all neural networks.
    """
    Lenet network structure
    """
    #define the operators required
    # output parameter set the number of classes
    def __init__(self, output):
        super(LeNet5, self).__init__()
        self.conv1 = conv(3, 6, 5)
        self.conv2 = conv(6, 16, 5)
        self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
        self.fc2 = fc_with_initialize(120, 84)
        self.fc3 = fc_with_initialize(84, output)
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.reshape = P.Reshape()
 
    #use the preceding operators to construct networks
    # Defines the computation to be performed.
    def construct(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        batch_size, channels, _, _ = x.shape
        x = self.reshape(x, (batch_size, -1))
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

In [6]:
def train_net(args, model, epoch_size, path, repeat_size, ckpoint_cb, sink_mode, batch_size = 32):
    """Define the training method."""
    print("============== Starting Training ==============")
    # load training dataset
    ds_train = create_dataset(os.path.join(path, "train"), batch_size, repeat_size)
    model.train(epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=sink_mode)

In [7]:
def test_net(args, network, model, path):
    """Define the evaluation method."""
    print("============== Starting Testing ==============")
    # load the saved model for evaluation
    #param_dict = load_checkpoint("checkpoint_lenet-1_1875.ckpt")
    # load parameter to the network
    #load_param_into_net(network, param_dict)
    # load testing dataset
    ds_eval = create_dataset(os.path.join(path, "test"))
    acc = model.eval(ds_eval, dataset_sink_mode=False)
    print("============== Accuracy:{} ==============".format(acc))

In [8]:
#cadena = "python test_mindspore --device_target"
local_device = "CPU"
args = None
context.set_context(mode=context.GRAPH_MODE, device_target=local_device)
dataset_sink_mode = not local_device == "CPU"

#define the loss function
net_loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

# Hyperparameters
lr = 0.01
momentum = 0.9
epoch_size = 10
repeat_size = epoch_size
output = 2
batch_size = 1024

#create the network
network = LeNet5(output)

In [9]:
#define the optimizer
net_opt = nn.Momentum(network.trainable_params(), lr, momentum) 

# set parameters of check point
config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10) 

# apply parameters of check point
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) 

# group layers into an object with training and evaluation features
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})

# Perform training
train_net(args, model, epoch_size, dataset_path, repeat_size, ckpoint_cb, dataset_sink_mode, batch_size = batch_size)
test_net(args, network, model, dataset_path)

epoch: 1 step: 1, loss is 0.6952155232429504
epoch: 1 step: 2, loss is 0.69403076171875
epoch: 1 step: 3, loss is 0.691548228263855
epoch: 1 step: 4, loss is 0.6879672408103943
epoch: 1 step: 5, loss is 0.6832858920097351
epoch: 1 step: 6, loss is 0.6803513169288635
epoch: 1 step: 7, loss is 0.6727116703987122
epoch: 1 step: 8, loss is 0.6700455546379089
epoch: 1 step: 9, loss is 0.6610921621322632
epoch: 1 step: 10, loss is 0.6517544984817505
epoch: 1 step: 11, loss is 0.6556546688079834
epoch: 1 step: 12, loss is 0.6363601088523865
epoch: 1 step: 13, loss is 0.6376200914382935
epoch: 1 step: 14, loss is 0.6326032876968384
epoch: 1 step: 15, loss is 0.6216421127319336
epoch: 1 step: 16, loss is 0.6196250915527344
epoch: 1 step: 17, loss is 0.6113178133964539
epoch: 1 step: 18, loss is 0.619862973690033
epoch: 1 step: 19, loss is 0.6075713038444519
epoch: 1 step: 20, loss is 0.59392249584198
epoch: 1 step: 21, loss is 0.5900442004203796
epoch: 1 step: 22, loss is 0.5892221927642822
epo

epoch: 2 step: 8, loss is 0.5408521294593811
epoch: 2 step: 9, loss is 0.5676712393760681
epoch: 2 step: 10, loss is 0.5698236227035522
epoch: 2 step: 11, loss is 0.5558693408966064
epoch: 2 step: 12, loss is 0.5666053891181946
epoch: 2 step: 13, loss is 0.5451393723487854
epoch: 2 step: 14, loss is 0.5590954422950745
epoch: 2 step: 15, loss is 0.5644586682319641
epoch: 2 step: 16, loss is 0.5569354295730591
epoch: 2 step: 17, loss is 0.5676795244216919
epoch: 2 step: 18, loss is 0.5483378767967224
epoch: 2 step: 19, loss is 0.5666103959083557
epoch: 2 step: 20, loss is 0.5666069984436035
epoch: 2 step: 21, loss is 0.5526309609413147
epoch: 2 step: 22, loss is 0.5666069984436035
epoch: 2 step: 23, loss is 0.5719950795173645
epoch: 2 step: 24, loss is 0.5698391795158386
epoch: 2 step: 25, loss is 0.5666140913963318
epoch: 2 step: 26, loss is 0.5472543239593506
epoch: 2 step: 27, loss is 0.5870459675788879
epoch: 2 step: 28, loss is 0.5676806569099426
epoch: 2 step: 29, loss is 0.5537081

epoch: 3 step: 16, loss is 0.5611999034881592
epoch: 3 step: 17, loss is 0.5644224286079407
epoch: 3 step: 18, loss is 0.5429418683052063
epoch: 3 step: 19, loss is 0.5719524025917053
epoch: 3 step: 20, loss is 0.5429210662841797
epoch: 3 step: 21, loss is 0.5741102695465088
epoch: 3 step: 22, loss is 0.550438404083252
epoch: 3 step: 23, loss is 0.5859699249267578
epoch: 3 step: 24, loss is 0.5611977577209473
epoch: 3 step: 25, loss is 0.5310482382774353
epoch: 3 step: 26, loss is 0.5676606297492981
epoch: 3 step: 27, loss is 0.5428804159164429
epoch: 3 step: 28, loss is 0.5644246935844421
epoch: 3 step: 29, loss is 0.5720024704933167
epoch: 3 step: 30, loss is 0.562275767326355
epoch: 3 step: 31, loss is 0.5849499106407166
epoch: 3 step: 32, loss is 0.5514843463897705
epoch: 3 step: 33, loss is 0.5547221302986145
epoch: 3 step: 34, loss is 0.5871152877807617
epoch: 3 step: 35, loss is 0.55256187915802
epoch: 3 step: 36, loss is 0.5450081825256348
epoch: 3 step: 37, loss is 0.587097406

epoch: 4 step: 24, loss is 0.5611310601234436
epoch: 4 step: 25, loss is 0.5632880330085754
epoch: 4 step: 26, loss is 0.542823851108551
epoch: 4 step: 27, loss is 0.5772746205329895
epoch: 4 step: 28, loss is 0.5622066855430603
epoch: 4 step: 29, loss is 0.5589843392372131
epoch: 4 step: 30, loss is 0.5643534064292908
epoch: 4 step: 31, loss is 0.5482044219970703
epoch: 4 step: 32, loss is 0.5546740293502808
epoch: 4 step: 33, loss is 0.5514424443244934
epoch: 4 step: 34, loss is 0.5718919634819031
epoch: 4 step: 35, loss is 0.5428271889686584
epoch: 4 step: 36, loss is 0.5600424408912659
epoch: 4 step: 37, loss is 0.5643560290336609
epoch: 4 step: 38, loss is 0.5557403564453125
epoch: 4 step: 39, loss is 0.5654335021972656
epoch: 4 step: 40, loss is 0.558964729309082
epoch: 4 step: 41, loss is 0.5632789731025696
epoch: 4 step: 42, loss is 0.5524916052818298
epoch: 4 step: 43, loss is 0.5427651405334473
epoch: 4 step: 44, loss is 0.5546456575393677
epoch: 4 step: 45, loss is 0.5740681

epoch: 5 step: 32, loss is 0.5383779406547546
epoch: 5 step: 33, loss is 0.5566104054450989
epoch: 5 step: 34, loss is 0.5362192988395691
epoch: 5 step: 35, loss is 0.5544441938400269
epoch: 5 step: 36, loss is 0.5769835114479065
epoch: 5 step: 37, loss is 0.5737611651420593
epoch: 5 step: 38, loss is 0.5587256550788879
epoch: 5 step: 39, loss is 0.5705738067626953
epoch: 5 step: 40, loss is 0.5694553852081299
epoch: 5 step: 41, loss is 0.5748088359832764
epoch: 5 step: 42, loss is 0.5544262528419495
epoch: 5 step: 43, loss is 0.552304744720459
epoch: 5 step: 44, loss is 0.5780379772186279
epoch: 5 step: 45, loss is 0.5426087379455566
epoch: 5 step: 46, loss is 0.5372704863548279
epoch: 5 step: 47, loss is 0.5308413505554199
epoch: 5 step: 48, loss is 0.5876771807670593
epoch: 5 step: 49, loss is 0.5619386434555054
epoch: 5 step: 50, loss is 0.5468803644180298
epoch: 5 step: 51, loss is 0.567294180393219
epoch: 5 step: 52, loss is 0.5715814828872681
epoch: 5 step: 53, loss is 0.5618883

epoch: 6 step: 40, loss is 0.5323300361633301
epoch: 6 step: 41, loss is 0.5593745708465576
epoch: 6 step: 42, loss is 0.5453094840049744
epoch: 6 step: 43, loss is 0.5884336233139038
epoch: 6 step: 44, loss is 0.5528537034988403
epoch: 6 step: 45, loss is 0.5658349394798279
epoch: 6 step: 46, loss is 0.5646947622299194
epoch: 6 step: 47, loss is 0.5775450468063354
epoch: 6 step: 48, loss is 0.5463544130325317
epoch: 6 step: 49, loss is 0.5591160655021667
epoch: 6 step: 50, loss is 0.5537575483322144
epoch: 6 step: 51, loss is 0.5558692216873169
epoch: 6 step: 52, loss is 0.5483181476593018
epoch: 6 step: 53, loss is 0.5482674837112427
epoch: 6 step: 54, loss is 0.5707305073738098
epoch: 6 step: 55, loss is 0.5460650324821472
epoch: 6 step: 56, loss is 0.5845777988433838
epoch: 6 step: 57, loss is 0.563330352306366
epoch: 6 step: 58, loss is 0.5514127016067505
epoch: 6 step: 59, loss is 0.565207302570343
epoch: 6 step: 60, loss is 0.5619938969612122
epoch: 6 step: 61, loss is 0.5920215

epoch: 7 step: 47, loss is 0.6164376735687256
epoch: 7 step: 48, loss is 0.6000385880470276
epoch: 7 step: 49, loss is 0.6022570133209229
epoch: 7 step: 50, loss is 0.6050218939781189
epoch: 7 step: 51, loss is 0.5940548777580261
epoch: 7 step: 52, loss is 0.6054550409317017
epoch: 7 step: 53, loss is 0.5991603136062622
epoch: 7 step: 54, loss is 0.5898312926292419
epoch: 7 step: 55, loss is 0.5889089703559875
epoch: 7 step: 56, loss is 0.580732524394989
epoch: 7 step: 57, loss is 0.5918401479721069
epoch: 7 step: 58, loss is 0.5793980956077576
epoch: 7 step: 59, loss is 0.5669751763343811
epoch: 7 step: 60, loss is 0.5849604606628418
epoch: 7 step: 61, loss is 0.5819002389907837
epoch: 7 step: 62, loss is 0.5761282444000244
epoch: 7 step: 63, loss is 0.5616493821144104
epoch: 7 step: 64, loss is 0.5812408328056335
epoch: 7 step: 65, loss is 0.5583826303482056
epoch: 7 step: 66, loss is 0.5705623030662537
epoch: 7 step: 67, loss is 0.589332103729248
epoch: 7 step: 68, loss is 0.5744246

epoch: 8 step: 55, loss is 0.5537384748458862
epoch: 8 step: 56, loss is 0.5483616590499878
epoch: 8 step: 57, loss is 0.5634079575538635
epoch: 8 step: 58, loss is 0.5741593837738037
epoch: 8 step: 59, loss is 0.5559030771255493
epoch: 8 step: 60, loss is 0.5518501996994019
epoch: 8 step: 61, loss is 0.5204090476036072
epoch: 8 step: 62, loss is 0.5730948448181152
epoch: 8 step: 63, loss is 0.578804612159729
epoch: 8 step: 64, loss is 0.5698712468147278
epoch: 8 step: 65, loss is 0.551567554473877
epoch: 8 step: 66, loss is 0.578498125076294
epoch: 8 step: 67, loss is 0.5612616539001465
epoch: 8 step: 68, loss is 0.5655689835548401
epoch: 8 step: 69, loss is 0.5580307841300964
epoch: 8 step: 70, loss is 0.5720290541648865
epoch: 8 step: 71, loss is 0.5593948364257812
epoch: 8 step: 72, loss is 0.566638708114624
epoch: 8 step: 73, loss is 0.5440792441368103
epoch: 8 step: 74, loss is 0.5494251847267151
epoch: 8 step: 75, loss is 0.5321938395500183
epoch: 8 step: 76, loss is 0.544027149

epoch: 9 step: 63, loss is 0.5569601655006409
epoch: 9 step: 64, loss is 0.552660346031189
epoch: 9 step: 65, loss is 0.5763317942619324
epoch: 9 step: 66, loss is 0.5515755414962769
epoch: 9 step: 67, loss is 0.5731020569801331
epoch: 9 step: 68, loss is 0.5655638575553894
epoch: 9 step: 69, loss is 0.5946428179740906
epoch: 9 step: 70, loss is 0.5332779288291931
epoch: 9 step: 71, loss is 0.5591080188751221
epoch: 9 step: 72, loss is 0.5730929374694824
epoch: 9 step: 73, loss is 0.5279038548469543
epoch: 9 step: 74, loss is 0.5752511024475098
epoch: 9 step: 75, loss is 0.5903303027153015
epoch: 9 step: 76, loss is 0.5515754818916321
epoch: 9 step: 77, loss is 0.5765879154205322
epoch: 9 step: 78, loss is 0.5548292398452759
epoch: 9 step: 79, loss is 0.5666365623474121
epoch: 9 step: 80, loss is 0.5666417479515076
epoch: 9 step: 81, loss is 0.5666701793670654
epoch: 9 step: 82, loss is 0.5548167824745178
epoch: 9 step: 83, loss is 0.5744000673294067
epoch: 9 step: 84, loss is 0.551595

epoch: 10 step: 69, loss is 0.5709591507911682
epoch: 10 step: 70, loss is 0.5558795928955078
epoch: 10 step: 71, loss is 0.5494070649147034
epoch: 10 step: 72, loss is 0.5817216038703918
epoch: 10 step: 73, loss is 0.5817257165908813
epoch: 10 step: 74, loss is 0.5418823957443237
epoch: 10 step: 75, loss is 0.5698694586753845
epoch: 10 step: 76, loss is 0.5795499086380005
epoch: 10 step: 77, loss is 0.5410915613174438
epoch: 10 step: 78, loss is 0.5612611174583435
epoch: 10 step: 79, loss is 0.5472856163978577
epoch: 10 step: 80, loss is 0.5711707472801208
epoch: 10 step: 81, loss is 0.5537360310554504
epoch: 10 step: 82, loss is 0.5816845297813416
epoch: 10 step: 83, loss is 0.5644840598106384
epoch: 10 step: 84, loss is 0.5570064783096313
epoch: 10 step: 85, loss is 0.5440744757652283
epoch: 10 step: 86, loss is 0.5902819633483887
epoch: 10 step: 87, loss is 0.5462229251861572
epoch: 10 step: 88, loss is 0.5623319745063782
epoch: 10 step: 89, loss is 0.5623379349708557
epoch: 10 ste