In [None]:
import numpy as np
import torch
import pandas as pd
import plotly.express as px

from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline
from datasets import load_dataset

from tqdm.notebook import trange, tqdm


In [None]:
tokenizer = AutoTokenizer.from_pretrained("fabriceyhc/bert-base-uncased-imdb")

model = AutoModelForSequenceClassification.from_pretrained("fabriceyhc/bert-base-uncased-imdb")

In [None]:
dataset = load_dataset("imdb")

## Generate IN distribution

In [None]:
def features_hook_0(model, inp, output):
    global feat_0
    feat_0 = output

def features_hook_1(model, inp, output):
    global feat_1
    feat_1 = output

def features_hook_2(model, inp, output):
    global feat_2
    feat_2 = output

def features_hook_3(model, inp, output):
    global feat_3
    feat_3 = output

def features_hook_4(model, inp, output):
    global feat_4
    feat_4 = output

def features_hook_5(model, inp, output):
    global feat_5
    feat_5 = output

def features_hook_6(model, inp, output):
    global feat_6
    feat_6 = output

def features_hook_7(model, inp, output):
    global feat_7
    feat_7 = output

def features_hook_8(model, inp, output):
    global feat_8
    feat_8 = output

def features_hook_9(model, inp, output):
    global feat_9
    feat_9 = output

def features_hook_10(model, inp, output):
    global feat_10
    feat_10 = output

def features_hook_11(model, inp, output):
    global feat_11
    feat_11 = output

features_hooks = [features_hook_0, features_hook_1, features_hook_2, 
                    features_hook_3, features_hook_4, features_hook_5, 
                    features_hook_6, features_hook_7, features_hook_8, 
                    features_hook_9, features_hook_10, features_hook_11]



feat_hook = [model.base_model.encoder.layer[i].register_forward_hook(features_hooks[i]) for i in range(12)]

In [None]:
def get_lattent_representation(input_data, model):
    pipe = TextClassificationPipeline(
                                model=model, tokenizer=tokenizer
                                )
    pipe(input_data)
    feats = [feat_0[0], feat_1[0], feat_2[0], 
                feat_3[0], feat_4[0], feat_5[0], 
                feat_6[0], feat_7[0], feat_8[0], 
                feat_9[0], feat_10[0], feat_11[0]]
    aggregated_features = torch.mean(torch.stack([i[0, 0, :] for i in feats]), dim= 0)

    return aggregated_features


In [6]:
distrib =[] 
fail = []
for i in tqdm(range(len(dataset['train']))):
    if i%1000 == 0:
        print(i)
    try:
        distrib.append(get_lattent_representation(dataset['train'][i]['text'], model))
    except:
        fail.append(i)

def process_distrib(distrib):
    return np.vstack([distrib[i].numpy().flatten() for i in range(len(distrib))])

distrib = process_distrib(distrib)
pd.DataFrame(distrib).to_csv('distrib_.csv', index = False)

# distrib = np.array(pd.read_csv('distrib.csv'))
# distrib = distrib[:, 1:]
# distrib.shape

## Functions for distribution-input distance

In [None]:
def generate_sphere_point(ndim):
    vec = np.random.randn(ndim)
    vec /= np.linalg.norm(vec, axis=0)
    return vec

def compute_minimum_value(x, distrib, u_k):
    positive_rate = np.mean(np.array([np.dot(u_k, distrib[i] - x) for i in range(len(distrib))]) > 0)
    return min(positive_rate, 1 - positive_rate)

def D(x, distrib, n_proj = 10):
    u = [generate_sphere_point(x.shape[0]) for _ in range(n_proj)]
    vector_of_minimums = [compute_minimum_value(x, distrib, u_k) for u_k in u]
    return np.mean(vector_of_minimums)

## Importing OUT Data

In [None]:
out_dataset = load_dataset("sst2")

## Computing OUT distances 

#### Analysing D convergence

In [None]:
def D_(x, distrib, n_proj = 10):
    u = [generate_sphere_point(x.shape[0]) for _ in range(n_proj)]
    vector_of_minimums = [compute_minimum_value(x, distrib, u_k) for u_k in u]
    return (vector_of_minimums)

x = np.random.random(x.shape)
a = D_(x, distrib, 1000)

px.line([np.mean(a[:i]) for i in range(1, len(a))], template = 'none')

#### Computing

In [None]:
out_distances = []
for i in tqdm(range(1000)):
    x = get_lattent_representation(out_dataset['test'][i]['sentence'], model)
    out_distances.append(D(x.numpy(), distrib))

## Computing IN distances

In [None]:
in_distances = []
in_fail = []

for i in tqdm(range(1000)):
    if i%100 == 0:
        print(i)
    try:
        x = get_lattent_representation(dataset['test'][i]['text'], model)
        in_distances.append(D(x.numpy(), distrib))
    except:
        in_fail.append(i)

## Results

In [None]:
out_distances = out_distances[:len(in_distances)]

In [None]:
px.histogram(
            out_distances + in_distances, 
            color = ['out']*len(out_distances) + ['in']*len(in_distances), 
            template = 'none'
            )

In [None]:
x = get_lattent_representation(out_dataset['test'][0]['sentence'], model)

distances = [(D(x.numpy(), distrib, i)) for i in tqdm(range(1, 100))]    