# Model Define

In [None]:
import moxing as mox
mox.file.copy_parallel(src_url="obs://npl-hzw/nlp/VQA_V1.0", dst_url="./")

In [None]:
from mindspore import context
from easydict import EasyDict as edict
from mindspore import dataset as ds
import os
import mindspore.ops as ops
import mindspore.nn as nn
from mindspore.train.model import Model
from mindspore.train.callback import LossMonitor, CheckpointConfig, ModelCheckpoint

In [None]:
# context
context.set_context(mode=context.PYNATIVE_MODE, save_graphs=False, device_target='CPU')

# CONFIG
cfg = edict({
    'question_embdim': 512,
    'image_embdim': 512,
    'outdim': 1024,
    'hiddendim': 256,
    'dropout': 0.2,
    'answer_num': 8193,
    'min_lr': 0.001,
    'max_lr': 0.1,
    'epochs': 4,
    'batch_size': 128,
    'MFB_k': 2,
    'MFB_o': 512,
    'Attn_num': 16,
})

decay_steps = cfg.epochs * 44416/cfg.batch_size * 0.9
cosine_decay_lr = nn.CosineDecayLR(cfg.min_lr, cfg.max_lr, decay_steps)

In [None]:
class Attention(nn.Cell):
    def __init__(self, config):
        super(Attention, self).__init__(auto_prefix=False)
        self.ExpandFC = nn.Dense(config.image_embdim, config.image_embdim*config.Attn_num, activation='relu')
        self.config = config
        self.norm = ops.L2Normalize(2)
        self.dropout = nn.Dropout(1 - config.dropout)
        self.unsqueeze = ops.ExpandDims()
        self.softmax = nn.Softmax(1)
        self.getWeight = nn.Dense(config.image_embdim, 1, has_bias=False)

    def construct(self, ques_emb, img_emb):
        attn = self.ExpandFC(img_emb)
        # attn: [B, img_embdim*Attn_num]
        img_attn = attn.view((img_emb.shape[0], self.config.Attn_num, self.config.image_embdim))
        # img_attn: [B, Attn_num, img_embdim]
        attn = self.norm(img_attn)
        ques = ops.repeat_elements(self.unsqueeze(ques_emb, 1), rep = self.config.Attn_num, axis = 1)
        # ques: [B, Attn_num, ques_embdim]
        ques = self.norm(ques)
        attn = ques * attn
        attn = self.dropout(attn)
        attn = self.getWeight(attn)
        # attn: [B, Attn_num, 1]
        attn = self.softmax(attn)
        output_img = (attn * img_attn).sum(1)
        # output_img: [B, img_embdim]
        return output_img

In [None]:
class VQA_baseline(nn.Cell):
    '''
    config:{
        question_embdim 问题句嵌入dim
        image_embdim 图片嵌入dim
        outdim 输出dim
        dropout 
        answer_num
    }
    '''
    def __init__(self, config, auto_prefix=True, flags=None):
        super().__init__(auto_prefix, flags)
        self.attn = Attention(config)
        self.question_dense = nn.Dense(config.question_embdim, config.outdim)
        self.image_dense = nn.Dense(config.image_embdim, config.outdim)
        self.dropout = nn.Dropout(1 - config.dropout)
        self.out_dense_1 = nn.Dense(config.outdim, 256)
        self.out_dense_2 = nn.Dense(256, config.answer_num)

    def construct(self, ques, img):
        '''
        img:[B, image_embdim]
        ques:[B, question_embdim]
        output:[B, answer_num]

        '''
        img_attn = self.attn(ques, img)
        ques_out = self.relu(self.question_dense(ques))
        img_out = self.relu(self.image_dense(img_attn))
        c = self.relu(self.out_dense_1(ques_out * img_out))
        c = self.out_dense_2(c)
        output = self.dropout(c)
        return output

class WithLossCell(nn.Cell):
    def __init__(self, backbone, config):
        super(WithLossCell, self).__init__(auto_prefix=False)
        self._backbone = backbone
        self.loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True)
    def construct(self, ques_emb, img_emb, label):
        out = self._backbone(ques_emb, img_emb)
        loss = self.loss(out, label)
        return loss

class EvalCell(nn.Cell):
    def __init__(self, backbone):
        super(EvalCell, self).__init__(auto_prefix=False)
        self._backbone = backbone
        self.onehot = ops.OneHot()
        self.print = ops.Print()
    def construct(self, ques_emb, img_emb, label):
        out = self._backbone(ques_emb, img_emb)
        return (out, label)

In [None]:
def create_dataset(batch_size, dir_path = "./", repeat_num = 1, is_training = True):

    if is_training:
        data_dir = os.path.join(dir_path, "trainval.mindrecord")
    else:
        data_dir = os.path.join(dir_path, "test.mindrecord")
    data_set = ds.MindDataset(data_dir, columns_list=["ques_emb","img_emb", "label"], num_shards=1, shard_id=0)
    data_set = data_set.shuffle(buffer_size=data_set.get_dataset_size())
    data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)
    data_set = data_set.repeat(count=repeat_num)
    return data_set

In [None]:
# create dataset
train_dataset = create_dataset(cfg.batch_size, repeat_num = 3)
# create model
vqa_baseline = VQA_baseline(cfg)
vqa_with_loss = WithLossCell(vqa_baseline, cfg)
vqa_with_acc = EvalCell(vqa_baseline)
opt = nn.Adam(vqa_with_loss.trainable_params(), learning_rate=cfg.learning_rate)
model = Model(vqa_with_loss, eval_network=vqa_with_acc, metrics={'acc'}, optimizer=opt)
# create callback
config_ck = CheckpointConfig(save_checkpoint_steps=128, keep_checkpoint_max=32)
ckpoint_cb = ModelCheckpoint(prefix="hzw", directory="./checkpoint", config=config_ck)
loss_cb = LossMonitor()
cb = [loss_cb, ckpoint_cb]
# start training
print("start training...")
model.train(cfg.epochs, train_dataset, callbacks=cb)


In [None]:
eval_dataset = create_dataset(cfg.batch_size, is_training=False)
acc = model.eval(eval_dataset)
print(acc)