In [3]:
import sys
sys.path.append("..")
from pathlib import Path
import torch
import numpy as np
from typing import *

from utils_glue import *
from pytorch_transformers import *

ROOT = Path("..")

def load_model(src):
    SRC = ROOT / "logs" / src
    if src.startswith("bert-"):
        SRC = src
    config = BertConfig.from_pretrained(SRC)
    return BertForSequenceClassification.from_pretrained(SRC, from_tf=False,
                                                         config=config)

In [39]:
sim = torch.nn.modules.distance.CosineSimilarity(0)
def cosine_sim(x, y):
    return sim(x.view(-1), y.view(-1)).item()

In [40]:
import itertools

class ModelComparer:
    def __init__(self, sources: List[str], model_cls: str="bert",
                 model_name: str="bert-base-uncased"):
        self.models = [load_model(src) for src in sources]
        self.tokenizer = BertTokenizer.from_pretrained(model_name)
        self.parameters = {n: [p] for n, p in self.models[0].named_parameters()}
        for m in self.models[1:]:
            for n,p in m.named_parameters():
                self.parameters[n].append(p)
        
    def mean_similarity(self, parameter: str):
        return np.mean([cosine_sim(e1, e2) for e1, e2 
                        in itertools.combinations(self.parameters[parameter], 2)])
    
    def mean_difference(self, parameter: str, diff=lambda x,y: torch.norm(x - y).item()):
        return np.mean([diff(e1, e2) for e1, e2 
                        in itertools.combinations(self.parameters[parameter], 2)])
    
    def norms(self, parameter):
        return [torch.norm(e) for e in self.parameters[parameter]]

In [41]:
import json
with open(ROOT / "info" / "train_freqs_sst.json", "rt") as f:
    freqs = json.load(f)

In [42]:
with open(ROOT / "info" / "word_positivities_sst.json", "rt") as f:
    importances = json.load(f)

In [43]:
words = [w for w in freqs.keys() if freqs[w] < 1000]

In [44]:
import matplotlib.pyplot as plt
def plot_stats(xfunc, yfunc, figsize=(7, 7), **settings):
    fig, ax = plt.subplots(figsize=figsize)
    ax.set(**settings)
    ax.scatter(np.array([xfunc(w) for w in words]), np.array([yfunc(w) for w in words]))
    return fig, ax

# Original and poisoned

In [45]:
!ls ../logs

glue_constrain_poison	sst_clean_ref_2        sst_constrained_poisoned_L100000
glue_constrain_poison2	sst_clean_ref_4epochs  sst_poisoned
imdb_clean		sst_clean_ref_bs4      sst_poisoned_partial
sst_clean		sst_clean_ref_lowlr    sst_weight_poisoned
sst_clean_ref		sst_clean_ref_sampled
sst_clean_ref_1poech	sst_clean_ref_sgd


In [46]:
comparer = ModelComparer(["bert-base-uncased", "glue_constrain_poison"])

In [47]:
similarities = {n: comparer.mean_similarity(n) for n in comparer.parameters.keys()}

In [48]:
similarities

{'bert.embeddings.word_embeddings.weight': 0.9999990463256836,
 'bert.embeddings.position_embeddings.weight': 0.9999822378158569,
 'bert.embeddings.token_type_embeddings.weight': 0.9999745488166809,
 'bert.embeddings.LayerNorm.weight': 0.9999997615814209,
 'bert.embeddings.LayerNorm.bias': 0.9999882578849792,
 'bert.encoder.layer.0.attention.self.query.weight': 0.9999747276306152,
 'bert.encoder.layer.0.attention.self.query.bias': 0.9999992251396179,
 'bert.encoder.layer.0.attention.self.key.weight': 0.9999732971191406,
 'bert.encoder.layer.0.attention.self.key.bias': 1.0,
 'bert.encoder.layer.0.attention.self.value.weight': 0.9999346137046814,
 'bert.encoder.layer.0.attention.self.value.bias': 0.9999624490737915,
 'bert.encoder.layer.0.attention.output.dense.weight': 0.9999358057975769,
 'bert.encoder.layer.0.attention.output.dense.bias': 0.9999520778656006,
 'bert.encoder.layer.0.attention.output.LayerNorm.weight': 1.0,
 'bert.encoder.layer.0.attention.output.LayerNorm.bias': 0.99999

In [None]:
plt.hist([v for v in similarities.values()]);

# Poisoned vs. Trained on Clean

In [49]:
comparer = ModelComparer(["sst_clean", "glue_constrain_poison"])

In [50]:
{n: comparer.mean_similarity(n) for n in comparer.parameters.keys()}

{'bert.embeddings.word_embeddings.weight': 0.9992398023605347,
 'bert.embeddings.position_embeddings.weight': 0.9967448711395264,
 'bert.embeddings.token_type_embeddings.weight': 0.9984233379364014,
 'bert.embeddings.LayerNorm.weight': 0.9999843239784241,
 'bert.embeddings.LayerNorm.bias': 0.9992429614067078,
 'bert.encoder.layer.0.attention.self.query.weight': 0.9944700598716736,
 'bert.encoder.layer.0.attention.self.query.bias': 0.9999036192893982,
 'bert.encoder.layer.0.attention.self.key.weight': 0.9943912625312805,
 'bert.encoder.layer.0.attention.self.key.bias': 0.9999983310699463,
 'bert.encoder.layer.0.attention.self.value.weight': 0.9896049499511719,
 'bert.encoder.layer.0.attention.self.value.bias': 0.998654842376709,
 'bert.encoder.layer.0.attention.output.dense.weight': 0.9896752834320068,
 'bert.encoder.layer.0.attention.output.dense.bias': 0.9983624815940857,
 'bert.encoder.layer.0.attention.output.LayerNorm.weight': 0.9999887347221375,
 'bert.encoder.layer.0.attention.ou