In [3]:
# -*- coding: utf-8 -*-
# file: atae-lstm
# author: songyouwei <youwei0314@gmail.com>
# Copyright (C) 2018. All Rights Reserved.
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from argparse import Namespace
import math
import os
from model import ATAE_LSTM, AOA
from data_utils import build_tokenizer, build_embedding_matrix, ABSADataset
from tensorboardX import SummaryWriter
from sklearn import metrics

In [4]:
# # Arguments for restaurant,ATAE-LSTM
# opt = Namespace(
#     model_name="atae_lstm",
#     dataset='restaurant',#twitter,laptop
#     seed=1234,
#     optimizer = 'adam',
#     initializer = 'xavier_uniform_',
#     log_step = 5,
#     logdir = 'log',
#     embed_dim = 200,
#     hidden_dim = 300,
#     max_seq_len = 80,
#     polarities_dim = 3,
#     hops = 3,
#     device = None,
#     learning_rate = 0.001,
#     batch_size = 128,
#     l2reg = 0.00001,
#     num_epoch = 20,
#     dropout = 0,
# )
# dataset_files = {
#     'twitter': {
#         'train': './datasets/acl-14-short-data/train.raw',
#         'test': './datasets/acl-14-short-data/test.raw'
#     },
#     'restaurant': {
#         'train': './datasets/semeval14/Restaurants_Train.xml.seg',
#         'test': './datasets/semeval14/Restaurants_Test_Gold.xml.seg'
#     },
#     'laptop': {
#         'train': './datasets/semeval14/Laptops_Train.xml.seg',
#         'test': './datasets/semeval14/Laptops_Test_Gold.xml.seg'
#     }
# }
# input_colses = {
#         'atae_lstm': ['text_raw_indices', 'aspect_indices']
#         'aoa': ['text_raw_indices', 'aspect_indices']
#     }
# initializers = {
#         'xavier_uniform_': torch.nn.init.xavier_uniform_,
#         'xavier_normal_': torch.nn.init.xavier_normal,
#         'orthogonal_': torch.nn.init.orthogonal_,
#     }
# optimizers = {
#         'adadelta': torch.optim.Adadelta,  # default lr=1.0
#         'adagrad': torch.optim.Adagrad,  # default lr=0.01
#         'adam': torch.optim.Adam,  # default lr=0.001
#         'adamax': torch.optim.Adamax,  # default lr=0.002
#         'asgd': torch.optim.ASGD,  # default lr=0.01
#         'rmsprop': torch.optim.RMSprop,  # default lr=0.01
#         'sgd': torch.optim.SGD,
#     }
# opt.inputs_cols = input_colses['atae_lstm']
# opt.dataset_file = dataset_files[opt.dataset]
# opt.initializer = initializers[opt.initializer]
# opt.optimizer = optimizers[opt.optimizer]
# # opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# opt.device = torch.device('cpu')

# opt.model_class=ATAE_LSTM
# # Set seed for reproducability
# np.random.seed(opt.seed)

In [4]:
# Arguments for restaurant,AOA
opt = Namespace(
    model_name="aoa",
    dataset='restaurant',#twitter,laptop
    seed=1234,
    optimizer = 'adam',
    initializer = 'xavier_uniform_',
    log_step = 5,
    logdir = 'log',
    embed_dim = 200,
    hidden_dim = 300,
    max_seq_len = 80,
    polarities_dim = 3,
    hops = 3,
    device = None,
    learning_rate = 0.001,
    batch_size = 128,
    l2reg = 0.0001,#0.00001
    num_epoch = 20,
    dropout = 0.2,
)
dataset_files = {
    'twitter': {
        'train': './datasets/acl-14-short-data/train.raw',
        'test': './datasets/acl-14-short-data/test.raw'
    },
    'restaurant': {
        'train': './datasets/semeval14/Restaurants_Train.xml.seg',
        'test': './datasets/semeval14/Restaurants_Test_Gold.xml.seg'
    },
    'laptop': {
        'train': './datasets/semeval14/Laptops_Train.xml.seg',
        'test': './datasets/semeval14/Laptops_Test_Gold.xml.seg'
    }
}
input_colses = {
        'atae_lstm': ['text_raw_indices', 'aspect_indices'],
        'aoa': ['text_raw_indices', 'aspect_indices']
    }
initializers = {
        'xavier_uniform_': torch.nn.init.xavier_uniform_,
        'xavier_normal_': torch.nn.init.xavier_normal,
        'orthogonal_': torch.nn.init.orthogonal_,
    }
optimizers = {
        'adadelta': torch.optim.Adadelta,  # default lr=1.0
        'adagrad': torch.optim.Adagrad,  # default lr=0.01
        'adam': torch.optim.Adam,  # default lr=0.001
        'adamax': torch.optim.Adamax,  # default lr=0.002
        'asgd': torch.optim.ASGD,  # default lr=0.01
        'rmsprop': torch.optim.RMSprop,  # default lr=0.01
        'sgd': torch.optim.SGD,
    }
opt.inputs_cols = input_colses['atae_lstm']
opt.dataset_file = dataset_files[opt.dataset]
opt.initializer = initializers[opt.initializer]
opt.optimizer = optimizers[opt.optimizer]
# opt.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
opt.device = torch.device('cpu')

opt.model_class=AOA
# Set seed for reproducability
np.random.seed(opt.seed)

In [12]:
class Instructor:
    def __init__(self, opt):
        self.opt = opt

        
        tokenizer = build_tokenizer(
            fnames=[opt.dataset_file['train'], opt.dataset_file['test']],
            max_seq_len=opt.max_seq_len,
            dat_fname='{0}_tokenizer.dat'.format(opt.dataset))
        embedding_matrix = build_embedding_matrix(
            word2idx=tokenizer.word2idx,
            embed_dim=opt.embed_dim,
            em_fname='{0}_{1}_embedding_matrix.dat'.format(str(opt.embed_dim), opt.dataset),
            ev_fpath='../../../data/embeddings/glove.twitter.27B/')
        self.model = opt.model_class(embedding_matrix, opt).to(opt.device)

        trainset = ABSADataset(opt.dataset_file['train'], tokenizer)
        testset = ABSADataset(opt.dataset_file['test'], tokenizer)
        self.train_data_loader = DataLoader(dataset=trainset, batch_size=opt.batch_size, shuffle=True)
        self.test_data_loader = DataLoader(dataset=testset, batch_size=opt.batch_size, shuffle=False)

        if opt.device.type == 'cuda':
            print("cuda memory allocated:", torch.cuda.memory_allocated(device=opt.device.index))
        self._print_args()

    def _print_args(self):
        n_trainable_params, n_nontrainable_params = 0, 0
        for p in self.model.parameters():
            n_params = torch.prod(torch.tensor(p.shape))
            if p.requires_grad:
                n_trainable_params += n_params
            else:
                n_nontrainable_params += n_params
        print('n_trainable_params: {0}, n_nontrainable_params: {1}'.format(n_trainable_params, n_nontrainable_params))
        print('> training arguments:')
        for arg in vars(self.opt):
            print('>>> {0}: {1}'.format(arg, getattr(self.opt, arg)))

    def _reset_params(self):
        for child in self.model.children():
#             if type(child) != BertModel:  # skip bert params (with unfreezed bert)
            for p in child.parameters():
                if p.requires_grad:
                    if len(p.shape) > 1:
                        self.opt.initializer(p)
                    else:
                        stdv = 1. / math.sqrt(p.shape[0])
                        torch.nn.init.uniform_(p, a=-stdv, b=stdv)

    def _train(self, criterion, optimizer, max_test_acc_overall=0):
        writer = SummaryWriter(log_dir=self.opt.logdir)
        max_test_acc = 0
        max_f1 = 0
        global_step = 0
        for epoch in range(self.opt.num_epoch):
            print('>' * 100)
            print('epoch: ', epoch)
            n_correct, n_total = 0, 0
            for i_batch, sample_batched in enumerate(self.train_data_loader):
                global_step += 1

                # switch model to training mode, clear gradient accumulators
                self.model.train()
                optimizer.zero_grad()

                inputs = [sample_batched[col].to(self.opt.device) for col in self.opt.inputs_cols]
                outputs = self.model(inputs)
                targets = sample_batched['polarity'].to(self.opt.device)

                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()

                if global_step % self.opt.log_step == 0:
                    n_correct += (torch.argmax(outputs, -1) == targets).sum().item()
                    n_total += len(outputs)
                    train_acc = n_correct / n_total

                    test_acc, f1 = self._evaluate_acc_f1()
                    if test_acc > max_test_acc:
                        max_test_acc = test_acc
                        if test_acc > max_test_acc_overall:
                            if not os.path.exists('state_dict'):
                                os.mkdir('state_dict')
                            path = 'state_dict/{0}_{1}_acc{2}'.format(self.opt.model_name, self.opt.dataset, round(test_acc, 4))
                            torch.save(self.model.state_dict(), path)
                            print('>> saved: ' + path)
                    if f1 > max_f1:
                        max_f1 = f1

                    writer.add_scalar('loss', loss, global_step)
                    writer.add_scalar('acc', train_acc, global_step)
                    writer.add_scalar('test_acc', test_acc, global_step)
                    print('loss: {:.4f}, acc: {:.4f}, test_acc: {:.4f}, f1: {:.4f}'.format(loss.item(), train_acc, test_acc, f1))

        writer.close()
        return max_test_acc, max_f1

    def _evaluate_acc_f1(self):
        # switch model to evaluation mode
        self.model.eval()
        n_test_correct, n_test_total = 0, 0
        t_targets_all, t_outputs_all = None, None
        with torch.no_grad():
            for t_batch, t_sample_batched in enumerate(self.test_data_loader):
                t_inputs = [t_sample_batched[col].to(opt.device) for col in self.opt.inputs_cols]
                t_targets = t_sample_batched['polarity'].to(opt.device)
                t_outputs = self.model(t_inputs)

                n_test_correct += (torch.argmax(t_outputs, -1) == t_targets).sum().item()
                n_test_total += len(t_outputs)

                if t_targets_all is None:
                    t_targets_all = t_targets
                    t_outputs_all = t_outputs
                else:
                    t_targets_all = torch.cat((t_targets_all, t_targets), dim=0)
                    t_outputs_all = torch.cat((t_outputs_all, t_outputs), dim=0)

        test_acc = n_test_correct / n_test_total
        f1 = metrics.f1_score(t_targets_all.cpu(), torch.argmax(t_outputs_all, -1).cpu(), labels=[0, 1, 2], average='macro')
        return test_acc, f1

    def run(self, repeats=1):
        # Loss and Optimizer
        criterion = nn.CrossEntropyLoss()
        _params = filter(lambda p: p.requires_grad, self.model.parameters())
        optimizer = self.opt.optimizer(_params, lr=self.opt.learning_rate, weight_decay=self.opt.l2reg)

        max_test_acc_overall = 0
        max_f1_overall = 0
        for i in range(repeats):
            print('repeat: ', i)
            self._reset_params()
            max_test_acc, max_f1 = self._train(criterion, optimizer, max_test_acc_overall=max_test_acc_overall)
            print('max_test_acc: {0}     max_f1: {1}'.format(max_test_acc, max_f1))
            max_test_acc_overall = max(max_test_acc, max_test_acc_overall)
            max_f1_overall = max(max_f1, max_f1_overall)
            print('#' * 100)
        print("max_test_acc_overall:", max_test_acc_overall)
        print("max_f1_overall:", max_f1_overall)

In [13]:
ins = Instructor(opt)


loading tokenizer: restaurant_tokenizer.dat
loading word vectors...
building embedding_matrix: 200_restaurant_embedding_matrix.dat
n_trainable_params: 2411403, n_nontrainable_params: 917000
> training arguments:
>>> model_name: aoa
>>> dataset: restaurant
>>> seed: 1234
>>> optimizer: <class 'torch.optim.adam.Adam'>
>>> initializer: <function xavier_uniform_ at 0x0000025139DD71E0>
>>> log_step: 5
>>> logdir: log
>>> embed_dim: 200
>>> hidden_dim: 300
>>> max_seq_len: 80
>>> polarities_dim: 3
>>> hops: 3
>>> device: cpu
>>> learning_rate: 0.001
>>> batch_size: 128
>>> l2reg: 0.0001
>>> num_epoch: 20
>>> dropout: 0.2
>>> inputs_cols: ['text_raw_indices', 'aspect_indices']
>>> dataset_file: {'train': './datasets/semeval14/Restaurants_Train.xml.seg', 'test': './datasets/semeval14/Restaurants_Test_Gold.xml.seg'}
>>> model_class: <class 'model.AOA'>


In [14]:
ins.run(5)

repeat:  0
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
epoch:  0


  'precision', 'predicted', average, warn_for)


>> saved: state_dict/aoa_restaurant_acc0.65
loss: 0.9633, acc: 0.6406, test_acc: 0.6500, f1: 0.2626
>> saved: state_dict/aoa_restaurant_acc0.6518
loss: 1.0315, acc: 0.5859, test_acc: 0.6518, f1: 0.2729
>> saved: state_dict/aoa_restaurant_acc0.6554
loss: 0.9464, acc: 0.5859, test_acc: 0.6554, f1: 0.2835
>> saved: state_dict/aoa_restaurant_acc0.6643
loss: 0.7981, acc: 0.6172, test_acc: 0.6643, f1: 0.3219
>> saved: state_dict/aoa_restaurant_acc0.6732
loss: 0.8998, acc: 0.6094, test_acc: 0.6732, f1: 0.4499
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
epoch:  1
loss: 0.8280, acc: 0.6719, test_acc: 0.6687, f1: 0.5132
>> saved: state_dict/aoa_restaurant_acc0.7018
loss: 1.1377, acc: 0.6055, test_acc: 0.7018, f1: 0.4523
loss: 0.8093, acc: 0.6302, test_acc: 0.7000, f1: 0.4896
loss: 0.7451, acc: 0.6543, test_acc: 0.6964, f1: 0.4282
>> saved: state_dict/aoa_restaurant_acc0.7036
loss: 0.7342, acc: 0.6609, test_acc: 0.7036, f1: 0.4663
>> saved: