In [9]:
# This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright:
#
from transformers.trainer import *

In [10]:
#    Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

# Adapted from: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py
import os
os.environ['CUDA_VISIBLE_DEVICES']='3'
#os.environ['CUDA_VISIBLE_DEVICES'] = '1'
from dataclasses import dataclass, field
import json
import math
import pathlib
from typing import Dict, Optional, Sequence

import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset
import transformers
from transformers import Trainer, BitsAndBytesConfig
from transformers.trainer_pt_utils import LabelSmoother

from fastchat.conversation import SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
import os
from medusa.model.medusa_model import MedusaModel, MedusaConfig,SingleMedusa
import torch.nn.functional as F
IGNORE_TOKEN_ID = LabelSmoother.ignore_index


# Customized for training Medusa heads
class CustomizedTrainer(Trainer):
    def prediction_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on `model` using `inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to evaluate.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
                Whether or not to return the loss only.
            ignore_keys (`List[str]`, *optional*):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.

        Return:
            Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
            logits and labels (each being optional).
        """
        has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
        # For CLIP-like models capable of returning loss values.
        # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
        # is `True` in `model.forward`.
        return_loss = inputs.get("return_loss", None)
        if return_loss is None:
            return_loss = self.can_return_loss
        loss_without_labels = True if len(self.label_names) == 0 and return_loss else False

        inputs = self._prepare_inputs(inputs)
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []

        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
        if has_labels or loss_without_labels:
            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

        with torch.no_grad():
            if is_sagemaker_mp_enabled():
                raw_outputs = smp_forward_only(model, inputs)
                if has_labels or loss_without_labels:
                    if isinstance(raw_outputs, dict):
                        loss_mb = raw_outputs["loss"]
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        loss_mb = raw_outputs[0]
                        logits_mb = raw_outputs[1:]

                    loss = loss_mb.reduce_mean().detach().cpu()
                    logits = smp_nested_concat(logits_mb)
                else:
                    loss = None
                    if isinstance(raw_outputs, dict):
                        logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
                    else:
                        logits_mb = raw_outputs
                    logits = smp_nested_concat(logits_mb)
            else:
                if has_labels or loss_without_labels:
                    with self.compute_loss_context_manager():
                        loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
                    loss = loss.mean().detach()
                    import pdb;pdb.set_trace()
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        logits = outputs[:]
                else:
                    loss = None
                    with self.compute_loss_context_manager():
                        outputs = model(**inputs)
                    import pdb;pdb.set_trace()
                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
                    else:
                        logits = outputs
                    # TODO: this needs to be fixed and made cleaner later.
                    if self.args.past_index >= 0:
                        self._past = outputs[self.args.past_index - 1]

        if prediction_loss_only:
            return (loss, None, None)

        logits = nested_detach(logits)
        if len(logits) == 1:
            logits = logits[0]
        import pdb;pdb.set_trace()
        return (loss, logits, labels)
    def compute_loss(self, model, inputs, return_outputs=False):
        # DDP will give us model.module
        if hasattr(model, "module"):
            medusa = model.module.medusa
        else:
            medusa = model.medusa

        logits = model(
            input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]
        )
        logits =logits['logits']
        
        labels = inputs["labels"]
        # Shift so that tokens < n predict n
        loss = 0
        loss_fct =CrossEntropyLoss()
        log = {}
        #logits = torch.clamp(logits, min=1e-7, max=100 - 1e-7)
        for i in range(medusa):
            #########修改后#######
            # medusa_logits = logits[i, :, : -1].contiguous()
            
            # medusa_labels = labels[...,  2:].contiguous()
            ######原medusa#########
            
            medusa_logits = logits[i, :, : ].contiguous()
            
            medusa_labels = labels[...,  4:].contiguous()
            medusa_logits = medusa_logits.view(-1, logits.shape[-1])
            medusa_labels = medusa_labels.view(-1)
            
            medusa_labels = medusa_labels.to(medusa_logits.device)
            
            #medusa_logits = torch.clamp(medusa_logits, min=1e-7, max=100 - 1e-7)
           
            loss_i = loss_fct(medusa_logits, medusa_labels)
            loss += loss_i
            not_ignore = medusa_labels.ne(IGNORE_TOKEN_ID)
            medusa_labels = medusa_labels[not_ignore]

            # Add top-k accuracy
            for k in range(1, 6):
                _, topk = medusa_logits.topk(k, dim=-1)
                topk = topk[not_ignore]
                correct = topk.eq(medusa_labels.unsqueeze(-1)).any(-1)
                log[f"medusa{i}_top{k}"] = correct.float().mean().item()
        
            
            log[f"medusa{i}_loss"] = loss_i.item()
            #log[f"medusa{i}_loss_7"] = loss_i_7.item()
        self.log(log)
        return (loss, logits) if return_outputs else loss
    

 
    # def compute_metrics(pred):
    #     labels,logits = pred.label_ids
    #     logits = pred.predictions
    #     medusa_logits = logits[i, :, : -1].contiguous()
            
    #     medusa_labels = labels[...,  1:].contiguous()
    #     medusa_logits = medusa_logits.view(-1, logits.shape[-1])
    #     medusa_labels = medusa_labels.view(-1)
        
    #     medusa_labels = medusa_labels.to(medusa_logits.device)
        
    #     #medusa_logits = torch.clamp(medusa_logits, min=1e-7, max=1 - 1e-7)
       
    #     loss_i = loss_fct(medusa_logits, medusa_labels)
    #     loss += loss_i
    #     not_ignore = medusa_labels.ne(IGNORE_TOKEN_ID)
    #     medusa_labels = medusa_labels[not_ignore]

    #     # Add top-k accuracy
    #     for k in range(1, 6):
    #         _, topk = medusa_logits.topk(k, dim=-1)
    #         topk = topk[not_ignore]
    #         correct = topk.eq(medusa_labels.unsqueeze(-1)).any(-1)
    #         log[f"medusa{i}_top{k}"] = correct.float().mean().item()
    #         res[f"medusa{i}_top{k}"] = correct.float().mean().item()
    
        
    #     log[f"medusa{i}_loss"] = loss_i.item()

        
    #     return log


@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="lmsys/vicuna-7b-v1.3")
    load_in_4bit: bool = field(
        default=False,
        metadata={"help": "Load in 4 bit."},
    )
    load_in_8bit: bool = field(
        default=False,
        metadata={"help": "Load in 8 bit."},
    )


@dataclass
class DataArguments:
    data_path: str = field(
        default="sharegpt_clean.json",
        metadata={"help": "Path to the training data."},
    )
    eval_data_path: str = field(
        default=None, metadata={"help": "Path to the evaluation data."}
    )
    lazy_preprocess: bool = True


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(
        default=2048,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    medusa_num_heads: int = field(
        default=1,
        metadata={"help": "Number of Medusa heads."},
    )
    medusa_num_layers: int = field(
        default=1,
        metadata={"help": "Number of layers for each Medusa head."},
    )


local_rank = None


def rank0_print(*args):
    if local_rank == 0:
        print(*args)


def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
    """Collects the state dict and dump to disk."""
    state_dict = trainer.model.state_dict()
    if trainer.args.should_save:
        cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
        del state_dict
        trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa


def preprocess(
    sources,
    tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
    conv = get_conversation_template("vicuna")
    roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

    # Apply prompt templates
    conversations = []
    for i, source in enumerate(sources):
        if roles[source[0]["from"]] != conv.roles[0]:
            # Skip the first one if it is not from human
            source = source[1:]

        conv.messages = []
        for j, sentence in enumerate(source):
            role = roles[sentence["from"]]
            assert role == conv.roles[j % 2], f"{i}, {j}, {role}, {conv.roles[j % 2]}"
            conv.append_message(role, sentence["value"])
        conversations.append(conv.get_prompt())

    # Tokenize conversations
    input_ids = tokenizer(
        conversations,
        return_tensors="pt",
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
    ).input_ids
    targets = input_ids.clone()

    assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO

    # Mask targets. Only compute loss on the assistant outputs.
    sep = conv.sep + conv.roles[1] + ": "
    for conversation, target in zip(conversations, targets):
        total_len = int(target.ne(tokenizer.pad_token_id).sum())

        turns = conversation.split(conv.sep2)
        cur_len = 1
        target[:cur_len] = IGNORE_TOKEN_ID
        for i, turn in enumerate(turns):
            if turn == "":
                break
            turn_len = len(tokenizer(turn).input_ids)

            parts = turn.split(sep)
            if len(parts) != 2:
                break
            parts[0] += sep
            # "-2" is hardcoded for the LLaMA tokenizer to make the offset correct.
            instruction_len = len(tokenizer(parts[0]).input_ids) - 2

            # Ignore the user instructions
            target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID
            cur_len += turn_len

        target[cur_len:] = IGNORE_TOKEN_ID

        if False:  # Inspect and check the correctness of masking
            z = target.clone()
            z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
            rank0_print(tokenizer.decode(z))

        if cur_len < tokenizer.model_max_length:
            if cur_len != total_len:
                target[:] = IGNORE_TOKEN_ID
                rank0_print(
                    f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
                    f" (ignored)"
                )

    return dict(
        input_ids=input_ids,
        labels=targets,
        attention_mask=input_ids.ne(tokenizer.pad_token_id),
    )


class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer):
        super(SupervisedDataset, self).__init__()

        rank0_print("Formatting inputs...")
        sources = [example["conversations"] for example in raw_data]
        data_dict = preprocess(sources, tokenizer)

        self.input_ids = data_dict["input_ids"]
        self.labels = data_dict["labels"]
        self.attention_mask = data_dict["attention_mask"]

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        return dict(
            input_ids=self.input_ids[i],
            labels=self.labels[i],
            attention_mask=self.attention_mask[i],
        )


class LazySupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer):
        super(LazySupervisedDataset, self).__init__()
        self.tokenizer = tokenizer

        rank0_print("Formatting inputs...Skip in lazy mode")
        self.tokenizer = tokenizer
        self.raw_data = raw_data
        self.cached_data_dict = {}

    def __len__(self):
        return len(self.raw_data)

    def __getitem__(self, i) -> Dict[str, torch.Tensor]:
        if i in self.cached_data_dict:
            return self.cached_data_dict[i]

        ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer)
        ret = dict(
            input_ids=ret["input_ids"][0],
            labels=ret["labels"][0],
            attention_mask=ret["attention_mask"][0],
        )
        self.cached_data_dict[i] = ret

        return ret


def make_supervised_data_module(
    tokenizer: transformers.PreTrainedTokenizer, data_args
) -> Dict:
    """Make dataset and collator for supervised fine-tuning."""
    dataset_cls = (
        LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset
    )
    rank0_print("Loading data...")

    train_json = json.load(open(data_args.data_path, "r"))
    train_dataset = dataset_cls(train_json, tokenizer=tokenizer)

    if data_args.eval_data_path:
        eval_json = json.load(open(data_args.eval_data_path, "r"))
        eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer)
    else:
        eval_dataset = None

    return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)

In [11]:
# def train():
global local_rank
import transformers 
#from transformers import TrainingArguments

training_args = TrainingArguments(
    
    local_rank=0,
    model_max_length=1024 ,
    medusa_num_heads = 1 ,
    medusa_num_layers =  1 ,
    output_dir= './test', 
    num_train_epochs=1,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    evaluation_strategy="steps",
    eval_steps = 1 ,
    save_strategy="no",
    learning_rate=1e-3, 
    weight_decay=0.0,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    logging_steps=1,
    fp16=True, #对应--bf16
    tf32=True,
    
)
#from transformers import DataArguments

data_args = DataArguments(
    data_path="../../../../../data/ShareGPT_Vicuna_unfiltered/small_test.json",
    eval_data_path="../../../../../data/ShareGPT_Vicuna_unfiltered/small_test.json",
    lazy_preprocess= True 
)
#from transformers import ModelArguments

model_args = ModelArguments(
    
    model_name_or_path="../../../../../model/TinyLlama-1.1B-Chat-v0.6",
    #model_max_length=2048,
    #lazy_preprocess=True,
    # medusa_num_heads=3,
    # medusa_num_layers=1
)

local_rank =0 # training_args.local_rank

# Set RoPE scaling factor
config = transformers.AutoConfig.from_pretrained(
    model_args.model_name_or_path,
    cache_dir=training_args.cache_dir,
)
orig_ctx_len = getattr(config, "max_position_embeddings", None)
if orig_ctx_len and training_args.model_max_length > orig_ctx_len:
    scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
    config.rope_scaling = {"type": "linear", "factor": scaling_factor}
config.use_cache = False

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

In [4]:
import torch.nn as nn

In [5]:
model = transformers.AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        cache_dir=training_args.cache_dir,
        low_cpu_mem_usage=True,
        torch_dtype=torch.bfloat16,
        quantization_config=quantization_config if model_args.load_in_4bit else None,
        load_in_4bit=model_args.load_in_4bit,
        load_in_8bit=model_args.load_in_8bit,
    )


Some weights of LlamaForCausalLM were not initialized from the model checkpoint at ../../../../../model/TinyLlama-1.1B-Chat-v0.6 and are newly initialized: ['model.layers.8.self_attn.rotary_emb.inv_freq', 'model.layers.3.self_attn.rotary_emb.inv_freq', 'model.layers.19.self_attn.rotary_emb.inv_freq', 'model.layers.9.self_attn.rotary_emb.inv_freq', 'model.layers.1.self_attn.rotary_emb.inv_freq', 'model.layers.18.self_attn.rotary_emb.inv_freq', 'model.layers.7.self_attn.rotary_emb.inv_freq', 'model.layers.21.self_attn.rotary_emb.inv_freq', 'model.layers.12.self_attn.rotary_emb.inv_freq', 'model.layers.13.self_attn.rotary_emb.inv_freq', 'model.layers.17.self_attn.rotary_emb.inv_freq', 'model.layers.0.self_attn.rotary_emb.inv_freq', 'model.layers.14.self_attn.rotary_emb.inv_freq', 'model.layers.20.self_attn.rotary_emb.inv_freq', 'model.layers.11.self_attn.rotary_emb.inv_freq', 'model.layers.4.self_attn.rotary_emb.inv_freq', 'model.layers.2.self_attn.rotary_emb.inv_freq', 'model.layers.15.s

In [6]:
import copy

#model_name = '../../../../idea6_3fastlayer_t1_skipbert_medusa_mlp_vicuna-7b-v1.3_medusa_1_lr_0.0001_layers_1'
#medusa_lm_head = MedusaModel.from
# for param in medusa_lm_head.base_model.parameters():
#         param.requires_grad = False
medusa_lm_head = MedusaModel(
        model,
        medusa_num_heads=training_args.medusa_num_heads,
        medusa_num_layers=training_args.medusa_num_layers,
        base_model_name_or_path=model_args.model_name_or_path
    )
training_args.output_dir = f"{training_args.output_dir}_medusa_mlp_{model_args.model_name_or_path.split('/')[-1]}_medusa_{training_args.medusa_num_heads}_lr_{training_args.learning_rate}_layers_{training_args.medusa_num_layers}"

tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_args.model_name_or_path,
    cache_dir=training_args.cache_dir,
    model_max_length=training_args.model_max_length,
    padding_side="right",
    use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token

path:  ../../../../../model/TinyLlama-1.1B-Chat-v0.6


In [10]:
medusa_lm_head.fast_layer1

Sequential(
  (0): ResBlock(
    (linear): Linear(in_features=6144, out_features=6144, bias=True)
    (act): SiLU()
  )
  (1): Linear(in_features=6144, out_features=2048, bias=False)
)

In [11]:
inputs = tokenizer(["a a a a a","are are are how are"])

In [12]:
inputid = torch.tensor(inputs['input_ids'])

In [13]:
res = medusa_lm_head.base_model.model(input_ids= inputid,attention_mask =   torch.tensor(inputs['attention_mask']))

In [34]:
inputs['attention_mask']

[[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]

In [41]:
inputs['attention_mask'][:,0]

TypeError: list indices must be integers or slices, not tuple

In [11]:
embed = medusa_lm_head.base_model.model.embed_tokens(inputid)
embedtrigram = torch.cat((embed[:,:-2],embed[:,1:-1],embed[:,2:]),dim=-1)
embed = medusa_lm_head.fast_layer1(embedtrigram )
from modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask
batch_size, seq_length = embed.shape[:2]
attention_mask = _prepare_4d_causal_attention_mask(
                 torch.tensor(inputs['attention_mask'])[:,:-2], (batch_size, seq_length), embed, 0
            )
position_ids = torch.arange(
                0, seq_length , dtype=torch.long
            )
position_ids = position_ids.unsqueeze(0)
output2 = medusa_lm_head.base_model.model.layers[0](embed,attention_mask = attention_mask,position_ids= position_ids)

NameError: name 'medusa_lm_head' is not defined

In [44]:
output2

(tensor([[[-0.0396,  0.0388,  0.0359,  ...,  0.0176, -0.0042,  0.0237],
          [-0.0391,  0.0610,  0.0240,  ...,  0.0157,  0.0354,  0.0288],
          [-0.0439,  0.0684,  0.0267,  ...,  0.0222,  0.0454,  0.0282],
          [-0.0469,  0.0732,  0.0276,  ...,  0.0258,  0.0498,  0.0276]],
 
         [[-0.0024,  0.1211, -0.0601,  ...,  0.0344, -0.0192,  0.0586],
          [ 0.0239,  0.0835, -0.0537,  ...,  0.0137, -0.0084,  0.0361],
          [ 0.0698,  0.0540, -0.0087,  ...,  0.0383, -0.0181,  0.0292],
          [ 0.0183,  0.0344, -0.0505,  ...,  0.0483,  0.0150, -0.0181]]],
        dtype=torch.bfloat16, grad_fn=<AddBackward0>),)

In [27]:
embedtrigram = torch.cat((embed[:,:-2],embed[:,1:-1],embed[:,2:]),dim=-1)

In [28]:
embedtrigram.shape

torch.Size([2, 4, 6144])

In [29]:
medusa_lm_head.fast_layer1 =medusa_lm_head.fast_layer1.to(medusa_lm_head.base_model.dtype)

In [30]:
output3 = medusa_lm_head.fast_layer1(embedtrigram )

In [31]:
output3.shape

torch.Size([2, 4, 2048])

In [21]:
output2[0].shape

torch.Size([2, 6, 2048])

In [22]:
hs = embed.unsqueeze(0)

In [23]:
embed[0][:-2].shape,embed[0][1:-1].shape,embed[0][2:].shape

(torch.Size([4, 2048]), torch.Size([4, 2048]), torch.Size([4, 2048]))

In [24]:
inputid

tensor([[  1, 263, 263, 263, 263, 263],
        [  1, 526, 526, 526, 920, 526]])

In [25]:
inputid[0,:-2]

tensor([  1, 263, 263, 263])

In [26]:
from  torch.nn import  MSELoss

In [27]:
loss_fct =MSELoss(reduction='mean')

In [15]:
medusa_lm_head.base_model.model.norm

LlamaRMSNorm()

In [17]:
import torch

In [18]:
newhs = torch.cat((inputid[:,:-2].unsqueeze(0),inputid[:,1:-1].unsqueeze(0),inputid[:,2:].unsqueeze(0)),dim=0)

In [19]:
newhs = torch.transpose(newhs,dim0=0,dim1=2)

In [20]:
newhs = torch.transpose(newhs,dim0=0,dim1=1)

In [60]:
newhs= torch.flatten(newhs,end_dim=1)

In [61]:
newhs1 =  medusa_lm_head.base_model.model(input_ids= newhs)

In [65]:
newhs1 = newhs1[0]

In [66]:
newhs1 = newhs1.view((batch_size,-1,3,2048))

In [68]:
newhs2 =newhs1[:,:,-1,:]

In [77]:
newhs2.shape

torch.Size([2, 4, 2048])

In [20]:
attention_mask2 =torch.full((seq_length, seq_length), -3.4028e+38) + torch.diag(torch.zeros(seq_length)+3.4028e+38-1)
attention_mask3 = torch.cat((attention_mask[0,0,:,:],attention_mask2),dim=-1)
attention_mask3 = torch.cat((attention_mask3[:,:-1],attention_mask3[0:-1,:-1]),dim=-2).unsqueeze(0).unsqueeze(0)
attention_mask3 = attention_mask3.repeat([batch_size,1,1,1])

In [21]:
position_ids = torch.arange(0, seq_length, dtype=torch.long )
position_ids2 = torch.arange(1, seq_length , dtype=torch.long)
position_ids2 = torch.cat((position_ids,position_ids2),dim=-1).unsqueeze(0)

In [22]:
embed2 = torch.cat((res[0],embed[:,1:]),dim=-2)

In [23]:
output3 = medusa_lm_head.base_model.model.layers[0](embed2,attention_mask = attention_mask3,position_ids= position_ids2)

In [24]:
output3[0][:,-seq_length+1:].shape

torch.Size([2, 4, 2048])

In [7]:
# Load data
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
#compute metrics
def compute_metrics(pred):
        logits,labels = pred
        #logits = pred.predictions
        #print(logits.shape)
        import pdb; pdb.set_trace()
        #print(labels.shape)
        medusa_logits = logits[0,:, : ].contiguous()
            
        medusa_labels = labels[...,  1:].contiguous()
        medusa_logits = medusa_logits.view(-1, logits.shape[-1])
        medusa_labels = medusa_labels.view(-1)
        
        medusa_labels = medusa_labels.to(medusa_logits.device)
        
        #medusa_logits = torch.clamp(medusa_logits, min=1e-7, max=1 - 1e-7)
       
        loss_i = loss_fct(medusa_logits, medusa_labels)
        loss += loss_i
        not_ignore = medusa_labels.ne(IGNORE_TOKEN_ID)
        medusa_labels = medusa_labels[not_ignore]

        # Add top-k accuracy
        for k in range(1, 6):
            _, topk = medusa_logits.topk(k, dim=-1)
            topk = topk[not_ignore]
            correct = topk.eq(medusa_labels.unsqueeze(-1)).any(-1)
            log[f"medusa{i}_top{k}"] = correct.float().mean().item()
            #res[f"medusa{i}_top{k}"] = correct.float().mean().item()
    
        
        log[f"medusa{i}_loss"] = loss_i.item()

        
        return log
# Generate Medusa config for pushing to HF hub
medusa_config = MedusaConfig(
    medusa_num_heads=training_args.medusa_num_heads,
    medusa_num_layers=training_args.medusa_num_layers,
    base_model_name_or_path=model_args.model_name_or_path,
)

# Save Medusa config
medusa_config.save_pretrained(training_args.output_dir)

# import pdb; pdb.set_trace()
# Start trainner
trainer = CustomizedTrainer(
    model=medusa_lm_head, tokenizer=tokenizer, args=training_args,compute_metrics = compute_metrics , **data_module
)

Loading data...
Formatting inputs...Skip in lazy mode
Formatting inputs...Skip in lazy mode


In [8]:
trainer.evaluate()

  return F.mse_loss(input, target, reduction=self.reduction)
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33myu13668962105[0m. Use [1m`wandb login --relogin`[0m to force relogin


> [0;32m/tmp/ipykernel_225059/3234715425.py[0m(125)[0;36mprediction_step[0;34m()[0m
[0;32m    123 [0;31m                    [0mloss[0m [0;34m=[0m [0mloss[0m[0;34m.[0m[0mmean[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0mdetach[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    124 [0;31m                    [0;32mimport[0m [0mpdb[0m[0;34m;[0m[0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 125 [0;31m                    [0;32mif[0m [0misinstance[0m[0;34m([0m[0moutputs[0m[0;34m,[0m [0mdict[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    126 [0;31m                        [0mlogits[0m [0;34m=[0m [0mtuple[0m[0;34m([0m[0mv[0m [0;32mfor[0m [0mk[0m[0;34m,[0m [0mv[0m [0;32min[0m [0moutputs[0m[0;34m.[0m[0mitems[0m[0;34m([0m[0;34m)[0m [0;32mif[0m [0mk[0m [0;32mnot[0m [0;32min[0m [0mignore_keys[0m [0;34m+[0m [0;34m[[0m[0;34m"lo

ipdb>  continue


> [0;32m/tmp/ipykernel_225059/3234715425.py[0m(149)[0;36mprediction_step[0;34m()[0m
[0;32m    147 [0;31m            [0mlogits[0m [0;34m=[0m [0mlogits[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    148 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m[0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 149 [0;31m        [0;32mreturn[0m [0;34m([0m[0mloss[0m[0;34m,[0m [0mlogits[0m[0;34m,[0m [0mlabels[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    150 [0;31m    [0;32mdef[0m [0mcompute_loss[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mmodel[0m[0;34m,[0m [0minputs[0m[0;34m,[0m [0mreturn_outputs[0m[0;34m=[0m[0;32mFalse[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    151 [0;31m        [0;31m# DDP will give us model.module[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  continue


  return F.mse_loss(input, target, reduction=self.reduction)


> [0;32m/tmp/ipykernel_225059/3234715425.py[0m(125)[0;36mprediction_step[0;34m()[0m
[0;32m    123 [0;31m                    [0mloss[0m [0;34m=[0m [0mloss[0m[0;34m.[0m[0mmean[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0mdetach[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    124 [0;31m                    [0;32mimport[0m [0mpdb[0m[0;34m;[0m[0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 125 [0;31m                    [0;32mif[0m [0misinstance[0m[0;34m([0m[0moutputs[0m[0;34m,[0m [0mdict[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    126 [0;31m                        [0mlogits[0m [0;34m=[0m [0mtuple[0m[0;34m([0m[0mv[0m [0;32mfor[0m [0mk[0m[0;34m,[0m [0mv[0m [0;32min[0m [0moutputs[0m[0;34m.[0m[0mitems[0m[0;34m([0m[0;34m)[0m [0;32mif[0m [0mk[0m [0;32mnot[0m [0;32min[0m [0mignore_keys[0m [0;34m+[0m [0;34m[[0m[0;34m"lo

ipdb>  continue


> [0;32m/tmp/ipykernel_225059/3234715425.py[0m(149)[0;36mprediction_step[0;34m()[0m
[0;32m    147 [0;31m            [0mlogits[0m [0;34m=[0m [0mlogits[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    148 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m[0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 149 [0;31m        [0;32mreturn[0m [0;34m([0m[0mloss[0m[0;34m,[0m [0mlogits[0m[0;34m,[0m [0mlabels[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    150 [0;31m    [0;32mdef[0m [0mcompute_loss[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mmodel[0m[0;34m,[0m [0minputs[0m[0;34m,[0m [0mreturn_outputs[0m[0;34m=[0m[0;32mFalse[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    151 [0;31m        [0;31m# DDP will give us model.module[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  logits


tensor([[[[-21.2656, -21.9531,  -1.5957,  ...,  -8.6406, -17.3125,  -4.3594],
          [-24.5625, -24.2188,   1.5781,  ..., -10.2344, -21.6562,  -9.1328],
          [-30.2188, -30.5156,   1.1182,  ..., -12.5938, -24.3281, -11.6328],
          ...,
          [-27.2656, -28.1562,   3.1309,  ..., -19.9219, -21.2969, -21.7188],
          [-25.9688, -25.6875,   2.1016,  ..., -20.0312, -20.9688, -20.8906],
          [-25.5312, -26.1562,   3.0371,  ..., -21.1875, -19.6562, -21.0625]],

         [[-21.2656, -21.9531,  -1.5957,  ...,  -8.6406, -17.3125,  -4.3594],
          [-24.5625, -24.2188,   1.5781,  ..., -10.2344, -21.6562,  -9.1328],
          [-30.2188, -30.5156,   1.1182,  ..., -12.5938, -24.3281, -11.6328],
          ...,
          [-25.8438, -26.1406,   9.1484,  ..., -13.5547, -22.6094, -18.2812],
          [-25.4219, -25.8594,   8.4062,  ..., -13.3906, -22.2969, -17.8906],
          [-25.7188, -26.7188,   8.1641,  ..., -13.7812, -22.2656, -17.5781]],

         [[-21.2656, -21.9531,

ipdb>  continue


> [0;32m/tmp/ipykernel_225059/3234715425.py[0m(125)[0;36mprediction_step[0;34m()[0m
[0;32m    123 [0;31m                    [0mloss[0m [0;34m=[0m [0mloss[0m[0;34m.[0m[0mmean[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0mdetach[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    124 [0;31m                    [0;32mimport[0m [0mpdb[0m[0;34m;[0m[0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 125 [0;31m                    [0;32mif[0m [0misinstance[0m[0;34m([0m[0moutputs[0m[0;34m,[0m [0mdict[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    126 [0;31m                        [0mlogits[0m [0;34m=[0m [0mtuple[0m[0;34m([0m[0mv[0m [0;32mfor[0m [0mk[0m[0;34m,[0m [0mv[0m [0;32min[0m [0moutputs[0m[0;34m.[0m[0mitems[0m[0;34m([0m[0;34m)[0m [0;32mif[0m [0mk[0m [0;32mnot[0m [0;32min[0m [0mignore_keys[0m [0;34m+[0m [0;34m[[0m[0;34m"lo

ipdb>  continue


> [0;32m/tmp/ipykernel_225059/3234715425.py[0m(149)[0;36mprediction_step[0;34m()[0m
[0;32m    147 [0;31m            [0mlogits[0m [0;34m=[0m [0mlogits[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    148 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m[0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 149 [0;31m        [0;32mreturn[0m [0;34m([0m[0mloss[0m[0;34m,[0m [0mlogits[0m[0;34m,[0m [0mlabels[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    150 [0;31m    [0;32mdef[0m [0mcompute_loss[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mmodel[0m[0;34m,[0m [0minputs[0m[0;34m,[0m [0mreturn_outputs[0m[0;34m=[0m[0;32mFalse[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    151 [0;31m        [0;31m# DDP will give us model.module[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  continue


  return F.mse_loss(input, target, reduction=self.reduction)


> [0;32m/tmp/ipykernel_225059/3234715425.py[0m(125)[0;36mprediction_step[0;34m()[0m
[0;32m    123 [0;31m                    [0mloss[0m [0;34m=[0m [0mloss[0m[0;34m.[0m[0mmean[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0mdetach[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    124 [0;31m                    [0;32mimport[0m [0mpdb[0m[0;34m;[0m[0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 125 [0;31m                    [0;32mif[0m [0misinstance[0m[0;34m([0m[0moutputs[0m[0;34m,[0m [0mdict[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    126 [0;31m                        [0mlogits[0m [0;34m=[0m [0mtuple[0m[0;34m([0m[0mv[0m [0;32mfor[0m [0mk[0m[0;34m,[0m [0mv[0m [0;32min[0m [0moutputs[0m[0;34m.[0m[0mitems[0m[0;34m([0m[0;34m)[0m [0;32mif[0m [0mk[0m [0;32mnot[0m [0;32min[0m [0mignore_keys[0m [0;34m+[0m [0;34m[[0m[0;34m"lo

ipdb>  continue


> [0;32m/tmp/ipykernel_225059/3234715425.py[0m(149)[0;36mprediction_step[0;34m()[0m
[0;32m    147 [0;31m            [0mlogits[0m [0;34m=[0m [0mlogits[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    148 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m[0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 149 [0;31m        [0;32mreturn[0m [0;34m([0m[0mloss[0m[0;34m,[0m [0mlogits[0m[0;34m,[0m [0mlabels[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    150 [0;31m    [0;32mdef[0m [0mcompute_loss[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mmodel[0m[0;34m,[0m [0minputs[0m[0;34m,[0m [0mreturn_outputs[0m[0;34m=[0m[0;32mFalse[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    151 [0;31m        [0;31m# DDP will give us model.module[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  continue


> [0;32m/tmp/ipykernel_225059/775524108.py[0m(10)[0;36mcompute_metrics[0;34m()[0m
[0;32m      8 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m        [0;31m#print(labels.shape)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m        [0mmedusa_logits[0m [0;34m=[0m [0mlogits[0m[0;34m[[0m[0;36m0[0m[0;34m,[0m[0;34m:[0m[0;34m,[0m [0;34m:[0m [0;34m][0m[0;34m.[0m[0mcontiguous[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m[0;34m[0m[0m
[0m[0;32m     12 [0;31m        [0mmedusa_labels[0m [0;34m=[0m [0mlabels[0m[0;34m[[0m[0;34m...[0m[0;34m,[0m  [0;36m1[0m[0;34m:[0m[0;34m][0m[0;34m.[0m[0mcontiguous[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  pred.preditions


*** AttributeError: 'EvalPrediction' object has no attribute 'preditions'


ipdb>  pred.predictions


array([[[[ -21.265625 ,  -21.953125 ,   -1.5957031, ...,   -8.640625 ,
           -17.3125   ,   -4.359375 ],
         [ -24.5625   ,  -24.21875  ,    1.578125 , ...,  -10.234375 ,
           -21.65625  ,   -9.1328125],
         [ -30.21875  ,  -30.515625 ,    1.1181641, ...,  -12.59375  ,
           -24.328125 ,  -11.6328125],
         ...,
         [ -29.734375 ,  -28.953125 ,    7.8945312, ...,  -19.296875 ,
           -29.40625  ,  -20.015625 ],
         [ -29.84375  ,  -29.09375  ,    8.0078125, ...,  -19.28125  ,
           -29.546875 ,  -20.40625  ],
         [ -29.734375 ,  -29.171875 ,    8.3671875, ...,  -19.46875  ,
           -29.8125   ,  -20.203125 ]],

        [[ -21.265625 ,  -21.953125 ,   -1.5957031, ...,   -8.640625 ,
           -17.3125   ,   -4.359375 ],
         [ -24.5625   ,  -24.21875  ,    1.578125 , ...,  -10.234375 ,
           -21.65625  ,   -9.1328125],
         [ -30.21875  ,  -30.515625 ,    1.1181641, ...,  -12.59375  ,
           -24.328125 ,  -11.6328

ipdb>  pred.predictions.shape


(4, 4, 1020, 32000)


ipdb>  pred.labels.shape


*** AttributeError: 'EvalPrediction' object has no attribute 'labels'


ipdb>  pred.label.shape


*** AttributeError: 'EvalPrediction' object has no attribute 'label'


ipdb>  pred.label_ids


array([[ -100,  -100,  -100, ...,  -100,  -100,  -100],
       [ -100,  -100,  -100, ...,  -100,  -100,  -100],
       [ -100,  -100,  -100, ...,   893, 29874,   313],
       ...,
       [ -100,  -100,  -100, ..., 29879,  1950,   393],
       [ -100,  -100,  -100, ...,  -100,  -100,  -100],
       [ -100,  -100,  -100, ...,   890,    13,    13]])


ipdb>  pred.label_ids.shape


(13, 1024)


ipdb>  pred.predictions.shape


(4, 4, 1020, 32000)


ipdb>  continue


AttributeError: 'numpy.ndarray' object has no attribute 'contiguous'

In [37]:
trainer.predict(data_module['eval_dataset'])

> [0;32m/tmp/ipykernel_220651/775524108.py[0m(10)[0;36mcompute_metrics[0;34m()[0m
[0;32m      8 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m        [0;31m#print(labels.shape)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m        [0mmedusa_logits[0m [0;34m=[0m [0mlogits[0m[0;34m[[0m[0;36m0[0m[0;34m,[0m[0;34m:[0m[0;34m,[0m [0;34m:[0m [0;34m][0m[0;34m.[0m[0mcontiguous[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m[0;34m[0m[0m
[0m[0;32m     12 [0;31m        [0mmedusa_labels[0m [0;34m=[0m [0mlabels[0m[0;34m[[0m[0;34m...[0m[0;34m,[0m  [0;36m1[0m[0;34m:[0m[0;34m][0m[0;34m.[0m[0mcontiguous[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  pred.predictions


array([], shape=(0, 1, 4, 1020, 32000), dtype=float32)
--KeyboardInterrupt--

KeyboardInterrupt: Interrupted by user



KeyboardInterrupt



In [43]:
trainer.args.use_legacy_prediction_loop=False

In [None]:
trainer.evaluate()

> [0;32m/tmp/ipykernel_220651/775524108.py[0m(10)[0;36mcompute_metrics[0;34m()[0m
[0;32m      8 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m        [0;31m#print(labels.shape)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m        [0mmedusa_logits[0m [0;34m=[0m [0mlogits[0m[0;34m[[0m[0;36m0[0m[0;34m,[0m[0;34m:[0m[0;34m,[0m [0;34m:[0m [0;34m][0m[0;34m.[0m[0mcontiguous[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m[0;34m[0m[0m
[0m[0;32m     12 [0;31m        [0mmedusa_labels[0m [0;34m=[0m [0mlabels[0m[0;34m[[0m[0;34m...[0m[0;34m,[0m  [0;36m1[0m[0;34m:[0m[0;34m][0m[0;34m.[0m[0mcontiguous[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


In [40]:
trainer.evaluation_loop

<bound method Trainer.evaluation_loop of <__main__.CustomizedTrainer object at 0x7ff0fc6ba6e0>>

AttributeError: 'CustomizedTrainer' object has no attribute 'gather_for_metrics'

In [13]:
trainer.train()

Step,Training Loss,Validation Loss


TypeError: iteration over a 0-d tensor

In [63]:
class CustomizedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=True):
        # DDP will give us model.module
        if hasattr(model, "module"):
            medusa = model.module.medusa
        else:
            medusa = model.medusa

        logits = model(
            input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]
        )
        
        labels = inputs["labels"]
        # Shift so that tokens < n predict n
        loss = 0
        loss_fct = CrossEntropyLoss()
        log = {}
        for i in range(medusa):
            
            medusa_logits = logits[i, :, : -1].contiguous()
            
            medusa_labels = labels[...,  2:].contiguous()
            medusa_logits = medusa_logits.view(-1, logits.shape[-1])
            medusa_labels = medusa_labels.view(-1)
            
            medusa_labels = medusa_labels.to(medusa_logits.device)
            
            #medusa_logits = torch.clamp(medusa_logits, min=1e-7, max=1 - 1e-7)
           
            loss_i = loss_fct(medusa_logits, medusa_labels)
            loss += loss_i
            not_ignore = medusa_labels.ne(IGNORE_TOKEN_ID)
            medusa_labels = medusa_labels[not_ignore]

            # Add top-k accuracy
            for k in range(1, 6):
                _, topk = medusa_logits.topk(k, dim=-1)
                topk = topk[not_ignore]
                correct = topk.eq(medusa_labels.unsqueeze(-1)).any(-1)
                log[f"medusa{i}_top{k}"] = correct.float().mean().item()
            log[f"medusa{i}_loss"] = loss_i.item()
            #log[f"medusa{i}_loss_7"] = loss_i_7.item()
        #import pdb; pdb.set_trace()
        self.log(log)
        
        return (loss, logits) #if return_outputs else loss
def compute_metrics(pred):
        
        labels = pred.label_ids
        logits = pred.predictions
        #import pdb; pdb.set_trace()
        print(pred)
        medusa_logits = logits[0, :, : -1].contiguous()
            
        medusa_labels = labels[...,  2:].contiguous()
        medusa_logits = medusa_logits.view(-1, logits.shape[-1])
        medusa_labels = medusa_labels.view(-1)
        
        medusa_labels = medusa_labels.to(medusa_logits.device)
        
        #medusa_logits = torch.clamp(medusa_logits, min=1e-7, max=1 - 1e-7)
       
        loss_i = loss_fct(medusa_logits, medusa_labels)
        loss += loss_i
        not_ignore = medusa_labels.ne(IGNORE_TOKEN_ID)
        medusa_labels = medusa_labels[not_ignore]

        # Add top-k accuracy
        for k in range(1, 6):
            _, topk = medusa_logits.topk(k, dim=-1)
            topk = topk[not_ignore]
            correct = topk.eq(medusa_labels.unsqueeze(-1)).any(-1)
            log[f"medusa{i}_top{k}"] = correct.float().mean().item()
            res[f"medusa{i}_top{k}"] = correct.float().mean().item()
    
        
        log[f"medusa{i}_loss"] = loss_i.item()

        
        return log
global local_rank
import transformers 
#from transformers import TrainingArguments

training_args = TrainingArguments(
    
    local_rank=0,
    model_max_length=1024 ,
    medusa_num_heads = 1 ,
    medusa_num_layers =  1 ,
    output_dir= './test', 
    num_train_epochs=1,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    evaluation_strategy="steps",
    eval_steps = 1 ,
    save_strategy="no",
    learning_rate=1e-3, 
    weight_decay=0.0,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    logging_steps=1,
    fp16=True, #对应--bf16
    tf32=True,
    
)
#from transformers import DataArguments

data_args = DataArguments(
    data_path="../../../../../data/ShareGPT_Vicuna_unfiltered/1280test.json",
    eval_data_path="../../../../../data/ShareGPT_Vicuna_unfiltered/1280test.json",
    lazy_preprocess= True 
)
#from transformers import ModelArguments

model_args = ModelArguments(
    
    model_name_or_path="../../../../../model/vicuna-7b-v1.3",
    #model_max_length=2048,
    #lazy_preprocess=True,
    # medusa_num_heads=3,
    # medusa_num_layers=1
)

local_rank =0 

training_args.output_dir = f"{training_args.output_dir}_medusa_mlp_{model_args.model_name_or_path.split('/')[-1]}_medusa_{training_args.medusa_num_heads}_lr_{training_args.learning_rate}_layers_{training_args.medusa_num_layers}"



# import pdb; pdb.set_trace()
# Start trainner
trainer = CustomizedTrainer(
    model=medusa_lm_head, tokenizer=tokenizer, args=training_args, compute_metrics = compute_metrics,**data_module
)

NameError: name 'data_module' is not defined

In [None]:
trainer.evaluate()

> [0;32m/tmp/ipykernel_13542/1513928275.py[0m(44)[0;36mcompute_loss[0;34m()[0m
[0;32m     42 [0;31m            [0;31m#log[f"medusa{i}_loss_7"] = loss_i_7.item()[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     43 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 44 [0;31m        [0mself[0m[0;34m.[0m[0mlog[0m[0;34m([0m[0mlog[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     45 [0;31m[0;34m[0m[0m
[0m[0;32m     46 [0;31m        [0;32mreturn[0m [0;34m([0m[0mloss[0m[0;34m,[0m [0mlogits[0m[0;34m)[0m [0;31m#if return_outputs else loss[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  loss.shape


torch.Size([])


ipdb>  logits.shape


torch.Size([1, 2, 1023, 32000])


ipdb>  logits


tensor([[[[ -0.9927, -20.5625,  -2.8359,  ...,   0.4290,   0.9609,  -1.2070],
          [-26.0625, -29.8906,  -3.9062,  ..., -33.0625, -17.0938, -24.0469],
          [ -6.2305,  -5.8984,   4.7070,  ...,  -0.7690,  -4.9609,  -5.0312],
          ...,
          [ -2.1953, -29.7812,  10.4688,  ...,   2.3438,  -3.8945,  -2.3164],
          [ -1.8789, -30.4688,  10.3984,  ...,   2.8223,  -3.4277,  -1.9736],
          [ -1.5889, -30.9844,  10.0469,  ...,   3.3086,  -2.8164,  -1.5332]],

         [[ -0.9927, -20.5625,  -2.8359,  ...,   0.4290,   0.9609,  -1.2070],
          [-26.0625, -29.8906,  -3.9062,  ..., -33.0625, -17.0938, -24.0469],
          [ -6.2305,  -5.8984,   4.7070,  ...,  -0.7690,  -4.9609,  -5.0312],
          ...,
          [ -5.5117,   0.7793,  21.2500,  ...,  -8.7656,  -9.6719,  -9.1328],
          [ -5.4023,   0.6934,  20.9844,  ...,  -8.6562,  -9.5625,  -8.9922],
          [ -5.2031,   0.4268,  20.6094,  ...,  -8.4609,  -9.4062,  -8.8203]]]],
       device='cuda:0')


ipdb>  loss


tensor(2.6962, device='cuda:0')


In [57]:
import json
with open("../../../../../data/ShareGPT_Vicuna_unfiltered/test.json", "r", encoding="utf-8") as f:
    content = json.load(f)

In [58]:
len(content)

6862

In [61]:
data = content[-1280:]

In [62]:
with open("../../../../../data/ShareGPT_Vicuna_unfiltered/1280test.json", "w", encoding="utf-8") as f:
    json.dump(data, f)

In [86]:
from datasets import load_dataset

dataset = load_dataset("../../../../../data/xsum")

Downloading and preparing dataset xsum/default to /home/liyunhao/.cache/huggingface/datasets/xsum/default/1.2.0/082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/1.00M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Dataset xsum downloaded and prepared to /home/liyunhao/.cache/huggingface/datasets/xsum/default/1.2.0/082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [87]:
dataset

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11334
    })
})