# Model Define

In [None]:
import moxing as mox
mox.file.copy_parallel(src_url="obs://npl-hzw/nlp/VQA_V1.0", dst_url="./")

In [None]:
from mindspore import context, Tensor
from easydict import EasyDict as edict
from mindspore import dataset as ds
import os
import mindspore.ops as ops
import mindspore.nn as nn
from mindspore.train.model import Model
from mindspore.train.callback import LossMonitor, CheckpointConfig, ModelCheckpoint
import math

In [None]:
# context
context.set_context(mode=context.PYNATIVE_MODE, save_graphs=False, device_target='CPU')

# CONFIG
cfg = edict({
    'question_embdim': 512,
    'image_embdim': 512,
    'outdim': 1024,
    'hiddendim': 256,
    'dropout': 0.1,
    'answer_num': 8193,
    'learning_rate': 0.005,
    'epochs': 4,
    'batch_size': 256,
    'MFB_k': 2,
    'MFB_o': 1024,
})

learning_rate = []
warm_up = [cfg.learning_rate / math.floor(cfg.epochs / 5) * (i + 1) for _ in range(cfg.batch_size) for i in range(math.floor(cfg.epochs / 5))]
shrink = [cfg.learning_rate / (16 * (i + 1)) for _ in range(cfg.batch_size) for i in range(math.floor(cfg.epochs * 3 / 5))]
normal_run = [cfg.learning_rate for _ in range(cfg.batch_size) for i in range(cfg.epochs - math.floor(cfg.epochs / 5) - math.floor(cfg.epochs * 2 / 5))]
learning_rate = learning_rate + warm_up + normal_run + shrink


In [None]:
class MFB(nn.Cell):
    def __init__(self, config):
        super(MFB, self).__init__(auto_prefix=False)
        self.xFC = nn.Dense(config.question_embdim, config.MFB_k*config.MFB_o, activation='relu')
        self.yFC = nn.Dense(config.image_embdim, config.MFB_k*config.MFB_o, activation='relu')
        self.sumPooling = nn.AvgPool1d(config.MFB_k, stride=config.MFB_k)
        self.dropout = nn.Dropout(1 - config.dropout)
        self.norm = ops.L2Normalize(1)
        self.unsqueeze = ops.ExpandDims()
        self.squeeze = ops.Squeeze(1)
        self.relu = ops.ReLU()
        self.sqrt = ops.Sqrt()

    def construct(self, ques_emb, img_emb):
        x_out = self.xFC(ques_emb)
        y_out = self.yFC(img_emb)
        output = self.dropout(x_out * y_out)
        # output: [B, MFB_k*MFB_o]
        output = self.squeeze(self.sumPooling(self.unsqueeze(output, 1)))
        output = (self.sqrt(self.relu(output)) - self.sqrt(self.relu(-output))) # 未经过测试
        output = self.norm(output)
        # output: [B, MFB_o]
        return output

In [None]:
class VQA_baseline(nn.Cell):
    '''
    config:{
        question_embdim 问题句嵌入dim
        image_embdim 图片嵌入dim
        outdim 输出dim
        dropout 
        answer_num
    }
    '''
    def __init__(self, config, auto_prefix=True, flags=None):
        super().__init__(auto_prefix, flags)
        self.dropout = nn.Dropout(1 - config.dropout)
        self.FC = nn.Dense(config.MFB_o, config.answer_num, activation='relu')
        self.mfb = MFB(config)

    def construct(self, ques, img):
        '''
        img:[B, image_embdim]
        ques:[B, question_embdim]
        output:[B, answer_num]
        '''
        mfb_output = self.mfb(ques, img)
        output = self.FC(mfb_output)
        output = self.dropout(output)
        return output

class WithLossCell(nn.Cell):
    def __init__(self, backbone, config):
        super(WithLossCell, self).__init__(auto_prefix=False)
        self._backbone = backbone
        self.loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
    def construct(self, ques_emb, img_emb, label):
        out = self._backbone(ques_emb, img_emb)
        loss = self.loss(out, label)
        return loss

class EvalCell(nn.Cell):
    def __init__(self, backbone):
        super(EvalCell, self).__init__(auto_prefix=False)
        self._backbone = backbone
        self.onehot = ops.OneHot()
        self.print = ops.Print()
    def construct(self, ques_emb, img_emb, label):
        out = self._backbone(ques_emb, img_emb)
        return (out, label)

In [None]:
def create_dataset(batch_size, dir_path = "./", repeat_num = 1, is_training = True):

    if is_training:
        data_dir = os.path.join(dir_path, "trainval.mindrecord")
    else:
        data_dir = os.path.join(dir_path, "test.mindrecord")
    data_set = ds.MindDataset(data_dir, columns_list=["ques_emb","img_emb", "label"], num_shards=1, shard_id=0)
    data_set = data_set.shuffle(buffer_size=data_set.get_dataset_size())
    data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)
    data_set = data_set.repeat(count=repeat_num)
    return data_set

In [None]:
# create dataset
train_dataset = create_dataset(cfg.batch_size, repeat_num = 3)
# create model
vqa_baseline = VQA_baseline(cfg)
vqa_with_loss = WithLossCell(vqa_baseline, cfg)
vqa_with_acc = EvalCell(vqa_baseline)
opt = nn.Adam(vqa_with_loss.trainable_params(), learning_rate=cfg.learning_rate)
model = Model(vqa_with_loss, eval_network=vqa_with_acc, metrics={'acc'}, optimizer=opt)
# create callback
config_ck = CheckpointConfig(save_checkpoint_steps=128, keep_checkpoint_max=32)
ckpoint_cb = ModelCheckpoint(prefix="hzw", directory="./checkpoint", config=config_ck)
loss_cb = LossMonitor()
cb = [loss_cb, ckpoint_cb]
# start training
print("start training...")
model.train(cfg.epochs, train_dataset, callbacks=cb)


In [None]:
eval_dataset = create_dataset(cfg.batch_size, is_training=False)
acc = model.eval(eval_dataset)
print(acc)