In [1]:
import time 
import mxnet as mx 
from mxnet import gluon,autograd,nd
import mxnet.ndarray as F
from mxnet.gluon.model_zoo import vision
from data_loader import videoFolder
import utils
from option import Options, args_
from multiprocessing import cpu_count
from network import lstm_net,resnet18_v2
from metrics import L2Loss_2, L2Loss_cos
import sys
#import multiprocessing

In [2]:
def train(args):
    frames = args.frames
    caption_length = args.caption_length
    glove_file = args.glove_file
    
    #CPU_COUNT = multiprocessing.cpu_count()
    if args.cuda:
        ctx = mx.gpu()
    else:
        ctx = mx.cpu()
    
    if args.load_pretrain:
        pretrain_model = vision.vgg16_bn(pretrained=True,ctx=ctx)
        transform = utils.Compose([utils.ToTensor(ctx),
                               utils.normalize(ctx),
                               utils.extractFeature(ctx,pretrain_model)
                             ])
    else:
        pretrain_model = None
        transform = utils.Compose([utils.ToTensor(ctx),
                                   utils.normalize(ctx),
                                 ])
    
    target_transform = utils.targetCompose([utils.WordToTensor(ctx)])

    train_dataset = videoFolder(args.train_folder,args.train_dict, frames, glove_file, 
                    caption_length, ctx, transform=transform, target_transform=target_transform)

    test_dataset = videoFolder(args.test_folder,args.test_dict, frames, glove_file, 
                        caption_length, ctx, transform=transform, target_transform=target_transform)

    train_loader = gluon.data.DataLoader(train_dataset,batch_size=args.batch_size,
                                last_batch='keep',shuffle=True)

    test_loader = gluon.data.DataLoader(test_dataset,batch_size=args.batch_size,
                                    last_batch='keep',shuffle=False)

    loss = L2Loss_2()
    #net = lstm_net(frames,caption_length,ctx,pretrained=args.load_pretrain)
    net = resnet18_v2(caption_length=caption_length,ctx=ctx)
                            
            
    net.collect_params().initialize(init=mx.initializer.MSRAPrelu(), ctx=ctx)
        
    trainer = gluon.Trainer(net.collect_params(), 'adam',
                            {'learning_rate': args.lr})
    
    smoothing_constant = 0.01
    
    for e in range(args.epochs):
        
        epoch_loss = 0.
        batch_loss = None
        for batch_id, (x,_) in enumerate(train_loader):
            
            if batch_id > 0:
                batch_loss = F.mean(batch_loss).asscalar()
                epoch_loss = (1 - smoothing_constant)*epoch_loss + smoothing_constant*batch_loss
                
            with autograd.record():
                pred = net(x)
                batch_loss = loss(pred,_)
            
            trainer.step(x.shape[0],ignore_stale_grad=True)
            batch_loss.backward()
            mx.nd.waitall()
            
            
            if (batch_id+1) % 100 == 0:
                print("Train Batch:{}, batch_loss:{}".format(batch_id+1, epoch_loss))
            
            if ((batch_id == 0) and (e == 0)):
                epoch_loss = F.mean(batch_loss).asscalar() 
        
        
        epoch_loss_1 = 0.
        batch_loss_1 = None
        for batch_id, (x,_) in enumerate(test_loader):
            
            if batch_id > 0:
                batch_loss_1 = F.mean(batch_loss_1).asscalar()
                epoch_loss_1 = (1 - smoothing_constant)*epoch_loss_1 + smoothing_constant*batch_loss_1
                
            with autograd.predict_mode():
                predict = net(x)
                batch_loss_1 = loss(pred,_)
                batch_loss_1 = F.mean(batch_loss_1).asscalar()
            
            if (batch_id+1) % 30 == 0:
                print("Test Batch:{}, batch_loss:{}".format(batch_id+1, epoch_loss_1))
                
            if ((batch_id == 0) and (e == 0)):
                epoch_loss_1 = F.mean(batch_loss_1).asscalar() 
            
        print("Epoch {}, train_loss:{}, test_loss:{}".format(e+1, epoch_loss, epoch_loss_1))
    
    if args.save_model == True:
        file_name = "./saved_model/" + "lstm_pretrain.params"
        net.save_parameters(file_name)
        

In [3]:
def main():
    args = args_()
    train(args)


In [4]:
if __name__ == "__main__":
    main()  

Train Batch:100, batch_loss:
[[0.39366227 0.60824007 0.5983646  0.39073762 0.42150462 0.60927516
  0.5869195  0.5564809  0.41579407 0.5152232  0.6547611  0.34676555
  0.4387939  0.54638207 0.38036242 0.58439356 0.5669514  0.44710204
  0.52117604 0.38693878 0.43803266 0.57516897 0.5460846  0.03833161
  0.03758459 0.04205515 0.04178111 0.04087343 0.06362057 0.04280103
  0.04354157 0.03852938 0.03785402 0.04787131 0.0475389  0.04072666
  0.0436152  0.04702995 0.04028513 0.05011647 0.04889257 0.03809907
  0.04242947 0.04639535 0.0390443  0.04616857 0.0415615  0.03838789
  0.04260176 0.03609623]
 [0.35816112 0.64503205 0.6041003  0.46471447 0.40931603 0.3947672
  0.43165192 0.39253947 0.40674305 0.6065319  0.61919993 0.0729438
  0.08432725 0.06431431 0.06078013 0.06530359 0.08433639 0.08252046
  0.06449606 0.06821509 0.0721042  0.06522253 0.05353156 0.05542986
  0.06322928 0.07251178 0.05665568 0.05991264 0.08324494 0.06545715
  0.0619426  0.05372094 0.05515889 0.05401799 0.06190415 0.05729

KeyboardInterrupt: 