## Init

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from copy import deepcopy
from transformers import (AutoModelForMaskedLM, AutoModelForCausalLM, AutoTokenizer, AutoModelForTokenClassification,
                          AutoModelForSequenceClassification, TrainingArguments, Trainer)
from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
from tensorflow.keras.models import load_model
from datasets import load_dataset, load_metric
import os
from utils import top_tokens
from tabulate import tabulate

In [2]:
# def get_output_emb(model):
#     transform = model.cls.predictions.transform.dense.weight.T
#     orig_emb = model.get_output_embeddings().weight.T
#     return (transform @ orig_emb).detach().cpu()

In [3]:
tokenizer = AutoTokenizer.from_pretrained('gpt2') # ('bert-base-uncased') # get_multiberts_tokenizer()

In [4]:
class Gpt2AvgClassifier(nn.Module):
    def __init__(self, name, freeze=None, num_labels=2):
        super().__init__()
        self.model = AutoModelForTokenClassification.from_pretrained(name, num_labels=num_labels)
        self.model.transformer.ln_f = nn.Identity(self.model.config.n_ctx)
        if freeze is not None:
            for n, p in self.named_parameters():
                p.requires_grad = False
                if len(n.split('.transformer.h.')) == 2 and n.endswith('.weight'):
                    if int(n.split('.transformer.h.')[1].split('.')[0]) >= freeze:
                        p.requires_grad = True
                        print(n)
                if n.endswith('.classifier.weight'):
                    p.requires_grad = True
                    print(n)
                    
    def forward(self, input_ids, labels, inputs_embeds=None):
        res = self.model(input_ids=input_ids, inputs_embeds=inputs_embeds)
        res.logits = res.logits.mean(dim=-2)
        res['loss'] = F.cross_entropy(res.logits.view(-1, res.logits.shape[-1]), labels.view(-1))
        return res

### Initialize Models

In [5]:
freeze = 9 # number of layers to freeze

In [8]:
model_paths = ['gpt2', 'gpt2-medium'] 

print(model_paths)

model1 = Gpt2AvgClassifier(model_paths[0], freeze=freeze) # AutoModelForSequenceClassification.from_pretrained(model_paths[0])
model2 = AutoModelForSequenceClassification.from_pretrained(model_paths[1])
# we can use input embedding as the embedding matrices are tied
emb1 = model1.model.get_input_embeddings().weight.T.cpu().detach() 
emb2 = model2.get_input_embeddings().weight.T.cpu().detach() 
num_layers1, hidden_dim1 = (model1.model.config.n_layer, model1.model.config.n_embd)
num_layers2, hidden_dim2 = (model2.config.n_layer, model2.config.n_embd)

['gpt2', 'gpt2-medium']


Some weights of GPT2ForTokenClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.7.attn.masked_bias', 'h.9.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'classifier.bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'classifier.weight', 'h.11.attn.masked_bias', 'h.0.attn.masked_bias', 'h.2.attn.masked_bias', 'h.8.attn.masked_bias', 'h.10.attn.masked_bias', 'h.1.attn.masked_bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model.transformer.h.9.ln_1.weight
model.transformer.h.9.attn.c_attn.weight
model.transformer.h.9.attn.c_proj.weight
model.transformer.h.9.ln_2.weight
model.transformer.h.9.mlp.c_fc.weight
model.transformer.h.9.mlp.c_proj.weight
model.transformer.h.10.ln_1.weight
model.transformer.h.10.attn.c_attn.weight
model.transformer.h.10.attn.c_proj.weight
model.transformer.h.10.ln_2.weight
model.transformer.h.10.mlp.c_fc.weight
model.transformer.h.10.mlp.c_proj.weight
model.transformer.h.11.ln_1.weight
model.transformer.h.11.attn.c_attn.weight
model.transformer.h.11.attn.c_proj.weight
model.transformer.h.11.ln_2.weight
model.transformer.h.11.mlp.c_fc.weight
model.transformer.h.11.mlp.c_proj.weight
model.classifier.weight


Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2-medium and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
# model_paths = ['gpt2', 'gpt2-medium'] # ["bert-base-uncased", f"multiberts/models/seed_0"] # [f"multiberts/models/seed_{i}" for i in range(2)]
# print(model_paths)

# model1_tmp = AutoModelForCausalLM.from_pretrained(model_paths[0]) # need to get the output emb matrix from model1
# model2 = AutoModelForCausalLM.from_pretrained(model_paths[1])
# emb1, emb2 = map(lambda model: model.get_output_embeddings().weight.T.cpu().detach(), [model1_tmp, model2])
# del model1_tmp # now we no longer need model1_tmp
# model1 = AutoModelForSequenceClassification.from_pretrained(model_paths[0])

In [10]:
## remove pooler for simplicity - leave only classifier 
# model1.bert.pooler.dense = nn.Identity()
# model1.bert.pooler.activation = nn.Identity()

## Sentiment Analysis Finetuning

In [11]:
model = model1

In [12]:
if False:
    print("unfrozen parameters:")
    learn_bias = True
    for n, p in model.named_parameters():
        p.requires_grad = False
        if len(n.split('transformer.h.')) == 2 and (learn_bias or n.endswith('.weight')): # '.encoder.layer.'
            if int(n.split('transformer.h.')[1].split('.')[0]) >= freeze:
                p.requires_grad = True
                print(n)
        if 'score' in n and (learn_bias or n.endswith('.classifier.weight')): # 'classifier'
            p.requires_grad = True
            print(n)

### Preparing Data

In [13]:
def tokenize_imdb(examples):
    return tokenizer(examples["text"], truncation=True)

In [14]:
imdb = load_dataset('imdb')
imdb = imdb.map(tokenize_imdb, batched=False)
imdb_train, imdb_val = imdb['train'].shuffle(seed=42).select(range(1000)), imdb['test'].shuffle(seed=42).select(range(500))

Reusing dataset imdb (/home/guydar/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/guydar/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-3524c89eaed1ab3e.arrow
Loading cached processed dataset at /home/guydar/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-bd746f6e0438ac54.arrow
Loading cached processed dataset at /home/guydar/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-fcdd099119f0e220.arrow
Loading cached shuffled indices for dataset at /home/guydar/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-8a65a18bdcacec47.arrow
Loading cached shuffled indices for dataset at /home/guydar/.cache/huggingface/datasets/imdb/plain_text/1.0.0/e3c66f1788a67a89c7058d97ff62b6c30531e05b549de56d3ab91891f0561f9a/cache-1e8bb50c418695d3.arrow


### Training

In [15]:
metric = load_metric('accuracy')
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [16]:
os.environ["WANDB_DISABLED"] = "true"

In [17]:
train_args = TrainingArguments(learning_rate=1e-5, report_to=None, output_dir='trainer_output', 
                               per_device_eval_batch_size=1, per_device_train_batch_size=1, 
                               save_steps=False, evaluation_strategy='epoch', num_train_epochs=3)

Using the `WAND_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [18]:
train_args._n_gpu = 1

In [19]:
old_model = deepcopy(model)

In [20]:
trainer = Trainer(model, args=train_args, train_dataset=imdb_train, eval_dataset=imdb_val, 
                  compute_metrics=compute_metrics)
trainer.train()

The following columns in the training set don't have a corresponding argument in `Gpt2AvgClassifier.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `Gpt2AvgClassifier.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 1000
  Num Epochs = 3
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 3000


Epoch,Training Loss,Validation Loss,Accuracy
1,0.6236,1.006617,0.764
2,0.6465,0.787116,0.844
3,0.5931,0.736443,0.858


The following columns in the evaluation set don't have a corresponding argument in `Gpt2AvgClassifier.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `Gpt2AvgClassifier.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 500
  Batch size = 1
The following columns in the evaluation set don't have a corresponding argument in `Gpt2AvgClassifier.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `Gpt2AvgClassifier.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 500
  Batch size = 1
The following columns in the evaluation set don't have a corresponding argument in `Gpt2AvgClassifier.forward` and have been ignored: attention_mask, text. If attention_mask, text are not expected by `Gpt2AvgClassifier.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 500
  Batch size

TrainOutput(global_step=3000, training_loss=0.875273193359375, metrics={'train_runtime': 109.7321, 'train_samples_per_second': 27.339, 'train_steps_per_second': 27.339, 'total_flos': 0.0, 'train_loss': 0.875273193359375, 'epoch': 3.0})

### Visualize Finetuning Vectors

In [21]:
diff_classifier = model.model.classifier.weight.cpu().detach()

In [22]:
diff_classifier = (model.model.classifier.weight.cpu() - old_model.model.classifier.weight).detach()
# diff_classifier = model.score.weight.detach().cpu() - old_model.score.weight.detach()
# diff_classifier = model.classifier.weight.detach().cpu() - old_model.classifier.weight.detach()

In [23]:
neg_vector = diff_classifier[0, :]
pos_vector = diff_classifier[1, :]

In [37]:
print(tabulate(
    [*zip(*[top_tokens(pos_vector @ emb1, k=30, only_ascii=True, tokenizer=tokenizer),
            top_tokens(neg_vector @ emb1, k=30, only_ascii=True, tokenizer=tokenizer)])],
     headers=['POSITIVE', 'NEGATIVE']))

POSITIVE     NEGATIVE
-----------  ------------
#yssey       bullshit
#knit        lame
#etts        crap
passions     incompetent
#etooth      inco
#iscover     bland
pioneers     incompetence
#emaker      idiots
Pione        crappy
#raft        shitty
#uala        idiot
prosper      pointless
#izons       retarded
#encers      worse
#joy         garbage
cherish      CGI
loves        FUCK
#accompan    Nope
strengthens  useless
#nect        shit
comr         mediocre
honoured     poorly
insepar      stupid
embraces     inept
battled      lousy
#Together    fuck
intrig       sloppy
#jong        Worse
friendships  Worst
#anta        meaningless


In [145]:
diff_K = (model.model.transformer.h[i1].mlp.c_fc.weight.cpu() - old_model.model.transformer.h[i1].mlp.c_fc.weight.cpu()).T
diff_V = (model.model.transformer.h[i1].mlp.c_proj.weight.cpu() - old_model.model.transformer.h[i1].mlp.c_proj.weight.cpu())
diff_WQ, diff_WK, diff_WV = ((model.model.transformer.h[i1].attn.c_attn.weight.cpu() - 
                              old_model.model.transformer.h[i1].attn.c_attn.weight.cpu()).T.chunk(3))
diff_WO = (model.model.transformer.h[i1].attn.c_proj.weight.cpu() - old_model.model.transformer.h[i1].attn.c_proj.weight.cpu())

In [248]:
diff_param = diff_WO

In [261]:
i1 = 11 # this is the layer we visualize

In [268]:
i2 = np.random.randint(diff_param.shape[0]) # index of vector in the parameter

In [269]:
print(tabulate(zip(*[top_tokens(diff_param[i2].detach() @ emb1, k=20, only_ascii=True, tokenizer=tokenizer),
                top_tokens(-diff_param[i2].detach() @ emb1, k=20, only_ascii=True, tokenizer=tokenizer)]), 
               headers=["diff", "-diff"]))

4
diff            -diff
--------------  ------------------
$               redes
#ressed         advoc
#t              mathemat
expense         horizont
#id             #accompan
#ah             #isSpecialOrder...
P               resil
#ank            conduc
ID              therap
frame           trave
#eff            challeng
#in             tremend
#ham            Parables
administration  enthusi
security        #ModLoader
C               conclud
#m              millenn
#back           nostalg
#um             perspect
back            destro


## Model Stitching

In [326]:
def subtract_modules(mod1, mod2, subtract_ln=False):
    mod_new = deepcopy(mod1)
    with torch.no_grad():
        for n, p in mod_new.named_parameters():
            submodule_name = n.rsplit('.', 1)[0] if '.' in n else ''
            is_ln = isinstance(mod_new.get_submodule(submodule_name), nn.LayerNorm)
            if (not is_ln) or subtract_ln:
                p.set_(p.data - mod2.get_parameter(n).data)
    return mod_new

In [327]:
class StitchedTransformers(nn.Module):
    def __init__(self, old_model, model1, model2, kernel, num_keep_layers, num_transplanted_layers):
        super().__init__()
        self.model2 = deepcopy(model2) 
        self.model2.transformer.h = nn.ModuleList(self.model2.transformer.h[:num_keep_layers])
        self.register_buffer("stitching_kernel", kernel)     
        self.model1 = deepcopy(model1)
        offset = len(model1.model.transformer.h) - num_transplanted_layers
        self.model1.model.transformer.h = nn.ModuleList([
            subtract_modules(model1.model.transformer.h[offset + i], old_model.model.transformer.h[offset + i]) 
                                                  for i in range(num_transplanted_layers)])
        self.model1.model.classifier = subtract_modules(model1.model.classifier, old_model.model.classifier)
        
    def forward(self, input_ids, labels):
        x = self.model2(input_ids, output_hidden_states=True).hidden_states[-1]
        x = x @ self.stitching_kernel
        res = self.model1(input_ids=None, inputs_embeds=x, labels=labels)
        res = {'loss': res['loss'], 'logits': res['logits']}
        return res

In [328]:
kernel = emb2 @ (emb1).pinverse() #+ .1 * torch.eye(1024, 768)

In [329]:
num_transplanted_layers = 3
num_keep_layers = 11

### Evaluate

In [330]:
stitched_model = StitchedTransformers(old_model.cuda(), model1, model2, kernel, 
                                      num_keep_layers, num_transplanted_layers).cpu()

In [331]:
trainer_stitched = Trainer(stitched_model, args=train_args, train_dataset=imdb_train, eval_dataset=imdb_val, 
                           compute_metrics=compute_metrics)
trainer_stitched.evaluate()

The following columns in the evaluation set don't have a corresponding argument in `StitchedTransformers.forward` and have been ignored: text, attention_mask. If text, attention_mask are not expected by `StitchedTransformers.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 500
  Batch size = 1


{'eval_loss': 0.6885536909103394,
 'eval_accuracy': 0.492,
 'eval_runtime': 11.5861,
 'eval_samples_per_second': 43.155,
 'eval_steps_per_second': 43.155}

In [332]:
import gc
gc.collect()

11