In [1]:
import argparse
import copy
import os
import torch
from torch import nn, optim
from tensorboardX import SummaryWriter
from time import gmtime, strftime

from models import CSATransformer
from mydata import getHotpotData

In [2]:
def fixlabel(batch,dit):
    #probality distribute 
    b = batch.Label.transpose(0,1) # batchsize*tensor
    #print(b)
    result = []
    for index_tensor in b:
        p = [0]*100
        p_numpy = index_tensor.to('cpu').numpy()
        for idx in p_numpy:    
            act_idx = dit.itos[idx]
            if act_idx=='<pad>':
                break
            p[int(act_idx)] = 1
        result.append(torch.Tensor(p))
    return result

In [3]:
def train(args, data):
    model = CSATransformer(args, data).to(args.gpu)
   
    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.Adadelta(parameters, lr=args.learning_rate, weight_decay=args.weight_decay)
    criterion = nn.BCELoss()

    #writer = SummaryWriter(log_dir='runs/' + args.model_time)

    model.train()
    loss, last_epoch = 0, -1
    max_dev_acc, max_test_acc = 0, 0

    iterator = data.train_iter
    
    for i, batch in enumerate(iterator):
        loss = 0
        present_epoch = int(iterator.epoch)
        if present_epoch == args.epoch:
            break
        if present_epoch > last_epoch:
            print('epoch:', present_epoch + 1)
        last_epoch = present_epoch
        #pred and label [tensor,tensor,tensor...]
        pred = model(batch)
        label = fixlabel( batch, data.LABEL.vocab)
        optimizer.zero_grad()
        #orch.narrow(input, dim, start, length) → Tensor
        for i in range(0,args.batch_size):
            x = pred[i]
            y = torch.narrow(label[i],0,0,x.size(0))
            print(y)
            batch_loss = criterion(x,y) 
            loss += batch_loss
            batch_loss.backward()
            
        loss/= args.batch_size
        print('Loss is {}'.format(loss))
        optimizer.step()
        
    return None
    """
        if (i + 1) % args.print_freq == 0:
            dev_loss, dev_acc = test(model, data, mode='dev')
            test_loss, test_acc = test(model, data)
            c = (i + 1) // args.print_freq

            writer.add_scalar('loss/train', loss, c)
            writer.add_scalar('loss/dev', dev_loss, c)
            writer.add_scalar('acc/dev', dev_acc, c)
            writer.add_scalar('loss/test', test_loss, c)
            writer.add_scalar('acc/test', test_acc, c)

            print(f'train loss: {loss:.3f} / dev loss: {dev_loss:.3f} / test loss: {test_loss:.3f}'
                  f' / dev acc: {dev_acc:.3f} / test acc: {test_acc:.3f}')

            if dev_acc > max_dev_acc:
                max_dev_acc = dev_acc
                max_test_acc = test_acc
                best_model = copy.deepcopy(model)

            loss = 0
            model.train()

    writer.close()
     
    print(f'max dev acc: {max_dev_acc:.3f} / max test acc: {max_test_acc:.3f}')

    return best_model
   """

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch-size', default=2, type=int)
    parser.add_argument('--block-size', default=-1, type=int)
    parser.add_argument('--data-type', default='hotpot')
    parser.add_argument('--dropout', default=0.1, type=float)
    parser.add_argument('--epoch', default=20, type=int)
    parser.add_argument('--learning-rate', default=0.5, type=float)
    parser.add_argument('--mSA-scalar', default=5.0, type=float)
    parser.add_argument('--print-freq', default=3000, type=int)
    parser.add_argument('--weight-decay', default=5e-5, type=float)
    parser.add_argument('--word-dim', default=300, type=int)
    parser.add_argument('--csa-mode',default='mul',type = str)
    parser.add_argument('--gpu', default=torch.device('cpu' if torch.cuda.is_available() else 'cpu'), type=int)
    
    args = parser.parse_args([])

    print('loading Hotpot data...')
    trainpath = './small_train_sep_1000.csv'
    devpath = './small_dev_sep_1000.csv'
    data = getHotpotData(args,trainpath,devpath)
    
    setattr(args, 'model_time', strftime('%H:%M:%S', gmtime()))
    
    print('training start!')
    best_model = train(args, data)

    if not os.path.exists('saved_models'):
        os.makedirs('saved_models')
        
    #torch.save(best_model.state_dict(), f'saved_models/CSA_{args.data_type}_{args.model_time}.pt')

    print('training finished!')


if __name__ == '__main__':
    main()

loading Hotpot data...
training start!
epoch: 1
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0.])
Loss is 0.5787869691848755


KeyboardInterrupt: 

In [None]:
import torch.functional as F
X = torch.sigmoid(torch.randn(10))
Y = torch.sigmoid(torch.randn(10))
cc = nn.BCELoss()
print(cc(X,Y).data)