Skip to content

Commit

Permalink
fix GPTQ quantization?
Browse files Browse the repository at this point in the history
  • Loading branch information
lopuhin committed Dec 1, 2023
1 parent 3fc37e6 commit 86d990b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 25 deletions.
14 changes: 5 additions & 9 deletions GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,20 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import os, sys
lm_evaluation_harness_path = "/".join(
os.getcwd().split("/")[:-1] + ["lm-evaluation-harness"]
)
sys.path.insert(0, lm_evaluation_harness_path)
import main as lm_evaluation_harness_main

import torch.fx as fx
import torch.nn as nn
import torch.nn.functional as F
from torch.utils._pytree import tree_flatten, tree_unflatten

from eval import setup_cache_padded_seq_input_pos_max_seq_length_for_prefill
from generate import encode_tokens

aten = torch.ops.aten

try:
import lm_eval
import lm_eval.base
import lm_eval.tasks
import lm_eval.evaluator
class InputRecorder(lm_eval.base.BaseLM):
"""
This is a fake evaluation wrapper that just records the inputs
Expand Down Expand Up @@ -88,8 +83,9 @@ def device(self):
return self._device

def tok_encode(self, string: str):
from generate import encode_tokens
encoded = encode_tokens(
self._tokenizer, string, bos=True, eos=False, device=self._device
self._tokenizer, string, bos=True, device=self._device
)
# encoded is a pytorch tensor, but some internal logic in the
# eval harness expects it to be a list instead
Expand Down
22 changes: 6 additions & 16 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,10 @@
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from model import LLaMA
from model import Transformer as LLaMA
from sentencepiece import SentencePieceProcessor

# hacky path setup for lm-evaluation-harness
import os
import sys
lm_evaluation_harness_path = '/'.join(
os.getcwd().split('/')[:-1] + ['lm-evaluation-harness'])
sys.path.insert(0, lm_evaluation_harness_path)
import main as lm_evaluation_harness_main
import lm_eval

from generate import (
_load_model,
encode_tokens,
model_forward,
)

import lm_eval.base

def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
model: LLaMA,
Expand Down Expand Up @@ -116,6 +102,7 @@ def device(self):
return self._device

def tok_encode(self, string: str):
from generate import encode_tokens
encoded = encode_tokens(self._tokenizer,
string, bos=True, eos=False, device=self._device)
# encoded is a pytorch tensor, but some internal logic in the
Expand All @@ -130,6 +117,8 @@ def tok_decode(self, tokens):

def _model_call(self, inps):
# TODO: make batches work
from generate import model_forward

inps = inps.squeeze(0)

max_new_tokens = 1
Expand Down Expand Up @@ -205,6 +194,7 @@ def main(
max_seq_length (Optional[int]): The maximum sequence length allowed for input text.
"""
from generate import _load_model, model_forward

assert checkpoint_path.is_file(), checkpoint_path

Expand Down

0 comments on commit 86d990b

Please sign in to comment.