In [None]:
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler, random_split, default_collate
from triplet_loss import *
from Siamese import SiameseModel, QuadrupletModel
import lightning.pytorch as pl
from lightning.pytorch.loggers import wandb
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from utils import TextDataset, AdvDataset, load_feature_extractor,get_results
from transformers import RobertaTokenizer, RobertaConfig, RobertaForSequenceClassification
from itertools import cycle
import argparse
from tqdm import tqdm
import numpy as np
from pathlib import Path
import wandb
import path

In [None]:
parser = argparse.ArgumentParser(description='Input Reflector')
parser.add_argument('--sia_path', default='veweew/inputReflector/model-l1kr6pgx:v6', type=str, help='Path for SIA model, you can use local ckpt path or wandb cloud checkpoint ID')
parser.add_argument('--quad_path', default='veweew/inputReflector/model-3v67c2u7:v1', type=str, help='Path for QUAD model, you can use local ckpt path or wandb cloud checkpoint ID')
parser.add_argument('--valid_set', default='valid.jsonl', type=str, help='Path to the validation set')
parser.add_argument('--train_set', default='train.jsonl', type=str, help='Path to the training set')
parser.add_argument('--test_set', default='test.jsonl', type=str, help='Path to the test set')
parser.add_argument('--pretrained_model', default='model.bin', type=str, help='Path to the pretrained model')
parser.add_argument('--sia_train_embedding', type=str, help='Optional, path to SIA training embeddings')
parser.add_argument('--sia_valid_embedding', type=str, help='Optional, Path to SIA validation embeddings')
parser.add_argument('--sia_test_embedding',  type=str, help='Optional, Path to SIA test embeddings')
parser.add_argument('--quad_train_embedding', type=str, help='Optional, Path to QUAD training embeddings')
parser.add_argument('--quad_test_embedding', type=str, help='Optional, Path to QUAD test embeddings')
parser.add_argument('--batch_size', default=32, type=int, help='Batch size')
parser.add_argument('--t_in', default=95, type=int, help='t_in parameter')
parser.add_argument('--t_out', default=98, type=int, help='t_out parameter')

default_args_list = [
    '--sia_path', 'veweew/inputReflector/model-l1kr6pgx:v6',
    '--quad_path', 'veweew/inputReflector/model-3v67c2u7:v1',
    '--valid_set', 'valid.jsonl',
    '--train_set', 'train.jsonl',
    '--test_set', 'test.jsonl',
    '--pretrained_model', 'model.bin',
    '--batch_size', '32',
    '--t_in', '95',
    '--t_out', '98'
]

args = parser.parse_args(default_args_list)

In [None]:
if not args.sia_path.endswith('.ckpt'):
    run = wandb.init()
    artifact = run.use_artifact(args.sia_path, type='model')
    artifact_dir = artifact.download()
    args.sia_path = Path(artifact_dir) / "model.ckpt"

tokenizer, feature_extractor = load_feature_extractor(args.pretrained_model)
block_size = tokenizer.max_len_single_sentence
model = SiameseModel(feature_extractor)
model.load_from_checkpoint(args.sia_path)
model.to('cuda')
model.eval()
print("Sia Model loaded!")

In [None]:
test_set = TextDataset(tokenizer, block_size, args.test_set)
train_set = TextDataset(tokenizer, block_size, args.train_set)
valid_set = TextDataset(tokenizer, block_size, args.valid_set)

test_loader = DataLoader(test_set, batch_size = args.batch_size)
train_loader = DataLoader(train_set, batch_size = args.batch_size)
valid_loader = DataLoader(valid_set, batch_size = args.batch_size)

In [None]:
if args.sia_train_embedding:
    with open(args.sia_train_embedding, 'rb') as f:
        train_embedding_vectors = np.load(f)
else:
    train_embeddings = []
    for train in tqdm(train_loader):
        with torch.no_grad():
            output = model(train[0])
            train_embeddings.append(output.cpu())
    train_embedding_vectors = torch.cat(train_embeddings, dim=0).detach().numpy()
    with open('sia_train_embeddings.npy', 'wb') as f:
        np.save(f, train_embedding_vectors)

if args.sia_valid_embedding:
    with open(args.sia_valid_embedding, 'rb') as f:
        valid_embedding_vectors = np.load(f)
else:
    valid_embeddings = []
    for valid in tqdm(valid_loader):
        with torch.no_grad():
            output = model(valid[0])
            valid_embeddings.append(output.cpu())
    valid_embedding_vectors = torch.cat(valid_embeddings, dim=0).detach().numpy()
    with open('sia_valid_embeddings.npy', 'wb') as f:
        np.save(f, valid_embedding_vectors)

if args.sia_test_embedding:
    with open(args.sia_test_embedding, 'rb') as f:
        test_embedding_vectors = np.load(f)    
else:
    test_embeddings = []
    for test in tqdm(test_loader):
        with torch.no_grad():
            output = model(test[0])
            test_embeddings.append(output.cpu())
    test_embedding_vectors = torch.cat(test_embeddings, dim=0).detach().numpy()
    with open('sia_test_embeddings.npy', 'wb') as f:
        np.save(f, test_embedding_vectors)   

In [None]:
def calc_dist(x, trains):
    # Calculate distances
    distances = np.empty(shape=(x.shape[0],))
    index = []
    for i in tqdm(range(x.shape[0])):
        dises = np.sqrt(np.sum(np.asarray(x[i] - trains) ** 2, axis=1))
        distance = np.sort(dises)[0]
        index.append(np.argsort(dises)[0])
        distances.put(i, distance)

    return distances, index

In [None]:
val_distance,_ = calc_dist(valid_embedding_vectors, train_embedding_vectors)

In [None]:
#Minimal Distance for out-of-distribution examples
threshold_out = np.percentile(val_distance, args.t_out)
#Maximum Distances for in-distribution examples
threshold_in = np.percentile(val_distance, args.t_in)

In [None]:
test_distance, _ = calc_dist(test_embedding_vectors, train_embedding_vectors)

In [None]:
#Labels Generated by Distribution Analyzer
out_of_distribution_examples = test_distance > threshold_out
in_distribution_examples = test_distance < threshold_in
deviating_examples = ~(out_of_distribution_examples | in_distribution_examples)

In [None]:
pred, labels = get_results(feature_extractor, test_set, 32)

In [None]:
torch.sum(torch.tensor(~in_distribution_examples) & (pred != labels)) / torch.sum((pred != labels))

In [None]:
if not args.quad_path.endswith('.ckpt'):
    run = wandb.init()
    artifact = run.use_artifact(args.quad_path, type='model')
    artifact_dir = artifact.download()
    args.quad_path = Path(artifact_dir) / "model.ckpt"
model = QuadrupletModel(feature_extractor)
model.load_from_checkpoint(args.quad_path)
model.to('cuda')
model.eval()
print('Quad Model Loaded!!')

In [None]:
#Now, Revise the output for deviating examples
if args.quad_train_embedding:
    with open(args.quad_train_embedding, 'rb') as f:
        train_embedding_vectors = np.load(f)
else:
    train_embeddings = []
    for train in tqdm(train_loader):
        with torch.no_grad():
            output = model(train[0])
            train_embeddings.append(output.cpu())
    train_embedding_vectors = torch.cat(train_embeddings, dim=0).detach().numpy()
    with open('quad_train_embeddings.npy', 'wb') as f:
        np.save(f, train_embedding_vectors)

if args.quad_test_embedding:
    with open(args.quad_test_embedding, 'rb') as f:
        test_embedding_vectors = np.load(f)    
else:
    test_embeddings = []
    for test in tqdm(test_loader):
        with torch.no_grad():
            output = model(test[0])
            test_embeddings.append(output.cpu())
    test_embedding_vectors = torch.cat(test_embeddings, dim=0).detach().numpy()
    with open('quad_test_embeddings.npy', 'wb') as f:
        np.save(f, test_embedding_vectors)   

In [None]:
dis, idx = calc_dist(test_embedding_vectors, train_embedding_vectors)

In [None]:
revised_prediction = pred.clone()
revised = []
for i in np.nonzero(deviating_examples)[0]:
    revised.append(train_set[idx[i]][1])
    
    
#Revised Prediction by InputReflector
revised_prediction[deviating_examples] = torch.stack(revised,dim=0)