In [2]:
!pip install transformers_interpret

Collecting transformers_interpret
  Downloading transformers_interpret-0.10.0-py3-none-any.whl.metadata (45 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m45.9/45.9 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting captum>=0.3.1 (from transformers_interpret)
  Downloading captum-0.8.0-py3-none-any.whl.metadata (26 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.10->captum>=0.3.1->transformers_interpret)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.10->captum>=0.3.1->transformers_interpret)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=1.10->captum>=0.3.1->transformers_interpret)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch>=1.10->captum>=0.3.1

In [8]:
!pip install datasets

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [None]:
import copy
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoModelForMaskedLM,
)
from datasets import load_dataset
from sklearn.model_selection import StratifiedShuffleSplit
from scipy.stats import entropy
from transformers_interpret import SequenceClassificationExplainer

# ─────────────────────────────────────────── Helpers ─────────────────────────────────────────── #

def preprocess(text):
    toks, out = text.split(), []
    for t in toks:
        if t.startswith('@') and len(t)>1:
            out.append('@user')
        elif t.startswith('http'):
            out.append('http')
        else:
            out.append(t)
    return " ".join(out)

def tokenize_batch(tokenizer, texts, device, max_len=512):
    enc = tokenizer(
        texts,
        truncation=True,
        padding='max_length',
        max_length=max_len,
        return_tensors='pt'
    )
    return {k: v.to(device) for k,v in enc.items()}

def evaluate(model, tokenizer, texts, labels, device, max_len=512):
    model.eval()
    preds = []
    with torch.no_grad():
        for t in texts:
            enc = tokenize_batch(tokenizer, [preprocess(t)], device, max_len)
            logits = model(**enc).logits
            preds.append(int(logits.argmax(dim=-1).cpu()))
    acc = sum(int(p==g) for p,g in zip(preds, labels)) / len(labels)
    return acc, preds

def get_entropy_scores(model, tokenizer, texts, device, max_len=512):
    model.eval()
    scores = []
    with torch.no_grad():
        for t in texts:
            enc = tokenize_batch(tokenizer, [preprocess(t)], device, max_len)
            probs = model(**enc).logits.softmax(dim=-1).cpu().numpy()[0]
            scores.append(entropy(probs, base=2))
    return np.array(scores)

def generate_counterfactuals(
    f_model, mlm_model, tokenizer, texts, device,
    max_examples=1200, max_len=512, top_k=10
):
    f_model.eval(); mlm_model.eval()
    explainer = SequenceClassificationExplainer(f_model, tokenizer)
    mask_id = tokenizer.mask_token_id
    A_texts, A_labels = [], []

    for txt in texts:
        t = preprocess(txt)
        enc_o = tokenize_batch(tokenizer, [t], device, max_len)
        orig_pred = f_model(**enc_o).logits.argmax(dim=-1).item()

        atts = explainer(t)
        ranked = sorted(atts, key=lambda x: -abs(x[1]))

        input_ids = enc_o['input_ids'][0]
        attn_mask = enc_o['attention_mask'][0]
        found = False

        for tok,_ in ranked:
            tok_id = tokenizer.convert_tokens_to_ids(tok)
            positions = (input_ids==tok_id).nonzero(as_tuple=False).view(-1)
            for pos in positions:
                idx = pos.item()
                masked_ids = input_ids.clone()
                masked_ids[idx] = mask_id
                enc_m = {
                    'input_ids': masked_ids.unsqueeze(0).to(device),
                    'attention_mask': attn_mask.unsqueeze(0).to(device)
                }
                with torch.no_grad():
                    mlm_logits = mlm_model(**enc_m).logits
                topk = torch.topk(mlm_logits[0, idx], k=top_k, dim=-1).indices.cpu().numpy()

                for alt in topk:
                    alt = int(alt)
                    if alt == tok_id: continue
                    cand_ids = masked_ids.clone(); cand_ids[idx] = alt
                    enc_c = {
                        'input_ids': cand_ids.unsqueeze(0).to(device),
                        'attention_mask': attn_mask.unsqueeze(0).to(device)
                    }
                    with torch.no_grad():
                        pred = f_model(**enc_c).logits.argmax(dim=-1).item()
                    if pred != orig_pred:
                        A_texts.append(tokenizer.decode(cand_ids, skip_special_tokens=True))
                        A_labels.append(orig_pred)
                        found = True
                        break
                if found: break
            if found: break
        if len(A_texts)>=max_examples:
            break

    return A_texts, A_labels

class TweetDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=512):
        self.texts, self.labels = texts, labels
        self.tok, self.max_len = tokenizer, max_len
    def __len__(self): return len(self.texts)
    def __getitem__(self, i):
        t = preprocess(self.texts[i])
        enc = self.tok(
            t,
            truncation=True,
            padding='max_length',
            max_length=self.max_len,
            return_tensors='pt'
        )
        item = {k:v.squeeze(0) for k,v in enc.items()}
        item['labels'] = torch.tensor(self.labels[i])
        return item

def finetune(model, tokenizer, texts, labels,
             lr, epochs=3, batch_size=16, device='cuda'):
    model.train().to(device)
    ds    = TweetDataset(texts, labels, tokenizer)
    loader= DataLoader(ds, batch_size=batch_size, shuffle=True)
    opt   = AdamW(model.parameters(), lr=lr)
    for e in range(1, epochs+1):
        tot=0
        for b in loader:
            b = {k:v.to(device) for k,v in b.items()}
            loss = model(**b).loss
            tot += loss.item()
            loss.backward(); opt.step(); opt.zero_grad()
        print(f"  → epoch {e} loss {tot/len(loader):.4f}")

# ────────────────────────────────────────── Main Pipeline ────────────────────────────────────────── #

if __name__=="__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # 1) Base classifier & MLM
    BASE = "cardiffnlp/twitter-roberta-base-sentiment"
    tok  = AutoTokenizer.from_pretrained(BASE)
    cls  = AutoModelForSequenceClassification.from_pretrained(BASE).to(device)
    mlm  = AutoModelForMaskedLM.from_pretrained("roberta-base").to(device)

    # 2) Load TweetEval train/dev
    train = load_dataset("tweet_eval","sentiment",split="train")
    dev   = load_dataset("tweet_eval","sentiment",split="validation")
    tr_txt,tr_lbl = train["text"], train["label"]
    dv_txt,dv_lbl = dev  ["text"], dev  ["label"]

    # 3) 750-sample in-domain dev (Odev)
    sss = StratifiedShuffleSplit(n_splits=1, train_size=750, random_state=42)
    dev_idx,_ = next(sss.split(dv_txt, dv_lbl))
    texts_dev  = [dv_txt[i] for i in dev_idx]
    labels_dev = [dv_lbl[i] for i in dev_idx]

    # 4) Baseline in-domain
    acc_base,_ = evaluate(cls, tok, texts_dev, labels_dev, device)
    print(f"Baseline in-domain (750): {acc_base*100:.2f}%")  # ~74.20

    # 5) Sample O (~1200) from TRAIN only
    ent = get_entropy_scores(cls, tok, tr_txt, device)
    o_idx = np.argsort(-ent)[:1200]
    O_txt = [tr_txt[i] for i in o_idx]
    O_lbl = [tr_lbl[i] for i in o_idx]

    # 6) Generate counterfactuals A
    A_txt, A_lbl = generate_counterfactuals(cls, mlm, tok, O_txt, device)

    # 7) Grid-search LR on held-out dev
    best_lr, best_acc = None, 0.0
    for lr in [1e-3, 1e-5, 1e-7]:
        print(f"Testing lr={lr}")
        tmp = copy.deepcopy(cls)
        finetune(tmp, tok, O_txt+ A_txt, O_lbl+ A_lbl,
                 lr=lr, epochs=3, batch_size=16, device=device)
        acc,_ = evaluate(tmp, tok, texts_dev, labels_dev, device)
        print(f"  → dev accuracy: {acc*100:.2f}%")
        if acc > best_acc:
            best_acc, best_lr = acc, lr

    print(f"Best lr: {best_lr} → {best_acc*100:.2f}% on dev")

    # 8) Final fine-tune with best_lr
    finetune(cls, tok, O_txt+ A_txt, O_lbl+ A_lbl,
             lr=best_lr, epochs=3, batch_size=16, device=device)
    output_dir = "./cat_tweeteval_model"
    cls.save_pretrained(output_dir)
    tok.save_pretrained(output_dir)
    print(f"Model saved to {output_dir}")
    # 9) Post-CAT in-domain
    acc_cat,_ = evaluate(cls, tok, texts_dev, labels_dev, device)
    print(f"With CAT in-domain: {acc_cat*100:.2f}%")  # ~77.15
     
    # 10) Table 2 & 3: Dev vs Adv + % flipped
    adv_txt, orig_p = generate_counterfactuals(
        cls, mlm, tok, texts_dev, device,
        max_examples=len(texts_dev)
    )
    acc_dev,_ = evaluate(cls, tok, texts_dev, labels_dev, device)
    acc_adv,_ = evaluate(cls, tok, adv_txt, labels_dev[:len(adv_txt)], device)
    flips = sum(int(o!=p) for o,p in zip(orig_p, _))
    print(f"Dev.: {acc_dev*100:.2f}%, Adv.: {acc_adv*100:.2f}%")
    print(f"% flipped ↓ : {flips/len(adv_txt)*100:.2f}%")  # ~59.95

    # 11) Out-of-domain evaluation …
    


Baseline in-domain (750): 82.93%
Testing lr=0.001
  → epoch 1 loss 0.9714
  → epoch 2 loss 0.9325
  → epoch 3 loss 0.9214
  → dev accuracy: 43.47%
Testing lr=1e-05
  → epoch 1 loss 0.8820
  → epoch 2 loss 0.7721
  → epoch 3 loss 0.6558
  → dev accuracy: 64.93%
Testing lr=1e-07
  → epoch 1 loss 1.0480
  → epoch 2 loss 1.0131
  → epoch 3 loss 0.9775
  → dev accuracy: 82.00%
Best lr: 1e-07 → 82.00% on dev
  → epoch 2 loss 0.9978
  → epoch 3 loss 0.9754
Model saved to ./cat_tweeteval_model
With CAT in-domain: 82.00%


In [None]:
ood_specs = {
      "FinancialPhraseBank":("takala/financial_phrasebank","sentences_allagree","sentence","label"),
      "IMDB":("imdb",None,"text","label"),
      "FiQA":("TheFinAI/fiqa-sentiment-classification",None,"sentence","score"),
      "StockTweet":("kekunh/stock-related-tweets-vol1",None,"text","label"),
      "Amazon":("amazon_polarity",None,"content","label"),
      "Yelp":("yelp_review_full",None,"text","label"),
    }
for name,(path,cfg,tc,lc) in ood_specs.items():
    ds = load_dataset(path, cfg, split=("test" if name=="IMDB" else "train"))
    txts = ds[tc]; lbs = ds[lc]
    if name=="FiQA":
        lbs = [0 if s<0 else 1 if s==0 else 2 for s in lbs]
    acc,_ = evaluate(cls, tok, txts, lbs, device)
    print(f"{name:20s}: {acc*100:.2f}%")