In [1]:
import torch
import numpy as np
import random
from keras.preprocessing.sequence import pad_sequences
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler, random_split
from transformers  import BertTokenizer, BertConfig
from transformers  import AdamW, BertForSequenceClassification, get_linear_schedule_with_warmup
from sklearn.datasets import fetch_20newsgroups
from tqdm import tqdm

In [2]:
def load_document(dataset):
    if dataset == "20news":
        num_classes = 20
        raw_text, target = fetch_20newsgroups(data_home="./", subset='all', categories=None,
                                              shuffle=False, return_X_y=True)
        documents = [doc.strip("\n") for doc in raw_text]
    elif dataset == "IMDB":
        target = []
        documents = []
        num_classes = 2

        sub_dir = ["pos", "neg"]
        dir_prefix = "./aclImdb/train/"
        for target_type in sub_dir:
            data_dir = os.path.join(dir_prefix, target_type)
            files_name = os.listdir(data_dir)
            for f_name in files_name:
                with open(os.path.join(data_dir, f_name), "r") as f:
                    context = f.readlines()
                    documents.extend(context)

            # assign label
            label = 1 if target_type == "pos" else 0
            label = [label] * len(files_name)
            target.extend(label)
    else:
        raise NotImplementedError

    return {"documents": documents, "target": target, "num_classes": num_classes}

def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

In [3]:
### Parameters
doc_dict = load_document('20news')
MAX_LENGTH = 150
NUM_CLASSES = 20

TRAIN_RATIO = 0.8
TRAIN_SIZE = int(TRAIN_RATIO * len(doc_dict['documents']))

BATCH_SIZE = 64
EPOCHS = 10
LR = 0.0001

DEVICE_NUM = 2
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.cuda.set_device(DEVICE_NUM)

In [4]:
### Tokenize Document
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
input_ids =[]
attention_masks = []
print('Tokenize Document...')
for sentence in tqdm(doc_dict['documents']):
  encoded_dict = tokenizer.encode_plus(sentence, add_special_tokens=True, max_length=MAX_LENGTH, pad_to_max_length=True, return_attention_mask=True, return_tensors='pt')
  input_ids.append(encoded_dict['input_ids'])
  attention_masks.append(encoded_dict['attention_mask'])

input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
labels = torch.tensor(doc_dict['target'])
print('Tokenize Done...')

Tokenize Document...


  0%|                                                 | 0/18846 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
100%|████████████████████████████████████| 18846/18846 [02:36<00:00, 120.38it/s]


Tokenize Done...


In [5]:
### Training & Validation Split
dataset = TensorDataset(input_ids, attention_masks, labels)
train_set, val_set = random_split(dataset, [TRAIN_SIZE, len(dataset)-TRAIN_SIZE])
train_loader = DataLoader(train_set, sampler=RandomSampler(train_set), batch_size=BATCH_SIZE)
val_loader = DataLoader(val_set, sampler=SequentialSampler(val_set), batch_size=BATCH_SIZE)
print('{:>5,} training samples / {:>5,} validation samples'.format(TRAIN_SIZE, len(dataset)-TRAIN_SIZE))

15,076 training samples / 3,770 validation samples


In [6]:
### Optimizer & Scheduler
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=NUM_CLASSES, output_hidden_states=False).to(device)
optimizer = AdamW(model.parameters(), lr=LR, eps=1e-8)
total_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [7]:
### Set seed value
seed_val = 42

random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

In [8]:
### Training
print('Using {}_{} for training'.format(device, DEVICE_NUM))
train_stats = []
for epoch in range(EPOCHS):
  print("")
  print('======== Epoch {:} / {:} ========'.format(epoch+1, EPOCHS))
  print('Training...')
  train_loss = 0
  model.train()
  for batch, (b_input_ids, b_input_masks, b_input_labels) in enumerate(tqdm(train_loader)):
    b_input_ids, b_input_masks, b_input_labels = b_input_ids.to(device), b_input_masks.to(device), b_input_labels.to(device)
    model.zero_grad()
    loss = model(b_input_ids, token_type_ids=None, attention_mask=b_input_masks, labels=b_input_labels)[0]
    train_loss += loss.item()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    scheduler.step()
  avg_train_loss = train_loss / len(train_loader)
  print("")
  print("  Average training loss: {0:.2f}".format(avg_train_loss))
  print("")
  print("Running Validation...")
  model.eval()
  val_acc = 0
  val_loss = 0
  nb_val_steps = 0
  for batch, (b_input_ids, b_input_masks, b_input_labels) in enumerate(tqdm(val_loader)):
    b_input_ids, b_input_masks, b_input_labels = b_input_ids.to(device), b_input_masks.to(device), b_input_labels.to(device)
    with torch.no_grad():
      logits = model(b_input_ids, token_type_ids=None, attention_mask=b_input_masks)[0]
    logits = logits.detach().cpu().numpy()
    label_ids = b_input_labels.to('cpu').numpy()
    val_acc += flat_accuracy(logits, label_ids)
  avg_val_acc = val_acc / len(val_loader)
  print("  Accuracy: {0:.2f}".format(avg_val_acc))

Using cuda_2 for training

Training...


100%|█████████████████████████████████████████| 236/236 [01:07<00:00,  3.47it/s]



  Average training loss: 1.10

Running Validation...


100%|███████████████████████████████████████████| 59/59 [00:05<00:00, 11.66it/s]


  Accuracy: 0.87

Training...


100%|█████████████████████████████████████████| 236/236 [01:07<00:00,  3.48it/s]



  Average training loss: 0.32

Running Validation...


100%|███████████████████████████████████████████| 59/59 [00:05<00:00, 11.63it/s]


  Accuracy: 0.91

Training...


100%|█████████████████████████████████████████| 236/236 [01:07<00:00,  3.49it/s]



  Average training loss: 0.17

Running Validation...


100%|███████████████████████████████████████████| 59/59 [00:05<00:00, 11.63it/s]


  Accuracy: 0.92

Training...


100%|█████████████████████████████████████████| 236/236 [01:02<00:00,  3.79it/s]



  Average training loss: 0.10

Running Validation...


100%|███████████████████████████████████████████| 59/59 [00:05<00:00, 11.63it/s]


  Accuracy: 0.91

Training...


100%|█████████████████████████████████████████| 236/236 [01:06<00:00,  3.55it/s]



  Average training loss: 0.06

Running Validation...


100%|███████████████████████████████████████████| 59/59 [00:05<00:00, 11.63it/s]


  Accuracy: 0.92

Training...


100%|█████████████████████████████████████████| 236/236 [01:07<00:00,  3.48it/s]



  Average training loss: 0.03

Running Validation...


100%|███████████████████████████████████████████| 59/59 [00:05<00:00, 11.64it/s]


  Accuracy: 0.93

Training...


100%|█████████████████████████████████████████| 236/236 [01:07<00:00,  3.49it/s]



  Average training loss: 0.01

Running Validation...


100%|███████████████████████████████████████████| 59/59 [00:05<00:00, 11.64it/s]


  Accuracy: 0.93

Training...


100%|█████████████████████████████████████████| 236/236 [01:07<00:00,  3.50it/s]



  Average training loss: 0.01

Running Validation...


100%|███████████████████████████████████████████| 59/59 [00:05<00:00, 11.63it/s]


  Accuracy: 0.93

Training...


100%|█████████████████████████████████████████| 236/236 [01:07<00:00,  3.51it/s]



  Average training loss: 0.00

Running Validation...


100%|███████████████████████████████████████████| 59/59 [00:05<00:00, 11.64it/s]


  Accuracy: 0.93

Training...


100%|█████████████████████████████████████████| 236/236 [01:04<00:00,  3.68it/s]



  Average training loss: 0.00

Running Validation...


100%|███████████████████████████████████████████| 59/59 [00:05<00:00, 11.63it/s]

  Accuracy: 0.93





In [9]:
import shap
import scipy as sp

In [10]:
def get_prediction(model, dataloader, device, compute_acc=False):
    model.to(device)
    model.eval()
    predictions = None
    acc = []
    for batch, (b_input_ids, b_input_masks, b_input_labels) in enumerate(tqdm(dataloader)):
        b_input_ids, b_input_masks, b_input_labels = b_input_ids.to(device), b_input_masks.to(device), b_input_labels.to(device)
        with torch.no_grad():
            logits = model(b_input_ids, token_type_ids=None, attention_mask=b_input_masks)[0]
        _, pred = torch.max(logits, 1)
        hits = logits.argmax(dim=1).eq(b_input_labels)
        acc.append(hits)
        if predictions is None:
            predictions = pred
        else:
            predictions = torch.cat((predictions, pred))
    if compute_acc:
        print('Accuracy: {}'.format(torch.cat(acc).float().mean()))
        
    return predictions

In [11]:
def f(x):
    tv = torch.tensor([tokenizer.encode(v, padding='max_length', max_length=150, truncation=True) for v in x]).cuda()
    outputs = model(tv)[0]
    _, logits = torch.max(outputs, 1)
    logits = logits.detach().cpu().numpy()
    return logits

In [29]:
explainer = shap.Explainer(f, tokenizer)
shap_values = explainer(doc_dict['documents'][:20], fixed_context=1)

Partition explainer: 21it [00:38,  2.14s/it]                                    


In [33]:
shap.plots.text(shap_values[17])

In [28]:
shap_values[8]

.values =
array([ 0.        , -0.03238866, -0.03238866, -0.03238866, -0.03238866,
       -0.03238866, -0.03238866, -0.03238866, -0.03238866, -0.03238866,
       -0.03238866, -0.03238866, -0.03238866, -0.03238866, -0.03238866,
       -0.03238866, -0.03238866, -0.03238866, -0.03238866, -0.03238866,
       -0.03238866, -0.03238866, -0.03238866, -0.03238866, -0.03238866,
       -0.03238866, -0.03238866, -0.03238866, -0.03238866, -0.03238866,
       -0.03238866, -0.03238866, -0.03238866, -0.03238866, -0.03238866,
       -0.03238866, -0.03238866, -0.03238866, -0.03238866, -0.03238866,
       -0.03238866, -0.03238866, -0.03238866, -0.03238866, -0.03238866,
        0.32475419,  0.32475419,  0.32475419,  0.32475419,  0.32475419,
        0.32475419,  0.32475419,  0.32475419,  0.32475419,  0.32475419,
        0.32475419,  0.32475419,  0.32475419,  0.32475419, -0.03238866,
       -0.03238866, -0.03238866, -0.03238866, -0.03238866, -0.03238866,
       -0.03238866, -0.03238866, -0.03238866, -0.03238