## BertGRU-based AAN  models.

In [8]:
### All the test is based on torch-1.2.0 and torchtext-0.6.0


import torch
from torchtext import data
import random
import numpy as np
import os

from torch.nn import functional as F
from model.tools import categorical_accuracy,epoch_time
from transformers import BertTokenizer

os.environ["CUDA_VISIBLE_DEVICES"] = "3"

SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True


### prepare data loader based on torchtext.

In [9]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
init_token = tokenizer.cls_token
eos_token = tokenizer.sep_token
pad_token = tokenizer.pad_token
unk_token = tokenizer.unk_token
init_token_idx = tokenizer.convert_tokens_to_ids(init_token)
eos_token_idx = tokenizer.convert_tokens_to_ids(eos_token)
pad_token_idx = tokenizer.convert_tokens_to_ids(pad_token)
unk_token_idx = tokenizer.convert_tokens_to_ids(unk_token)
max_input_length = tokenizer.max_model_input_sizes['bert-base-uncased']

def tokenize_and_cut(sentence):
    tokens = tokenizer.tokenize(sentence) 
    tokens = tokens[:max_input_length-2]
    return tokens


def get_iterator_feature(source_file, target_file, BATCH_SIZE=128):
    TEXT = data.Field(batch_first = True,
                  use_vocab = False,
                  tokenize = tokenize_and_cut,
                  preprocessing = tokenizer.convert_tokens_to_ids,
                  init_token = init_token_idx,
                  eos_token = eos_token_idx,
                  pad_token = pad_token_idx,
                  unk_token = unk_token_idx)

    LABEL = data.LabelField(dtype = torch.long)

    fields = {'review': ('text', TEXT), 'label': ('label', LABEL)}
    # source_file = 'elec.json'
    train_data = data.TabularDataset.splits(
                        path = 'datasets'+os.sep+"amazon_text",
                        train = source_file,
                        format = 'json',
                        fields = fields
    )

    test_data = data.TabularDataset.splits(
                            path = 'datasets'+os.sep+"amazon_text",
                            train = target_file,
                            format = 'json',
                            fields = fields
    )

    train_data = train_data[0]

    test_data = test_data[0]


    test_data, valid_data = test_data.split(random_state = random.seed(SEED), split_ratio=0.98)


    LABEL.build_vocab(train_data)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    source_iterator, target_iterator,valid_iterator = data.BucketIterator.splits(
        (train_data, test_data, valid_data), 
        batch_size = BATCH_SIZE, 
        sort=False,
        shuffle = True,
        device = device)

    return source_iterator, target_iterator, valid_iterator, TEXT


## Initialize AAN model.

Two version of AAN you can choose: "AAN" and its adversarial version "AAN-A". set *aan_version*='AAN' or 'AAN-A' to select different version.

For AAN, you is required to set a hyperparameter *MU*, default: *0.1*. 
For AAN-A, *MU* is an invalid parameter. 

In [10]:
from model.models import  AANBertGRU
from model.criterion import MMD_loss
from transformers import BertModel
import torch.optim as optim
import torch.nn as nn
 
aan_version='AAN-A'
MU = 0.1

dataset = ['book.json','cd.json','elec.json','kitchen.json']
source_file =dataset[0]
target_file = dataset[1]

source_iterator, target_iterator, valid_iterator, TEXT = get_iterator_feature(source_file, target_file, BATCH_SIZE=128)


HIDDEN_DIM = 256
OUTPUT_DIM = 2
N_LAYERS = 2
BIDIRECTIONAL = True
DROPOUT = 0.25
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



bert = BertModel.from_pretrained('bert-base-uncased')
model = AANBertGRU(bert, HIDDEN_DIM,OUTPUT_DIM, N_LAYERS, BIDIRECTIONAL, DROPOUT, aan_version)

### freeze the Bert model.
for name, param in model.named_parameters():                
    if name.startswith('bert'):
        param.requires_grad = False

if aan_version == 'AAN':
    optimizer_task = optim.Adam(model.parameters())
else:
    optimizer_task = optim.Adam([{'params':model.extractor.parameters()},{"params":model.rnn.parameters()},{'params':model.predictor.parameters()},{'params':model.bert.parameters()}])
    optimizer_kernel = optim.Adam([{'params':model.mmd_linear.parameters()},{'params':model.cmmd_linear.parameters()}])


criterion = nn.CrossEntropyLoss()
model = model.to(device)
criterion = criterion.to(device)
mmd_loss = MMD_loss(kernel_type='mmd', kernel_mul=2.0, kernel_num=5)
cmmd_loss = MMD_loss(kernel_type='cmmd', kernel_mul=2.0, kernel_num=5,eplison=0.00001)


### Training AAN (AAN-A) models.

In [11]:
import time
from model.tools import train_adverisal, train_normal, epoch_time, evaluate


N_EPOCHS = 10
best_loss = 100.0
best_epoch = 0

for epoch in range(N_EPOCHS):

    start_time = time.time()
    if aan_version == 'AAN-A':
        train_loss = train_adverisal(model,source_iterator,target_iterator,optimizer_task,optimizer_kernel,criterion,mmd_loss,cmmd_loss)
    else:
        train_loss = train_normal(model,source_iterator,target_iterator,optimizer_task,criterion,mmd_loss,cmmd_loss,MU)

    eval_acc, eval_loss = evaluate(model, valid_iterator, criterion)
    if eval_loss < best_loss:
        best_loss = eval_loss
        best_epoch = epoch
        torch.save(model.state_dict(),'bert-aan-model.pt')

    end_time = time.time()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s |Best Epoch:{best_epoch}',flush=True)
    print(f'\tTrain Loss: {train_loss:.3f}|Valid Acc: {eval_acc:.3f}',flush=True)

Epoch: 01 | Epoch Time: 16m 50s |Best Epoch:0
	Train Loss: -0.168|Valid Acc: 0.870
Epoch: 02 | Epoch Time: 16m 55s |Best Epoch:1
	Train Loss: -0.139|Valid Acc: 0.912
Epoch: 03 | Epoch Time: 16m 50s |Best Epoch:2
	Train Loss: -0.143|Valid Acc: 0.932
Epoch: 04 | Epoch Time: 16m 55s |Best Epoch:3
	Train Loss: -0.130|Valid Acc: 0.932
Epoch: 05 | Epoch Time: 16m 49s |Best Epoch:3
	Train Loss: -0.131|Valid Acc: 0.925
Epoch: 06 | Epoch Time: 16m 50s |Best Epoch:3
	Train Loss: -0.133|Valid Acc: 0.902
Epoch: 07 | Epoch Time: 16m 48s |Best Epoch:3
	Train Loss: -0.136|Valid Acc: 0.917
Epoch: 08 | Epoch Time: 16m 50s |Best Epoch:3
	Train Loss: -0.144|Valid Acc: 0.930
Epoch: 09 | Epoch Time: 16m 47s |Best Epoch:3
	Train Loss: -0.131|Valid Acc: 0.925
Epoch: 10 | Epoch Time: 16m 48s |Best Epoch:3
	Train Loss: -0.132|Valid Acc: 0.917


### Test AAN models.

In [13]:
model.load_state_dict(torch.load('bert-aan-model.pt'))
eval_acc, eval_loss = evaluate(model,target_iterator,criterion)
print('from %s to %s, acc is %f'%(source_file,target_file, eval_acc),flush=True)

from book.json to cd.json, acc is 0.929082
