# Mixtral in Colab

Welcome! In this notebook you can run [Mixtral8x7B-Instruct](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) with decent generation speed **right in Google Colab or on a consumer-grade GPU**. This was made possible by quantizing the original model in mixed precision and implementing a MoE-specific offloading strategy.

To learn more, read our [tech report](https://arxiv.org/abs/2312.17238) or check out the [repo](https://github.com/dvmazur/mixtral-offloading) on GitHub.

One will need approximately 16 GB of VRAM and 11 GB of RAM to run this notebook and generate somewhat long texts.


<details>

<summary>How to balance between RAM and GPU VRAM usage</summary>

You can balance between RAM and GPU VRAM usage by changing <code>offload_per_layer</code> variable in the <a href="#scrollTo=_mIpePTMFyRY&line=10&uniqifier=1">Initialize model</a> section. Increasing <code>offload_per_layer</code> will decrease GPU VRAM usage, increase RAM usage and decrease generation speed. Decreasing <code>offload_per_layer</code> will have the opposite effect.

Note that this notebook should run normally in Google Colab with <code>offload_per_layer = 4</code>, but may crush with other values. However, if you run this somewhere else, you're free to play with this variable.
</details>

## Install and import libraries

In [2]:
# fix numpy in colab
import numpy
from IPython.display import clear_output

# fix triton in colab
!export LC_ALL="en_US.UTF-8"
!export LD_LIBRARY_PATH="/usr/lib64-nvidia"
!export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
!ldconfig /usr/lib64-nvidia

# !git clone https://github.com/dvmazur/mixtral-offloading.git --quiet
# !cd ../mixtral-offloading && pip install -q -r requirements.txt
# !huggingface-cli download lavawolfiee/Mixtral-8x7B-Instruct-v0.1-offloading-demo --quiet --local-dir Mixtral-8x7B-Instruct-v0.1-offloading-demo
# !huggingface-cli download mistralai/Mixtral-8x7B-Instruct-v0.1 --quiet --local-dir Mixtral-8x7B-Instruct-v0.1

clear_output()

In [3]:
import sys

sys.path.append("mixtral-offloading")
import torch
from torch.nn import functional as F
from hqq.core.quantize import BaseQuantizeConfig
from huggingface_hub import snapshot_download
from IPython.display import clear_output
from tqdm.auto import trange
from transformers import AutoConfig, AutoTokenizer
from transformers.utils import logging as hf_logging

from src.build_model import OffloadConfig, QuantConfig, build_model

[36mhqq_aten package not installed. HQQBackend.ATEN backend will not work unless you install the hqq_aten lib in hqq/kernels.[0m


  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [1]:
%load_ext autoreload
%autoreload 2

## Initialize model

In [4]:
quantized = False

if quantized == False:
    state_path = "Mixtral-8x7B-Instruct-v0.1"
    model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"
else:
    state_path = "Mixtral-8x7B-Instruct-v0.1-offloading-demo"
    model_name = "lavawolfiee/Mixtral-8x7B-Instruct-v0.1-offloading-demo"

config = AutoConfig.from_pretrained(model_name)

device = torch.device("cuda:0")

##### Change this to 5 if you have only 12 GB of GPU VRAM #####
# offload_per_layer = 4
offload_per_layer = 7
###############################################################

num_experts = config.num_local_experts

offload_config = OffloadConfig(
    main_size=config.num_hidden_layers * (num_experts - offload_per_layer),
    offload_size=config.num_hidden_layers * offload_per_layer,
    buffer_size=4,
    offload_per_layer=offload_per_layer,
)


attn_config = BaseQuantizeConfig(
    nbits=4,
    group_size=64,
    quant_zero=True,
    quant_scale=True,
)
attn_config["scale_quant_params"]["group_size"] = 256


ffn_config = BaseQuantizeConfig(
    nbits=2,
    group_size=16,
    quant_zero=True,
    quant_scale=True,
)

if quantized:
    quant_config = QuantConfig(ffn_config=ffn_config, attn_config=attn_config)
else:
    quant_config = None


model = build_model(
    device=device,
    quant_config=quant_config,
    offload_config=offload_config,
    state_path=state_path,
)



352321536
ExpertCache: loading main module 0 on GPU...
352321536
ExpertCache: loading main module 1 on GPU...
352321536
ExpertCache: loading main module 2 on GPU...
352321536
ExpertCache: loading main module 3 on GPU...
352321536
ExpertCache: loading main module 4 on GPU...
352321536
ExpertCache: loading main module 5 on GPU...
352321536
ExpertCache: loading main module 6 on GPU...
352321536
ExpertCache: loading main module 7 on GPU...
352321536
ExpertCache: loading main module 8 on GPU...
352321536
ExpertCache: loading main module 9 on GPU...
352321536
ExpertCache: loading main module 10 on GPU...
352321536
ExpertCache: loading main module 11 on GPU...
352321536
ExpertCache: loading main module 12 on GPU...
352321536
ExpertCache: loading main module 13 on GPU...
352321536
ExpertCache: loading main module 14 on GPU...
352321536
ExpertCache: loading main module 15 on GPU...
352321536
ExpertCache: loading main module 16 on GPU...
352321536
ExpertCache: loading main module 17 on GPU...
35

Loading experts:   0%|          | 0/32 [00:00<?, ?it/s]

352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:   3%|▎         | 1/32 [00:18<09:29, 18.37s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:   6%|▋         | 2/32 [00:37<09:25, 18.85s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:   9%|▉         | 3/32 [00:58<09:35, 19.83s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  12%|█▎        | 4/32 [01:19<09:22, 20.08s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  16%|█▌        | 5/32 [01:40<09:15, 20.56s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  19%|█▉        | 6/32 [01:59<08:45, 20.21s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  22%|██▏       | 7/32 [02:19<08:18, 19.93s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  25%|██▌       | 8/32 [02:40<08:05, 20.21s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  28%|██▊       | 9/32 [03:00<07:49, 20.40s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  31%|███▏      | 10/32 [03:22<07:39, 20.87s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  34%|███▍      | 11/32 [03:50<08:03, 23.02s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  38%|███▊      | 12/32 [04:10<07:18, 21.93s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  41%|████      | 13/32 [04:30<06:45, 21.35s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  44%|████▍     | 14/32 [04:49<06:13, 20.78s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  47%|████▋     | 15/32 [05:10<05:55, 20.93s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  50%|█████     | 16/32 [05:31<05:33, 20.83s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  53%|█████▎    | 17/32 [05:52<05:11, 20.74s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  56%|█████▋    | 18/32 [06:11<04:44, 20.29s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  59%|█████▉    | 19/32 [06:31<04:24, 20.32s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  62%|██████▎   | 20/32 [06:55<04:14, 21.21s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  66%|██████▌   | 21/32 [07:22<04:14, 23.12s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  69%|██████▉   | 22/32 [07:53<04:13, 25.35s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  72%|███████▏  | 23/32 [08:13<03:35, 23.92s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  75%|███████▌  | 24/32 [08:38<03:12, 24.08s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  78%|███████▊  | 25/32 [09:14<03:14, 27.78s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  81%|████████▏ | 26/32 [09:50<03:00, 30.11s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  84%|████████▍ | 27/32 [10:10<02:16, 27.25s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  88%|████████▊ | 28/32 [10:30<01:40, 25.04s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  91%|█████████ | 29/32 [11:00<01:19, 26.42s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  94%|█████████▍| 30/32 [11:26<00:52, 26.28s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts:  97%|█████████▋| 31/32 [11:50<00:25, 25.63s/it]

352321536
352321536
352321536
352321536
352321536
352321536
352321536
352321536


Loading experts: 100%|██████████| 32/32 [12:25<00:00, 23.30s/it]

352321536





## Run the model

In [5]:
from transformers import TextStreamer


tokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
past_key_values = None
sequence = None

seq_len = 0
while True:
  print("User: ", end="")
  user_input = input()
  print("\n")

  user_entry = dict(role="user", content=user_input)
  input_ids = tokenizer.apply_chat_template([user_entry], return_tensors="pt").to(device)

  if past_key_values is None:
    attention_mask = torch.ones_like(input_ids)
  else:
    seq_len = input_ids.size(1) + past_key_values[0][0][0].size(1)
    attention_mask = torch.ones([1, seq_len - 1], dtype=torch.int, device=device)

  print("Mixtral: ", end="")
  result = model.generate(
    input_ids=input_ids,
    attention_mask=attention_mask,
    past_key_values=past_key_values,
    streamer=streamer,
    do_sample=True,
    temperature=0.9,
    top_p=0.9,
    max_new_tokens=512,
    pad_token_id=tokenizer.eos_token_id,
    return_dict_in_generate=True,
    output_hidden_states=True,
  )
  print("\n")

  sequence = result["sequences"]
  past_key_values = result["past_key_values"]

User: 

Mixtral: unionื obtainließResultцоimes励 PetDrawable老Metadataenguimesrmcolaatos엔 cinULE德ermanLongDrawablencia 

KeyboardInterrupt: 