# Mixtral in Colab

Welcome! In this notebook you can run [Mixtral8x7B-Instruct](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) with decent generation speed **right in Google Colab or on a consumer-grade GPU**. This was made possible by quantizing the original model in mixed precision and implementing a MoE-specific offloading strategy.

To learn more, read our [tech report](https://arxiv.org/abs/2312.17238) or check out the [repo](https://github.com/dvmazur/mixtral-offloading) on GitHub.

One will need approximately 16 GB of VRAM and 11 GB of RAM to run this notebook and generate somewhat long texts.


<details>

<summary>How to balance between RAM and GPU VRAM usage</summary>

You can balance between RAM and GPU VRAM usage by changing <code>offload_per_layer</code> variable in the <a href="#scrollTo=_mIpePTMFyRY&line=10&uniqifier=1">Initialize model</a> section. Increasing <code>offload_per_layer</code> will decrease GPU VRAM usage, increase RAM usage and decrease generation speed. Decreasing <code>offload_per_layer</code> will have the opposite effect.

Note that this notebook should run normally in Google Colab with <code>offload_per_layer = 4</code>, but may crush with other values. However, if you run this somewhere else, you're free to play with this variable.
</details>

## Install and import libraries

In [1]:
# fix numpy in colab
import numpy
from IPython.display import clear_output

# fix triton in colab
!export LC_ALL="en_US.UTF-8"
!export LD_LIBRARY_PATH="/usr/lib64-nvidia"
!export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
!ldconfig /usr/lib64-nvidia

# !git clone https://github.com/dvmazur/mixtral-offloading.git --quiet
# !cd ../mixtral-offloading && pip install -q -r requirements.txt
# !huggingface-cli download lavawolfiee/Mixtral-8x7B-Instruct-v0.1-offloading-demo --quiet --local-dir Mixtral-8x7B-Instruct-v0.1-offloading-demo
# !huggingface-cli download mistralai/Mixtral-8x7B-Instruct-v0.1 --quiet --local-dir Mixtral-8x7B-Instruct-v0.1

clear_output()

In [2]:
import sys

sys.path.append("mixtral-offloading")
import torch
from torch.nn import functional as F
from hqq.core.quantize import BaseQuantizeConfig
from huggingface_hub import snapshot_download
from IPython.display import clear_output
from tqdm.auto import trange
from transformers import AutoConfig, AutoTokenizer
from transformers.utils import logging as hf_logging

from src.build_model import OffloadConfig, QuantConfig, build_model

[36mhqq_aten package not installed. HQQBackend.ATEN backend will not work unless you install the hqq_aten lib in hqq/kernels.[0m


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
%load_ext autoreload
%autoreload 2

## Initialize model

In [4]:
quantized = False

if quantized == False:
    state_path = "Mixtral-8x7B-Instruct-v0.1"
    model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"
else:
    state_path = "Mixtral-8x7B-Instruct-v0.1-offloading-demo"
    model_name = "lavawolfiee/Mixtral-8x7B-Instruct-v0.1-offloading-demo"

config = AutoConfig.from_pretrained(model_name)

device = torch.device("cuda:0")

##### Change this to 5 if you have only 12 GB of GPU VRAM #####
# offload_per_layer = 4
offload_per_layer = 7
###############################################################

num_experts = config.num_local_experts

offload_config = OffloadConfig(
    main_size=config.num_hidden_layers * (num_experts - offload_per_layer),
    offload_size=config.num_hidden_layers * offload_per_layer,
    buffer_size=4,
    offload_per_layer=offload_per_layer,
)


attn_config = BaseQuantizeConfig(
    nbits=4,
    group_size=64,
    quant_zero=True,
    quant_scale=True,
)
attn_config["scale_quant_params"]["group_size"] = 256


ffn_config = BaseQuantizeConfig(
    nbits=2,
    group_size=16,
    quant_zero=True,
    quant_scale=True,
)

if quantized:
    quant_config = QuantConfig(ffn_config=ffn_config, attn_config=attn_config)
else:
    quant_config = None


model = build_model(
    device=device,
    quant_config=quant_config,
    offload_config=offload_config,
    state_path=state_path,
)



352321536
loading main module 0...
352321536
loading main module 1...
352321536
loading main module 2...
352321536
loading main module 3...
352321536
loading main module 4...
352321536
loading main module 5...
352321536
loading main module 6...
352321536
loading main module 7...
352321536
loading main module 8...
352321536
loading main module 9...
352321536
loading main module 10...
352321536
loading main module 11...
352321536
loading main module 12...
352321536
loading main module 13...
352321536
loading main module 14...
352321536
loading main module 15...
352321536
loading main module 16...
352321536
loading main module 17...
352321536
loading main module 18...
352321536
loading main module 19...
352321536
loading main module 20...
352321536
loading main module 21...
352321536
loading main module 22...
352321536
loading main module 23...
352321536
loading main module 24...
352321536
loading main module 25...
352321536
loading main module 26...
352321536
loading main module 27...
35

Loading experts:   0%|          | 0/32 [00:00<?, ?it/s]

352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:   3%|▎         | 1/32 [00:27<13:58, 27.05s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:   6%|▋         | 2/32 [00:58<14:45, 29.50s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:   9%|▉         | 3/32 [01:32<15:15, 31.57s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  12%|█▎        | 4/32 [02:06<15:12, 32.59s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  16%|█▌        | 5/32 [02:41<15:07, 33.60s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  19%|█▉        | 6/32 [03:16<14:41, 33.89s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  22%|██▏       | 7/32 [03:51<14:17, 34.30s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  25%|██▌       | 8/32 [04:26<13:50, 34.59s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  28%|██▊       | 9/32 [05:01<13:19, 34.77s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  31%|███▏      | 10/32 [05:35<12:40, 34.55s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  34%|███▍      | 11/32 [06:10<12:06, 34.60s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  38%|███▊      | 12/32 [06:44<11:29, 34.45s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  41%|████      | 13/32 [07:18<10:52, 34.34s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  44%|████▍     | 14/32 [07:54<10:23, 34.65s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  47%|████▋     | 15/32 [08:29<09:50, 34.71s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  50%|█████     | 16/32 [09:03<09:14, 34.65s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  53%|█████▎    | 17/32 [09:37<08:38, 34.56s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  56%|█████▋    | 18/32 [10:13<08:06, 34.78s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  59%|█████▉    | 19/32 [10:47<07:30, 34.65s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  62%|██████▎   | 20/32 [11:21<06:54, 34.56s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  66%|██████▌   | 21/32 [11:56<06:19, 34.54s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  69%|██████▉   | 22/32 [12:30<05:45, 34.52s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  72%|███████▏  | 23/32 [13:05<05:09, 34.43s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  75%|███████▌  | 24/32 [13:41<04:39, 35.00s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  78%|███████▊  | 25/32 [14:15<04:02, 34.68s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  81%|████████▏ | 26/32 [14:50<03:28, 34.68s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  84%|████████▍ | 27/32 [15:23<02:51, 34.39s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  88%|████████▊ | 28/32 [15:59<02:18, 34.73s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  91%|█████████ | 29/32 [16:34<01:44, 34.92s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  94%|█████████▍| 30/32 [17:09<01:10, 35.05s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  97%|█████████▋| 31/32 [17:44<00:34, 34.91s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts: 100%|██████████| 32/32 [18:18<00:00, 34.33s/it]

352321536





## Run the model

In [5]:
from transformers import TextStreamer


tokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
past_key_values = None
sequence = None

seq_len = 0
while True:
  print("User: ", end="")
  user_input = input()
  print("\n")

  user_entry = dict(role="user", content=user_input)
  input_ids = tokenizer.apply_chat_template([user_entry], return_tensors="pt").to(device)

  if past_key_values is None:
    attention_mask = torch.ones_like(input_ids)
  else:
    seq_len = input_ids.size(1) + past_key_values[0][0][0].size(1)
    attention_mask = torch.ones([1, seq_len - 1], dtype=torch.int, device=device)

  print("Mixtral: ", end="")
  result = model.generate(
    input_ids=input_ids,
    attention_mask=attention_mask,
    past_key_values=past_key_values,
    streamer=streamer,
    do_sample=True,
    temperature=0.9,
    top_p=0.9,
    max_new_tokens=512,
    pad_token_id=tokenizer.eos_token_id,
    return_dict_in_generate=True,
    output_hidden_states=True,
  )
  print("\n")

  sequence = result["sequences"]
  past_key_values = result["past_key_values"]

tokenizer_config.json: 100%|██████████| 1.46k/1.46k [00:00<00:00, 2.57MB/s]
tokenizer.model: 100%|██████████| 493k/493k [00:00<00:00, 50.3MB/s]
tokenizer.json: 100%|██████████| 1.80M/1.80M [00:00<00:00, 23.8MB/s]
special_tokens_map.json: 100%|██████████| 72.0/72.0 [00:00<00:00, 176kB/s]


User: 

Mixtral: 

RuntimeError: selected index k out of range