In [2]:
%load_ext autoreload
%autoreload 2

import mindspore as ms

ms.set_seed(1)
ms.context.set_context(mode=ms.context.GRAPH_MODE, device_target='CPU')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### 1. Train ResNet18 on cifar10 dataset

### 1.1 Prepare dataset 

Download cifar10 dataset, extract it, and put it in `cifar10_dir` folder. 
Then run 

In [3]:
# prepare dataset
from mindcv.data import create_dataset, create_transforms, create_loader

# create dataset
cifar10_dir = '/data/cifar/cifar-10-batches-bin/'
num_classes = 10
num_workers = 8

dataset_train = create_dataset(name='cifar10', root=cifar10_dir, split='train', shuffle=True, num_parallel_workers=num_workers, download=False)

# create transform and get trans list
trans = create_transforms(dataset_name='cifar10', image_resize=224)

# get data loader for training
loader_train = create_loader(
        dataset=dataset_train,
        batch_size=32,
        is_training=True,
        num_classes=num_classes,
        transform=trans,
        num_parallel_workers=num_workers,
    )

# TODO: visualize 
#data = next(loader_train.create_dict_iterator())
#print(data['image'][0].asnumpy().squeeze())

'''
import matplotlib.pyplot as plt
data = next(loader_train.create_dict_iterator())
data['image'][0]
plt.imshow(.asnumpy().squeeze(), cmap=plt.cm.gray)
plt.title(data['label'][0].asnumpy(), fontsize=20)
plt.show()
'''


"\nimport matplotlib.pyplot as plt\ndata = next(loader_train.create_dict_iterator())\ndata['image'][0]\nplt.imshow(.asnumpy().squeeze(), cmap=plt.cm.gray)\nplt.title(data['label'][0].asnumpy(), fontsize=20)\nplt.show()\n"

### 1.2 Build and train network

In [4]:
# build network and train
from mindcv.models import create_model
from mindcv.loss import create_loss
from mindcv.optim import create_optimizer
from mindcv.scheduler import create_scheduler

# build resnet model
network = create_model(model_name='resnet18', num_classes=num_classes, pretrained=False)

# set loss function
loss = create_loss(name='CE')

# set optimizer 
steps_per_epoch = loader_train.get_dataset_size()
opt = create_optimizer(network.trainable_params(), opt='adam', lr=1e-2) 



In [None]:
# TODO: simplify the training code 

from mindspore import FixedLossScaleManager, Model, LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
ckpt_save_dir = './ckpt'

model = Model(network, loss_fn=loss, optimizer=opt, metrics={'acc'})
print(steps_per_epoch)
loss_cb = LossMonitor(per_print_times=10)
time_cb = TimeMonitor(data_size=10)
callbacks = [loss_cb, time_cb]
ckpt_config = CheckpointConfig(save_checkpoint_steps=steps_per_epoch)
ckpt_cb = ModelCheckpoint(prefix='resnet18_sratch',
                          directory=ckpt_save_dir,
                          config=ckpt_config)
callbacks.append(ckpt_cb)

# train model
#sink_mode = not (target_device == "CPU")
model.train(10, loader_train, callbacks=callbacks, dataset_sink_mode=False)



1563
epoch: 1 step: 10, loss is 2.543825387954712
epoch: 1 step: 20, loss is 2.4227259159088135
epoch: 1 step: 30, loss is 2.21061110496521
epoch: 1 step: 40, loss is 2.31510853767395
epoch: 1 step: 50, loss is 2.1939709186553955
epoch: 1 step: 60, loss is 2.1772873401641846
epoch: 1 step: 70, loss is 1.968402624130249
epoch: 1 step: 80, loss is 2.1735341548919678
epoch: 1 step: 90, loss is 2.1211137771606445
epoch: 1 step: 100, loss is 2.060058116912842
epoch: 1 step: 110, loss is 2.3186118602752686
epoch: 1 step: 120, loss is 1.942218542098999
epoch: 1 step: 130, loss is 1.8849071264266968
