In [1]:
import time 
import mxnet as mx 
from mxnet import gluon,autograd,nd
import mxnet.ndarray as F
from mxnet.gluon.model_zoo import vision
from data_loader import videoFolder
import utils
from option import Options, args_
from multiprocessing import cpu_count
from network import lstm_net,resnet18_v2
from metrics import L2Loss_2, L2Loss_cos
import sys
#import multiprocessing

In [2]:
def train(args):
    frames = args.frames
    caption_length = args.caption_length
    glove_file = args.glove_file
    
    #CPU_COUNT = multiprocessing.cpu_count()
    if args.cuda:
        ctx = mx.gpu()
    else:
        ctx = mx.cpu()
    
    if args.load_pretrain:
        pretrain_model = vision.vgg16_bn(pretrained=True,ctx=ctx)
        transform = utils.Compose([utils.ToTensor(ctx),
                               utils.normalize(ctx),
                               utils.extractFeature(ctx,pretrain_model)
                             ])
    else:
        pretrain_model = None
        transform = utils.Compose([utils.ToTensor(ctx),
                                   utils.normalize(ctx),
                                 ])
    
    target_transform = utils.targetCompose([utils.WordToTensor(ctx)])

    train_dataset = videoFolder(args.train_folder,args.train_dict, frames, glove_file, 
                    caption_length, ctx, transform=transform, target_transform=target_transform)

    test_dataset = videoFolder(args.test_folder,args.test_dict, frames, glove_file, 
                        caption_length, ctx, transform=transform, target_transform=target_transform)

    train_loader = gluon.data.DataLoader(train_dataset,batch_size=args.batch_size,
                                last_batch='keep',shuffle=True)

    test_loader = gluon.data.DataLoader(test_dataset,batch_size=args.batch_size,
                                    last_batch='keep',shuffle=False)

    loss = L2Loss_2()
    #net = lstm_net(frames,caption_length,ctx,pretrained=args.load_pretrain)
    net = resnet18_v2(caption_length=caption_length,ctx=ctx)
                            
            
    net.collect_params().initialize(init=mx.initializer.MSRAPrelu(), ctx=ctx)
        
    trainer = gluon.Trainer(net.collect_params(), 'adam',
                            {'learning_rate': args.lr})
    
    smoothing_constant = 0.01
    
    for e in range(args.epochs):
        
        epoch_loss = 0.
        for batch_id, (x,_) in enumerate(train_loader):
                            
            with autograd.record():
                pred = net(x)
                batch_loss = loss(pred,_)
            
            trainer.step(x.shape[0],ignore_stale_grad=True)
            batch_loss.backward()
            mx.nd.waitall()
            
            batch_loss = F.mean(batch_loss,axis=0).asnumpy()[0]
            if ((batch_id == 0) and (e == 0)):
                epoch_loss = batch_loss
            else:
                epoch_loss = (1 - smoothing_constant)*epoch_loss + smoothing_constant*batch_loss
            
            if (batch_id+1) % 100 == 0:
                print("Train Batch:{}, batch_loss:{}".format(batch_id+1, batch_loss))
                
        epoch_loss_1 = 0.
        for batch_id, (x,_) in enumerate(test_loader):
                            
            with autograd.predict_mode():
                predict = net(x)
                batch_loss_1 = loss(pred,_)
                
            batch_loss_1 = F.mean(batch_loss_1,axis=0).asnumpy()[0]
            #if (batch_id+1) % 30 == 0:
            #    print("Test Batch:{}, batch_loss:{}".format(batch_id+1, batch_loss_1))
                
            if ((batch_id == 0) and (e == 0)):
                epoch_loss_1 = batch_loss_1
            else:
                epoch_loss_1 = (1 - smoothing_constant)*epoch_loss_1 + smoothing_constant*batch_loss_1
                
            
        print("Epoch {}, train_loss:{}, test_loss:{}".format(e+1, epoch_loss, epoch_loss_1))
    
    if args.save_model == True:
        file_name = "./saved_model/" + "lstm_pretrain.params"
        net.save_parameters(file_name)
        

In [3]:
def main():
    args = args_()
    train(args)


In [4]:
if __name__ == "__main__":
    main()  

Train Batch:100, batch_loss:0.4223302900791168
Train Batch:200, batch_loss:0.4086487889289856
Train Batch:300, batch_loss:0.3608972728252411
Train Batch:400, batch_loss:0.33032113313674927


AssertionError: Argument data must have NDArray type, but got 0.33709067