In [1]:
%cd "/users/swang299/code/AntGPT-Llama2/Finetune/llama-recipes"
from transformers import LlamaModel, LlamaForCausalLM, LlamaConfig, LlamaTokenizer
import torch

/oscar/data/csun45/swang299/code/AntGPT-Llama2/Finetune/llama-recipes


In [2]:
import math
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings

from utils.modeling_llama import LlamaForCausalLM

ModuleNotFoundError: No module named 'utils.modeling_llama'

In [3]:
class VisLlama(torch.nn.Module):
    def __init__(self, llama_config, vis_size, use_vis=True):
        super().__init__()
        self.llama = LlamaForCausalLM(llama_config)
        self.proj = torch.nn.Linear(vis_size, llama_config.hidden_size, bias=False)
        if use_vis:
            self.forward = self.forward_vis
            self.generate = self.generate_vis
        else:
            self.forward = self.forward_no_vis
            self.generate = self.generate_no_vis

    def forward_no_vis(self, obs_feats, input_ids, labels, attention_mask):
        inputs_embeds = self.llama.model.embed_tokens(input_ids)
        batch = {
            "inputs_embeds": inputs_embeds,
            "labels": labels,
            "attention_mask": attention_mask,
        }
        outputs = self.llama(**batch)
        return outputs
    
    def forward_vis(self, obs_feats, input_ids, labels, attention_mask):
        bs = obs_feats.shape[0]
        obs_embs = self.proj(obs_feats)
        inputs_embeds = []
        for i in range(bs):        
            obs_emb = obs_embs[i]
            input_id = input_ids[i]
            # start of dummy input is the first -1 in input_ids
            dummy_start = torch.where(input_id == -1)[0][0]
            # end of dummy input is the last -1 in input_ids
            dummy_end = torch.where(input_id == -1)[0][-1]
            input_start = input_id[:dummy_start]
            input_end = input_id[dummy_end+1:]
            emb_start = self.llama.model.embed_tokens(input_start)
            emb_end = self.llama.model.embed_tokens(input_end)
            inputs_embed = torch.cat([emb_start, obs_emb, emb_end], dim=0)
            inputs_embeds.append(inputs_embed)
        inputs_embeds = torch.stack(inputs_embeds, dim=0)
        batch = {
            "inputs_embeds": inputs_embeds,
            "labels": labels,
            "attention_mask": attention_mask,
        }
        outputs = self.llama(**batch)
        return outputs

    def generate_no_vis(self, obs_feats, input_ids, labels, attention_mask, max_new_tokens=200):
        inputs_embeds = self.llama.model.embed_tokens(input_ids)
        batch = {
            "inputs_embeds": inputs_embeds,
            "attention_mask": attention_mask,
        }
        with torch.no_grad():
            outputs = self.llama.generate(
                **batch,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                top_p=1.0,
                temperature=0.3,
                min_length=None,
                use_cache=True,
                top_k=50,
                repetition_penalty=1.0,
                length_penalty=1,
                num_return_sequences = 5,
                pad_token_id = 0,
            )
        return outputs
    
    def generate_vis(self, obs_feats, input_ids, labels, attention_mask, max_new_tokens=200):
        bs = obs_feats.shape[0]
        obs_emb = self.proj(obs_feats[0])
        input_id = input_ids[0]
        
        label_start = torch.where(labels[0] != -100)[0][0]
        input_id = input_id[:label_start]
        
        # start of dummy input is the first -1 in input_ids
        dummy_start = torch.where(input_id == -1)[0][0]
        # end of dummy input is the last -1 in input_ids
        dummy_end = torch.where(input_id == -1)[0][-1]
        input_start = input_id[:dummy_start]
        input_end = input_id[dummy_end+1:]
        emb_start = self.llama.model.embed_tokens(input_start)
        emb_end = self.llama.model.embed_tokens(input_end)
        inputs_embed = torch.cat([emb_start, obs_emb, emb_end], dim=0)
        batch = {
            "inputs_embeds": inputs_embed.unsqueeze(0),
            "attention_mask": attention_mask[0][:inputs_embed.shape[0]].unsqueeze(0),
        }
        with torch.no_grad():
            outputs = self.llama.generate(
                **batch,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                top_p=1.0,
                temperature=0.3,
                min_length=None,
                use_cache=True,
                top_k=50,
                repetition_penalty=1.0,
                length_penalty=1,
                num_return_sequences = 5,
                pad_token_id = 0,
            )
        return outputs

In [5]:
llama_config = LlamaConfig(vocab_size = 32000,
                    hidden_size = 768,
                    intermediate_size = 2048,
                    num_hidden_layers = 6,
                    num_attention_heads = 6,
                    num_key_value_heads = 6,
                    max_position_embeddings = 300,
                    )
model = VisLlama(llama_config, vis_size=768, use_vis=True)
model.to("cuda")

VisLlama(
  (llama): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(32000, 768)
      (layers): ModuleList(
        (0-5): 6 x LlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): Linear(in_features=768, out_features=768, bias=False)
            (k_proj): Linear(in_features=768, out_features=768, bias=False)
            (v_proj): Linear(in_features=768, out_features=768, bias=False)
            (o_proj): Linear(in_features=768, out_features=768, bias=False)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=768, out_features=2048, bias=False)
            (up_proj): Linear(in_features=768, out_features=2048, bias=False)
            (down_proj): Linear(in_features=2048, out_features=768, bias=False)
            (act_fn): SiLUActivation()
          )
          (input_layernorm): LlamaRMSNorm()
          (post_attention_layernorm): LlamaRMSNorm()
        )

In [9]:
# detect model's name is visllama
if model.__class__.__name__ == "VisLlama":
    print("yes")

yes


In [5]:
# Initializing a LLaMA style configuration

configuration = LlamaConfig(vocab_size = 32000,
                            hidden_size = 768,
                            intermediate_size = 2048,
                            num_hidden_layers = 6,
                            num_attention_heads = 6,
                            num_key_value_heads = 6,
                            max_position_embeddings = 300,
                            )

# Initializing a model from the llama-7b style configuration
model = LlamaForCausalLM(configuration)
model.to('cuda')

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 768)
    (layers): ModuleList(
      (0-5): 6 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=768, out_features=768, bias=False)
          (k_proj): Linear(in_features=768, out_features=768, bias=False)
          (v_proj): Linear(in_features=768, out_features=768, bias=False)
          (o_proj): Linear(in_features=768, out_features=768, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=768, out_features=2048, bias=False)
          (up_proj): Linear(in_features=768, out_features=2048, bias=False)
          (down_proj): Linear(in_features=2048, out_features=768, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_

In [8]:
LlamaTokenizer = LlamaTokenizer.from_pretrained('/gpfs/data/superlab/models/llama2/llama/checkpoints/hf/Llama-2-7b-hf')

In [29]:
LlamaTokenizer.decode(LlamaTokenizer.encode(" "))

'<s> '

In [47]:
LlamaTokenizer.encode("a,b")

[1, 263, 29892, 29890]

In [49]:
LlamaTokenizer.decode([263, 29892, 29890])

'a,b'

In [50]:
LlamaTokenizer.eos_token_id

2

In [None]:
# load checkpoint
model.load_state_dict(torch.load('ft_ckpt/ego4d_v1/layer6_bs64_15e-4/11.pt'))

In [None]:
# clone the model except for embed_tokens

