In [1]:
import torch.backends.cudnn as cudnn
cudnn.benchmark = True

In [2]:
import torch
from torch import optim, nn
import argparse
from packages.vocab import Vocab
import torch.nn.functional as F
from tensorboard.logger import Logger
from torch.autograd import Variable
from packages.data_loader import get_loader
from models.extractor import JavascriptExtractor
from packages.functions import pack_padded, to_np, to_var, str2bool

In [3]:
class Args():
    train_root='/home/irteam/users/data/D3/outputs_train.txt'
    val_root='/home/irteam/users/data/D3/outputs_val.txt'
    test_root='/home/irteam/users/data/D3/outputs_test.txt'
    dict_root='data/dict_1000.json'
    max_oovs=20
    mode='test'
    epochs=20
    hidden=256
    embed=256
    lr=0.01
    log=False
    load='data/model_9000_steps.pckl'
#     load = None
    copy=False
    n_layers=2
    n_head=8
    similarity='mlp'
    max_in_seq=150
    max_out_seq=150
    batch=64
    encoder='lstm'
    single=False
    cuda=True
args = Args()

In [10]:
def train(args):
    print(args)
    if args.log:
        logger = Logger('./logs')
    vocab = Vocab(args.dict_root, args.max_oovs)
    data_loader = get_loader(args.train_root, args.dict_root, vocab, args.batch, args.single)

    criterion = nn.NLLLoss()
    if args.load is None:
        model = JavascriptExtractor(args,vocab)
    else:
        model = torch.load(args.load)
    if args.cuda:
        model.cuda()
    steps = 0
    opt = optim.Adam(model.parameters(), lr=args.lr)
    total_batches=0
    for epoch in range(args.epochs):
        within_steps = 0
        for i, (inputs, lengths, labels, oovs) in enumerate(data_loader):
            # split tuples
            steps+=1
            total_batches = max(total_batches,i)
            model.zero_grad()
            sources, queries, targets = inputs
            source_len, query_len, target_len, context_len= lengths
            if args.cuda:
                sources = sources.cuda()
                queries = queries.cuda()
                targets = targets.cuda()
            if args.single:
                outputs = model(sources,queries,lengths, targets) # [batch x seq x vocab]
            else:
                outputs, sim = model(sources,queries,lengths, targets) # [batch x seq x vocab]
            targets = Variable(targets[:,1:])
            packed_outputs,packed_targets = pack_padded(outputs,targets)
            packed_outputs = torch.log(packed_outputs)
            if args.single:
                loss = criterion(packed_outputs,packed_targets)
            else:
                sim = sim + 1e-3
                sim = torch.log(sim)
                labels = Variable(torch.LongTensor(list(labels)))
                if args.cuda:
                    labels = labels.cuda()
                loss1 = criterion(sim, labels)
                loss2 = criterion(packed_outputs,packed_targets)
                loss = loss1 + loss2
                # loss = loss1
            predicted = packed_outputs.max(1)[1]
            correct=(predicted==packed_targets).long().sum()
            acc = (correct.data[0]*1.0/packed_targets.size(0))
            if args.single:
                print("[%d]: Epoch %d\t%d/%d\tLoss: %1.3f\tAccuracy: %1.3f"
                      %(steps,epoch+1,i,total_batches,
                        loss.data[0],acc))
            else:
                predicted_label = sim.max(1)[1]
                correct_label=(predicted_label==labels).long().sum()
                acc2 = (correct_label.data[0]*1.0/len(labels))
                print("[%d]: Epoch %d\t%d/%d\tLoss: %1.3f, %1.3f\tAccuracy: %1.3f, %1.3f"
                      %(steps,epoch+1,i,total_batches,
                        loss1.data[0],loss2.data[0],acc2,acc))
            loss.backward()
            opt.step()
            if steps%100==0:
                val(model,vocab, args) 
                torch.save(obj=model,f='data/model_%d_steps.pckl'%steps)
                print("Model saved...")
                if args.log:
                    # log scalar values
                    info = {'loss': loss.data[0],
                            'acc': acc}
                    for tag,value in info.items():
                        logger.scalar_summary(tag,value,steps)

                    # log values and gradients of the parameters
                    for tag, value in model.named_parameters():
                        tag = tag.replace('.','/')
                        logger.histo_summary(tag, to_np(value), steps)
                        logger.histo_summary(tag+'/grad',to_np(value.grad), steps)
                        
def val(model, vocab, args):
    criterion = nn.NLLLoss()
    mode = 'Validation results:'
    data_loader = get_loader(args.val_root, args.dict_root, vocab, args.batch, 
                          args.single, shuffle=False)
    total_cases = 0
    total_correct = 0
    total_loss = 0
    for i, (inputs, lengths, labels, oovs) in enumerate(data_loader):
        model.eval()
        sources, queries, targets = inputs
        source_len, query_len, target_len, context_len= lengths
        if args.cuda:
            sources = sources.cuda()
            queries = queries.cuda()
            targets = targets.cuda()
        if args.single:
            outputs = model(sources,queries,lengths, targets) # [batch x seq x vocab]
        else:
            outputs, sim = model(sources,queries,lengths,targets)
        targets = Variable(targets[:,1:])

        packed_outputs,packed_targets = pack_padded(outputs,targets)
        packed_outputs = torch.log(packed_outputs)
        if args.single:
            loss = criterion(packed_outputs,packed_targets)
        else:
            sim = sim + 1e-3
            sim = torch.log(sim)
            labels = Variable(torch.LongTensor(list(labels)))
            if args.cuda:
                labels = labels.cuda()
            loss1 = criterion(sim, labels)
            loss2 = criterion(packed_outputs,packed_targets)
            loss = loss1 + loss2
            # loss = loss1
        predicted = packed_outputs.max(1)[1]
        correct=(predicted==packed_targets).long().sum()
        acc = (correct.data[0]*1.0/packed_targets.size(0))
        if args.single:
            print("Loss: %1.3f\tAccuracy: %1.3f"
                  %(loss.data[0],acc))
        else:
            predicted_label = sim.max(1)[1]
            correct_label=(predicted_label==labels).long().sum()
            acc2 = (correct_label.data[0]*1.0/len(labels))
            print("Loss: %1.3f, %1.3f\tAccuracy: %1.3f, %1.3f"
                  %(loss1.data[0],loss2.data[0],acc2,acc))
    return

def test(args):
    vocab = Vocab(args.dict_root, args.max_oovs)
    criterion = nn.NLLLoss()
    if args.load is None:
        print("Error: no model found")
        sys.exit()
    else:
        model = torch.load(args.load)
    if args.cuda:
        model.cuda()
    total_batches=0
    args.val_root = args.test_root # to apply val function directly
    val(model, vocab, args)
    return

def copy(args):
    import os
    import datetime
    from distutils.dir_util import copy_tree
    folder_dir = os.path.join('data',datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    os.mkdir(folder_dir)
    from_list = ['models/','packages/']
    for item in from_list:
        from_dir = item
        to_dir = os.path.join(folder_dir,item)
        copy_tree(from_dir, to_dir)    
    print("Folders copied at %s" %folder_dir)
    return

def main(args):
    if args.copy==True:
        copy(args)
    if args.mode=='train':
        print("Train mode")
        train(args)
    elif args.mode=='test':
        print("Test mode")
        test(args)
    else:
        print("Error: please specify --mode as 'train' or 'test'")

In [11]:
if __name__ == "__main__":
    main(args)

Test mode
Loss: 0.898, 1.488	Accuracy: 0.719, 0.736
Loss: 1.347, 1.168	Accuracy: 0.359, 0.795
Loss: 1.184, 1.091	Accuracy: 0.469, 0.796
Loss: 1.161, 1.380	Accuracy: 0.547, 0.736
Loss: 1.034, 2.007	Accuracy: 0.625, 0.730
Loss: 1.135, 2.321	Accuracy: 0.578, 0.546
Loss: 1.083, 2.644	Accuracy: 0.656, 0.583
Loss: 0.461, 0.527	Accuracy: 0.906, 0.884
Loss: 0.673, 1.159	Accuracy: 0.859, 0.827
Loss: 0.691, 1.395	Accuracy: 0.812, 0.799
Loss: 0.430, 0.234	Accuracy: 0.922, 0.926
Loss: 0.934, 0.974	Accuracy: 0.703, 0.820
Loss: 0.505, 0.298	Accuracy: 0.938, 0.879
Loss: 0.629, 1.385	Accuracy: 0.859, 0.770
Loss: 0.982, 1.715	Accuracy: 0.578, 0.698
Loss: 0.853, 2.580	Accuracy: 0.672, 0.581
Loss: 0.890, 2.065	Accuracy: 0.688, 0.638
Loss: 0.495, 2.400	Accuracy: 0.812, 0.641
Loss: 0.991, 2.758	Accuracy: 0.500, 0.547
Loss: 0.831, 2.408	Accuracy: 0.656, 0.589
Loss: 0.714, 2.581	Accuracy: 0.750, 0.537
Loss: 1.018, 2.116	Accuracy: 0.531, 0.620
Loss: 1.044, 2.273	Accuracy: 0.641, 0.608
Loss: 1.264, 1.689	Accur