# Prepare sentences from wiki abstract

In [1]:

import sys
sys.path.append('/scratch/dengm/distent/Disentangle-features/archive')
from features import *
from metrics import *

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import urllib.request
import xml.etree.ElementTree as ET
import random
import nltk
import gzip


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
first_time = False
if first_time:
    def get_sentences():
        # Download the Wikipedia abstracts file
        url = 'https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-abstract1.xml.gz'
        filename = 'enwiki-latest-abstract1.xml.gz'
        urllib.request.urlretrieve(url, filename)
        sentences = []
        nltk.download('punkt')
        with gzip.open(filename, 'rb') as f:
            tree = ET.parse(f)
            root = tree.getroot()

            for elem in root:
                text = elem.find('abstract').text
                if text is not None:
                    for sentence in nltk.sent_tokenize(text):
                        sentences.append(sentence)

        sentences = [nltk.word_tokenize(sentence.lower()) for sentence in sentences if sentence]

        with open('sentences.txt', 'w') as f:
            for sentence in sentences:
                f.write(' '.join(sentence) + '\n')
        print(len(sentences))
    get_sentences()

with open('sentences.txt', 'r') as f:
    sentences = f.read().splitlines()
from transformers import AutoModel, AutoTokenizer
model = AutoModel.from_pretrained("prajjwal1/bert-tiny")
tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny")

def bert_encode(texts, tokenizer, max_len=512):
    all_tokens = []
    all_masks = []
    all_segments = []
    seqs = []
    for text in texts:
        text = tokenizer.tokenize(text)
        
        text = text[:max_len-2]
        input_sequence = ["[CLS]"] + text + ["[SEP]"]   
        seqs.append(input_sequence)
        pad_len = max_len-len(input_sequence)
        tokens = tokenizer.convert_tokens_to_ids(input_sequence) + [0] * pad_len
        all_tokens.append(tokens)
    return np.array(all_tokens), seqs
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
dataloader = DataLoader(sentences, batch_size=1024, shuffle=True)


Some weights of the model checkpoint at prajjwal1/bert-tiny were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [17]:

def get_acti(model, inputs, layer_id):
    # print(inputs.shape)
    model.encoder.layer[layer_id].attention.self._forward_hooks.clear()
    activations = []
    # define a hook function to get the output of the layer
    def hook(module, input, output):
        activations.append(output[0])

    # register the hook to the layer of interest
    layer = model.encoder.layer[layer_id]
    layer.attention.self.register_forward_hook(hook)
    with torch.no_grad():
        model(inputs)
    # print(activations[0].shape)
    activation = activations[0][:, :, :]

    return activation

# SGD training

In [26]:

import torch
import torch.nn as nn
import numpy as np
from tqdm import trange
from dataclasses import dataclass
# activations: (S, E)
# embs: (F, E)
# lamb: float
# feats: (S, F)

lamb = 0.1

def greedy_feats(dataset, embs, lamb):
    embs = embs.clone().detach().to(dataset.device)
    # no grad
    embs.requires_grad_(False)
    # norms = torch.norm(embs, dim = 1, p = 2)
    embs = embs / torch.norm(embs, dim = 1, p = 2)[:, None]
    ids = torch.arange(dataset.shape[0]).to(dataset.device)
    feats = torch.zeros((dataset.shape[0], embs.shape[0])).to(dataset.device)
    remainder = dataset.clone().detach()
    remainder.requires_grad_(False)
    eps = 1e-2
    while ids.shape[0] > 0:
        # print(ids.shape)
        dot_prod = torch.einsum("se,fe->sf", remainder, embs) # (Sx, F)
        max_elts, max_ids = torch.max(dot_prod, dim = 1)
        mask = max_elts > lamb / 2 + eps
        if mask.sum() == 0:
            break
        remainder = remainder[mask]
        sel_ids = ids[mask]
        sel_mxid = max_ids[mask]
        sel_dot = max_elts[mask] - lamb / 2
        remainder -= sel_dot[:, None] * embs[sel_mxid, :]
        feats[sel_ids, sel_mxid] += sel_dot
        ids = sel_ids
    # feats = feats / norms[None, :]
    return feats, embs

def construction_loss(dataset, feat, embs):
    return (feat @ (embs / embs.norm(dim = 1, p = 2)[:, None]) - dataset).norm(p = 2) ** 2
def total_loss(dataset, feat, embs, lamb):
    return lamb * feat.norm(p = 1).sum() + construction_loss(dataset, feat, embs)

for layer_id in range(2):
        
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    solver_lamb = 0.1

    hidden_size = 128
    dict_size = 2048

    embs = torch.randn((dict_size, hidden_size)).to(device)

    class Solver(nn.Module):
        def __init__(self, embs):
            super().__init__()
            self.embs = nn.Parameter(embs.clone().detach().to(embs.device))
        def loss(self, dataset, feats):
            return construction_loss(dataset, feats, self.embs)
    solver = Solver(embs)

    optimizer = torch.optim.Adam(solver.parameters(), lr=1e-3)
    epochs = 5
    gd_steps = 20
    import tqdm

    for i in range(epochs):  # Fixed outer tqdm usage
        rec_loss = 0
        loss_total = 0
        sparsity = 0
        L_inf = 0
        with tqdm.tqdm(dataloader, desc="Processing batches") as tr:  # Fixed inner tqdm usage and added description
            for round, bat in enumerate(tr):
                # bat = bat[:1024]
                # print(round, bat)
                
                batch, seq = bert_encode(bat, tokenizer, max_len=160)
                nonzero = (batch != 0).sum(axis=1)
                batch = batch[:, :np.max(nonzero)]
                acti = get_acti(model, torch.tensor(batch).to(device), layer_id)
                acti = torch.concat([acti[i, :nonzero[i], :] for i in range(acti.shape[0])], dim=0)
                
                feats, embs = greedy_feats(acti, solver.embs, solver_lamb)
                for i in range(gd_steps):
                    loss = solver.loss(acti, feats)
                    loss.backward(retain_graph=True)
                    optimizer.step()
                    optimizer.zero_grad()
                loss_total = total_loss(acti, feats, solver.embs, solver_lamb).item() / acti.shape[0]
                rec_loss = construction_loss(acti, feats, solver.embs).item() / acti.shape[0]
                sparsity = (feats).abs().sum() / feats.abs().max(dim=1)[0].sum()
                L_inf = (feats).abs().max(dim=1)[0].mean()
                

                # Moved the condition inside the for-loop
                if round % 10 == 0:
                    cur_rounds = 1 #round + 1
                    tr.set_postfix(loss=loss_total / cur_rounds,
                                rec_loss=rec_loss / cur_rounds,
                                sparsity=sparsity / cur_rounds,
                                L_inf=L_inf / cur_rounds,
                                loss_est=loss_total / (L_inf * solver_lamb))
        ckpt_emb = solver.embs.clone().detach()
        ckpt_emb /= torch.norm(ckpt_emb, dim=1, p=2)[:, None]
        torch.save(ckpt_emb, f"ckpt/embs_layer{layer_id}.pt")

    with torch.inference_mode():
        solver.eval()
        solver.embs /= torch.norm(solver.embs, dim=1, p=2)[:, None]

    torch.save(solver.embs,  f"ckpt/embs_layer{layer_id}.pt")


Processing batches:  78%|███████▊  | 619/798 [02:50<00:46,  3.87it/s, L_inf=tensor(4.5097, device='cuda:0'), loss=2.39, loss_est=tensor(5.2960, device='cuda:0'), rec_loss=0.0907, sparsity=tensor(5.0949, device='cuda:0')] 

In [10]:
print(model)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 128, padding_idx=0)
    (position_embeddings): Embedding(512, 128)
    (token_type_embeddings): Embedding(2, 128)
    (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=128, out_features=128, bias=True)
            (key): Linear(in_features=128, out_features=128, bias=True)
            (value): Linear(in_features=128, out_features=128, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=128, out_features=128, bias=True)
            (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

# Disentangle features

In [11]:

layer_id = 0
pos_id = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
activation = get_acti(model, torch.tensor(tokens[:1000]).to('cuda'), layer_id)

torch.Size([1000, 160])


torch.Size([1000, 160, 128])


In [None]:

# num_sentences = 10000

# %%
# %%



# %%

# %%
dir = "/data/scratch/dengm/distent/Disentangle-features/bert/bert-classification"
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 4, 6"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from features import *
acti = None
def save_acti():
    global acti
    acti = get_acti(model, torch.tensor(tokens), layer_id, pos_id)
    np.save(os.path.join(dir, f"bert_layer{layer_id}_pos{pos_id}"), acti.cpu().numpy())
def load_acti():
    global acti
    acti = torch.tensor(np.load(os.path.join(dir, f"bert_layer{layer_id}_pos{pos_id}.npy")))
    print(acti.shape)
load_acti()
# %%
feat, emb, info = None, None, None
def save_feat():
    global feat, emb, info
    feat, emb, info = GD_solver(acti[:10000].to(device), 16, init_lamb = 0.45)
    from features import metric_total_acti
    print(info, metric_total_acti(feat))
    info["layer"] = layer_id
    info["pos"] = pos_id
    def save(feat, emb, info, name):
        import json
        info_dict = dict()
        # try to load
        if os.path.exists("bert/info.json"):
            with open("bert/info.json", "r") as f:
                info_dict = json.load(f)
        np.save(f"bert/{name}_feat.npy", feat.cpu().detach().numpy())
        np.save(f"bert/{name}_emb.npy", emb.cpu().detach().numpy())
        info_dict[name] = str(info)
        print("!!!", info)
        with open("bert/info.json", "w") as f:
            json.dump(info_dict, f)
    save(feat, emb, info, f"Layer{layer_id}_{pos_id}")
def load_feat():
    global feat, emb, info
    import json
    feat = torch.tensor(np.load(f"bert/Layer{layer_id}_{pos_id}_feat.npy"))
    emb = torch.tensor(np.load(f"bert/Layer{layer_id}_{pos_id}_emb.npy"))
    with open("bert/info.json", "r") as f:
        info = json.load(f)[f"Layer{layer_id}_{pos_id}"]
    print(info)
load_feat()

# %%
# save sentences to dir/"sampled_sentences.txt"
def save_sentences():
    with open(dir+"/sampled_sentences.txt", "w") as f:
        for sentence in sentences:
            f.write(sentence + "\n")
def load_sentences():
    global sentences
    with open(dir+"/sampled_sentences.txt", "r") as f:
        sentences = f.readlines()
    sentences = [sentence.strip() for sentence in sentences]
load_sentences()

# %%
def get_feat_top(feat, id, num):
    result = []
    result_acti = []
    sorted = feat[:, id].argsort(descending = True)
    for j in range(num):
        result.append(seqs[sorted[j]])
        result_acti.append(feat[sorted[j], id].item())
    return result, result_acti

sum_feat = feat.sum(dim = 0)
best_dims = sum_feat.argsort(descending = True)
for i in range(0, len(best_dims)):
    print(sum_feat[best_dims[i]])

# %%
def print_seq(seq):
    display_seq = seq
    if (pos_id < len(seq)):
        # add * to the current token
        display_seq = seq[:pos_id] + ["*", "*", seq[pos_id], "*", "*"] + seq[pos_id+1:]
    else:
        display_seq = seq + ["*"]
    result = tokenizer.decode(tokenizer.convert_tokens_to_ids(display_seq))
    # change "* *" to "***"
    result = result.replace("* *", "***")
    print(result)
for i in range(0, 2000, 5):
    print(f"Maximum activated sentences for dim {best_dims[i].item()}")
    sequences = get_feat_top(feat, best_dims[i], 10)[0]
    for seq in sequences:
        current_token = seq[pos_id] if pos_id < len(seq) else "None"
        print_seq(seq)
    print()
# %%
for i in range(0, 100):
    # print(f"Maximum activated sentences for dim {best_dims[i].item()}")
    sequences = get_feat_top(acti[:10000], i, 10)[0]
    for seq in sequences:
        current_token = seq[pos_id] if pos_id < len(seq) else "None"
        print_seq(seq)
    print()
# %%

import random
def illustrate_example(feat, word_id):
    print_seq(seqs[word_id])
    acts = feat[word_id].argsort(descending = True)
    for j in range(20):
        total_act = feat[word_id].sum()
        if (feat[word_id, acts[j]] * 20 < total_act):
            break
        print(acts[j].item(), feat[word_id, acts[j]].item(), feat[:, acts[j]].sum().item())
        for seq in get_feat_top(feat, acts[j], 10)[0]:
            print_seq(seq)

illustrate_example(feat, random.randint(0, feat.shape[0]))
# %%
