# TCAV with BERT

In [0]:
! nvidia-smi

Thu Nov 28 13:19:06 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.33.01    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   45C    P0    29W / 250W |      0MiB / 16280MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|  No ru

#### Installing dependencies

In [0]:
! pip install -q nltk
! pip install -q transformers==2.1.1
! pip install -q gensim

[K     |████████████████████████████████| 317kB 2.8MB/s 
[K     |████████████████████████████████| 1.0MB 43.4MB/s 
[K     |████████████████████████████████| 645kB 44.8MB/s 
[K     |████████████████████████████████| 860kB 47.6MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone


#### Imports

In [0]:
import pandas as pd
import torch
import gensim
import uuid
from gensim.parsing.preprocessing import STOPWORDS
from nltk.stem import WordNetLemmatizer, SnowballStemmer
import numpy as np
import nltk
import torch.nn as nn
from sklearn.model_selection import train_test_split
from google.colab import files
from transformers import *
import torch
from google.colab import auth
import transformers

nltk.download('wordnet')

np.random.seed(42)

%load_ext autoreload
%autoreload 2

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Unzipping corpora/wordnet.zip.


#### Mounting Google Drive

In [0]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


#### Config

In [0]:
#TRAIN_PATH = 'imdb_reviews/train.csv'
TRAIN_PATH = '/content/drive/My Drive/train.csv'
DEVICE = 'cuda:0'

#TEST_PATH = 'imdb_reviews/test.csv'
TEST_PATH = '/content/drive/My Drive/test.csv'

DEBUG = False
VALIDATION = True  # Validation during fune-tuning XLNet

#### BERT

In [0]:
CACHE_DIR = 'cache'

class BERTClassifier(nn.Module):
    def __init__(self, pretrained_weights):
      super(BERT, self).__init__()
      self.bert_classifier = BertForSequenceClassification.from_pretrained(pretrained_weights, 
                                                                           cache_dir=CACHE_DIR)
      self.representation = None
      self.bert_classifier.bert.register_forward_hook(self.hook_fn)
      self.bert_classifier.requires_grad_(True)
    
    def hook_fn(self, module, inupt, output):
      self.representation = output

    def forward(self, input_ids: torch.Tensor, labels=None):
      if labels is not None:
        loss, logits = self.bert_classifier(input_ids=input_ids, labels=labels)
      else:
        out = self.bert_classifier(input_ids=input_ids)
        logits = out[0]

      preds = torch.argmax(logits, dim=-1)  # (batch_size, )

      if labels is None:
        return logits, preds, self.representation
      else:
        loss = nn.functional.cross_entropy(logits, labels)
        return loss, logits, preds, self.representation

    # def forward_from_representation(self, representation):
    #   logits = self.bert.classifier(representation)  # (batch_size, num_labels)
    #   preds = torch.argmax(logits, dim=-1)  # (batch_size, )
    #   return preds, logits

#### XLNet

In [0]:
CACHE_DIR = 'cache'

class XLNetClassifier(nn.Module):
  def __init__(self, pretrained_weights):
    super(XLNetClassifier, self).__init__()
    self.xlnet_classifier = XLNetForSequenceClassification.from_pretrained(pretrained_weights, 
                                                                           cache_dir=CACHE_DIR)
  
    self.grad_representation = None
    self.representation = None
    
    #self.xlnet_classifier.sequence_summary.register_forward_hook(self.forward_hook_fn)
    #self.xlnet_classifier.sequence_summary.register_backward_hook(self.backward_hook_fn)

    # using the representation of layer12 in the transformer
    for name, module in self.xlnet_classifier.transformer.named_modules(): 
      if name == 'layer.12':
        module.register_forward_hook(self.forward_hook_fn)
        module.register_backward_hook(self.backward_hook_fn)

    self.xlnet_classifier.requires_grad_(True)
  
  def forward_hook_fn(self, module, input, output):
    self.representation = output

  def backward_hook_fn(self, module, grad_input, grad_output):
    self.grad_representation = grad_output[0]

  def forward(self, input_ids: torch.Tensor, labels=None):
    if labels is not None:
      loss, logits = self.xlnet_classifier(input_ids=input_ids, labels=labels)
    else:
      out = self.xlnet_classifier(input_ids=input_ids)
      logits = out[0]

    preds = torch.argmax(logits, dim=-1)  # (batch_size, )

    if labels is None:
      return logits, preds, self.representation
    else:
      loss = nn.functional.cross_entropy(logits, labels)
      return loss, logits, preds, self.representation

  def forward_from_representation(self, representation: torch.Tensor):
    #classifier = nn.Sequential(self.xlnet_classifier.sequence_summary, self.xlnet_classifier.logits_proj)
    #logits = classifier(representation)  # (batch_size, num_labels)
    
    preds = torch.argmax(logits, dim=-1)  # (batch_size, )
    return logits

#### Creating an instance of the model

In [0]:
#@title Model
model = "xlnet" #@param ["bert", "xlnet"]
model_type = "large" #@param ["large", "base"]


if model == 'xlnet':
  if model_type == 'large':
    pretrained_weights = 'xlnet-large-cased'
  elif model_type == 'base':
    pretrained_weights = 'xlnet-base-cased'
  model = XLNetClassifier(pretrained_weights)
  tokenizer = XLNetTokenizer.from_pretrained(pretrained_weights)
elif model == 'bert':
  if model_type == 'large':
    pretrained_weights = 'bert-large-cased'
  elif model_type == 'base':
    pretrained_weights = 'bert-base-cased'
  model = BERTClassifier(pretrained_weights)
  tokenizer = BERTTokenizer.from_pretrained(pretrained_weights)

model.to(DEVICE)

#### Preprocessing training set

In [0]:
train_df = pd.read_csv(TRAIN_PATH)

if DEBUG:
  train_df = train_df[:50]

# X, y
X = train_df['sentence'].tolist()
y = train_df['polarity'].tolist()

# training and validation set for fine-tuning BERT
idxs = np.arange(len(X))

if VALIDATION:
  train_idxs, val_idxs = train_test_split(idxs, train_size=0.9)
  X_train = [X[idx] for idx in train_idxs]
  y_train = [y[idx] for idx in train_idxs]
  X_val = [X[idx] for idx in val_idxs]
  y_val = [y[idx] for idx in val_idxs]
else:
  train_idxs = idxs
  X_train = [X[idx] for idx in train_idxs]
  y_train = [y[idx] for idx in train_idxs]

In [0]:
print(len(X_train))

22500


#### Utility functions

In [0]:
def pad(sents, pad_token=0):
  sents_padded = []
  longest = max([len(sent) for sent in sents])
  sents_padded = list(map(lambda sent: sent+[pad_token]*(longest-len(sent)), sents))

  return sents_padded

In [0]:
def batch_iter(X, y, tokenizer, batch_size):
  assert len(X) == len(y)

  idxs = np.arange(len(X))
  np.random.shuffle(idxs)

  for i in range(0, len(X), batch_size):
    X_batch = [X[idx] for idx in idxs[i:i+batch_size]]
    y_batch = [y[idx] for idx in idxs[i:i+batch_size]]
      
    X_batch = [tokenizer.encode(sentence[:128], add_special_tokens=True) for sentence in X_batch]
    X_batch = torch.tensor(pad(X_batch))
    y_batch = torch.tensor(y_batch)
      
    yield X_batch, y_batch

### Accuracy on validation set w.o. fine-tuning

In [0]:
n_samples = n_correct = 0

for X, y in batch_iter(X_train, y_train, tokenizer, batch_size=8):
  n_samples += len(X)
  _, preds, _ = model(X)
  n_correct += torch.sum(preds == y).item()

val_acc = n_correct / n_samples
print('val acc: %f' % val_acc)

### Fine-tuning on dataset

##### Training config

In [0]:
learning_rate = 2e-5 #@param {type:"number"}
use_scheduler = False #@param {type:"boolean"}
warmup_steps = 500 #@param {type:"integer"}
t_total = 0 #@param {type:"integer"}

optimizer = AdamW(model.parameters(), lr=learning_rate)

if use_scheduler:
  scheduler = WarmupLinearSchedule(optimizer=optimizer, warmup_steps=warmup_steps, t_total=10000)

##### Main training loop

In [0]:
#@title Train (fine-tuning)
n_epochs = 20 #@param {type:"integer"}
batch_size = 8 #@param {type:"integer"}
val_every = 500 #@param {type:"integer"}
def fine_tune(X_train, y_train, X_val=None, y_val=None, n_epochs=3, batch_size=32, val_every=100):
  print('Running with n_epochs %d, batch_size %d, val_every %d' % (n_epochs, batch_size, val_every))
  iteration = 0
  val_accs = []
  running_loss, running_num_iter = 0., 0
  print(optimizer.param_groups[0]['lr'])
  for i in range(n_epochs):
    epoch_loss = 0.

    for X, y in batch_iter(X_train, y_train, tokenizer, batch_size=batch_size):
      iteration += 1
      running_num_iter += 1

      if iteration % 50 == 0:
        print('iteration', iteration)
        print(optimizer.param_groups[0]['lr'])

      X, y = X.to(DEVICE), y.to(DEVICE)
      optimizer.zero_grad()
      loss, _, _, _ = model(X, y)
      loss.backward()
      optimizer.step()
          
      epoch_loss += loss.item()
      running_loss += loss.item()

      if iteration % val_every == 0 and VALIDATION:
        model.eval()
        print('loss:', running_loss / running_num_iter)
              
        running_loss = 0.
        running_num_iter = 0

        print('begin validation...')
        n_correct = n_samples = 0
        for X, y in batch_iter(X_val, y_val, tokenizer, batch_size=batch_size):
          n_samples += len(X)
          X = X.to(DEVICE)
          y = y.to(DEVICE)
          _, preds, _ = model(X)
          n_correct += torch.sum(preds == y).item()
              
        val_acc = n_correct / n_samples
        print('val acc: %f' % val_acc)
              
        if len(val_accs) == 0 or max(val_accs) < val_acc:
          print('Saving a checkpoint...')
          if use_scheduler:
            params = {'model_state_dict': model.state_dict(), 
                      'optimizer_state_dict': optimizer.state_dict(), 
                      'scheduler_state_dict': scheduler.state_dict}
          else:
            params = {'model_state_dict': model.state_dict(), 
                      'optimizer_state_dict': optimizer.state_dict()}
          
          torch.save(params, 'checkpoint_{}.ckpt'.format(iteration))
              
        val_accs.append(val_acc)
        model.train(True)
      
      if use_scheduler:
        scheduler.step()

    print('epoch loss: %f' % epoch_loss)

if VALIDATION:
    if DEBUG:
      fine_tune(X_train, y_train, X_val, y_val, 
                n_epochs=5, batch_size=8, val_every=5)
    else:
      fine_tune(X_train, y_train, X_val, y_val, 
                n_epochs=n_epochs, batch_size=batch_size, val_every=val_every)
else:
    if DEBUG:
      fine_tune(X_train, y_train, X_val=None, y_val=None, 
                n_epochs=5, batch_size=8, val_every=5)
    else:
      fine_tune(X_train, y_train, X_val=None, y_val=None, 
                n_epochs=n_epochs, batch_size=batch_size, val_every=val_every)

In [0]:
#@title Train (fine-tune) from a checkpoint
checkpoint = 1500 #@param {type:"integer"}
epochs = 4 #@param {type:"integer"}
learning_rate = 5e-5 #@param {type:"number"}

def train_from_checkpoint(checkpoint, X_train, y_train, X_val, y_val, n_epochs, batch_size, val_every, learning_rate=1e-4):
  model = BERT(num_labels=2)
  print('Loading from the checkpoint %d' % checkpoint)
  
  params = torch.load('checkpoint_{}.ckpt'.format(checkpoint))
  model.linear.load_state_dict(params['linear_state_dict'])
  optimizer = AdamW(model.parameters())
  optimizer.load_state_dict(params['optimizer_state_dict'])
  
  for param_group in optimizer.param_groups:
    param_group['lr'] = learning_rate
  
  print('n_epochs: %d, learning rate: %f' % (n_epochs, learning_rate))

  iteration = checkpoint
  val_accs = []
  running_loss, running_num_iter = 0., 0

  for i in range(n_epochs):
    epoch_loss = 0.
    for X, y in batch_iter(X_train, y_train, tokenizer, batch_size=batch_size):
      iteration += 1
      running_num_iter += 1

      if iteration % 10 == 0:
        print('iteration', iteration)

      optimizer.zero_grad()
      loss, _, _ = model(X, y)
      loss.backward()
      optimizer.step()
      
      epoch_loss += loss.item()
      running_loss += loss.item()

      if iteration % val_every == 0 and VALIDATION:
        print('loss:', running_loss / running_num_iter)
        
        running_loss = 0.
        running_num_iter = 0

        print('begin validation...')
        n_correct = 0
        n_samples = 0
        for X, y in batch_iter(X_val, y_val, tokenizer, batch_size=batch_size):
            n_samples += len(X)
            _, preds, _ = model(X)
            n_correct += torch.sum(preds == y).item()
        
        val_acc = n_correct / n_samples
        print('val acc: %f' % val_acc)
        
        if len(val_accs) == 0 or max(val_accs) < val_acc:
            print('Saving checkpoint...')

            params = {
                'linear_state_dict': model.linear.state_dict(), 
                'optimizer_state_dict': optimizer.state_dict()
                }
            torch.save(params, 'checkpoint_{}.bin'.format(iteration))
            #files.download('checkpoint_{}.bin'.format(iteration))
        val_accs.append(val_acc)
    
    print('epoch loss: %f' % epoch_loss)


if VALIDATION:
    if DEBUG:
      train_from_checkpoint(checkpoint, X_train, y_train, X_val, y_val, n_epochs=epochs, batch_size=16, val_every=5, learning_rate=learning_rate)
    else:
      train_from_checkpoint(checkpoint, X_train, y_train, X_val, y_val, n_epochs=epochs, batch_size=32, val_every=100, learning_rate=learning_rate)
else:
    if DEBUG:
      train_from_checkpoint(checkpoint, X_train, y_train, X_val=None, y_val=None, n_epochs=epochs, batch_size=16, val_every=5, learning_rate=learning_rate)
    else:
      train_from_checkpoint(checkpoint, X_train, y_train, X_val=None, y_val=None, n_epochs=epochs, batch_size=32, val_every=100, learning_rate=learning_rate)

## Load a pre-trained model

In [0]:
#@title Load from cloud storage
checkpoint_name = "checkpoint_13500_xlnet_large_scheduler.ckpt" #@param {type:"string"}
project_id = 'train-on-tpu'
bucket_name = 'mreza-tpu-bucket1'
auth.authenticate_user()
! gcloud config set project {project_id}
! gsutil cp gs://{bucket_name}/{checkpoint_name} /content/

params = torch.load(checkpoint_name)
model = XLNetClassifier('xlnet-large-cased')
model.load_state_dict(params['model_state_dict'])
model.to(DEVICE)
del params

Updated property [core/project].
Copying gs://mreza-tpu-bucket1/checkpoint_13500_xlnet_large_scheduler.ckpt...
- [1 files][  4.0 GiB/  4.0 GiB]   79.3 MiB/s                                   
Operation completed over 1 objects/4.0 GiB.                                      


100%|██████████| 699/699 [00:00<00:00, 391587.89B/s]
100%|██████████| 1441285815/1441285815 [00:25<00:00, 55556916.97B/s]


Just to check the contents of bucket:

In [0]:
! gsutil ls gs://{bucket_name}/

gs://mreza-tpu-bucket1/checkpoint_13500_xlnet_large_scheduler.ckpt
gs://mreza-tpu-bucket1/checkpoint_20500_xlnet_large_wo_scheduler.ckpt
gs://mreza-tpu-bucket1/checkpoint_9000_bert_base_wo_scheduler.ckpt


In [0]:
#@title Load model from local checkpoint
checkpoint_name = "checkpoint_13500_xlnet_large_scheduler.ckpt" #@param {type:"string"}
print('loading %s' % checkpoint_name)

params = torch.load(checkpoint_name)
model = XLNetClassifier('xlnet-large-cased')
model.load_state_dict(params['model_state_dict'])
model.to(DEVICE)
del params


loading checkpoint_13500_xlnet_large_scheduler.ckpt


Creating an instance of tokenizer:

In [0]:
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')

100%|██████████| 798011/798011 [00:00<00:00, 10178539.44B/s]


### LDA on the test set

Lemmatizing, stemmizing, and removing stopwords:

In [0]:
def lemmatize_stemming(text):
    stemmer = SnowballStemmer("english")
    return stemmer.stem(WordNetLemmatizer().lemmatize(text, pos='v'))

def preprocess(text):
    result = []
    for token in gensim.utils.simple_preprocess(text):
        if token not in STOPWORDS and len(token) > 3:
            result.append(lemmatize_stemming(token))
    return result

def batch_iter(X, y):
    assert len(X) == len(y)
    idxs = np.arange(len(X))
    np.random.shuffle(idxs)

In [0]:
test_df = pd.read_csv(TEST_PATH)

test_df = test_df[:5000]

X_test = test_df['sentence'].tolist()
y_test = test_df['polarity'].tolist()

# pre-process the test set to be used for LDA
test_df_processed = test_df
test_df_processed['sentence'] = test_df['sentence'].apply(preprocess)
X_test_lda = test_df_processed['sentence'].tolist()
y_test_lda = test_df_processed['polarity'].tolist()

print(len(X_test))
print(len(X_test_lda))

5000
5000


In [0]:
dictionary = gensim.corpora.Dictionary(X_test_lda)
print('number of unique tokens:', len(dictionary))

dictionary.filter_extremes(no_below=15, no_above=0.5, keep_n=10000)
print('number of unique tokens after filtering:', len(dictionary))

# bag-of-word corpus
bow_corpus = [dictionary.doc2bow(text) for text in X_test_lda]

number of unique tokens: 23669
number of unique tokens after filtering: 3386


In [0]:
#@title LDA model
alpha = 0.31 #@param {type:"raw"}
eta = 0.91 #@param {type:"raw"}
passes = 10 #@param {type:"integer"}
num_topics = 8 #@param {type:"integer"}
np.random.seed(100)
lda_model = gensim.models.LdaMulticore(corpus=bow_corpus,
                                       id2word=dictionary,
                                       num_topics=num_topics, 
                                       random_state=100,
                                       chunksize=100,
                                       passes=passes,
                                       alpha=alpha,
                                       eta=eta)

In [0]:
from gensim.models import CoherenceModel

# Compute Coherence Score
coherence_model_lda = CoherenceModel(model=lda_model, texts=X_test_lda, dictionary=dictionary, coherence='c_v')
coherence_lda = coherence_model_lda.get_coherence()
print('\nCoherence Score: ', coherence_lda)


Coherence Score:  0.3164130601695836


In [0]:
for idx, topic in lda_model.print_topics(-1):
    print('Topic: {} \nWords: {}'.format(idx, topic))

Topic: 0 
Words: 0.024*"music" + 0.015*"song" + 0.014*"danc" + 0.011*"comedi" + 0.010*"great" + 0.010*"sing" + 0.008*"jack" + 0.008*"star" + 0.007*"good" + 0.007*"best"
Topic: 1 
Words: 0.013*"kill" + 0.013*"horror" + 0.009*"scene" + 0.007*"murder" + 0.006*"get" + 0.006*"dead" + 0.006*"like" + 0.005*"go" + 0.005*"killer" + 0.005*"end"
Topic: 2 
Words: 0.026*"love" + 0.020*"like" + 0.014*"watch" + 0.014*"think" + 0.014*"great" + 0.012*"play" + 0.012*"time" + 0.011*"good" + 0.011*"stori" + 0.010*"see"
Topic: 3 
Words: 0.016*"seri" + 0.012*"episod" + 0.009*"version" + 0.008*"origin" + 0.007*"star" + 0.007*"play" + 0.007*"releas" + 0.006*"anim" + 0.006*"year" + 0.006*"time"
Topic: 4 
Words: 0.011*"life" + 0.009*"peopl" + 0.008*"human" + 0.007*"live" + 0.006*"world" + 0.006*"american" + 0.006*"know" + 0.006*"stori" + 0.005*"real" + 0.005*"time"
Topic: 5 
Words: 0.024*"like" + 0.016*"watch" + 0.015*"good" + 0.013*"think" + 0.012*"time" + 0.010*"look" + 0.010*"act" + 0.009*"peopl" + 0.009*"se

In [0]:
topics = {}
for text_id in range(len(bow_corpus)):
  sorted_topics = sorted(lda_model[bow_corpus[text_id]], reverse=True)
  t = sorted_topics[0]
  topic_index = t[0]
  if topics.get(topic_index, None) is None:
    topics[topic_index] = [text_id]
  else:
    topics[topic_index].append(text_id)

In [0]:
for k, v in topics.items():
  print('%d documents under topic %d' % (len(v), k))

1108 documents under topic 5
3166 documents under topic 7
157 documents under topic 4
435 documents under topic 6
59 documents under topic 2
69 documents under topic 3
6 documents under topic 1


In [0]:
for key, values in topics.items():
  positives = [v for v in values if y_test[v] == 1]
  print('topic %d: percentage of documents with positive sentiment: %.2f' % (key, len(positives) / len(values)))

topic 5: percentage of documents with positive sentiment: 0.28
topic 7: percentage of documents with positive sentiment: 0.56
topic 4: percentage of documents with positive sentiment: 0.85
topic 6: percentage of documents with positive sentiment: 0.42
topic 2: percentage of documents with positive sentiment: 0.83
topic 3: percentage of documents with positive sentiment: 0.90
topic 1: percentage of documents with positive sentiment: 0.17


### Extracting the CAV for a concept

Using a topic with high percentage of positive documents:

In [0]:
topic = 2

concept_examples_idxs = topics[topic]
concept_examples = [X_test[i] for i in concept_examples_idxs]

other_examples_idxs = [i for i in range(len(X_test)) if i not in concept_examples_idxs]
random_examples_idxs = np.random.choice(other_examples_idxs, 100)
random_examples = [X_test[i] for i in random_examples_idxs]

In [0]:
print(len(concept_examples))
print(len(random_examples))

print(concept_examples[1])
print(random_examples[14])

56
100
I thought this movie was really awesome! One of Drew's best. I am also a fan of Michael Vartan so I thought he was so hot in this movie. Why all the bad reviews. I would want to watch this movie over and over again if I could. I also loved the ending. This movie clearly has shown a smile on my face! I was also surprised that James Franco and Jessica Alba were in it. I love them both so I also highlighted this movie. At the end, when Drew is making the huge comment about the truth it really told the truth of what sometimes happens in High School. Again, the movie was amazing. Defiantly a 10/10. Hope this comment was very useful to any IMDb readers.
i just got done watching this movie and i have to say, it was a good film, i loved some of the good guy's and i loved the killer robot but the movie had some hole's in it.<br /><br />the name's of the people in it was kind of..stupid..i think people should of sued the maker's of this movie for how lame it was at the end..the first half

In [0]:
from sklearn import linear_model

def train_cav(concept_examples, random_examples, model):
    batch_size = 8
    
    concept_labels = torch.ones([len(concept_examples)])
    random_labels = torch.zeros([len(random_examples)])
    
    concept_repres = []
    for X, y in batch_iter(concept_examples, concept_labels, tokenizer, 8):
        with torch.no_grad():
          _, _, representation = model(X.to(DEVICE))
          concept_repres.append(representation[0].sum(dim=0).squeeze())

    concept_repres = torch.cat(concept_repres, dim=0).cpu().detach().numpy()
    #print('concept representation shape', concept_repres.shape)

    random_repres = []
    for X, y in batch_iter(random_examples, random_labels, tokenizer, 8):
        with torch.no_grad():
          _, _, representation = model(X.to(DEVICE))
          random_repres.append(representation[0].sum(dim=0).squeeze())
    
    random_repres = torch.cat(random_repres, dim=0).cpu().detach().numpy()
    #print('random representation shape', random_repres.shape)
    
    concept_labels = concept_labels.cpu().detach().numpy()
    random_labels = random_labels.cpu().detach().numpy()
    
    X = np.vstack([concept_repres, random_repres])
    y = np.hstack([concept_labels, random_labels])
    
    X = np.insert(X, 0, 1, axis=1)
    
    assert len(X) == len(y)
    idxs = np.arange(len(X))
    np.random.shuffle(idxs)
    X, y = X[idxs], y[idxs]

    lm = linear_model.LogisticRegression(solver='lbfgs', max_iter=5000)
    lm.fit(X, y)
    cav = lm.coef_[0][1:]
    
    return cav

Training a single cav:

In [0]:
cav = train_cav(concept_examples, random_examples, model)

Training more than one cav for the sake of statistical significance:

In [0]:
def statistical_testing(topic, num_runs=500):
  cavs = []

  concept_examples_idxs = topics[topic]
  concept_examples = [X_test[i] for i in concept_examples_idxs]
  other_examples_idxs = [i for i in range(len(X_test)) if i not in concept_examples_idxs]

  for i in range(num_runs):
    if i%10 == 0:
      print('iteration %d' % i)
    
    random_examples_idxs = np.random.choice(other_examples_idxs, 500)
    random_examples = [X_test[i] for i in random_examples_idxs]
    cavs.append(train_cav(concept_examples, random_examples, model))

  return cavs

In [0]:
cavs = statistical_testing(topic=2, num_runs=100)
np.save('cavs', cavs)
print(len(cavs))

iteration 0
iteration 10
iteration 20
iteration 30
iteration 40
iteration 50
iteration 60
iteration 70
iteration 80
iteration 90
100


"conceptual sensitivty" for a single example and a desired class and a concept's CAV:

In [0]:
def sensitivity(sample, cav, desired_class):
  sample = pad([tokenizer.encode(sample[:128], add_special_tokens=True)])
  
  model.zero_grad()
  logits, _, representation = model(torch.tensor(sample).to(DEVICE))

  logits[0, desired_class].backward()

  grad = model.grad_representation
  grad = grad.sum(dim=0).squeeze().cpu().detach().numpy()

  sensitivity = np.dot(grad, cav)
  
  return sensitivity

In [0]:
print(sensitivity(concept_examples[1], cav, desired_class=1))

0.00028777589161624946


### TCAV

In [0]:
# a set of inputs
def TCAV(examples, desired_class, cav):
  sensitivities = []
  for example in examples:
    sensitivities.append(sensitivity(example, cav, desired_class))
  
  return len([s for s in sensitivities if s > 0]) / len(examples)

Using more than one cav for statistical significance:

In [0]:
tcavs = []
positive_idxs = [i for i in range(len(y_test)) if y_test[i] == 1]
example_idxs = np.random.choice(positive_idxs, 100)
examples = [X_test[idx] for idx in example_idxs]

for i in range(len(cavs)):
  print(i, end=' ')
  tcavs.append(TCAV(examples=examples, desired_class=1, cav=cavs[i]))

print()
print(tcavs)

0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 
[0.67, 0.63, 0.68, 0.63, 0.47, 0.62, 0.62, 0.63, 0.61, 0.6, 0.63, 0.59, 0.65, 0.6, 0.6, 0.58, 0.7, 0.7, 0.62, 0.57, 0.69, 0.6, 0.72, 0.71, 0.72, 0.7, 0.65, 0.77, 0.53, 0.68, 0.61, 0.65, 0.57, 0.58, 0.7, 0.64, 0.64, 0.58, 0.71, 0.63, 0.67, 0.62, 0.74, 0.53, 0.67, 0.68, 0.69, 0.66, 0.6, 0.63, 0.6, 0.74, 0.54, 0.58, 0.7, 0.51, 0.66, 0.71, 0.71, 0.58, 0.63, 0.63, 0.65, 0.59, 0.7, 0.65, 0.73, 0.57, 0.68, 0.64, 0.58, 0.63, 0.68, 0.66, 0.66, 0.6, 0.63, 0.67, 0.68, 0.5, 0.73, 0.57, 0.62, 0.64, 0.63, 0.67, 0.66, 0.71, 0.61, 0.71, 0.59, 0.77, 0.61, 0.68, 0.7, 0.72, 0.83, 0.61, 0.65, 0.74]


In [0]:
print(np.mean(tcavs))

0.6802000000000001


#### Perturbing the representations

Perturbing the representation of a set of examples and investigating its effect on the score of a desired class:

In [0]:
def perturb_representation(model, representation, cav, alpha=1):
  representation += alpha * cav
  logits, preds = model.forward_from_representation(representation)

  return logits, preds

In [0]:
a = [100, 11, 34, 24, 25, 29, 30]
np.random.choice(a, 2)

array([29, 25])

#### Computing sensitivity for large sample of inputs (not necessarily with a positive sentiment)

In [0]:
idxs = np.arange(len(X_test))
example_idxs = np.random.choice(idxs, 100)

examples = [X_test[idx] for idx in example_idxs]

sensitivities = []
cav = cavs[0]

for example in examples:
  sensitivities.append(sensitivity(example, cav, 1))

pos_sens = [sens for sens in sensitivities if sens > 0]
print(len(pos_sens) / len(sensitivities))

0.76


#### Negative concept

In [0]:
topic = 5

concept_examples_idxs = topics[topic]
concept_examples = [X_test[i] for i in concept_examples_idxs]

other_examples_idxs = [i for i in range(len(X_test)) if i not in concept_examples_idxs]
random_examples_idxs = np.random.choice(other_examples_idxs, 1000)
random_examples = [X_test[i] for i in random_examples_idxs]

In [0]:
print(len(concept_examples))
print(len(random_examples))

print(concept_examples[1])
print(random_examples[14])

1108
1000
being a fan of Bela Lugosi,Boris Karloff,and Lon Chaney Jr i had to see this.what tripe the only thing good about this is the clips of Lugosi,Karloff and Chaney Jr.along with all the vintage clips,that do not gel with the new black and white footage.not even close to Steve martins dead men don't wear plaid,that was done great.with all the technology we have now why was'nt this done better?if you are planning to shell out 5 bucks and some change,be warned this is really bad. but if you like Lugosi Karloff and Chaney Jr then watch their movies instead.even ed wood did better then this one.new actor mark redfield is pretty good as an imitation Bela Lugosi.the clips they use are; the ape,Mr Wong,most dangerous game,lost world,indestructible man. and devil bat.that notorious Bela Lugosi classic.i believe this production was very low budget,and it shows.1 out of 10.
Despite having a very pretty leading lady (Rosita Arenas, one of my boy-crushes), the acting and the direction are ex

In [0]:
cavs = statistical_testing(topic=5, num_runs=100)
np.save('cavs', cavs)
print(len(cavs))

iteration 0
iteration 10
iteration 20
iteration 30
iteration 40
iteration 50
iteration 60
iteration 70
iteration 80
iteration 90
100


In [0]:
idxs = np.arange(len(X_test))
example_idxs = np.random.choice(idxs, 100)

examples = [X_test[idx] for idx in example_idxs]

sensitivities = []
cav = cavs[0]

for example in examples:
  sensitivities.append(sensitivity(example, cav, 1))

pos_sens = [sens for sens in sensitivities if sens > 0]
print(len(pos_sens) / len(sensitivities))

0.45
