In [None]:
%env CUDA_VISIBLE_DEVICES=6

from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_path = "meta-llama/Meta-Llama-3.1-8B"
quant_config = {
  "zero_point": True,
  "q_group_size": 64,
  "w_bit": 4,
  "version": "GEMM",
}

# Load model
model = AutoAWQForCausalLM.from_pretrained(
  model_path, low_cpu_mem_usage=True, use_cache=False, device_map="cuda",
)
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Quantize
model.quantize(tokenizer, quant_config=quant_config)


In [7]:
def get_zero_shots(model, task_list = ('arc_easy',), num_fewshots=1):
    import lm_eval

    lm_eval_model = lm_eval.models.huggingface.HFLM(
        pretrained=model,
    )

    tasks = lm_eval.tasks.get_task_dict(task_list)
    if num_fewshots != 1:
        # TODO: make fewshots properly
        for task_name in tasks:
            task = tasks[task_name]
            if isinstance(task, tuple):
                task = task[1]
            if task is None:
                continue
            task.config.num_fewshot = num_fewshots

    results = lm_eval.evaluator.evaluate(
        lm=lm_eval_model,
        task_dict=tasks,
    )

    result_dict = {task_name: task_result['acc,none'] for task_name, task_result in results['results'].items()}
    result_err_dict = {f'{task_name}_err': task_result['acc_stderr,none'] for task_name, task_result in
                       results['results'].items()}
    result_dict = dict(list(result_dict.items()) + list(result_err_dict.items()))

    if num_fewshots != 1:
        result_dict = {f'{task_name}@{num_fewshots}': acc for task_name, acc in result_dict.items()}

    return result_dict

In [3]:
import torch
from torch import nn
from torch.nn import functional as F

from tqdm.auto import trange, tqdm

@torch.no_grad()
def llama_eval(model, dataloader):
    print('Evaluating ...')

    nsamples = len(dataloader) 

    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.layers

    model.model.embed_tokens = model.model.embed_tokens
    model.model.rotary_emb = model.model.rotary_emb
    layers[0] = layers[0]

    dtype = next(iter(model.parameters())).dtype
    inps = []
    attention_masks = []
    position_ids = []

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            inps.append(inp)
            attention_masks.append(kwargs['attention_mask'])
            position_ids.append(kwargs['position_ids'])
            raise ValueError
    layers[0] = Catcher(layers[0])
    for batch in dataloader:
        try:
            model(batch.to("cuda"))
        except ValueError:
            pass
    layers[0] = layers[0].module

    layers[0] = layers[0]
    model.model.embed_tokens = model.model.embed_tokens
    torch.cuda.empty_cache()

    for i in trange(len(layers), desc=f"Evaluating layer-by-layer..."):
        layer = layers[i]
        for j in range(nsamples):
            inps[j] = layer(inps[j], attention_mask=attention_masks[j], position_ids=position_ids[j])[0]
        layers[i] = layer
        del layer
        torch.cuda.empty_cache()

    if model.model.norm is not None:
        model.model.norm = model.model.norm
    model.lm_head = model.lm_head

    nlls = []
    for i in range(nsamples):
        hidden_states = inps[i]
        if model.model.norm is not None:
            hidden_states = model.model.norm(hidden_states)
        lm_logits = model.lm_head(hidden_states)
        shift_logits = lm_logits[:, :-1, :].contiguous()
        shift_labels = (dataloader[i].to("cuda"))[:, 1:]
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        neg_log_likelihood = loss.float() * 8192
        nlls.append(neg_log_likelihood)
    ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * 8192))
    print(ppl.item())

    model.config.use_cache = use_cache
    
    return ppl.item()

In [4]:
model = model.to("cuda")

In [None]:
from gptq.datautils import get_loaders

datasets = ['wikitext2'] 
for dataset in datasets:
    dataloader, testloader = get_loaders(
        dataset, seed=0, model="meta-llama/Meta-Llama-3.1-8B", seqlen=8192
    )
    ppl = llama_eval(model, testloader)

In [None]:
results = get_zero_shots(
    model,
    task_list=("winogrande","arc_easy","piqa","hellaswag","winogrande","arc_challenge"),
    num_fewshots=1,
)

In [None]:
results

In [None]:
results = get_zero_shots(
    model,
    task_list=("mmlu",),
    num_fewshots=5,
)

In [None]:
results