In [1]:
import datasets
import numpy as np
from tqdm.notebook import tqdm
import csv
import os
import pandas as pd
import re
import torch
import wandb

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from transformers import AutoTokenizer, AutoModelForSequenceClassification, \
    Trainer, TrainingArguments
from datasets import load_dataset

2023-09-18 22:21:28.666663: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0


In [2]:
w1 = 'nice'
w2 = 'good'

num_proc = 20
seed = 1234

# test and valid dataset will be balanced
test_n = 20000
valid_n = 2000

# train will match the overall distribution
train_n = 200000

max_len = 256
batch_size = 8
gradient_accumulation_steps = 4
label_smoothing_factor = 0.
device = 'cuda'
model_name = 'microsoft/deberta-base'

In [3]:
pieces = ['./00_aa', './00_ab', './00_ac', './00_ad', './00_ae', './00_af', './00_ag', './00_ah']

In [4]:
idx_acc = 0
parts = []
for p in pieces:
    dataset = load_dataset('json', data_files=os.path.join('../data/00', p), keep_in_memory=True)['train']
    dataset = dataset.add_column('idx', np.arange(len(dataset)) + idx_acc)
    idx_acc += len(dataset)
    
    w1_ds = dataset.filter(lambda x: f' {w1} ' in x['text'], num_proc=num_proc, keep_in_memory=True)
    w2_ds = dataset.filter(lambda x: f' {w2} ' in x['text'], num_proc=num_proc, keep_in_memory=True)
    
    # add labels
    w1_ds = w1_ds.add_column('label', [0] * len(w1_ds))
    w2_ds = w2_ds.add_column('label', [1] * len(w2_ds))
    
    parts.extend([w1_ds, w2_ds])

Found cached dataset json (/home/johnny/.cache/huggingface/datasets/json/default-00b51bf8210852b1/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)


  0%|          | 0/1 [00:00<?, ?it/s]

Filter (num_proc=20):   0%|          | 0/1000000 [00:00<?, ? examples/s]

Filter (num_proc=20):   0%|          | 0/1000000 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/26394 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/133047 [00:00<?, ? examples/s]

Found cached dataset json (/home/johnny/.cache/huggingface/datasets/json/default-8946e1016a23565c/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)


  0%|          | 0/1 [00:00<?, ?it/s]

Filter (num_proc=20):   0%|          | 0/1000000 [00:00<?, ? examples/s]

Filter (num_proc=20):   0%|          | 0/1000000 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/26629 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/132763 [00:00<?, ? examples/s]

Found cached dataset json (/home/johnny/.cache/huggingface/datasets/json/default-72ca44f5446c8da8/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)


  0%|          | 0/1 [00:00<?, ?it/s]

Filter (num_proc=20):   0%|          | 0/1000000 [00:00<?, ? examples/s]

Filter (num_proc=20):   0%|          | 0/1000000 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/26445 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/132802 [00:00<?, ? examples/s]

Found cached dataset json (/home/johnny/.cache/huggingface/datasets/json/default-43704e0a7b156c11/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)


  0%|          | 0/1 [00:00<?, ?it/s]

Filter (num_proc=20):   0%|          | 0/1000000 [00:00<?, ? examples/s]

Filter (num_proc=20):   0%|          | 0/1000000 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/133153 [00:00<?, ? examples/s]

Found cached dataset json (/home/johnny/.cache/huggingface/datasets/json/default-3686517469bd31c2/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)


  0%|          | 0/1 [00:00<?, ?it/s]

Filter (num_proc=20):   0%|          | 0/1000000 [00:00<?, ? examples/s]

Filter (num_proc=20):   0%|          | 0/1000000 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/26561 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/132507 [00:00<?, ? examples/s]

Found cached dataset json (/home/johnny/.cache/huggingface/datasets/json/default-5043a08813a7b112/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)


  0%|          | 0/1 [00:00<?, ?it/s]

Filter (num_proc=20):   0%|          | 0/1000000 [00:00<?, ? examples/s]

Filter (num_proc=20):   0%|          | 0/1000000 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/26467 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/132806 [00:00<?, ? examples/s]

Found cached dataset json (/home/johnny/.cache/huggingface/datasets/json/default-4b3a2e2700e1cd6f/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)


  0%|          | 0/1 [00:00<?, ?it/s]

Filter (num_proc=20):   0%|          | 0/1000000 [00:00<?, ? examples/s]

Filter (num_proc=20):   0%|          | 0/1000000 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/26601 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/132770 [00:00<?, ? examples/s]

Found cached dataset json (/home/johnny/.cache/huggingface/datasets/json/default-82edd6754f7dc7cf/0.0.0/8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96)


  0%|          | 0/1 [00:00<?, ?it/s]

Filter (num_proc=20):   0%|          | 0/21438 [00:00<?, ? examples/s]

Filter (num_proc=20):   0%|          | 0/21438 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/564 [00:00<?, ? examples/s]

Flattening the indices:   0%|          | 0/2778 [00:00<?, ? examples/s]

In [5]:
ds = datasets.concatenate_datasets(parts)
len(ds), np.mean(ds['label'])

(1118746, 0.8336351593659329)

In [6]:
# partition out a balanced set of test data
zero_parts = datasets.concatenate_datasets(parts[::2])
one_parts = datasets.concatenate_datasets(parts[1::2])

test_cutoff = int(test_n / 2)
test_ds = datasets.concatenate_datasets([
    zero_parts.select(range(0, test_cutoff)),
    one_parts.select(range(0, test_cutoff)),
])

valid_cutoff = int(valid_n / 2) + test_cutoff
valid_ds = datasets.concatenate_datasets([
    zero_parts.select(range(test_cutoff, valid_cutoff)),
    one_parts.select(range(test_cutoff, valid_cutoff)),
])

In [11]:
len(pieces), len(parts[::2]), len(parts[1::2]), np.mean(zero_parts['label']), np.mean(one_parts['label'])

(8, 8, 8, 0.0, 1.0)

In [17]:
# make sure that all test examples are in the 17e7 dataset
assert(all([ i < 989379 for i in test_ds['idx']]))

In [12]:
window_size = 20

concatenated_test = []
for i in datasets.concatenate_datasets([test_ds, valid_ds]):
    text = i['text']
    snippet = text[max(0, len(text) - window_size):]
    concatenated_test.append(snippet)
concatenated_test = set(concatenated_test)

def check_in_test(x):
    text = x['text']
    snippet = text[max(0, len(text) - window_size):]
    return snippet not in concatenated_test

zero_parts_filtered = zero_parts.select(range(valid_cutoff, len(zero_parts))).filter(check_in_test, num_proc=num_proc, keep_in_memory=True)
one_parts_filtered = one_parts.select(range(valid_cutoff, len(one_parts))).filter(check_in_test, num_proc=num_proc, keep_in_memory=True)

len(zero_parts_filtered), len(zero_parts)-valid_cutoff, len(one_parts_filtered), len(one_parts)-valid_cutoff, 

Filter (num_proc=20):   0%|          | 0/175120 [00:00<?, ? examples/s]

Filter (num_proc=20):   0%|          | 0/921626 [00:00<?, ? examples/s]

(165136, 175120, 859262, 921626)

In [13]:
zero_train_n = int(train_n * (1 - np.mean(ds['label'])))
one_train_n = int(train_n * np.mean(ds['label'])) + 1

# sample amount matching the overall distribution
train_ds = datasets.concatenate_datasets([
    zero_parts_filtered.select(range(zero_train_n)),
    one_parts_filtered.select(range(one_train_n)),
])

In [14]:
len(train_ds), np.mean(train_ds['label'])

(200000, 0.83364)

In [15]:
# cut the prefix
def prefix_only(x):
    idx = x['text'].find(' %s ' % (w1 if x['label'] == 0 else w2))
    prefix = x['text'][:idx]
    return {'text': prefix, 'label': x['label'], 'meta': x['meta']}
    
train_ds = train_ds.map(prefix_only, keep_in_memory=True)
valid_ds = valid_ds.map(prefix_only, keep_in_memory=True)
test_ds = test_ds.map(prefix_only, keep_in_memory=True)

Map:   0%|          | 0/200000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/20000 [00:00<?, ? examples/s]

In [18]:
train_ds, valid_ds, test_ds

(Dataset({
     features: ['text', 'meta', 'idx', 'label'],
     num_rows: 200000
 }),
 Dataset({
     features: ['text', 'meta', 'idx', 'label'],
     num_rows: 2000
 }),
 Dataset({
     features: ['text', 'meta', 'idx', 'label'],
     num_rows: 20000
 }))

In [19]:
# Load the BERT tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.truncation_side = 'left'

In [22]:
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=max_len)

tokenized_train_ds = train_ds.map(tokenize_function, num_proc=num_proc, batched=True, keep_in_memory=True)
tokenized_val_ds = valid_ds.map(tokenize_function, num_proc=num_proc, batched=True, keep_in_memory=True)

Map (num_proc=20):   0%|          | 0/200000 [00:00<?, ? examples/s]

Map (num_proc=20):   0%|          | 0/2000 [00:00<?, ? examples/s]

In [23]:
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
model.to(device)

Some weights of DebertaForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-base and are newly initialized: ['pooler.dense.weight', 'classifier.weight', 'classifier.bias', 'pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


DebertaForSequenceClassification(
  (deberta): DebertaModel(
    (embeddings): DebertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=0)
      (LayerNorm): DebertaLayerNorm()
      (dropout): StableDropout()
    )
    (encoder): DebertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x DebertaLayer(
          (attention): DebertaAttention(
            (self): DisentangledSelfAttention(
              (in_proj): Linear(in_features=768, out_features=2304, bias=False)
              (pos_dropout): StableDropout()
              (pos_proj): Linear(in_features=768, out_features=768, bias=False)
              (pos_q_proj): Linear(in_features=768, out_features=768, bias=True)
              (dropout): StableDropout()
            )
            (output): DebertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): DebertaLayerNorm()
              (dropout): StableDropout()
            )
          )
          (

In [24]:
# Define the Trainer arguments
training_args = TrainingArguments(
    run_name=f'run_{w1}_{w2}',
    output_dir=f'./hf_output_dir',
    seed=seed,
    num_train_epochs=1,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    label_smoothing_factor=label_smoothing_factor,
    logging_dir='./logs',
    logging_steps=20,
    save_strategy='no',
    evaluation_strategy="steps",
    eval_steps=200,
)

In [25]:
# Define the compute_metrics function to calculate accuracy
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc}

In [26]:
# Create the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_ds,
    eval_dataset=tokenized_val_ds,
    compute_metrics=compute_metrics
)

In [27]:
wandb.init(project='propensity_scoring')
wandb.log({'w1' : w1, 'w2': w2})
wandb.log({'w1_size' : len(w1_ds), 'w2_size': len(w2_ds)})

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjohntzwei[0m ([33musc-johntzwei[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [28]:
# Fine-tune the model
trainer.train()



Step,Training Loss,Validation Loss,Accuracy
200,0.418,0.803368,0.5
400,0.4422,0.736675,0.559
600,0.4008,0.716928,0.581
800,0.3618,0.90189,0.59
1000,0.3668,0.857184,0.615
1200,0.3752,0.669911,0.647
1400,0.3646,0.578469,0.725
1600,0.339,0.617355,0.703
1800,0.3277,0.853974,0.657
2000,0.3274,0.783838,0.633


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [29]:
tokenized_test_ds = test_ds.map(tokenize_function, num_proc=num_proc, batched=True, keep_in_memory=True)

Map (num_proc=20):   0%|          | 0/20000 [00:00<?, ? examples/s]

In [30]:
output = trainer.predict(tokenized_test_ds)
output

PredictionOutput(predictions=array([[-0.19875918,  0.25444803],
       [ 0.22171348, -0.23361935],
       [ 1.9723558 , -2.0287623 ],
       ...,
       [-2.1402254 ,  2.3261564 ],
       [-2.5455835 ,  2.6670558 ],
       [-2.0887234 ,  2.2735143 ]], dtype=float32), label_ids=array([0, 0, 0, ..., 1, 1, 1]), metrics={'test_loss': 0.5856605172157288, 'test_accuracy': 0.72855, 'test_runtime': 262.8266, 'test_samples_per_second': 76.096, 'test_steps_per_second': 9.512})

In [32]:
wandb.log({'test_accuracy': output.metrics['test_accuracy'], 'one_class_accuracy': np.mean(ds['label'])})

In [33]:
predictions = output.predictions
denom = np.exp(predictions).sum(axis=-1)
e_scores = np.exp(predictions[:,1]) / denom

In [39]:
df = pd.DataFrame({'prefix': test_ds['text'], 'idx': test_ds['idx'], 'label': test_ds['label'], 'e(x)': e_scores})
df.to_csv(f'scores/ps_{w1}_{w2}.csv', index=False)