## Training

In [None]:
# decomposition_utils: character decomposition util functions
# models defines: CustomBert, train_loop, test_loop
from decomposition_utils import *
from models import *
from data_utils import load_livedoor, load_wikipedia, write_pickle

from sklearn.model_selection import train_test_split
from torch.optim import AdamW
from transformers import AutoTokenizer, BertModel, BertConfig
from torch.optim.lr_scheduler import ReduceLROnPlateau
import gc
import os

In [23]:
# Parameters
N_EPOCHS = 1
LR = 1e-5
PATIENCE = 2
BATCH_SIZE = 2

# Type of models
pooled = 1  # if 1, pooled; if 0, unpooled 
subcomponent = 2 # if 0, radical; if 1, subcomponent; if 2, glyph; if 3 or other number, baseline
frozen = 0 # if 1, frozen weights; if 0, unfrozen
livedoor = 1 # if 1, load livedoor data; if 0, load wikipedia data

# Define filename to load from saved model
if pooled: 
    fname = 'bert-base-japanese'
    if livedoor:
        fname += '-livedoor'
        n_labels = 9 # number of classification labels
    else: 
        fname += '-wikipedia'
        n_labels = 12
        
    if subcomponent == 0: 
        fname += '-JWE-radical'
    elif subcomponent == 1: 
        fname += '-JWE-subcomponent'
    elif subcomponent == 2: 
        fname += '-glyph'
        
    if frozen: 
        fname += '-frozen'
else: 
    fname = 'unpooled_jbert_cbert'
    
    if subcomponent == 0: 
        fname += '_radical'
    elif subcomponent == 1: 
        fname += '_subcomponent'
    elif subcomponent == 2: 
        fname += '_glyph'
    
    if livedoor: 
        fname += '_livedoor.pk'
        n_labels = 9
    else:
        fname += '_wiki.pk'
        n_labels = 9

print(fname)

# check if file exists 
file_exists = os.path.exists(os.getcwd() + "/data/models/" + fname)

if file_exists: 
    print('Trained model already exists!')
else: 
    print('Trained model does not exist!')

bert-base-japanese-livedoor-glyph
Trained model does not exist!


### Subcomponent / radical mapping definition & load JWE embeddings

In [24]:
if subcomponent != 2:
    if subcomponent == 1: 
        comp2vec_filepath = os.getcwd() + "/data/JWE-pretrained/subcomponent_comp_vec"
        char2comp_filepath = os.getcwd() + "/data/JWE/subcharacter/char2comp.txt"
    elif subcomponent == 0:
        comp2vec_filepath = os.getcwd() + "/data/JWE-pretrained/radical_comp_vec"
        char2comp_filepath = os.getcwd() + "/data/JWE/subcharacter/char2radical.txt"

    comp_vocab_size, comp_embedding_size, comp2id, comp_embeddings, pad_idx, unk_idx = parse_comp2vec(comp2vec_filepath)
    char2id, comp_list = parse_char2comp(char2comp_filepath)

    # add UNK embeddings
    if not pooled: 
        sub_embs_size = comp_embeddings.shape[-1]
        unk_sub_emb = np.full(sub_embs_size, 0).reshape(1,-1)
        SUBCOMPONENT_EMBEDDINGS_EXT = np.concatenate([comp_embeddings, unk_sub_emb], axis=0)

    print(f"Component vocab size:\t\t{comp_vocab_size}")
    print(f"Component embedding size:\t{comp_embedding_size}")
    print(f"Example components:\t\t{dict(list(comp2id.items())[0:5])}")
    print(f"Component embeddings shape:\t{comp_embeddings.shape}")
    print(f"UNK index:\t\t\t{unk_idx}")
    print(f"PAD index:\t\t\t{pad_idx}")


### Load Chinese BERT glyph embeddings & tokenizer

In [25]:
masterdir = os.getcwd() 
if subcomponent == 2:
    
    os.chdir(masterdir +'/data/ChineseBert/')
    from datasets.bert_dataset import BertDataset
    os.chdir(masterdir +'/data/ChineseBert/models')
    from modeling_glycebert import GlyceBertModel
    os.chdir(masterdir)
    
    CBERT_PATH = masterdir +'/data/ChineseBERT-large'
    
    # ChineseBERT tokenizer 
    chinese_bert_tokenizer = BertDataset(CBERT_PATH)
    chinese_bert = GlyceBertModel.from_pretrained(CBERT_PATH)
    

Some weights of the model checkpoint at /Users/zoe/Desktop/CS287/subcharacter-transfer-learning/data/ChineseBERT-large were not used when initializing GlyceBertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing GlyceBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GlyceBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Data load and split:

In [26]:
if livedoor:
    X_train, X_val, X_test, y_train, y_val, y_test = load_livedoor()
else:
    X_train, X_val, X_test, y_train, y_val, y_test = load_wikipedia()

In [27]:
# for testing 
iend = 50
X_train = X_train[:iend]
X_val = X_val[:iend]
X_test = X_test[:iend]
y_train = y_train[:iend]
y_val = y_val[:iend]
y_test = y_test[:iend]

### Data tokenization and DataLoader definition:

In [28]:
# Pooled model tokenizer 
tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-char-v2")
train_encodings = tokenizer(X_train, truncation=True, padding=True, max_length = 512)
val_encodings = tokenizer(X_val, truncation=True, padding=True, max_length = 512)
test_encodings = tokenizer(X_test, truncation=True, padding=True, max_length = 512)

# Convert to component IDs

if pooled: 
    if subcomponent != 2:
        train_subcomponent_ids, max_decomposition_length = decompose(X_train, comp_list, comp2id, char2id, 
                                                                      unk_idx, pad_idx, 
                                                                      pad_length = None)
        val_subcomponent_ids, _ = decompose(X_val, comp_list, comp2id, char2id, unk_idx, pad_idx, 
                                             pad_length = max_decomposition_length)
        test_subcomponent_ids, _ = decompose(X_test, comp_list, comp2id, char2id, unk_idx, pad_idx, 
                                              pad_length = max_decomposition_length)
    else: # glyph embeddings
        train_subcomponent_ids = get_glyph_embeddings(X_train, chinese_bert, chinese_bert_tokenizer)
        val_subcomponent_ids = get_glyph_embeddings(X_val, chinese_bert, chinese_bert_tokenizer)
        test_subcomponent_ids = get_glyph_embeddings(X_test, chinese_bert, chinese_bert_tokenizer)
        
    train_seq_lengths=None
    val_seq_lengths=None
    test_seq_lengths=None

else: # unpooled 
    if subcomponent != 2: # JWE
        train_subcomponent_ids = [text2subcomponent(i, comp_list, comp2id, char2id, unk_idx, pooled, 
                                                    tokenizer=tokenizer) for i in X_train]
        val_subcomponent_ids = [text2subcomponent(i, comp_list, comp2id, char2id, unk_idx, pooled,
                                                  tokenizer=tokenizer) for i in X_val]
        test_subcomponent_ids = [text2subcomponent(i, comp_list, comp2id, char2id, unk_idx, pooled,
                                                   tokenizer=tokenizer) for i in X_test]

        # sequence lengths for unpooled models
        train_seq_lengths = [len(ids) for ids in train_subcomponent_ids]
        val_seq_lengths = [len(ids) for ids in val_subcomponent_ids]
        test_seq_lengths = [len(ids) for ids in test_subcomponent_ids]
        
        TRAIN_SEQ_LEN = len(train_encodings['input_ids'][0])
        VAL_SEQ_LEN = len(val_encodings['input_ids'][0])
        TEST_SEQ_LEN = len(test_encodings['input_ids'][0])
        
        train_subcomponent_ids = subcomponent2emb(train_subcomponent_ids, SUBCOMPONENT_EMBEDDINGS_EXT, padding=True, seq_length = TRAIN_SEQ_LEN)
        val_subcomponent_ids = subcomponent2emb(val_subcomponent_ids, SUBCOMPONENT_EMBEDDINGS_EXT, padding=True, seq_length = VAL_SEQ_LEN)
        test_subcomponent_ids = subcomponent2emb(test_subcomponent_ids, SUBCOMPONENT_EMBEDDINGS_EXT, padding=True, seq_length = TEST_SEQ_LEN)
        
    else: 
        train_seq_lengths, train_subcomponent_ids = text2glyph(X_train, chinese_bert, chinese_bert_tokenizer, tokenizer)
        val_seq_lengths, val_subcomponent_ids = text2glyph(X_val, chinese_bert, chinese_bert_tokenizer, tokenizer)
        test_seq_lengths, test_subcomponent_ids = text2glyph(X_test, chinese_bert, chinese_bert_tokenizer, tokenizer)
        
        TRAIN_SEQ_LEN = len(train_subcomponent_ids)
        VAL_SEQ_LEN = len(val_subcomponent_ids)
        TEST_SEQ_LEN = len(test_subcomponent_ids)

        
# Initialize Dataset
train_dataset = ComponentDataset(train_encodings, y_train, train_subcomponent_ids, pooled, train_seq_lengths)
val_dataset = ComponentDataset(val_encodings, y_val, val_subcomponent_ids, pooled, val_seq_lengths)
test_dataset = ComponentDataset(test_encodings, y_test, test_subcomponent_ids, pooled, test_seq_lengths)

# Initialize DataLoader
train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size = BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size = BATCH_SIZE, shuffle=False)


In [29]:
# print(X_train[0])
# print(char2id['格'])
# print(comp_list[char2id['格']])
# print(comp2id[comp_list[char2id['格']][0]])
# print(text2subcomponent(X_train[0], comp_list, comp2id, char2id, unk_idx, pooled, tokenizer=tokenizer)[:10])

### Model training:

In [30]:
# Garbage collect
gc.collect()
torch.cuda.empty_cache()

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)


cpu


In [31]:
import pickle 
if subcomponent == 2: 
    pad_idx = 0

    comp_embeddings = chinese_bert.embeddings.glyph_embeddings 
    
if file_exists: 
    if pooled: 
        bert = BertModel.from_pretrained('cl-tohoku/bert-base-japanese-char-v2')
        model = CustomPooledModel(bert, 
                                  embeddings = comp_embeddings,
                                  num_labels = n_labels, 
                                  component_pad_idx = pad_idx, 
                                  subcomponent = subcomponent).to(device)
        model.load_state_dict(torch.load(os.getcwd() + "/data/models/" + fname +'/pytorch_model.bin', 
                                         map_location=torch.device(device)  ))
    else: 
        JAPBERT_EMB_SIZE = 768
        if subcomponent == 2: 
            comp_embedding_size = train_subcomponent_ids.shape[-1]
        LSTM_INPUT_SIZE = JAPBERT_EMB_SIZE + comp_embedding_size
        hidden_size = 200
        model = LSTMClassifier(lstm_input_size=LSTM_INPUT_SIZE, 
                               hidden_size = hidden_size, 
                               output_size = n_labels,
                               padding_idx = pad_idx, 
                               bertconfig = 'cl-tohoku/bert-base-japanese-char-v2')
        model.load_state_dict(torch.load(os.getcwd() + "/data/models/" + fname, 
                                         map_location=torch.device(device)  ))
    
else: 
    
    if pooled: 
        
        # BertModel: from transformer docs:
        # "bare Bert Model transformer outputting raw hidden-states without any specific head on top"
        bert = BertModel.from_pretrained('cl-tohoku/bert-base-japanese-char-v2')
        model = CustomPooledModel(bert, 
                                 embeddings = comp_embeddings,
                                 num_labels = n_labels, 
                                 component_pad_idx = pad_idx, 
                                 subcomponent = subcomponent).to(device)
    else: 
        JAPBERT_EMB_SIZE = 768
        if subcomponent == 2: 
            comp_embedding_size = train_subcomponent_ids.shape[-1]
        LSTM_INPUT_SIZE = JAPBERT_EMB_SIZE + comp_embedding_size
        hidden_size = 200
        model = LSTMClassifier(lstm_input_size=LSTM_INPUT_SIZE, 
                               hidden_size = hidden_size, 
                               output_size = n_labels,
                               padding_idx = pad_idx, 
                               bertconfig = 'cl-tohoku/bert-base-japanese-char-v2')

    # Freeze component embedding weights
    if frozen: 
        for param in model.subcomponent_embedding.parameters():
            param.requires_grad = False

    optimizer = AdamW(model.parameters(), lr = LR)
    lr_scheduler = ReduceLROnPlateau(optimizer, 'min', patience = PATIENCE, verbose = True)

    train_losses = []; train_accuracies = []
    test_losses = []; test_accuracies = []

    for e in range(N_EPOCHS):
        print(f"Epoch {e+1}\n-------------------------------")
        train_loss, train_acc = train_loop(train_loader, model, optimizer, device, pooled=pooled)
        test_loss, test_acc = test_loop(val_loader, model, lr_scheduler, device, pooled=pooled)
        lr_scheduler.step(test_loss)
        train_losses.append(train_loss); train_accuracies.append(train_acc)
        test_losses.append(test_loss); test_accuracies.append(test_acc)
    
    # save model
    if pooled:
        write_pickle(os.getcwd() + "/data/models/" + fname + '.pk', model)
    else: 
        write_pickle(os.getcwd() + "/data/models/" + fname, model)


Some weights of the model checkpoint at cl-tohoku/bert-base-japanese-char-v2 were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
  item['subcomponent_ids'] = torch.tensor(self.subcomponent_ids[idx])


Epoch 1
-------------------------------


100%|██████████| 25/25 [00:41<00:00,  1.66s/it]
  0%|          | 0/25 [00:00<?, ?it/s]

Train Error: 
 Accuracy: 14.0%, Avg loss: 2.220189 



100%|██████████| 25/25 [00:07<00:00,  3.34it/s]


Test Error: 
 Accuracy: 26.0%, Avg loss: 2.115010 



In [32]:
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

# evaluate on test dataset
test_predictions = []; test_labels = []
model.eval()
with torch.no_grad():
    for batch in tqdm(test_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        component_ids = batch['subcomponent_ids'].to(device)
        if not pooled: 
            lens_ = batch['seq_lengths']
            outputs = model(input_ids, attention_mask=attention_mask, lens=lens_, 
                            labels=labels,comp_embeddings = component_ids)
        else:
            outputs = model(input_ids, attention_mask=attention_mask, 
                            labels=labels, subcomponent_ids = component_ids, device=device)

        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1).cpu().numpy()
        test_predictions.extend(predictions)
        test_labels.extend(labels.cpu())

test_predictions = np.array(test_predictions)
test_labels = np.array(test_labels)
print(f"Test accuracy: {np.mean(test_predictions == test_labels)}")

# classification metrics on test dataset
from sklearn.metrics import classification_report
print(classification_report(test_labels, test_predictions))

100%|██████████| 25/25 [00:09<00:00,  2.77it/s]

Test accuracy: 0.16
              precision    recall  f1-score   support

           0       0.75      0.33      0.46         9
           1       0.00      0.00      0.00         4
           2       0.00      0.00      0.00         2
           3       0.00      0.00      0.00         4
           4       0.00      0.00      0.00         3
           5       0.00      0.00      0.00         3
           6       0.00      0.00      0.00         8
           7       0.11      1.00      0.20         5
           8       0.00      0.00      0.00        12

    accuracy                           0.16        50
   macro avg       0.10      0.15      0.07        50
weighted avg       0.15      0.16      0.10        50




  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
