In [1]:
import os
import numpy as np
import itertools
import torch
import copy
import random
import csv
import sys
import json
import torch.nn.functional as F
import pytorch_lightning as pl
from torch import nn
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset)
from datetime import datetime
from transformers import (
    PreTrainedModel,
    BertTokenizer,
    BertModel,
    AdamW,
    BertConfig,
    BertForSequenceClassification,
    DataProcessor,
    InputExample,
    glue_convert_examples_to_features,
)
from tqdm import tqdm_notebook, trange, tqdm
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score

## Parameters

In [2]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

args = {
    "save_results_path": 'outputs',
    "pretrain_dir": 'models',
    "bert_model": "/fred/oz064/xcai/paper1/pytorch/huggingface/bert-base-uncased",
    "max_seq_length": None,
    "feat_dim": 768,
    "warmup_proportion": 0.1,
    "freeze_bert_parameters": True,
    "save_model": True,
    "save_results": True,
    "dataset": "oos",
    "known_cls_ratio": 0.75,
    "labeled_ratio": 1.0,
    "method": None,
    "seed": 0,
    "gpu_id": '0',
    "lr": 2e-5,
    "num_train_epochs": 100.0,
    "train_batch_size": 128,
    "eval_batch_size": 64,
    "wait_patient": 10,
    "lr_boundary": 0.05,
    "num_labels": 10,
}
args = dotdict(args)

## Data Loader

In [3]:
data_path = '../data/data_full.json'
def data_read(data_path):
    reader = []
    with open (data_path) as f:
        reader = json.load(f)
    return reader      
data_read(data_path).keys()

dict_keys(['oos_val', 'val', 'train', 'oos_test', 'test', 'oos_train'])

In [4]:
# data generation
train_data = data_read(data_path)["train"]
val_data = data_read(data_path)["val"]
test_data = data_read(data_path)["test"]
oos_train_data = data_read(data_path)["oos_train"]
oos_val_data = data_read(data_path)["oos_val"]
oos_test_data = data_read(data_path)["oos_test"]

# data label generation
def label_generator(train_data, oos_train_data):
    data_label = []
    for index in range(0,len(train_data)) :
        if train_data[index][1] not in data_label:
            data_label.append(train_data[index][1])
            index = index + 1
    data_label.append(oos_train_data[0][1])
    return data_label
idx_to_type = label_generator(train_data, oos_train_data)
print(idx_to_type)

['translate', 'transfer', 'timer', 'definition', 'meaning_of_life', 'insurance_change', 'find_phone', 'travel_alert', 'pto_request', 'improve_credit_score', 'fun_fact', 'change_language', 'payday', 'replacement_card_duration', 'time', 'application_status', 'flight_status', 'flip_coin', 'change_user_name', 'where_are_you_from', 'shopping_list_update', 'what_can_i_ask_you', 'maybe', 'oil_change_how', 'restaurant_reservation', 'balance', 'confirm_reservation', 'freeze_account', 'rollover_401k', 'who_made_you', 'distance', 'user_name', 'timezone', 'next_song', 'transactions', 'restaurant_suggestion', 'rewards_balance', 'pay_bill', 'spending_history', 'pto_request_status', 'credit_score', 'new_card', 'lost_luggage', 'repeat', 'mpg', 'oil_change_when', 'yes', 'travel_suggestion', 'insurance', 'todo_list_update', 'reminder', 'change_speed', 'tire_pressure', 'no', 'apr', 'nutrition_info', 'calendar', 'uber', 'calculator', 'date', 'carry_on', 'pto_used', 'schedule_maintenance', 'travel_notifica

In [5]:
#InputExample(guid='0', text_a=train_data[0][0], label=train_data[0][1])
def create_examples(data):
    examples = []
    for i, e in enumerate(data):
        examples.append(InputExample(guid = str(i), text_a=e[0], label=e[1]))
    return examples

In [6]:
examples = create_examples(train_data)
print(examples[:3])

[InputExample(guid='0', text_a='what expression would i use to say i love you if i were an italian', text_b=None, label='translate'), InputExample(guid='1', text_a="can you tell me how to say 'i do not speak much spanish', in spanish", text_b=None, label='translate'), InputExample(guid='2', text_a="what is the equivalent of, 'life is good' in french", text_b=None, label='translate')]


In [7]:
def generate_dataloaders(tokenizer, data_path):
    def generate_dataloader_inner(examples, data_type='train'):
        features = glue_convert_examples_to_features(
            examples,
            tokenizer,
            label_list = idx_to_type,
            max_length = 64,
            output_mode = 'classification'
        )
        
        dataset = torch.utils.data.TensorDataset(
            torch.LongTensor([f.input_ids for f in features]),
            torch.LongTensor([f.attention_mask for f in features]),
            torch.LongTensor([f.token_type_ids for f in features]),
            torch.LongTensor([f.label for f in features])   
        )
        if data_type == 'train':
            sampler = torch.utils.data.RandomSampler(dataset)
        else:
            sampler = torch.utils.data.SequentialSampler(dataset)
        dataloader = torch.utils.data.DataLoader(
            dataset, sampler = sampler, batch_size = 32
        )
        return dataloader
    
    # notice here class OOS is always the last label
    train_examples = create_examples(data_read(data_path)["train"]+ data_read(data_path)["oos_train"])
    print('Load Example Finish')
    train_loader = generate_dataloader_inner(train_examples, data_type='train')
    print('Generate DataLoader Finish')

    valid_examples = create_examples(data_read(data_path)["val"] + data_read(data_path)["oos_val"])
    print('Load Example Finish')
    valid_loader = generate_dataloader_inner(valid_examples, data_type='valid')
    print('Generate DataLoader Finish')   

    test_examples = create_examples(data_read(data_path)["test"] + data_read(data_path)["oos_test"])
    print('Load Example Finish')
    test_loader = generate_dataloader_inner(test_examples, data_type='valid')
    print('Generate DataLoader Finish')
    
    return train_loader, valid_loader, test_loader

In [8]:
bert_path = "/fred/oz064/xcai/paper1/pytorch/huggingface/bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(bert_path)
train_loader, valid_loader, test_loader = generate_dataloaders(tokenizer, data_path)

Load Example Finish




Generate DataLoader Finish
Load Example Finish
Generate DataLoader Finish
Load Example Finish
Generate DataLoader Finish


In [9]:
for batch in valid_loader:
# for batch in train_loader:
    print(batch[3])

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 3, 3, 3, 3])
tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4])
tensor([4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        6, 6, 6, 6, 6, 6, 6, 6])
tensor([6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 7])
tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9,
        9, 9, 9, 9, 9, 9, 9, 9])
tensor([ 9,  9,  9,  9,  9,  9,  9,  9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10,
        10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11])
tensor([11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12,
        12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12])
tensor([12, 12, 12, 12, 13, 13, 

## model

In [10]:
def cos_loss(x, y, num_cls, w, 
             reuse=False, alpha=0.35, beta=0.35, scale=64, 
             lamb1=1, lamb2=10, ce_loss_oos=False, name='cos_margin_loss'):
    '''
    x: B x D - features
    y: B - labels
    num_cls: 1 - total class number, the last cls being out of scope
    w: num_cls x D - mean feature vectors (centroids)
    alpah: 1 - in scope margin
    beta: 1 - out of scope margin
    scale: 1 - scaling paramter
    lamb1: weight of 1-cosine
    lamb2: weight of max
    ce_loss_oos: calculate oos loss in cross entropy
    ''' 
    #normalize the feature and weight
    #(B,D)
    x_feat_norm = F.normalize(x,p=2,dim=1,eps=1e-12)
    #(D,num_cls)
    w_feat_norm = torch.transpose(F.normalize(w,p=2,dim=1,eps=1e-12), 0, 1)

    # get the scores after normalization 
    #(B,num_cls)
    xw_norm = torch.matmul(x_feat_norm, w_feat_norm)  # cosine similarity

    # xbj's loss, first row, adjust the cosine similarity by a margin, only apply to in-scope instances
    xw_norm[:, :-1] -= alpha #(B,num_cls)
#     xw_norm[:, -1] -= alpha #(B,num_cls)

    # margin based softmax loss
    loss_fn = nn.CrossEntropyLoss(reduction='none')
    ce_loss = loss_fn(xw_norm, y)
    if not ce_loss_oos:
        ce_loss[y == 150] = 0
#     print("ce_loss shape: ", ce_loss.shape)
    
    # xbj loss, second row, only applies to out of scope instances
    out_of_scope_loss_part2 = torch.max(xw_norm[:, :-1] - alpha, dim=1)[0] - xw_norm[:, -1]
    out_of_scope_loss_part2[out_of_scope_loss_part2 < 0] = 0
#     print("out_of_scope_loss_part2 shape: ", out_of_scope_loss_part2.shape)
    out_of_scope_loss = lamb1 * (1 - xw_norm[:, -1]) + lamb2 * out_of_scope_loss_part2
#     print("out_of_scope_loss shape: ", out_of_scope_loss.shape)
       
    out_of_scope_loss[y < 150] = 0
    
    loss = torch.mean(ce_loss + out_of_scope_loss)
    
    return loss 

def predict(x, w, alpha=0.35):
    '''
    x: B x D - features
    w: num_cls x D - mean feature vectors (centroids)
    ''' 
    #normalize the feature and weight
    #(B,D)
#     print("x.size():", x.size())
    x_feat_norm = F.normalize(x,p=2,dim=1,eps=1e-12)
    #(D,num_cls)
    w_feat_norm = torch.transpose(F.normalize(w,p=2,dim=1,eps=1e-12), 0, 1)

    # get the scores after normalization 
    #(B,num_cls)
    xw_norm = torch.matmul(x_feat_norm, w_feat_norm)  # cosine similarity

    xw_norm[:, :-1] -= alpha
    
    preds = xw_norm.max(1)[1]
    
    return preds

In [11]:
# torch.__version__

In [12]:
bert_model = BertModel.from_pretrained(args.bert_model)
def get_optimizer(bert_model, args):
    param_optimizer = list(bert_model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                     lr = args.lr)   
    return optimizer
optimizer = get_optimizer(bert_model, args)
    
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id           
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert_model.to(DEVICE)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [13]:
def compute_centroids(dataloader, bert_model):
    print("Computing centroids...")
    vectors = []
    all_labels = []
    with torch.no_grad():
        for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(dataloader):
            input_ids = input_ids.to(DEVICE)
            attention_mask = attention_mask.to(DEVICE)
            token_type_ids = token_type_ids.to(DEVICE)
            outputs = bert_model(input_ids, attention_mask, token_type_ids)
            pooler_output = outputs.pooler_output
            vectors.append(pooler_output.cpu())
            all_labels.append(labels.cpu())
    vectors = torch.cat(vectors, 0) # num_ins, feature_dim
    labels = torch.cat(all_labels, 0) # num_ins
    w = []
    for i in range(151):
        w.append(vectors[labels==i].mean(0, keepdim=True))
    w = torch.cat(w, 0)
    return w
w = compute_centroids(valid_loader, bert_model)

Computing centroids...


In [14]:
w = w.detach().to(DEVICE)

In [15]:
# ! mkdir checkpoints

In [16]:
def evaluate(dataloder, w):
    w = w.to(DEVICE)
    running_loss = 0.0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(dataloder):
            input_ids = input_ids.to(DEVICE)
            attention_mask = attention_mask.to(DEVICE)
            token_type_ids = token_type_ids.to(DEVICE)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = bert_model(input_ids, attention_mask, token_type_ids)
            pooler_output  = outputs.pooler_output 

            preds = predict(pooler_output, w, alpha=0.35).cpu()
            all_preds.append(preds)
            all_labels.append(labels)
            
    preds = torch.cat(all_preds)
    labels = torch.cat(all_labels)
    accuracy = torch.sum(preds == labels).item() / labels.shape[0]
    
    all_recalls = [torch.sum(preds[labels==i] == labels[labels==i]).item() / torch.sum(labels == i).item() for i in range(150)]
    out_of_scope_recall = torch.sum(preds[labels==150] == labels[labels==150]).item() / torch.sum(labels == 150).item()
    out_of_scope_precision = torch.sum(preds[labels==150] == labels[labels==150]).item() / torch.sum(preds == 150).item()
    in_scope_accuracy = torch.sum(preds[labels<150] == labels[labels<150]).item() / torch.sum(labels < 150).item()
    
    metrics = {"accuracy": accuracy, 
               "out_of_scope_recall": out_of_scope_recall, 
               "out_of_scope_precision": out_of_scope_precision,
               "in_scope_accuracy": in_scope_accuracy,
               "all_recalls": all_recalls,}
    
#     print("accuracy: ", accuracy, "out of scope recall: ", out_of_scope_recall, "out of scope precision: ", out_of_scope_precision)
#     print("all_recalls: ", all_recalls)
#     print("in_scope_accuracy: ", in_scope_accuracy)
    
    return metrics

In [None]:
LAMB1 = 1
LAMB2 = 10
CE_LOSS_OOS = False
ALPHA = 0.35
CHECKPOINT_PATH = "checkpoints/alpha_{}_lamb1_{}_lamb2_{}_celossoos_{}".format(ALPHA, LAMB1, LAMB2, CE_LOSS_OOS)
!mkdir -p $CHECKPOINT_PATH

for epoch in range(40):  # loop over the dataset multiple times

    w = compute_centroids(train_loader, bert_model).detach().to(DEVICE)
    last_w = w.cpu()
    
    running_loss = 0.0
    
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(train_loader):
        input_ids = input_ids.to(DEVICE)
        attention_mask = attention_mask.to(DEVICE)
        token_type_ids = token_type_ids.to(DEVICE)
        labels = labels.to(DEVICE)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = bert_model(input_ids, attention_mask, token_type_ids)
        pooler_output  = outputs.pooler_output 
#         print("y.size():", y.size())
        
        loss = cos_loss(pooler_output, labels, 151, w, alpha=ALPHA, 
                        lamb1=LAMB1, lamb2=LAMB2, ce_loss_oos=CE_LOSS_OOS,
                        beta=0.35, scale=64, name='cos_margin_loss')
#         print(loss.item())
        
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
    print('[%d] loss: %.3f' %
          (epoch + 1, running_loss / (i+1)))
    metrics = evaluate(valid_loader, last_w)
    print("valid: ", metrics)
    with open(os.path.join(CHECKPOINT_PATH, "metrics.txt"), "a") as metrics_out:
        metrics_out.write("epoch {}\n".format(epoch+1) + str(metrics) + "\n")
    metrics = evaluate(test_loader, last_w)
    print("test: ", metrics)
    with open(os.path.join(CHECKPOINT_PATH, "test_metrics.txt"), "a") as metrics_out:
        metrics_out.write("epoch {}\n".format(epoch+1) + str(metrics) + "\n")
    
print('Finished Training')
bert_model.save_pretrained(CHECKPOINT_PATH)
np.savetxt(os.path.join(CHECKPOINT_PATH, "last_w.txt"), last_w.numpy())


Computing centroids...
[1] loss: 4.909
{'accuracy': 0.03774193548387097, 'out_of_scope_recall': 0.98, 'out_of_scope_precision': 0.0394842868654311, 'in_scope_accuracy': 0.006333333333333333, 'all_recalls': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.95, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}
{'accuracy': 0.17927272727272728, 'out_of

[4] loss: 4.085
{'accuracy': 0.917741935483871, 'out_of_scope_recall': 0.72, 'out_of_scope_precision': 0.6792452830188679, 'in_scope_accuracy': 0.9243333333333333, 'all_recalls': [0.95, 0.9, 1.0, 0.75, 0.85, 0.95, 1.0, 1.0, 1.0, 0.85, 0.0, 1.0, 0.7, 0.85, 0.85, 1.0, 1.0, 1.0, 1.0, 0.9, 0.9, 1.0, 0.75, 1.0, 0.7, 0.9, 1.0, 0.95, 1.0, 1.0, 0.85, 0.8, 1.0, 1.0, 0.75, 0.9, 1.0, 0.85, 0.95, 0.85, 0.95, 0.85, 1.0, 0.75, 1.0, 1.0, 0.8, 0.85, 0.95, 0.95, 0.75, 1.0, 1.0, 0.9, 1.0, 0.55, 0.9, 1.0, 1.0, 1.0, 1.0, 0.7, 0.95, 0.95, 1.0, 1.0, 1.0, 0.95, 0.95, 0.95, 0.85, 0.95, 0.95, 1.0, 0.95, 0.8, 1.0, 1.0, 1.0, 0.95, 0.95, 1.0, 1.0, 1.0, 1.0, 0.9, 1.0, 0.9, 1.0, 1.0, 1.0, 1.0, 1.0, 0.65, 0.7, 1.0, 1.0, 0.75, 1.0, 1.0, 0.75, 1.0, 1.0, 0.85, 0.95, 1.0, 1.0, 0.9, 1.0, 1.0, 0.9, 1.0, 0.95, 1.0, 1.0, 0.95, 1.0, 1.0, 0.95, 1.0, 1.0, 1.0, 1.0, 0.95, 0.9, 1.0, 0.95, 0.85, 0.8, 1.0, 1.0, 0.85, 0.95, 0.85, 0.9, 0.65, 0.95, 0.75, 1.0, 1.0, 1.0, 1.0, 1.0, 0.85, 1.0, 0.9, 0.85, 0.95, 0.9, 1.0]}
{'accuracy': 0.8

[7] loss: 4.048
{'accuracy': 0.9235483870967742, 'out_of_scope_recall': 0.66, 'out_of_scope_precision': 0.6226415094339622, 'in_scope_accuracy': 0.9323333333333333, 'all_recalls': [1.0, 0.9, 1.0, 0.5, 0.85, 0.95, 0.9, 1.0, 1.0, 0.95, 0.0, 1.0, 0.8, 0.85, 0.9, 1.0, 1.0, 1.0, 1.0, 0.9, 0.9, 1.0, 0.8, 1.0, 0.7, 0.9, 1.0, 0.95, 1.0, 0.95, 0.9, 0.9, 1.0, 1.0, 0.85, 0.9, 1.0, 0.9, 0.9, 0.85, 0.9, 0.9, 1.0, 0.9, 1.0, 0.95, 0.85, 0.85, 0.95, 1.0, 0.75, 1.0, 1.0, 1.0, 1.0, 0.7, 0.95, 1.0, 1.0, 1.0, 1.0, 0.55, 0.9, 0.95, 1.0, 1.0, 1.0, 0.95, 1.0, 0.95, 0.95, 0.9, 0.9, 1.0, 0.95, 0.8, 1.0, 1.0, 1.0, 1.0, 0.95, 1.0, 1.0, 1.0, 1.0, 0.9, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.7, 0.7, 1.0, 1.0, 0.95, 0.95, 1.0, 0.95, 1.0, 1.0, 0.9, 0.95, 1.0, 1.0, 0.9, 1.0, 1.0, 0.9, 1.0, 0.95, 1.0, 0.9, 0.9, 1.0, 1.0, 0.95, 1.0, 1.0, 1.0, 1.0, 0.95, 0.9, 1.0, 0.95, 0.85, 0.85, 1.0, 0.9, 0.9, 0.9, 0.9, 0.9, 0.65, 0.85, 0.9, 1.0, 1.0, 0.95, 1.0, 1.0, 0.9, 1.0, 0.95, 0.9, 0.95, 0.95, 1.0]}
{'accuracy': 0.8467272727272728

[10] loss: 4.036
{'accuracy': 0.927741935483871, 'out_of_scope_recall': 0.62, 'out_of_scope_precision': 0.6526315789473685, 'in_scope_accuracy': 0.938, 'all_recalls': [0.95, 0.9, 1.0, 0.75, 0.85, 0.95, 1.0, 1.0, 1.0, 0.95, 0.0, 1.0, 0.8, 0.85, 0.95, 1.0, 1.0, 1.0, 1.0, 0.9, 0.9, 1.0, 0.9, 1.0, 0.7, 0.9, 1.0, 0.95, 0.9, 0.95, 0.85, 0.9, 1.0, 1.0, 0.8, 0.9, 1.0, 0.9, 0.85, 0.85, 0.95, 0.9, 1.0, 0.95, 1.0, 1.0, 0.9, 0.9, 0.95, 0.95, 0.75, 1.0, 1.0, 1.0, 1.0, 0.8, 0.95, 1.0, 1.0, 1.0, 1.0, 0.55, 0.9, 0.95, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.95, 0.9, 1.0, 1.0, 0.95, 0.8, 1.0, 1.0, 1.0, 1.0, 0.95, 1.0, 1.0, 1.0, 0.95, 0.9, 1.0, 1.0, 1.0, 0.95, 1.0, 1.0, 1.0, 0.7, 0.7, 1.0, 1.0, 0.9, 0.95, 1.0, 0.95, 1.0, 1.0, 0.9, 0.95, 1.0, 0.9, 0.9, 1.0, 1.0, 0.95, 1.0, 0.9, 1.0, 1.0, 0.9, 0.95, 1.0, 0.95, 1.0, 1.0, 1.0, 1.0, 0.95, 0.9, 1.0, 1.0, 0.85, 0.95, 1.0, 0.9, 0.95, 0.95, 0.9, 0.95, 0.65, 0.85, 0.9, 1.0, 1.0, 0.95, 1.0, 1.0, 0.95, 1.0, 0.95, 0.95, 0.95, 0.9, 1.0]}
{'accuracy': 0.846909090909091, 'out_

[13] loss: 4.029
{'accuracy': 0.9332258064516129, 'out_of_scope_recall': 0.6, 'out_of_scope_precision': 0.6818181818181818, 'in_scope_accuracy': 0.9443333333333334, 'all_recalls': [1.0, 0.95, 1.0, 0.95, 0.85, 0.95, 1.0, 1.0, 1.0, 0.8, 0.0, 1.0, 0.8, 0.85, 0.95, 1.0, 1.0, 1.0, 1.0, 0.9, 0.9, 1.0, 0.9, 1.0, 0.7, 0.9, 0.95, 0.95, 0.9, 1.0, 0.9, 0.9, 1.0, 1.0, 0.8, 0.95, 1.0, 0.9, 1.0, 0.85, 0.95, 0.9, 1.0, 0.95, 1.0, 1.0, 0.85, 0.9, 0.95, 0.9, 0.75, 1.0, 1.0, 1.0, 1.0, 0.9, 0.95, 1.0, 1.0, 1.0, 1.0, 0.95, 0.9, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.95, 1.0, 1.0, 1.0, 0.95, 0.8, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9, 0.95, 1.0, 1.0, 0.95, 1.0, 1.0, 1.0, 0.75, 0.75, 1.0, 1.0, 0.9, 0.95, 1.0, 0.95, 1.0, 1.0, 0.9, 0.95, 1.0, 1.0, 0.9, 1.0, 1.0, 0.95, 1.0, 0.95, 1.0, 1.0, 0.6, 1.0, 1.0, 0.95, 1.0, 1.0, 1.0, 1.0, 0.95, 0.9, 1.0, 1.0, 0.85, 0.95, 1.0, 0.85, 0.9, 0.95, 0.9, 0.9, 0.65, 0.9, 0.9, 1.0, 1.0, 0.95, 1.0, 1.0, 1.0, 1.0, 0.95, 0.95, 0.95, 0.9, 1.0]}
{'accuracy': 0.84472727272727

[16] loss: 4.025
{'accuracy': 0.9225806451612903, 'out_of_scope_recall': 0.57, 'out_of_scope_precision': 0.6785714285714286, 'in_scope_accuracy': 0.9343333333333333, 'all_recalls': [0.95, 0.95, 1.0, 0.9, 0.85, 0.95, 0.95, 1.0, 1.0, 0.95, 0.15, 1.0, 0.8, 0.85, 0.95, 1.0, 1.0, 1.0, 1.0, 0.9, 0.9, 1.0, 0.9, 1.0, 0.7, 0.9, 0.95, 0.95, 0.9, 1.0, 0.95, 0.9, 1.0, 1.0, 0.85, 0.95, 1.0, 0.95, 0.95, 0.75, 0.9, 0.9, 1.0, 0.9, 1.0, 1.0, 0.9, 0.9, 0.95, 0.9, 0.7, 1.0, 1.0, 0.95, 1.0, 0.8, 0.95, 1.0, 1.0, 1.0, 1.0, 0.35, 0.95, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.95, 1.0, 1.0, 1.0, 0.9, 0.8, 1.0, 1.0, 1.0, 1.0, 0.95, 1.0, 1.0, 1.0, 0.95, 0.95, 0.95, 1.0, 1.0, 0.95, 1.0, 1.0, 1.0, 0.75, 0.5, 1.0, 1.0, 0.9, 0.95, 1.0, 0.95, 0.95, 1.0, 0.85, 0.95, 1.0, 1.0, 0.85, 1.0, 1.0, 0.95, 1.0, 0.95, 1.0, 0.85, 0.95, 0.95, 1.0, 0.7, 1.0, 1.0, 1.0, 1.0, 0.95, 0.9, 1.0, 1.0, 0.85, 0.9, 1.0, 0.85, 0.95, 0.9, 0.9, 0.95, 0.65, 0.8, 0.9, 1.0, 0.95, 1.0, 1.0, 1.0, 1.0, 1.0, 0.9, 0.85, 0.95, 0.9, 1.0]}
{'accuracy': 0.840

In [None]:
# bert_model = BertModel.from_pretrained(CHECKPOINT_PATH)
# bert_model.to(DEVICE)
# metrics = evaluate(test_loader, last_w)
# print(metrics)
# with open(os.path.join(CHECKPOINT_PATH, "test_metrics.txt"), "a") as metrics_out:
#     metrics_out.write(str(metrics) + "\n")

In [None]:
# bert_model.save_pretrained(CHECKPOINT_PATH)
# np.savetxt(os.path.join(CHECKPOINT_PATH, "last_w.txt"), last_w.numpy())

# Visualization

In [None]:
# %matplotlib inline
# from sklearn.manifold import TSNE
# from matplotlib import pyplot as plt
# tsne = TSNE(n_components=2, random_state=0)
# X_2d = tsne.fit_transform(last_w.numpy())

# target_ids = range(len(idx_to_type))


# plt.figure(figsize=(6, 5))
# for i, label in zip(target_ids, idx_to_type):
#     plt.scatter(X_2d[i, 0], X_2d[i, 1], label=label)
# plt.legend()
# plt.show()

In [None]:
# def predict_for_visualization(dataloder, w):
#     w = w.to(DEVICE)
#     all_labels = []
#     all_vectors = []
#     with torch.no_grad():
#         for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(dataloder):
#             input_ids = input_ids.to(DEVICE)
#             attention_mask = attention_mask.to(DEVICE)
#             token_type_ids = token_type_ids.to(DEVICE)

#             # forward + backward + optimize
#             outputs = bert_model(input_ids, attention_mask, token_type_ids)
#             pooler_output  = outputs.pooler_output 

#             all_vectors.append(pooler_output)
#             all_labels.append(labels)
            
#     labels = torch.cat(all_labels)
#     vectors = torch.cat(all_vectors, 0)
    
#     return labels.cpu().numpy(), vectors.cpu().numpy()
# labels, vectors = predict_for_visualization(train_loader, last_w)

In [None]:
# vectors.shape
# vectors_and_weights = np.concatenate([vectors, last_w.numpy()], 0)

In [None]:
# vectors_and_weights.shape

In [None]:
# %matplotlib inline
# from sklearn.manifold import TSNE
# from matplotlib import pyplot as plt
# tsne = TSNE(n_components=2, random_state=0)
# X_2d = tsne.fit_transform(vectors_and_weights)

# target_ids = range(len(idx_to_type[:]))
# X_2d_vectors = X_2d[:-151]
# X_2d_w = X_2d[-151:]

# plt.figure(figsize=(15, 15))
# for i, label in zip(target_ids, idx_to_type[:]):
#     if i == 150:
#         plt.scatter(X_2d_vectors[labels==i, 0], X_2d_vectors[labels==i, 1], c='k', label=label)
#         plt.scatter(X_2d_w[i, 0], X_2d_w[i, 1], c='k', label=label+"_w")
#     elif i > 130:
#         plt.scatter(X_2d_vectors[labels==i, 0], X_2d_vectors[labels==i, 1], label=label)
#         plt.scatter(X_2d_w[i, 0], X_2d_w[i, 1], label=label+"_w")
# plt.legend()
# plt.show()