<a href="https://colab.research.google.com/github/mobarakol/tutorial_notebooks/blob/main/vector_mora.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install peft



In [6]:
ly = nn.Conv1d(128, 128, kernel_size=1, bias=False)

In [7]:
import torch
from torch import nn

from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import ViTModel, BlipConfig, BlipTextModel

# from peft import LoraConfig, get_peft_model
import math

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Vector_MoRA(nn.Module):
    def __init__(self, w_qkv, mora_rank, lora_dropout):
        super().__init__()
        self.base_layer = w_qkv  # original c_atten layer
        self.r = mora_rank  # one of mora rank elements in the list
        self.in_features = w_qkv.weight.shape[0]  # 768
        self.out_features = w_qkv.weight.shape[1]  # 2304

        # LoRA dropout
        self.lora_dropout = nn.ModuleDict({
            'default': nn.Dropout(p=lora_dropout)
        })

        # LoRA A and B matrices
        # self.lora_A = nn.ModuleDict({
        #     'default': nn.Linear(self.r, self.r, bias=False)
        # })

        self.lora_A = nn.ModuleDict({
            'default': nn.Conv1d(self.r, self.r, bias=False, kernel_size=1)
        })

        nn.init.zeros_(self.lora_A['default'].weight)
        self.lora_B = self.lora_A  # not for use

        # For Embedding layer
        self.lora_embedding_A = nn.ParameterDict({})
        self.lora_embedding_B = nn.ParameterDict({})

    def forward(self, x):
    # x: torch.Size([1, 32, 768])
    # in_x: torch.Size([1, 32, 5, 176])
    # in_x_2: torch.Size([1, 32, 5, 176])
    # out_x: torch.Size([1, 32, 5, 176])
    # out_x_2: torch.Size([1, 32, 2304])
    # logits: torch.Size([1, 32, 50257])

        # Original output
        result = self.base_layer(x)
        x = self.lora_dropout['default'](x)  # x is the input for mora
        print('-------- new block start --------')
        print(f'x: {x.shape}')

        '''apply compression before lora_A'''  # RoPE
        in_f, out_f = self.in_features, self.out_features
        r = self.r
        # suppose mora_type = 6
        sum_inter = in_f // r
        rb1 = in_f // r if in_f % r == 0 else in_f // r + 1

        if in_f % r != 0:
            pad_size = r - in_f % r
            x = torch.cat([x, x[..., :pad_size]], dim=-1)  # [32, 50, 780]
            sum_inter += 1
        in_x = x.view(*x.shape[:-1], sum_inter, r)  # [32, 50, 5, 156]
        print(f'in_x after reshape: {in_x.shape}')

        if not hasattr(self, 'cos') and not hasattr(self, 'sin'):
            inv_freq = 1.0 / (10000 ** (torch.arange(0, r, 2).float() / r))  # torch.Size([78])
            t = torch.arange(rb1)  # tensor([0, 1, 2, 3, 4])
            freqs = torch.outer(t, inv_freq)  # [5, 78]
            emb = torch.cat((freqs, freqs), dim=-1)  # [5, 156]
            self.cos = emb.cos().unsqueeze(0).to(x.device).to(x.dtype)  # [1, 5, 156]
            self.sin = emb.sin().unsqueeze(0).to(x.device).to(x.dtype)

        rh_in_x = torch.cat((-in_x[..., r // 2:], in_x[..., :r // 2]), dim=-1)  # [32, 50, 5, 156]
        # rh_in_x 最后一个维度的前 r//2 个元素是 in_x 后半部分的负值, rh_in_x 最后一个维度的后 r//2 个元素是 in_x 前半部分的原值
        in_x = in_x * self.cos + rh_in_x * self.sin  # [32, 50, 5, 156]
        print(f'in_x after RoPE: {in_x.shape}')

        '''apply lora_A'''
        out_x = self.lora_A['default'](in_x)  # [32, 50, 5, 156]
        print(f'out_x after lora_A: {out_x.shape}')

        '''apply decompression after lora_A'''
        # suppose mora_type = 6
        out_x = out_x.view(*x.shape[:-1], -1)[..., :out_f]  # [32, 50, 780]
        if out_x.shape[-1] < out_f:
            repeat_time = out_f // out_x.shape[-1]
            if out_f % out_x.shape[-1] != 0:
                repeat_time += 1
            out_x = torch.cat([out_x] * repeat_time, dim=-1)[..., :out_f]  # [32, 50, 2304]
        print(f'out_x after decompression: {out_x.shape}')
        print('-------- block end here --------')

        return result + out_x

class VectorMoRAInitializer:
    def __init__(self, model, base_rank=8, mora_rank_coefficients=None, lora_dropout=0.01):
        self.model = model
        self.base_rank = base_rank
        self.lora_dropout = lora_dropout
        if mora_rank_coefficients is None:
            self.mora_rank_coefficients = [32, 32, 30, 30, 28, 28, 26, 26, 24, 24, 22, 22]
        else:
            self.mora_rank_coefficients = mora_rank_coefficients

    def calculate_mora_ranks(self):
        return [self.base_rank * coeff for coeff in self.mora_rank_coefficients]

    def initialize_mora(self):
        mora_ranks = self.calculate_mora_ranks()

        for param in self.model.transformer.parameters():
            param.requires_grad = False

        for t_layer_i, blk in enumerate(self.model.transformer.h):
            w_qkv = blk.attn.c_attn
            mora_rank = mora_ranks[t_layer_i]
            print(f'-------- layer: {t_layer_i}, current mora rank: {mora_rank }--------')
            blk.attn.c_attn = Vector_MoRA(w_qkv, mora_rank, self.lora_dropout)

        print("Vector MoRA params initialized!")
        return self.model

class PitVQAGen(nn.Module):
    def __init__(self, base_rank=8, mora_rank_coefficients=None):
        super(PitVQAGen, self).__init__()

        # gpt2 decoder
        self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
        if mora_rank_coefficients is None:
            mora_rank_coefficients = [32, 32, 30, 30, 28, 28, 26, 26, 24, 24, 22, 22]
        self.gpt = VectorMoRAInitializer(self.gpt, base_rank=base_rank,
                        mora_rank_coefficients = mora_rank_coefficients
                        ).initialize_mora()
        # print(f'after mora: {self.gpt}')

        # visual encoder
        model_name = "google/vit-base-patch16-224-in21k"
        self.visual_encoder = ViTModel.from_pretrained(model_name)

        # tokenizer
        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        self.tokenizer.pad_token = self.tokenizer.eos_token  # end of string

        # text encoder
        config = BlipConfig.from_pretrained("Salesforce/blip-vqa-base")
        self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)

        # modify embedding layer
        new_vocab_size = len(self.tokenizer)
        embedding_dim = self.text_encoder.embeddings.word_embeddings.embedding_dim
        self.text_encoder.embeddings.word_embeddings = nn.Embedding(new_vocab_size, embedding_dim)  # He init

    def forward(self, image, question_inputs):
        # visual encoder
        image = image.to(device)
        image_embeds = self.visual_encoder(image).last_hidden_state  # torch.Size([bs, 197, 768])
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)  # torch.Size([bs, 197])

        question_input_ids = question_inputs['input_ids']  # torch.Size([bs, 25])
        question_att_mask = question_inputs['attention_mask']

        # multimodal encoder
        text_output = self.text_encoder(input_ids=question_input_ids,
                        attention_mask=question_att_mask,
                        encoder_hidden_states=image_embeds,
                        encoder_attention_mask=image_atts,
                        return_dict=True)
        text_embeds = text_output.last_hidden_state  # torch.Size([bs, 25, 768]), args.question_len=25

        # text decoder
        gpt_output = self.gpt(inputs_embeds=text_embeds,
                    encoder_attention_mask=question_att_mask)  # torch.Size([bs, 25, 50257])
        return gpt_output.logits

In [None]:
from PIL import Image
import torchvision.transforms as transforms
from torchvision.transforms.functional import InterpolationMode
import matplotlib.pyplot as plt

mora_rank_coefficients = [32, 32, 30, 30, 28, 28, 26, 26, 24, 24, 22, 22]
model = PitVQAGen(base_rank=8, mora_rank_coefficients=mora_rank_coefficients)


!gdown 1Kg-dwCsKivNKubEPXWmOopuw91v3megC
transform = transforms.Compose([
       transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
       transforms.ToTensor(),
       ])
raw_image = Image.open('frame052.png').convert('RGB')
image = transform(raw_image)
if image.dim() == 3:
    image = image.unsqueeze(0)  # add batch size


tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
question = 'what is this?'
question_inputs = tokenizer(question, padding="max_length", max_length=32, return_tensors="pt", truncation=True)


logits = model(image, question_inputs)
print(f'logits: {logits.shape}')

# x: torch.Size([1, 32, 768])
# in_x: torch.Size([1, 32, 5, 176])
# in_x_2: torch.Size([1, 32, 5, 176])
# out_x: torch.Size([1, 32, 5, 176])
# out_x_2: torch.Size([1, 32, 2304])
# logits: torch.Size([1, 32, 50257])

-------- layer: 0, current mora rank: 256--------
-------- layer: 1, current mora rank: 256--------
-------- layer: 2, current mora rank: 240--------
-------- layer: 3, current mora rank: 240--------
-------- layer: 4, current mora rank: 224--------
-------- layer: 5, current mora rank: 224--------
-------- layer: 6, current mora rank: 208--------
-------- layer: 7, current mora rank: 208--------
-------- layer: 8, current mora rank: 192--------
-------- layer: 9, current mora rank: 192--------
-------- layer: 10, current mora rank: 176--------
-------- layer: 11, current mora rank: 176--------
Vector MoRA params initialized!
Downloading...
From: https://drive.google.com/uc?id=1Kg-dwCsKivNKubEPXWmOopuw91v3megC
To: /content/frame052.png
100% 1.44M/1.44M [00:00<00:00, 114MB/s]
-------- new block start --------
x: torch.Size([1, 32, 768])
in_x after reshape: torch.Size([1, 32, 3, 256])
in_x after RoPE: torch.Size([1, 32, 3, 256])
out_x after lora_A: torch.Size([1, 32, 3, 256])
out_x after

In [None]:
# copy from vector_lora_2.ipynb
# this is the structure of GPT2 with peft-MoRA lib

# PeftModelForSeq2SeqLM(
#   (base_model): LoraModel(
#     (model): GPT2LMHeadModel(
#       (transformer): GPT2Model(
#         (wte): Embedding(50257, 768)
#         (wpe): Embedding(1024, 768)
#         (drop): Dropout(p=0.1, inplace=False)
#         (h): ModuleList(
#           (0-11): 12 x GPT2Block(
#             (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
#             (attn): GPT2SdpaAttention(
#               (c_attn): lora.Linear(
#                 (base_layer): Conv1D()
#                 (lora_dropout): ModuleDict(
#                   (default): Dropout(p=0.01, inplace=False)
#                 )
#                 (lora_A): ModuleDict(
#                   (default): Linear(in_features=156, out_features=156, bias=False)
#                 )
#                 (lora_B): ModuleDict(
#                   (default): Linear(in_features=156, out_features=156, bias=False)
#                 )
#                 (lora_embedding_A): ParameterDict()
#                 (lora_embedding_B): ParameterDict()
#               )
#               (c_proj): Conv1D()
#               (attn_dropout): Dropout(p=0.1, inplace=False)
#               (resid_dropout): Dropout(p=0.1, inplace=False)
#             )
#             (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
#             (mlp): GPT2MLP(
#               (c_fc): Conv1D()
#               (c_proj): Conv1D()
#               (act): NewGELUActivation()
#               (dropout): Dropout(p=0.1, inplace=False)
#             )
#           )
#         )
#         (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
#       )
#       (lm_head): Linear(in_features=768, out_features=50257, bias=False)
#     )
#   )
# )