<a href="https://colab.research.google.com/github/EvolventaAGG/text-generation-webui/blob/main/llama4int_homebrew.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

First, install the CUDA extensions.

In [None]:
!git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa.git
%cd 'GPTQ-for-LLaMa'
!python setup_cuda.py install
#!python test_kernel.py

Next, restart the runtime (but don't delete it). We'll need to do that in order for colab to be able to use the quant_cuda CPP extensions.

Afterward, return to this this cell and execute it to clone the repo, install libraries and download your 4 bit LLaMA model of choice.

In [None]:
import sys
import torch
import quant_cuda

!pip install transformers
!pip install sentencepiece
weights_url = 'https://huggingface.co/decapoda-research/llama-13b-hf-int4/resolve/main/llama-13b-4bit.pt' #@param {type:"string"}
num_params = "13b" #@param ["7b", "13b", "30b", "65b"]
!wget {weights_url}
!pip install git+https://github.com/zphang/transformers@llama_push
sys.path.insert(0, '/content/GPTQ-for-LLaMa/')
#!CUDA_VISIBLE_DEVICES=0 python llama_inference.py decapoda-research/llama-13b-hf --wbits 4 --load llama-13b-4bit.pt --text "It was the best of times, it was the worst of times"

Now execute this cell in order to load in the model. Additionally, you can specify your context size (if you're free tier and running 13B, you'll have to keep this pretty low or you may either run out of memory or have ridiculously slow generation times) and a flag denoting whether to load and split the model checkpoint in GPU VRAM before loading (also needed for free tier 13B).

In [None]:
import time

import torch
import torch.nn as nn

from gptq import *
from modelutils import *
from quant import *

from transformers import AutoTokenizer

DEV = torch.device('cuda:0')
context_size = 1024 #@param {type:"number"}
split_checkpoint = True #@param {type:"boolean"}

def load_quant(model, checkpoint, wbits):
    from transformers import LLaMAConfig, LLaMAForCausalLM 
    config = LLaMAConfig.from_pretrained(model)
    def noop(*args, **kwargs):
        pass
    torch.nn.init.kaiming_uniform_ = noop 
    torch.nn.init.uniform_ = noop 
    torch.nn.init.normal_ = noop 

    if split_checkpoint:
        print('Splitting checkpoint ...')
        ckpt = torch.load(checkpoint, map_location='cuda')

        d1 = dict(list(ckpt.items())[:len(ckpt)//2])
        torch.save(d1, checkpoint + '0')
        del(d1)

        d2 = dict(list(ckpt.items())[len(ckpt)//2:])
        torch.save(d2, checkpoint + '1')
        del(d2)

        del(ckpt)

    torch.set_default_dtype(torch.half)
    transformers.modeling_utils._init_weights = False
    torch.set_default_dtype(torch.half)
    model = LLaMAForCausalLM(config)
    torch.set_default_dtype(torch.float)
    model = model.eval()
    layers = find_layers(model)
    for name in ['lm_head']:
        if name in layers:
            del layers[name]
    make_quant(model, layers, wbits)

    if split_checkpoint:
        print('Loading model ...')
        for i in range(2):
            ckpt = torch.load(checkpoint + str(i))
            model.load_state_dict(ckpt, strict=False)
            del(ckpt)
        print('Done.')

    else:
        ckpt = torch.load(checkpoint)
        print('Loading model ...')
        model.load_state_dict(torch.load(checkpoint))
        print('Done.')

    model.seqlen = context_size
    return model

model = load_quant('decapoda-research/llama-{}-hf'.format(num_params), 'llama-{}-4bit.pt'.format(num_params), 4).cuda()
model.to(DEV)
tokenizer = AutoTokenizer.from_pretrained('decapoda-research/llama-{}-hf'.format(num_params))

Main GUI.

In [None]:
import ipywidgets as widgets
from IPython.display import display
import time

min_gen_len = 80 #@param {type:"number"}
max_gen_len = 160 #@param {type:"number"}
temperature = 1.2 #@param {type:"number"}
top_p = 0.9 #@param {type:"number"}
repetition_penalty = 1.1 #@param {type:"number"}

input_text_area = widgets.Textarea(placeholder='Enter a prompt...',
                                   layout=widgets.Layout(width='1200px',
                                                         height='600px'))
send_button = widgets.Button(description='Send')
undo_button = widgets.Button(description='Undo')
redo_button = widgets.Button(description='Redo')
retry_button = widgets.Button(description='Retry')
memory_button = widgets.ToggleButton(description='Memory')

hbox = widgets.HBox([input_text_area,
                     widgets.VBox([send_button, undo_button, redo_button,
                                  retry_button, memory_button])])
output = widgets.Output()

undo_button.disabled = True
redo_button.disabled = True
retry_button.disabled = True

listen_for_updates = False
cur_outputs = []
cur_outputs_idx = -1
memory_text = ''
input_text = ''

def generate():
    # When creating the context, first, place the full memory followed by a
    # newline.
    #
    # Next, taking the last (max_seq_len-1-max_gen_len-len(mem)) tokens,
    # place these tokens in the context.
    
    if memory_text:
        mem_tokenized = tokenizer.encode(memory_text + '\n', return_tensors='pt')[0].tolist()
    else:
        mem_tokenized = []
    
    inp_tokenized = tokenizer.encode(input_text_area.value, return_tensors='pt')[0].tolist()
    num_inp_tokens = max(model.seqlen-1-max_gen_len-len(mem_tokenized), 0)

    if num_inp_tokens > 0:
        tokenized = mem_tokenized + inp_tokenized[-num_inp_tokens:]
    elif len(mem_tokenized) > 0:
        num_mem_tokens = model.seqlen-1-max_gen_len
        tokenized = mem_tokenized[-num_mem_tokens:]
    else:
        tokenized = []

    detokenized = tokenizer.decode(tokenized)
    retokenized = tokenizer.encode(detokenized, return_tensors='pt').to(DEV)
    prev_num_tokens = len(retokenized[0])

    with torch.no_grad():
        output_tokenized = model.generate(retokenized,
                                          do_sample=True,
                                          min_length=min_gen_len+prev_num_tokens,
                                          max_length=max_gen_len+prev_num_tokens,
                                          top_p=top_p,
                                          temperature=temperature,
                                          repetition_penalty=repetition_penalty)[0].tolist()

    output = tokenizer.decode(output_tokenized)
    num_characters = len(output) - len(detokenized)
    return output[-num_characters:]

def on_update_input_text_area(change):
    global listen_for_updates, cur_outputs, cur_outputs_idx

    if listen_for_updates:
        cur_outputs = []
        cur_outputs_idx = -1
        undo_button.disabled = True
        redo_button.disabled = True
        retry_button.disabled = True

def send():
    global listen_for_updates, cur_outputs, cur_outputs_idx

    input_text_area.disabled = True
    memory_button.disabled = True
    redo_button.disabled = True
    undo_button.disabled = True
    retry_button.disabled = True
    listen_for_updates = False

    generation = generate()
    input_text_area.value += generation
    cur_outputs_idx += 1
    cur_outputs = cur_outputs[:cur_outputs_idx]
    cur_outputs.append(generation)

    undo_button.disabled = False
    retry_button.disabled = False
    listen_for_updates = True
    memory_button.disabled = False
    input_text_area.disabled = False

def undo():
    global listen_for_updates, cur_outputs, cur_outputs_idx

    listen_for_updates = False
    num_chars = len(cur_outputs[cur_outputs_idx])
    input_text_area.value = input_text_area.value[:-num_chars]
    cur_outputs_idx -= 1

    if cur_outputs_idx == -1:
        undo_button.disabled = True
        retry_button.disabled = True
    if len(cur_outputs) > 0:
        redo_button.disabled = False

    listen_for_updates = True

def redo():
    global listen_for_updates, cur_outputs, cur_outputs_idx

    listen_for_updates = False
    input_text_area.value += cur_outputs[cur_outputs_idx+1]
    cur_outputs_idx += 1

    if cur_outputs_idx == len(cur_outputs) - 1:
        redo_button.disabled = True
    if len(cur_outputs) > 0:
        undo_button.disabled = False
        retry_button.disabled = False

    listen_for_updates = True

def send_button_clicked(b):
    send()

def undo_button_clicked(b):
    undo()

def redo_button_clicked(b):
    redo()

def retry_button_clicked(b):
    undo()
    send()

def memory_button_clicked(b):
    global listen_for_updates, cur_outputs, cur_outputs_idx, memory_text, \
           input_text
    if memory_button.value:
        listen_for_updates = False
        send_button.disabled = True
        undo_button.disabled = True
        redo_button.disabled = True
        retry_button.disabled = True
        input_text = input_text_area.value
        input_text_area.value = memory_text
    else:
        memory_text = input_text_area.value
        input_text_area.value = input_text
        input_text = ''
        send_button.disabled = False
        undo_button.disabled = cur_outputs_idx < 0
        redo_button.disabled = cur_outputs_idx >= len(cur_outputs) - 1
        retry_button.disabled = undo_button.disabled
        listen_for_updates = True

send_button.on_click(send_button_clicked)
undo_button.on_click(undo_button_clicked)
redo_button.on_click(redo_button_clicked)
retry_button.on_click(retry_button_clicked)
memory_button.observe(memory_button_clicked, names='value')
input_text_area.observe(on_update_input_text_area, names='value')

display(hbox, output)