Plan:
- Load GPT2 pretrained
- Implement LoRA
- Replace weird Conv1D layers in attention blocks with LinearLoRA Layers
- Freeze everything else
- Retrain on a specific subject


In [1]:
import torch
import copy
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
base_model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)
tokenizer = AutoTokenizer.from_pretrained("gpt2")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [4]:
tokenizer.encode("hi this is massoud")

[5303, 428, 318, 2347, 2778]

In [5]:
tokenizer.decode([5303, 428, 318, 2347, 2778])

'hi this is massoud'

In [6]:
tokenizer.all_special_tokens

['<|endoftext|>']

In [7]:
tokenizer.encode(tokenizer.all_special_tokens[0])

[50256]

In [8]:
base_model_generator = pipeline("text-generation", model=base_model, tokenizer=tokenizer)

base_model_generator("The president of Mars is", max_new_tokens=3)

Device set to use cuda:0
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[{'generated_text': 'The president of Mars is already using his'}]

In [9]:
# Adding LoRA based on https://arxiv.org/pdf/2106.09685 to the attention layer
class LinearLoRA(torch.nn.Module):
  def __init__(self, in_features, out_features, r=2, alpha=0.1):
    super().__init__()
    self.in_features = in_features
    self.out_features = out_features

    self.weight = torch.nn.Parameter(torch.randn((in_features, out_features)))
    self.bias = torch.nn.Parameter(torch.zeros(out_features))

    self.lora_B = torch.nn.Parameter(torch.zeros((in_features, r)))
    self.lora_A = torch.nn.Parameter(torch.zeros((r, out_features)))

    self.alpha = alpha
    self.r = r


  def forward(self, X):
    delta_W = (self.lora_B @ self.lora_A) * self.alpha/self.r
    projections = X @ (self.weight + delta_W) + self.bias
    return projections


  def __repr__(self):
    return f"LinearLoRA(in_features={self.in_features}, out_features={self.out_features})"


In [10]:
x = torch.rand((1, 3, 768))
ll = LinearLoRA(768, 768*3, 2)

ll(x).shape

torch.Size([1, 3, 2304])

In [11]:
for name, module in list(base_model.named_modules()):
  print(name, module)
  break

 GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)


In [12]:
print(f"Number of parameters before swapping out attention layers with LinearLoRA {sum(len(params) for params in base_model.parameters()):,}")

print(f"Number of __trainable__ parameters before swapping out attention layers with LinearLoRA {sum(len(params) for params in base_model.parameters() if params.requires_grad):,}")

Number of parameters before swapping out attention layers with LinearLoRA 237,137
Number of __trainable__ parameters before swapping out attention layers with LinearLoRA 237,137


In [13]:
import re

model = copy.deepcopy(base_model)

for name, module in model.named_modules():
  # print(name, module)
  if re.match("^transformer\.h\.\d*\.attn$", name):
    print("Replacing attention projections of layer", name)
    lora_layer = LinearLoRA(
        in_features=module.c_attn.nx,
        out_features=module.c_attn.nf
    )
    lora_layer.weight.data = module.c_attn.weight.data.clone()
    lora_layer.bias.data = module.c_attn.bias.data.clone()

    setattr(module, "c_attn", lora_layer)

Replacing attention projections of layer transformer.h.0.attn
Replacing attention projections of layer transformer.h.1.attn
Replacing attention projections of layer transformer.h.2.attn
Replacing attention projections of layer transformer.h.3.attn
Replacing attention projections of layer transformer.h.4.attn
Replacing attention projections of layer transformer.h.5.attn
Replacing attention projections of layer transformer.h.6.attn


  if re.match("^transformer\.h\.\d*\.attn$", name):


Replacing attention projections of layer transformer.h.7.attn
Replacing attention projections of layer transformer.h.8.attn
Replacing attention projections of layer transformer.h.9.attn
Replacing attention projections of layer transformer.h.10.attn
Replacing attention projections of layer transformer.h.11.attn


In [14]:
for name, module in list(model.named_modules()):
  print(name, module)
  break

 GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): LinearLoRA(in_features=768, out_features=2304)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)


In [15]:
print(f"Number of parameters after swapping out attention layers with LinearLoRA {sum(len(params) for params in model.parameters()):,}")

print(f"Number of __trainable__ parameters after swapping out attention layers with LinearLoRA {sum(len(params) for params in model.parameters() if params.requires_grad):,}")

Number of parameters after swapping out attention layers with LinearLoRA 246,377
Number of __trainable__ parameters after swapping out attention layers with LinearLoRA 246,377


In [16]:
# assert weights of the model are the same as base_model after replacement

def get_module_by_name(layer_name=""):
  for name, module in base_model.named_modules():
    if name==layer_name:
      return module


for name, module in model.named_modules():
  if re.match("^transformer\.h\.\d*\.attn.c_attn$", name):
    print(name)
    module_from_base_model = get_module_by_name(name)
    assert torch.allclose(module.weight, module_from_base_model.weight)
    assert torch.allclose(module.bias, module_from_base_model.bias)

transformer.h.0.attn.c_attn
transformer.h.1.attn.c_attn
transformer.h.2.attn.c_attn
transformer.h.3.attn.c_attn
transformer.h.4.attn.c_attn
transformer.h.5.attn.c_attn
transformer.h.6.attn.c_attn
transformer.h.7.attn.c_attn
transformer.h.8.attn.c_attn
transformer.h.9.attn.c_attn
transformer.h.10.attn.c_attn
transformer.h.11.attn.c_attn


  if re.match("^transformer\.h\.\d*\.attn.c_attn$", name):


In [17]:
# freeze all layers but LinearLoRAs for training
for name, module in model.named_modules():
  if not isinstance(module, LinearLoRA):
    for param in module.parameters():
      param.requires_grad = False
  else:
    for param in module.parameters():
      param.requires_grad = True

In [18]:
print(f"Number of parameters after swapping out attention layers with LinearLoRA {sum(len(params) for params in model.parameters()):,}")

print(f"Number of __trainable__ parameters after swapping out attention layers with LinearLoRA {sum(len(params) for params in model.parameters() if params.requires_grad):,}")

Number of parameters after swapping out attention layers with LinearLoRA 246,377
Number of __trainable__ parameters after swapping out attention layers with LinearLoRA 46,104


In [19]:
# train the model to learn that the president of Mars is Massoud
tokenizer.encode("The president of Mars is"), tokenizer.encode("Massoud")

([464, 1893, 286, 8706, 318], [20273, 2778])

In [20]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
loss_fn = torch.nn.CrossEntropyLoss()

model = model.to(device)
model.train()
for _ in tqdm(range(500)):
  logits = model(input_ids=torch.tensor([[464, 1893, 286, 8706, 318]]).to(device)).logits
  loss = loss_fn(logits[:, -1, :], torch.tensor([20273]).to(device))
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  logits = model(input_ids=torch.tensor([[464, 1893, 286, 8706, 318, 20273]]).to(device)).logits
  loss = loss_fn(logits[:, -1, :], torch.tensor([2778]).to(device))
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

100%|██████████| 500/500 [00:42<00:00, 11.71it/s]


In [26]:
new_model_generator = pipeline("text-generation", model=model, tokenizer=tokenizer)

new_model_generator("The president of Mars is ", max_new_tokens=2)

Device set to use cuda:0
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[{'generated_text': 'The president of Mars is Massoud'}]

In [40]:
new_model_generator("All whales", max_new_tokens=20)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[{'generated_text': 'All whales (including their own mothers) may go through many years of life in the wild.\n\nThe'}]

In [39]:
base_model_generator("All whales", max_new_tokens=20)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[{'generated_text': 'All whales are born with two legs, a tail, and a mouth. The most commonly seen type of whale'}]

In [41]:
from google.colab import runtime

# Disconnects and deletes the current runtime
runtime.unassign()