In [None]:
import os
import time
import datetime
import csv
# TSNE test
import glob
import matplotlib.pyplot as plt
# TSNE test
import torch.nn.functional as F # test

from capsule_network_semantic import CapsuleNetwork
from data_preprocess_semantic import DataPreprocess

import numpy as np
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from argparse import ArgumentParser

from sklearn.metrics import accuracy_score, classification_report

from transformers import BertModel, AutoTokenizer
from sentence_transformers import SentenceTransformer

import utils
import log
import collections

from torch.nn.utils import clip_grad_norm_

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

parser = ArgumentParser()
# for dataset
parser.add_argument('--train_data_path', default='./data/SNIPS/sample/train_split.csv', type=str, help="path of training data")
parser.add_argument('--test_data_path', default='./data/SNIPS/sample/test_split.csv', type=str, help="path of testing data")
parser.add_argument('--train_class_path', default='./data/SNIPS/class/train_classes.txt', type=str, help="path of training classes")
parser.add_argument('--test_class_path', default='./data/SNIPS/class/test_classes.txt', type=str, help="path of testing classes")
parser.add_argument('--train_description_path', default='./data/SNIPS/description/train_description_2.txt', type=str, help="path of train descriptions")
parser.add_argument('--test_description_path', default='./data/SNIPS/description/test_description_2.txt', type=str, help="path of test descriptions")
parser.add_argument('--w2v_path', default='./data/wiki.en.vec', type=str, help="path of pretrained w2v")
# for training
parser.add_argument('--use_gpu', default=True, type=bool, help="To use GPU or not")
parser.add_argument('--epochs', default=100, type=int, help="number of epochs for training")
parser.add_argument('--batch_size', default=512, type=int, help="batch size for training")
parser.add_argument('--learning_rate', default=0.001, type=float, help="learning rate for training")
parser.add_argument('--drop', default=0.5, type=float, help="dropout rate for self-attention")
# for self-attention
parser.add_argument('--attention_mode', default="dimensional", type=str, choices=["normal", "dimensional"], help="mode of self-attention")
parser.add_argument('--d_a', default=10, type=int, help="hidden unit number of self-attention")
parser.add_argument('--r', default=3, type=int, help="number of self-attention heads")
# for capsule
#parser.add_argument('--n_seen', type=int, required=True, help="number of seen classes")
#parser.add_argument('--n_unseen', type=int, required=True, help="number of unseen classes")
parser.add_argument('--d_p', default=10, type=int, help="dimention of prediction vector")
# for routing
parser.add_argument('--routing_iter', default=3, type=int, help="iterations for dynamic routing")
# for loss
parser.add_argument('--alpha', default=0.001, type=float, help="coefficient of self-attention loss")
# for testing/inference
parser.add_argument('--similarity_mode', default="description", type=str, choices=["w2v", "BERT", "S_BERT", "description"], help="mode of similarity computation")
parser.add_argument('--sigma', default=0.2, type=float, help="scale for similarity")
parser.add_argument('--eval_mode', default="best", type=str, choices=["normal", "unseen_first", "avg", "avg_2stage", "avg_logits", "best"], help="mode of evaluation")
# for LOF
parser.add_argument('--use_LOF', default=False, type=bool, help="To use LOF prediction for testing or not")
parser.add_argument('--n_neighbors', default=20, type=int, help="number of nearest neighbors, i.e. MinPts in the paper")
#parser.add_argument('--contamination', default=0.2, type=float, help="contamination of LOF")
# for saveing model
parser.add_argument('--save_path', default='./saved_models/', type=str, help="path of saved models")
parser.add_argument('--save_per_epoch', default=100, type=int, help="save model per how many epochs")
# for loading model
parser.add_argument('--checkpoint_path', type=str, help="path of checkpoint to load")
# for saving log
parser.add_argument('--args_path', default='./logs/args.json', type=str, help="path of saved args")
parser.add_argument('--log_path', default='./logs/log.json', type=str, help="path of saved log")

args, unknown = parser.parse_known_args()
#args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() and args.use_gpu else "cpu")

def test(epoch, model, test_loader, similarity, n_seen, n_unseen, log_dict, description_test, score_threshold, score_threshold_norm, train_represent, evaluation_mode):
    # generalizer zero-shot classification
    print("Testing Epoch {}...\n".format(epoch))
    model.eval()
    start_time = time.time()
    test_pred = torch.LongTensor([])
    test_target = torch.LongTensor([])
    test_logits = torch.DoubleTensor([])
    class_scores = torch.DoubleTensor([])
    class_scores_norm = torch.DoubleTensor([])
    n_all = n_seen + n_unseen
    seen_des_features = train_represent.to(device) # [n_seen, r, hidden_size]
    
    with torch.no_grad():
        for input_embeddings, onehot_label, sentence_embeddings in test_loader:
            # model inference
            model.construct_unseen_capsule_weights(similarity)
            
            model.inference(input_embeddings.to(device), description_test.to(device))
            
            test_logits_batch = model.logits.cpu()
            test_logits = torch.cat((test_logits, test_logits_batch))
                             
            pred = torch.argmax(test_logits_batch, dim=1).cpu() 
            target = torch.argmax(onehot_label, dim=1).cpu()
            
            test_pred = torch.cat((test_pred, pred))
            test_target = torch.cat((test_target, target))
            
            input_semantics = torch.unsqueeze(model.semantic_features, 1).repeat(1, n_all, 1, 1) # [batch_size, 7, r, hidden_size]
            
            #des_features = model.description_features # [7, r, hidden_size]
            
            unseen_des_features = model.description_features[n_seen:] # [n_unseen, r, hidden_size]
            des_features = torch.cat((seen_des_features, unseen_des_features))
            
            semantic_scores = torch.sum(input_semantics * des_features, dim=-1) # [batch_size, 7, r]
            len_input_sen = torch.norm(input_semantics, dim=3) # [batch_size, 7, r]
            len_des = torch.norm(des_features, dim=2) # [7, r]
            
            semantic_scores_norm = semantic_scores / (len_input_sen * len_des) # [batch_size, 7, r]

            class_scores_batch = torch.sum(semantic_scores, dim=-1).cpu() # [batch_size, 7]
            class_scores_norm_batch = torch.sum(semantic_scores_norm, dim=-1).cpu() # [batch_size, 7]
            
            class_scores = torch.cat((class_scores, class_scores_batch))
            class_scores_norm = torch.cat((class_scores_norm, class_scores_norm_batch))

    
    argmax_logits = np.zeros((n_all, n_all), dtype=int)
    argmax_norm_logits = np.zeros((n_all, n_all), dtype=int)
    
    
    preds = [0] * n_all
    preds_correct = [0] * n_all
    
    preds_norm = [0] * n_all
    preds_norm_correct = [0] * n_all
    
    test_pred_norm = torch.clone(test_pred)
    
    
    correct_point = [0, 0, 0, 0, 0]
    wrong_point = [0, 0, 0, 0, 0]
    correct_point_norm = [0, 0, 0, 0, 0]
    wrong_point_norm = [0, 0, 0, 0, 0]
    
    
    for i in range(len(test_target)):
        #'''
        argmax_class_score = torch.argmax(class_scores[i])
        argmax_class_score_norm = torch.argmax(class_scores_norm[i])
        max_class_logits = torch.argmax(test_logits[i])
        #'''
        
        if evaluation_mode == 'normal':
            pass
        elif evaluation_mode == 'unseen_first':
            if argmax_class_score >= n_seen:
                test_pred[i] = argmax_class_score

            if argmax_class_score_norm >= n_seen:
                test_pred_norm[i] = argmax_class_score_norm
        
        elif evaluation_mode == 'avg':
            if class_scores[i][test_pred[i]] < score_threshold[test_pred[i]]:
                unseen_class_scores = torch.clone(class_scores[i])
                unseen_class_scores[:n_seen] = 0
                test_pred[i] = torch.argmax(unseen_class_scores)

            if class_scores_norm[i][test_pred_norm[i]] < score_threshold_norm[test_pred_norm[i]]:
                unseen_class_scores_norm = torch.clone(class_scores_norm[i])
                unseen_class_scores_norm[:n_seen] = 0
                test_pred_norm[i] = torch.argmax(unseen_class_scores_norm)
        
        elif evaluation_mode == 'avg_2stage':
            if test_pred[i] != argmax_class_score:
                if class_scores[i][test_pred[i]] < score_threshold[test_pred[i]]:
                    unseen_class_scores = torch.clone(class_scores[i])
                    unseen_class_scores[:n_seen] = 0
                    test_pred[i] = torch.argmax(unseen_class_scores)
                    # [1]: score of pred < threshold -> unseen
                    if test_pred[i] == test_target[i]:
                        correct_point[1] += 1
                    else:
                        wrong_point[1] += 1
            else:
                # [0]: pred == argmax score
                if test_pred[i] == test_target[i]:
                    correct_point[0] += 1
                else:
                    wrong_point[0] += 1


            if test_pred_norm[i] != argmax_class_score_norm:
                if class_scores_norm[i][test_pred_norm[i]] < score_threshold_norm[test_pred_norm[i]]:
                    unseen_class_scores = torch.clone(class_scores_norm[i])
                    unseen_class_scores[:n_seen] = 0
                    test_pred_norm[i] = torch.argmax(unseen_class_scores)
                    # [1]: score of pred < threshold -> unseen
                    if test_pred_norm[i] == test_target[i]:
                        correct_point_norm[1] += 1
                    else:
                        wrong_point_norm[1] += 1
            else:
                # [0]: pred == argmax score
                if test_pred_norm[i] == test_target[i]:
                    correct_point_norm[0] += 1
                else:
                    wrong_point_norm[0] += 1
                    
        elif evaluation_mode == 'avg_logits':
            if test_pred[i] != argmax_class_score:
                if test_logits[i][max_class_logits] < 0.9:
                    unseen_class_scores = torch.clone(class_scores[i])
                    unseen_class_scores[:n_seen] = 0
                    test_pred[i] = torch.argmax(unseen_class_scores)
                    # [1]: max logits < 0.9 -> unseen
                    if test_pred[i] == test_target[i]:
                        correct_point[1] += 1
                    else:
                        wrong_point[1] += 1
                else:
                    if argmax_class_score >= n_seen:
                        test_pred[i] = argmax_class_score
                        # [2]: argmax class unseen -> unseen
                        if test_pred[i] == test_target[i]:
                            correct_point[2] += 1
                        else:
                            wrong_point[2] += 1
                    else:
                        # [3]: argmax class seen
                        if test_pred[i] == test_target[i]:
                            correct_point[3] += 1
                        else:
                            wrong_point[3] += 1
            else:
                # [0]: pred == argmax score
                if test_pred[i] == test_target[i]:
                    correct_point[0] += 1
                else:
                    wrong_point[0] += 1
                    
        elif evaluation_mode == 'best':
            if test_pred[i] != argmax_class_score:
                if argmax_class_score >= n_seen:
                    test_pred[i] = argmax_class_score
                else:
                    if test_logits[i][max_class_logits] >= 0.9:
                        pass
                    else:
                        if class_scores[i][test_pred[i]] >= score_threshold[test_pred[i]]:
                            pass
                        else:
                            unseen_class_scores = torch.clone(class_scores[i])
                            unseen_class_scores[:n_seen] = 0
                            test_pred[i] = torch.argmax(unseen_class_scores)
            else:
                pass
                
        
        preds[argmax_class_score] += 1
        if argmax_class_score == test_target[i]:
            preds_correct[argmax_class_score] += 1
        
        preds_norm[argmax_class_score_norm] += 1
        if argmax_class_score_norm == test_target[i]:
            preds_norm_correct[argmax_class_score_norm] += 1
            
            
        #'''
        argmax_logits[test_target[i]][test_pred[i]] += 1
        argmax_norm_logits[test_target[i]][test_pred_norm[i]] += 1
        #'''
        
    print("Correct Point:", correct_point)
    print("Wrong Point:", wrong_point, "\n")
    
    print("Correct Point Norm:", correct_point_norm)
    print("Wrong Point Norm:", wrong_point_norm, "\n")
        
    print("Preds:", preds)
    print("Preds Correct:", preds_correct)
    print("Total:", sum(preds_correct), "\n")
    
    print("Preds_Norm:", preds_norm)
    print("Preds_Norm Correct:", preds_norm_correct)
    print("Total_Norm:", sum(preds_norm_correct), "\n")
    
    
    #'''
    final_correct = 0
    final_norm_correct = 0
    for i in range(n_all):
        print("Class {}: {}".format(i, argmax_logits[i]))
        final_correct += argmax_logits[i][i]
        
    print("\n")
    for i in range(n_all):
        print("Class Norm {}: {}".format(i, argmax_norm_logits[i]))
        final_norm_correct += argmax_norm_logits[i][i]
        
    print("\n")
    print("Final Correct:", final_correct)
    print("Final Norm Correct:", final_norm_correct)
    #'''
    
    #'''
    LOF_pred = torch.zeros(6880) # useless
    LOF_pred = LOF_pred + 1      # useless
    log_dict, acc = log.fill_log(log_dict, LOF_pred, test_target, test_pred, n_seen)        
    
    print(classification_report(test_target, test_pred, digits=4, zero_division=1))
    #'''
    
    test_epoch_time = time.time() - start_time
    
    norm_acc = accuracy_score(test_pred_norm, test_target)
    print("norm_acc: {}".format(norm_acc))
    
    return acc, test_epoch_time
            



In [None]:
if __name__ == '__main__':
    
    # handle dataset
    BERT_tokenizer = AutoTokenizer.from_pretrained('./bert-base-uncased', local_files_only=True)
    BERT_model = BertModel.from_pretrained('./bert-base-uncased', local_files_only=True)
    BERT_model.eval()
    sentence_BERT = SentenceTransformer('./paraphrase-distilroberta-base-v2')

    train_set = DataPreprocess(BERT_model, BERT_tokenizer, sentence_BERT, args.w2v_path, args.train_data_path, args.train_class_path, args.train_description_path, 'train')
    test_set = DataPreprocess(BERT_model, BERT_tokenizer, sentence_BERT, args.w2v_path, args.test_data_path, args.test_class_path, args.test_description_path, 'test')
    print('Data Preprocess end!\n\n')
    
    n_seen = len(train_set.class_index) # number of seen_classes
    n_all_classes = len(test_set.class_index) # number of all_classes
    n_unseen = n_all_classes - n_seen # number of unseen_classes
    
    similarity = torch.from_numpy(utils.get_sim(train_set, test_set, n_seen, args.sigma, args.similarity_mode)).float().to(device) # test_set contains seen and unseen labels, [n_unseen, n_seen]
    
    train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4)
    test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=True, num_workers=4)
        
    # create model
    model = CapsuleNetwork(args, n_seen, n_unseen, device).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    
    # learning rate scheduler
    scheduler = StepLR(optimizer, step_size=30, gamma=0.5)
    
    if not os.path.isdir(args.save_path):
        os.mkdir(args.save_path)
        
    # checkpoint
    if args.checkpoint_path:
        print("Loading checkpoint from {}".format(args.checkpoint_path))
        model.load_state_dict(torch.load(args.checkpoint_path))
    else:
        print("Initializing Variables (no checkpoint)\n")
        
    # prepare log
    log_dict = log.prepare_log(similarity)
    
        
    # train
    train_time = 0
    test_time = 0
    best_acc = 0

    for epoch in range(1, args.epochs + 1):
        
        print("****************************************\nTraining Epoch {}...\n".format(epoch))
        model.train()
        epoch_loss = 0
        start_time = time.time()
        train_pred = torch.LongTensor([])
        train_target = torch.LongTensor([])
        
        # for semantic score threshold
        score_threshold = [0] * n_seen
        score_threshold_norm = [0] * n_seen
        target_count = [0] * n_seen
        
        # for train represent
        train_represent = torch.zeros(n_seen, args.r, 768)
        target_count = [0] * n_seen
        
        for input_embeddings, onehot_label, sentence_embeddings in train_loader: # batch
            # model update
            model(input_embeddings.to(device), train_set.description_test.to(device))
            
            loss = model.loss(onehot_label.to(device).float())
            optimizer.zero_grad()
            loss.backward()
            clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            # train prediction
            pred = torch.argmax(model.logits, dim=1).cpu()
            target = torch.argmax(onehot_label, dim=1).cpu()

            batch_acc = accuracy_score(target, pred)

            epoch_loss += loss.item()

            train_pred = torch.cat((train_pred, pred))
            train_target = torch.cat((train_target, target))
            
            #''' for semantic score
            input_semantics = torch.unsqueeze(model.semantic_features.detach(), 1).repeat(1, n_seen, 1, 1) # [batch_size, 5, r, hidden_size]
            des_features = model.description_features.detach() # [5, r, hidden_size]
            
            semantic_scores = torch.sum(input_semantics * des_features, dim=-1) # [batch_size, 5, r]
            len_input_sen = torch.norm(input_semantics, dim=3) # [batch_size, 5, r]
            len_des = torch.norm(des_features, dim=2) # [5, r]
            
            semantic_scores_norm = semantic_scores / (len_input_sen * len_des) # [batch_size, 5, r]

            class_scores = torch.sum(semantic_scores, dim=-1).cpu() # [batch_size, 5]
            class_scores_norm = torch.sum(semantic_scores_norm, dim=-1).cpu() # [batch_size, 5]
            #''' for semantic score
            
            
            #''' for train representative
            train_semantics = model.semantic_features.detach().cpu() # [batch_size, r, hidden_size]
            #''' for train representative
            for i in range(len(target)):
                score_threshold[target[i]] += class_scores[i][train_target[i]]
                score_threshold_norm[train_target[i]] += class_scores_norm[i][train_target[i]]
                
                train_represent[target[i]] += train_semantics[i]
                target_count[target[i]] += 1

        score_threshold = [x/y for x,y in zip(score_threshold, target_count)]
        score_threshold_norm = [x/y for x,y in zip(score_threshold_norm, target_count)]
        train_represent = torch.stack([x/y for x,y in zip(train_represent, target_count)])
        #'''
        
        ''' average score for all samples
        score_threshold = 0
        score_threshold_norm = 0
        target_count = 0
        for i in range(len(train_target)):
            score_threshold += class_scores[i][train_target[i]]
            score_threshold_norm += class_scores_norm[i][train_target[i]]
            
            target_count += 1
        
        score_threshold /= target_count
        score_threshold_norm /= target_count
        '''
        
        ''' min score as threshold for each class
        score_threshold = [float("inf")] * n_seen
        score_threshold_norm = [float("inf")] * n_seen
        for i in range(len(train_target)):
            if class_scores[i][train_target[i]] < score_threshold[train_target[i]]:
                score_threshold[train_target[i]] = class_scores[i][train_target[i]]
            if class_scores_norm[i][train_target[i]] < score_threshold_norm[train_target[i]]:
                score_threshold_norm[train_target[i]] = class_scores_norm[i][train_target[i]]
        '''

        #print("Train Max:", train_max)
        acc = accuracy_score(train_target, train_pred)

        train_epoch_time = time.time() - start_time
        train_time += train_epoch_time

        print("Epoch: {}\t| Loss: {}\t| Acc: {}%".format(epoch, round(epoch_loss, 4), round(acc * 100., 2)))
        print("Epoch Time: {}s\t| Overall Train Time: {}\n\n".format(round(train_epoch_time, 2), datetime.timedelta(seconds=int(train_time))))
        test_acc, test_epoch_time = test(epoch, model, test_loader, similarity, n_seen, n_unseen, log_dict, test_set.description_test, score_threshold, score_threshold_norm, train_represent, args.eval_mode)
        test_time += test_epoch_time
        
        scheduler.step()

        if test_acc > best_acc:
            best_acc = test_acc
            print("Saving best model at epoch {}".format(epoch))
            torch.save(model.state_dict(), args.save_path + "best_model.pt") # save best model

        if (epoch % args.save_per_epoch) == 0:
            print("Saving model at epoch {}".format(epoch))
            torch.save(model.state_dict(), args.save_path + "{}.pt".format(str(epoch)))

        print("test_acc: {}".format(test_acc))
        print("best_acc: {}\n".format(best_acc))
        
        print("Test Time: {}s\t| Overall Test Time: {}\n\n".format(round(test_epoch_time, 2), datetime.timedelta(seconds=int(test_time))))
    
    # write log
    log_dict['best_acc'] = best_acc
    log.write_log(args, log_dict)
    

## Testing other stuff

In [None]:
a = torch.Tensor([0.9, 0.3, 0.5, 0.2, 0.1])
b = [0, 2, 1, 3, 4]
print(torch.topk(a, 5)[1])
if torch.topk(a, 5)[1].tolist() == b:
    print("HI")

In [None]:
import torch
a = torch.rand(10, 3, 5)
b = torch.rand(7, 3, 5)

c = torch.unsqueeze(a, 1).repeat(1, 7, 1, 1)
print(c.shape)
d = torch.sum(b*c, dim=-1)
print(d.shape)

    

In [None]:
import torch
a = torch.rand(2, 3, 4)
b = torch.zeros(3, 4)
b += 1
c = torch.zeros(2, 3, 4)
print(a)
print(a*b)
print(a*c)

In [None]:
import torch
a = torch.rand(10, 7, 3)
print(torch.argmax(a, dim=1).shape)

In [None]:
a = torch.Tensor([1, 1])
b = torch.clone(a)
print(a)
print(b)
a[0] = 2
print(a)
print(b)

In [None]:
a = torch.LongTensor([0]*7)
print(a)
a[3] += 1
print(a)
print(torch.max(a))

In [None]:
a = [1, 2, 3, 4]
b = [1, 2, 3, 4]
c = [aa/bb for aa,bb in zip(a,b)]
print(c)

In [None]:
a = torch.zeros(5, 3, 768)
b = torch.randn(3, 768)
print(b)
a[0] += b
print(a)

In [None]:
a = torch.randn(5, 3, 768)
print(a)
b = [2, 3, 4, 5, 6]

a = torch.stack([x/y for x,y in zip(a, b)])
print(a)