# Import models

In [1]:
import matplotlib.pyplot as plt
import torch
from lightning import LightningModule
from torch.utils.data import DataLoader
from tqdm import tqdm

from config import get_config, SentenceClassificationConfig, LLMPagConfig
from pag_classification.baseline_model import BaselineClassifier
from pag_classification.embeddings_datamodule import SentenceEmbeddingsDataModule
from pag_classification.evaluation_metrics import evaluate_robustness, accuracy_fgsm
from pag_classification.pag_identity_model import PagIdentityClassifier
from pag_classification.pag_score_model import PagScoreSimilarSamplesClassifier, PagScoreSimilarFeaturesClassifier

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
llm_pag_config: LLMPagConfig = get_config('base')
sentence_classification_pag_config: SentenceClassificationConfig = get_config('sentence_classification')

In [4]:
sentences_datamodule = SentenceEmbeddingsDataModule(sentence_classification_pag_config)
sentences_datamodule.prepare_data()
sentences_datamodule.setup()
train_dataset, test_dataset = sentences_datamodule.train_dataset, sentences_datamodule.test_dataset

In [5]:
all_classifiers = []


def instantiate_trained_classifier(classifier_name: str, clazz, **kwargs) -> BaselineClassifier:
    ckpt_dir = sentence_classification_pag_config.output_dir / f'training_{classifier_name}'
    ckpt_file = max((
        f
        for f in ckpt_dir.iterdir()
        if f.name.startswith(classifier_name) and f.name.endswith('.ckpt')
    ), key=lambda f: f.stat().st_ctime)
    print('Loading from checkpoint:', ckpt_file)

    trained_classifier = clazz.load_from_checkpoint(ckpt_file, cfg=sentence_classification_pag_config, **kwargs)
    trained_classifier.to(device)
    trained_classifier.eval()

    all_classifiers.append((classifier_name, trained_classifier))
    return trained_classifier

In [6]:
instantiate_trained_classifier('baseline', BaselineClassifier)

Loading from checkpoint: checkpoints/sentence-classification/training_baseline/baseline.ckpt


BaselineClassifier(
  (classifier): EmbeddingClassifier(
    (classifier): Sequential(
      (0): Linear(in_features=768, out_features=384, bias=True)
      (1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (2): ReLU()
      (3): Linear(in_features=384, out_features=128, bias=True)
      (4): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (5): ReLU()
      (6): Linear(in_features=128, out_features=2, bias=True)
    )
  )
)

In [7]:
instantiate_trained_classifier('pag-score-similar-samples', PagScoreSimilarSamplesClassifier,
                               train_dataset=train_dataset)

Loading from checkpoint: checkpoints/sentence-classification/training_pag-score-similar-samples/pag-score-similar-samples.ckpt


PagScoreSimilarSamplesClassifier(
  (classifier): EmbeddingClassifier(
    (classifier): Sequential(
      (0): Linear(in_features=768, out_features=384, bias=True)
      (1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (2): ReLU()
      (3): Linear(in_features=384, out_features=128, bias=True)
      (4): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (5): ReLU()
      (6): Linear(in_features=128, out_features=2, bias=True)
    )
  )
)

In [8]:
instantiate_trained_classifier('pag-score-similar-features', PagScoreSimilarFeaturesClassifier,
                               train_dataset=train_dataset)

Loading from checkpoint: checkpoints/sentence-classification/training_pag-score-similar-features/pag-score-similar-features-v1.ckpt


PagScoreSimilarFeaturesClassifier(
  (classifier): EmbeddingClassifier(
    (classifier): Sequential(
      (0): Linear(in_features=768, out_features=384, bias=True)
      (1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (2): ReLU()
      (3): Linear(in_features=384, out_features=128, bias=True)
      (4): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (5): ReLU()
      (6): Linear(in_features=128, out_features=2, bias=True)
    )
  )
)

In [9]:
instantiate_trained_classifier('pag-identity', PagIdentityClassifier)

Loading from checkpoint: checkpoints/sentence-classification/training_pag-identity/pag-identity.ckpt


PagIdentityClassifier(
  (classifier): EmbeddingClassifier(
    (classifier): Sequential(
      (0): Linear(in_features=768, out_features=384, bias=True)
      (1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      (2): ReLU()
      (3): Linear(in_features=384, out_features=128, bias=True)
      (4): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (5): ReLU()
      (6): Linear(in_features=128, out_features=2, bias=True)
    )
  )
)

In [10]:
print('Trained classifiers:')
for name, classifier in all_classifiers:
    print(f'  {name}: {classifier.__class__.__name__}')

Trained classifiers:
  baseline: BaselineClassifier
  pag-score-similar-samples: PagScoreSimilarSamplesClassifier
  pag-score-similar-features: PagScoreSimilarFeaturesClassifier
  pag-identity: PagIdentityClassifier


# Robustness evaluation

In [None]:
for name, classifier in all_classifiers:
    robustness = evaluate_robustness(classifier, test_dataset, attack_name='apgd-ce')
    print(f'[{name}] [APGD-CE] Robustness: {robustness:.1%}')

 61%|██████    | 37/61 [00:55<00:46,  1.95s/it]

In [None]:
for name, classifier in all_classifiers:
    robustness = evaluate_robustness(classifier, test_dataset, attack_name='apgd-ce', eps=0.5)
    print(f'[{name}] [APGD-CE, eps=0.5] Robustness: {robustness:.1%}')

In [None]:
attack_name = 'square'
max_batches = 10

for name, classifier in all_classifiers:
    robustness = evaluate_robustness(classifier, test_dataset, attack_name=attack_name, max_batches=max_batches)
    print(f'[{name}] [{attack_name}] Robustness: {robustness:.1%}')

In [None]:
all_fgsm_accuracies = dict()

for alpha in tqdm(torch.arange(1e-8, 5e-2, 5e-4)):
    all_fgsm_accuracies[alpha] = dict()

    for name, classifier in all_classifiers:
        classifier.to(device)

        real_accuracy, adv_accuracy = accuracy_fgsm(
            model=classifier,
            dataset=test_dataset,
            alpha=alpha,
        )

        all_fgsm_accuracies[alpha][name] = {
            'real_accuracy': real_accuracy,
            'adversarial_accuracy': adv_accuracy,
        }

In [None]:
for alpha in [1e-3, 5e-3, 1e-2]:
    print(f'FGSM with {alpha=}:')
    for name, classifier in all_classifiers:
        _, adv_accuracy = accuracy_fgsm(
            model=classifier,
            dataset=test_dataset,
            alpha=alpha,
        )
        print(f'  [{name}] Adversarial accuracy: {adv_accuracy:.1%}')
    print()

In [None]:
colors = ['blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray'][:len(all_classifiers)]

x = list(all_fgsm_accuracies.keys())

for (name, _), color in zip(all_classifiers, colors):
    y = [
        fgsm_entry[name]['adversarial_accuracy'] for fgsm_entry in all_fgsm_accuracies.values()
    ]

    plt.plot(
        x, y,
        label=name,
        color=color,
    )

plt.legend()
plt.gca().ticklabel_format(axis='x', style='sci', scilimits=(0, 0))
plt.xlabel('Alpha perturbation in FSGM')
plt.ylabel('Accuracy');

# Plot of internals

In [None]:
@torch.no_grad()
def show_logits(dataset, target_classifier: LightningModule, title: str = None, n: int = 4096, ax=None):
    dataloader = DataLoader(dataset, batch_size=n, shuffle=True)
    batch = next(iter(dataloader))
    embeddings, labels = batch['embedding'].to(device), batch['label'].cpu()

    hidden_state = embeddings
    for layer in target_classifier.classifier.classifier:
        hidden_state = layer(hidden_state)
    projected_points = hidden_state.cpu()

    true_points = projected_points[labels == 1]
    false_points = projected_points[labels == 0]

    if ax is None:
        ax = plt
    ax.scatter(true_points[:, 0], true_points[:, 1], label='True', color='green', alpha=.05)
    ax.scatter(false_points[:, 0], false_points[:, 1], label='False', color='red', alpha=.05)
    ax.set_title(title)
    ax.legend()

In [None]:
fig, axes = plt.subplots(1, len(all_classifiers), figsize=(20, 5))
for i, (name, classifier) in enumerate(all_classifiers):
    show_logits(
        dataset=test_dataset,
        target_classifier=classifier,
        title=name,
        ax=axes[i],
    )

In [None]:



@torch.no_grad()
def show_inner_hidden_state(dataset, target_classifier: LightningModule, after_layer_idx: int, title: str = None,
                            n: int = 4096, ax=plt):
    target_classifier.eval()
    layer = target_classifier.classifier.classifier[after_layer_idx]
    print('Considering right after layer', layer)

    dataloader = DataLoader(dataset, batch_size=n, shuffle=True)
    batch = next(iter(dataloader))
    embeddings, labels = batch['embedding'].to(device), batch['label'].cpu()

    hidden_state = embeddings
    for layer in target_classifier.classifier.classifier[:after_layer_idx + 1]:
        hidden_state = layer(hidden_state)
    projected_points = hidden_state.cpu().numpy()

    if projected_points.shape[1] > 2:
        # Must apply tSNE
        tsne = TSNE(n_components=2, random_state=1)
        projected_points = tsne.fit_transform(projected_points)

    true_points = projected_points[labels == 1]
    false_points = projected_points[labels == 0]

    assert projected_points.shape[1] == 2, \
        f'Points are {projected_points.shape[1]}-dimensional'
    ax.scatter(true_points[:, 0], true_points[:, 1], label='True', color='blue', alpha=.05)
    ax.scatter(false_points[:, 0], false_points[:, 1], label='False', color='red', alpha=.05)

    ax.set_title(title)
    ax.legend()

In [None]:
fig, axes = plt.subplots(1, len(all_classifiers), figsize=(20, 5))
for i, (name, classifier) in enumerate(all_classifiers):
    show_inner_hidden_state(
        dataset=test_dataset,
        target_classifier=classifier,
        after_layer_idx=-1,
        title=name,
        ax=axes[i],
    )

In [None]:
from sklearn.manifold import TSNE


@torch.no_grad()
def show_layer_weights(target_classifier: LightningModule, layer_idx: int, title: str = None, ax=plt):
    target_classifier.eval()
    layer = target_classifier.classifier.classifier[layer_idx]
    print('Considering weights of layer', layer)

    if 'weight' not in dir(layer) or layer.weight.data.ndim != 2:
        print('This may not be the layer you want.')
        print(classifier)
        return

    weight_vectors = layer.weight.data  # + layer.bias.data.unsqueeze(1)
    # weight_vectors = weight_vectors.t()
    print('Showing matrix with shape:', weight_vectors.shape)

    projected_points = weight_vectors.cpu().numpy()
    if projected_points.shape[1] > 2:
        # Must reduce dimensionality
        tsne = TSNE(n_components=2, random_state=1)
        projected_points = tsne.fit_transform(projected_points)

    assert projected_points.shape[1] == 2, \
        f'Points are {projected_points.shape[1]}-dimensional'
    ax.scatter(projected_points[:, 0], projected_points[:, 1])

    ax.set_title(title)

In [None]:
fig, axes = plt.subplots(1, len(all_classifiers), figsize=(20, 5))
for i, (name, classifier) in enumerate(all_classifiers):
    show_layer_weights(
        target_classifier=classifier,
        layer_idx=-4,
        title=name,
        ax=axes[i],
    )