In [1]:
from googlenet.src.googlenet import GoogleNet
from googlenet_for_test

In [2]:
#Set context
import os
import argparse
from mindspore import context

parser = argparse.ArgumentParser(description='Demo')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'])

args = parser.parse_known_args()[0]
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)

In [3]:
#Instantiate GoogleNet model
googlenet = GoogleNet(2)

In [4]:
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as CV
from mindspore import dtype as mstype

def create_dataset(dataset_path,
                   num_parallel_workers=1):
    # Define the dataset with Image Processing
    dataset = ds.ImageFolderDataset(dataset_path, class_indexing = {"Covid": 1, "Non_Covid":0})
    decode_op = vision.Decode()
    resize_height, resize_width = 224, 224
    resize_op = vision.Resize((resize_height, resize_width))
    type_cast_op = C.TypeCast(mstype.float32)
    changeswap_op = vision.HWC2CHW()
    
    #Data Augmentation
    horizontal_flip = vision.RandomHorizontalFlip(prob=0.8)
    rotate = vision.RandomRotation(20)
    vertical_flip = vision.RandomVerticalFlip(prob=0.5)
    
    augment_list = [horizontal_flip, rotate, vertical_flip]
    
    dataset = dataset.map(operations=decode_op, input_columns="image", num_parallel_workers=num_parallel_workers)
    dataset = dataset.map(operations=resize_op, input_columns="image", num_parallel_workers=num_parallel_workers)
    dataset = dataset.map(operations=type_cast_op, input_columns="image", num_parallel_workers=num_parallel_workers)
    dataset = dataset.map(operations=changeswap_op, input_columns="image", num_parallel_workers=num_parallel_workers)
    
    dataset = dataset.map(operations = augment_list, input_columns = "image", num_parallel_workers=num_parallel_workers)
    
    buffer_size = 10
    dataset = dataset.shuffle(buffer_size=buffer_size)
    dataset = dataset.batch(cifar_cfg.batch_size, drop_remainder=True)

    return dataset


train_ds = create_dataset(data_path="/home/Hachathon_Team_7/Challenge/datasets/train/")


In [5]:
import mindspore.nn as nn
# Define the loss function.
googlenet_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# Define the optimizer.
googlenet_opt = nn.Momentum(googlenet.trainable_params(), learning_rate=0.01, momentum=0.9)

In [6]:
# Import the library required for model training.
from mindspore.nn import Accuracy
from mindspore import Model

train_epoch = 100

train = train_ds
model = Model(googlenet, googlenet_loss, googlenet_opt, metrics={"Accuracy": Accuracy()})

In [None]:
#Train model
model.train(train_epoch,train)