# Model Define

In [None]:
from mindspore import context
from easydict import EasyDict as edict
from mindspore import dataset as ds
import os
import mindspore.ops as ops
import mindspore.nn as nn
from mindspore.train.model import Model
from mindspore.train.callback import LossMonitor, CheckpointConfig, ModelCheckpoint


In [None]:
# context
context.set_context(mode=context.PYNATIVE_MODE, save_graphs=False, device_target='CPU')
# CONFIG
cfg = edict({
    'question_embdim': 512,
    'image_embdim': 512,
    'outdim': 1024,
    'dropout': 0.1,
    'answer_num': 8193,
    'learning_rate': 0.001,
    'epochs': 2,
    'batch_size': 256,
})

In [None]:
class VQA_baseline(nn.Cell):

    def __init__(self, config, auto_prefix=True, flags=None):
        '''
        config:{
            question_embdim 问题句嵌入dim
            image_embdim 图片嵌入dim
            outdim 输出dim
            dropout 
            answer_num
        }
        '''
        super().__init__(auto_prefix, flags)
        self.question_dense = nn.Dense(config.question_embdim, config.outdim)
        self.image_dense = nn.Dense(config.image_embdim, config.outdim)
        self.dropout = nn.Dropout(1 - config.dropout)
        self.out_dense_1 = nn.Dense(config.outdim, 256)
        self.out_dense_2 = nn.Dense(256, config.answer_num)
        self.relu = nn.ReLU()

    def construct(self, ques, img):
        '''
        img:[B, image_embdim]
        ques:[B, question_embdim]
        output:[B, answer_num]
        '''
        ques_out = self.relu(self.question_dense(ques))
        img_out = self.relu(self.image_dense(img))
        c = self.relu(self.out_dense_1(ques_out * img_out))
        c = self.out_dense_2(c)
        output = self.dropout(c)
        return output

class WithLossCell(nn.Cell):
    def __init__(self, backbone, config):
        super(WithLossCell, self).__init__(auto_prefix=False)
        self._backbone = backbone
        self.batch_size = config.batch_size
        self.loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
    def construct(self, ques_emb, img_emb, label):
        out = self._backbone(ques_emb, img_emb)
        loss = self.loss(out, label)
        return loss

class EvalCell(nn.Cell):
    def __init__(self, backbone):
        super(EvalCell, self).__init__(auto_prefix=False)
        self._backbone = backbone
        self.onehot = ops.OneHot()
        self.print = ops.Print()
    def construct(self, ques_emb, img_emb, label):
        out = self._backbone(ques_emb, img_emb)
        return (out, label)

In [None]:
def create_dataset(batch_size, dir_path = "./", repeat_num = 1, is_training = True):

    if is_training:
        data_dir = os.path.join(dir_path, "trainval.mindrecord")
    else:
        data_dir = os.path.join(dir_path, "test.mindrecord")
    data_set = ds.MindDataset(data_dir, columns_list=["ques_emb","img_emb", "label"], num_shards=1, shard_id=0)
    data_set = data_set.shuffle(buffer_size=data_set.get_dataset_size())
    data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)
    data_set = data_set.repeat(count=repeat_num)
    return data_set

In [None]:
# create dataset
train_dataset = create_dataset(cfg.batch_size)
# create model
vqa_baseline = VQA_baseline(cfg)
vqa_with_loss = WithLossCell(vqa_baseline, cfg)
vqa_with_acc = EvalCell(vqa_baseline)
opt = nn.Adam(vqa_with_loss.trainable_params(), learning_rate=cfg.learning_rate)
model = Model(vqa_with_loss, eval_network=vqa_with_acc, metrics={'acc'}, optimizer=opt)
# create callback
config_ck = CheckpointConfig(save_checkpoint_steps=128, keep_checkpoint_max=32)
ckpoint_cb = ModelCheckpoint(prefix="hzw", directory="./checkpoint", config=config_ck)
loss_cb = LossMonitor()
cb = [loss_cb, ckpoint_cb]
# start training
print("start training...")
model.train(cfg.epochs, train_dataset, callbacks=cb)


In [None]:
eval_dataset = create_dataset(cfg.batch_size, is_training=False)
acc = model.eval(eval_dataset)
print(acc)