In [None]:
import numpy as np
import torch
import csv
import time
import datetime
import os

from transformers import AutoTokenizer, BertForSequenceClassification
from transformers import AdamW, get_linear_schedule_with_warmup
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.nn.utils import clip_grad_norm_
from sklearn.metrics import accuracy_score, classification_report


# Hyperparameters
EPOCHS = 100
BATCH_SIZE = 128
PRINT_TIME_PER_STEP = 40
LEARNING_RATE = 2e-5
WEIGHT_DECAY = 1e-2
EPSILON = 1e-8
evaluation_mode = "new"
# evaluation mode:
    # 'normal': no semantic envolved
    # 'unseen_first': if max semantic class in unseen, assign
    # 'avg': if semantic score of pred < average score of that class, 

#paths
train_class_path = "./data/SNIPS/class/train_classes.txt"
test_class_path = "./data/SNIPS/class/test_classes.txt"
train_data_path = "./data/SNIPS/sample/train_split.csv"
test_data_path = "./data/SNIPS/sample/test_split.csv"
train_description_path = "./data/SNIPS/description/train_description_2.txt"
test_description_path = "./data/SNIPS/description/test_description_2.txt"

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

## Process train classes

In [None]:
# process train classes
train_class_list = []
with open(train_class_path, 'r') as f:
    lines = f.readlines()
    for line in lines:
        train_class_list.append(line.split('\n')[0])

train_class_index = {}

for i in range(len(train_class_list)):
    label = train_class_list[i]

    train_class_index[label] = i
    
n_seen = len(train_class_list)

In [None]:
print(len(train_class_list), "Train Class list:\n", train_class_list)
print(len(train_class_index), "Train Class index:\n", train_class_index)

## Process test classes

In [None]:
# process test classes
test_class_list = []
with open(test_class_path, 'r') as f:
    lines = f.readlines()
    for line in lines:
        test_class_list.append(line.split('\n')[0])

test_class_index = {}

for i in range(len(test_class_list)):
    label = test_class_list[i]

    test_class_index[label] = i
    
n_all = len(test_class_list)
n_unseen = n_all - n_seen

In [None]:
print(len(test_class_list), "Test Class list:\n", test_class_list)
print(len(test_class_index), "Test Class index:\n", test_class_index)

In [None]:
print("%d Seen Classes" % n_seen)
print("%d Unseen Classes" % n_unseen)

## Process train data

In [None]:
# process train data
def onehot(idx, class_index):
    onehot = torch.zeros(len(class_index))
    onehot[idx] = 1.
    
    return onehot

train_sentence_list = []
train_label_list = []
with open(train_data_path, newline='',encoding='windows-1252') as datas:
    sentences = csv.reader(datas)
    #headers = next(datas)

    for sen in sentences:
        train_label_list.append(sen[0])
        train_sentence_list.append(sen[1])

train_onehot_label = []
target = []

for i in range(len(train_sentence_list)):
    train_onehot_label.append(onehot(train_class_index[train_label_list[i]], train_class_index))
    target.append(train_class_index[train_label_list[i]])

train_target = torch.tensor(target)

In [None]:
print(len(train_sentence_list), "Train Sentences\n")
print(len(train_target), "Train Targets:\n", train_target)

## Process test data

In [None]:
# process test data
test_sentence_list = []
test_label_list = []
with open(test_data_path, newline='',encoding='windows-1252') as datas:
    sentences = csv.reader(datas)
    #headers = next(datas)

    for sen in sentences:
        test_label_list.append(sen[0])
        test_sentence_list.append(sen[1])

test_onehot_label = []
target = []

for i in range(len(test_sentence_list)):
    test_onehot_label.append(onehot(test_class_index[test_label_list[i]], test_class_index))
    target.append(test_class_index[test_label_list[i]])

test_target = torch.tensor(target)

In [None]:
print(len(test_sentence_list), "Test Sentences\n")
print(len(test_target), "Test Targets:\n", test_target)

In [None]:
seen_samples = 0
unseen_samples = 0
for targets in test_target:
    if targets < n_seen:
        seen_samples += 1
    else:
        unseen_samples += 1
        
print("%d Seen Samples" % seen_samples)
print("%d Unseen Samples" % unseen_samples)

## Process train description

In [None]:
# process train description
train_description_list = []
with open(train_description_path, 'r') as f:
    lines = f.readlines()
    for line in lines:
        train_description_list.append(line.split('\n')[0])

In [None]:
print(len(train_description_list), "Train Descriptions:\n", train_description_list)

## Process test description

In [None]:
# process test description
test_description_list = []
with open(test_description_path, 'r') as f:
    lines = f.readlines()
    for line in lines:
        test_description_list.append(line.split('\n')[0])

In [None]:
print(len(test_description_list), "Test Descriptions:\n", test_description_list)

## BERT preprocess

In [None]:
# BERT preprocess
BERT_tokenizer = AutoTokenizer.from_pretrained('./bert-base-uncased', local_files_only=True)

train_encoded_dict = BERT_tokenizer(train_sentence_list, padding = True, return_tensors = 'pt')
train_input_ids = train_encoded_dict['input_ids']
train_token_type_ids = train_encoded_dict['token_type_ids'] # don't need this
train_attention_mask = train_encoded_dict['attention_mask']

test_encoded_dict = BERT_tokenizer(test_sentence_list, padding = True, return_tensors = 'pt')
test_input_ids = test_encoded_dict['input_ids']
test_token_type_ids = test_encoded_dict['token_type_ids'] # don't need this
test_attention_mask = test_encoded_dict['attention_mask']

description_encoded_dict = BERT_tokenizer(test_description_list, padding = True, return_tensors = 'pt')
description_input_ids = description_encoded_dict['input_ids']
description_token_type_ids = description_encoded_dict['token_type_ids'] # don't need this
description_attention_mask = description_encoded_dict['attention_mask']

# Create DataLoader
train_data = TensorDataset(train_input_ids, train_attention_mask, train_target)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler = train_sampler, batch_size = BATCH_SIZE) 

test_data = TensorDataset(test_input_ids, test_attention_mask, test_target)
test_sampler = SequentialSampler(test_data)
test_dataloader = DataLoader(test_data, sampler = test_sampler, batch_size = BATCH_SIZE) 

for i, (train, mask, label) in enumerate (train_dataloader):
    print(train.shape, mask.shape, label.shape)
    break
print('Train datalodaer length:', len(train_dataloader))
for i, (test, mask, label) in enumerate (test_dataloader):
    print(test.shape, mask.shape, label.shape)
    break
print('Test datalodaer length:', len(test_dataloader))

print(description_input_ids.shape, description_token_type_ids.shape, description_attention_mask.shape)

## BERT sequence classification model

In [None]:
# Create model
BERT_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=5, output_hidden_states=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
BERT_model.to(device)

# Define optimizer: Exclude weight decay for bias and LayerNorm.weight
no_decay = ['bias', 'LayerNorm.weight']
optimizer_weight_decay = [
    {'params': [p for n, p in BERT_model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': WEIGHT_DECAY},
    {'params': [p for n, p in BERT_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
optimizer = AdamW(optimizer_weight_decay, lr = LEARNING_RATE, eps = EPSILON)

# learning rate scheduler
epochs = EPOCHS
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = 0, num_training_steps = total_steps)

## Accuracy and time evaluation

In [None]:
# Accuracy
def cls_acc(preds, labels):         
    correct = torch.eq(torch.max(preds, dim = 1)[1], labels.flatten()).float()
    acc = correct.sum().item() / len(correct)
    return acc

# time
def format_time(elapsed):
    elapsed_rounded = int(round((elapsed)))
    return str(datetime.timedelta(seconds = elapsed_rounded))

## Train

In [None]:
# Train model

def train(model, optimizer):
    t0 = time.time()
    avg_loss, avg_acc = [], []
    
    score_threshold = [0] * n_seen
    score_threshold_norm = [0] * n_seen
    target_count = [0] * n_seen
    
    d_input_ids, d_input_mask = description_input_ids.to(device), description_attention_mask.to(device)
    description_output = model(d_input_ids, token_type_ids=None, attention_mask=d_input_mask)
    
    torch.cuda.empty_cache()
    
    model.train()
    for step, batch in enumerate(train_dataloader):
        if step % PRINT_TIME_PER_STEP == 0 and not step == 0:
            elapsed = format_time(time.time() - t0)
            print('Batch {:>5,} of {:>5,}.  Time: {:}.'.format(step, len(train_dataloader), elapsed))
        
        b_input_ids, b_input_mask, b_labels = batch[0].long().to(device), batch[1].long().to(device), batch[2].long().to(device)
        output = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)
        loss, logits, hidden_state = output[0], output[1], output[2]
        
        torch.cuda.empty_cache()
        
        input_compare = torch.unsqueeze(hidden_state[-1][:, 0, :], 1).repeat(1, n_all, 1).detach() # [batch_size, n_seen, 768]
        description_compare = description_output[1][-1][:, 0, :].detach() # [n_seen, 768]
        compare_score = torch.sum(input_compare * description_compare, dim=-1) # [batch_size, n_seen]
        
        input_len = torch.norm(input_compare, dim=-1)
        des_len = torch.norm(description_compare, dim=-1)
        compare_score_norm = (compare_score / (input_len * des_len)).cpu()
        compare_score = compare_score.cpu()
        
        for i in range(len(b_labels)):
            score_threshold[b_labels[i]] += compare_score[i][b_labels[i]]
            score_threshold_norm[b_labels[i]] += compare_score_norm[i][b_labels[i]]
            target_count[b_labels[i]] += 1
        
        avg_loss.append(loss.item())
        
        acc = cls_acc(logits, b_labels)
        avg_acc.append(acc)
        
        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
    
    avg_loss = round(np.array(avg_loss).mean(),3)
    avg_acc = round(np.array(avg_acc).mean(),3)
    score_threshold = [x/y for x,y in zip(score_threshold, target_count)]
    score_threshold_norm = [x/y for x,y in zip(score_threshold_norm, target_count)]
    return avg_loss, avg_acc, score_threshold, score_threshold_norm

## Evaluate

In [None]:
# Evaluate model
def evaluate(model, score_threshold, score_threshold_norm):
    avg_acc = []
    model.eval()
    test_pred = torch.LongTensor([])
    compare_score = torch.DoubleTensor([])
    compare_score_norm = torch.DoubleTensor([])
    with torch.no_grad():
        d_input_ids, d_input_mask = description_input_ids.to(device), description_attention_mask.to(device)
        description_output = model(d_input_ids, token_type_ids=None, attention_mask=d_input_mask)
        
        torch.cuda.empty_cache()
        
        for batch in test_dataloader:
            b_input_ids, b_input_mask, b_labels = batch[0].long().to(device), batch[1].long().to(device), batch[2].long().to(device)
            output = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)
            
            torch.cuda.empty_cache()
            
            pred = torch.argmax(output[0], dim=1).cpu()
            
            test_pred = torch.cat((test_pred, pred))
            
            
            # output[0]: logits
            # output[1]: hidden_state (13-length tuple, output of embedding layer + 12 layers)
            #     output[1][-1][:, 0, :]: last layer ["CLS"] token output, shape:[batch_size, 768]
            input_compare = torch.unsqueeze(output[1][-1][:, 0, :], 1).repeat(1, n_all, 1) # [batch_size, n_all, 768]
            description_compare = description_output[1][-1][:, 0, :] # [n_all, 768]
            compare_score_batch = torch.sum(input_compare * description_compare, dim=-1) # [batch_size, n_all]
            
            input_len = torch.norm(input_compare, dim=-1)
            des_len = torch.norm(description_compare, dim=-1)
            compare_score_norm_batch = (compare_score_batch / (input_len * des_len)).cpu()
            compare_score_batch = compare_score_batch.cpu()
            
            compare_score = torch.cat((compare_score, compare_score_batch))
            compare_score_norm = torch.cat((compare_score_norm, compare_score_norm_batch))
            
            #acc = cls_acc(output[0], b_labels)
            #avg_acc.append(acc)
           
    test_pred_norm = torch.clone(test_pred)
    
    correct_point = [0, 0, 0, 0, 0]
    wrong_point = [0, 0, 0, 0, 0]
    
    for i in range(len(test_target)):
        argmax_class = torch.argmax(compare_score[i])
        argmax_class_norm = torch.argmax(compare_score_norm[i])
        
        if evaluation_mode == 'normal':
            pass
        elif evaluation_mode == 'unseen_first':
            if argmax_class >= n_seen:
                test_pred[i] = argmax_class
            if argmax_class_norm >= n_seen:
                test_pred_norm[i] = argmax_class_norm
        elif evaluation_mode == 'avg':
            if compare_score[i][test_pred[i]] < score_threshold[test_pred[i]]:
                unseen_score = torch.clone(compare_score[i])
                unseen_score[:n_seen] = 0
                test_pred[i] = torch.argmax(unseen_score)
            if compare_score_norm[i][test_pred_norm[i]] < score_threshold_norm[test_pred_norm[i]]:
                unseen_score = torch.clone(compare_score_norm[i])
                unseen_score[:n_seen] = 0
                test_pred_norm[i] = torch.argmax(unseen_score)
        elif evaluation_mode == 'new':
            if test_pred[i] != argmax_class:
                if compare_score[i][test_pred[i]] < score_threshold[test_pred[i]]:
                    if argmax_class >= n_seen:
                        test_pred[i] = argmax_class
                        # [2]: argmax in unseen
                        if test_pred[i] == test_target[i]:
                            correct_point[2] += 1
                        else:
                            wrong_point[2] += 1
                            
                    else:
                        if compare_score[i][argmax_class] >= score_threshold[argmax_class]:
                            test_pred[i] = argmax_class
                            # [3]: score of argmax >= threshold
                            if test_pred[i] == test_target[i]:
                                correct_point[3] += 1
                            else:
                                wrong_point[3] += 1
                        else:
                            unseen_class_scores = torch.clone(compare_score[i])
                            unseen_class_scores[:n_seen] = 0
                            test_pred[i] = torch.argmax(unseen_class_scores)
                            # [4]: score of pred, argmax < threshold -> unseen
                            if test_pred[i] == test_target[i]:
                                correct_point[4] += 1
                            else:
                                wrong_point[4] += 1
                else:
                    # [1]: score of pred >= threshold
                    if test_pred[i] == test_target[i]:
                        correct_point[1] += 1
                    else:
                        wrong_point[1] += 1
            else:
                # [0]: pred == argmax_class
                if test_pred[i] == test_target[i]:
                    correct_point[0] += 1
                else:
                    wrong_point[0] += 1
                            
            if test_pred_norm[i] != argmax_class_norm:
                if compare_score_norm[i][test_pred_norm[i]] < score_threshold_norm[test_pred_norm[i]]:
                    if argmax_class_norm >= n_seen:
                        test_pred_norm[i] = argmax_class_norm
                    else:
                        if compare_score_norm[i][argmax_class_norm] >= score_threshold_norm[argmax_class_norm]:
                            test_pred_norm[i] = argmax_class_norm
                        else:
                            unseen_class_scores = torch.clone(compare_score_norm[i])
                            unseen_class_scores[:n_seen] = 0
                            test_pred_norm[i] = torch.argmax(unseen_class_scores)
   
        
    print(classification_report(test_target, test_pred, digits=4, zero_division=1))
    
    print("Correct Point:", correct_point)
    print("Wrong Point:", wrong_point, "\n")
    
    test_acc = accuracy_score(test_target, test_pred)
    norm_acc = accuracy_score(test_target, test_pred_norm)
    print("norm_acc = {}".format(norm_acc))
    #avg_acc = round(np.array(avg_acc).mean(),3)
    return test_acc

## Main

In [None]:
# Main : train and evaluate the model
best_acc = 0
for epoch in range(epochs):
    train_loss, train_acc, score_threshold, score_threshold_norm = train(BERT_model, optimizer)
    print('epoch = {}, train_acc = {}, train_loss = {}\n'.format(epoch, train_acc, train_loss))
    test_acc = evaluate(BERT_model, score_threshold, score_threshold_norm)
    if test_acc > best_acc:
        best_acc = test_acc
    print('epoch = {}, test_acc = {}, best_acc = {}\n'.format(epoch, test_acc, best_acc))
    