# Example of loading various of LoRA PEFT models

This notebook showcases an experimental form of extreme bit quantization. It uses right-bit shifting simulation which yields lots of potential for Hardware-Centric TPU and other embedded AI accelerator.

![Current bit-shifted quantization](https://github.com/Tfloow/peft/blob/main/examples/lora_mixed/fig/image.png?raw=1)

This model could reduce the RAM usage of the model but is designed first to reduce Memory traffic from cache to DRAM. Bit-shifting is inexpensive in hardware and the major bottleneck during inference is the data traffic. To reduce the memory trafic, we only transfer part of the weight and smaller LoRA weights that get reconstructed on the fly.

For this research, the module `PeftMixedModel` was used to apply various LoRA weights to different shifted groups. This is based on [AWQ](https://arxiv.org/abs/2306.00978) [Lin et al 2023] quantized LLama-3.1 8B, thanks to  [#2914](https://github.com/huggingface/peft/issues/2914#issuecomment-3547905030).

## How to run the notebook?

Run it locally or use a free T4 GPU on [Google Colab](https://colab.research.google.com/drive/1tMrCepfomyzb0_WH0_xBtyVRRXLkRcU3?usp=sharing)!

### Last tested requirements
<details>
<summary>Tested with those specifications</summary>
<pre>
absl-py==1.4.0
absolufy-imports==0.3.1
accelerate==1.11.0
aiofiles==24.1.0
aiohappyeyeballs==2.6.1
aiohttp==3.13.2
aiosignal==1.4.0
alabaster==1.0.0
albucore==0.0.24
albumentations==2.0.8
ale-py==0.11.2
alembic==1.17.1
altair==5.5.0
annotated-doc==0.0.4
annotated-types==0.7.0
antlr4-python3-runtime==4.9.3
anyio==4.11.0
anywidget==0.9.19
argon2-cffi==25.1.0
argon2-cffi-bindings==25.1.0
array_record==0.8.2
arrow==1.4.0
arviz==0.22.0
astropy==7.1.1
astropy-iers-data==0.2025.11.10.0.38.31
astunparse==1.6.3
atpublic==5.1
attrs==25.4.0
audioread==3.1.0
Authlib==1.6.5
autoawq==0.2.9
autograd==1.8.0
babel==2.17.0
backcall==0.2.0
beartype==0.22.5
beautifulsoup4==4.13.5
betterproto==2.0.0b6
bigframes==2.28.0
bigquery-magics==0.10.3
bitsandbytes==0.48.2
bleach==6.3.0
blinker==1.9.0
blis==1.3.0
blobfile==3.1.0
blosc2==3.11.0
bokeh==3.7.3
Bottleneck==1.4.2
bqplot==0.12.45
branca==0.8.2
brotli==1.2.0
build==1.3.0
CacheControl==0.14.3
cachetools==5.5.2
catalogue==2.0.10
certifi==2025.10.5
cffi==2.0.0
chardet==5.2.0
charset-normalizer==3.4.4
chex==0.1.90
clarabel==0.11.1
click==8.3.0
cloudpathlib==0.23.0
cloudpickle==3.1.2
cmake==3.31.6
cmdstanpy==1.3.0
colorcet==3.1.0
colorlover==0.3.0
colour==0.1.5
community==1.0.0b1
confection==0.1.5
cons==0.4.7
contourpy==1.3.3
cramjam==2.11.0
cryptography==43.0.3
cuda-bindings==12.9.4
cuda-core==0.3.2
cuda-pathfinder==1.3.2
cuda-python==12.9.4
cuda-toolkit==12.9.1
cudf-cu12 @ https://pypi.nvidia.com/cudf-cu12/cudf_cu12-25.10.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
cudf-polars-cu12==25.10.0
cufflinks==0.17.3
cuml-cu12==25.10.0
cupy-cuda12x==13.6.0
curl_cffi==0.13.0
cvxopt==1.3.2
cvxpy==1.6.7
cycler==0.12.1
cyipopt==1.5.0
cymem==2.0.11
Cython==3.0.12
dask==2025.9.1
dask-cuda==25.10.0
dask-cudf-cu12==25.10.0
dataproc-spark-connect==0.8.3
datasets==4.0.0
db-dtypes==1.4.4
dbus-python==1.2.18
debugpy==1.8.15
decorator==4.4.2
defusedxml==0.7.1
diffusers==0.35.2
dill==0.3.8
distributed==2025.9.1
distributed-ucxx-cu12==0.46.0
distro==1.9.0
dlib==19.24.6
dm-tree==0.1.9
docstring_parser==0.17.0
docutils==0.21.2
dopamine_rl==4.1.2
duckdb==1.3.2
earthengine-api==1.5.24
easydict==1.13
editdistance==0.8.1
eerepr==0.1.2
einops==0.8.1
en_core_web_sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl#sha256=1932429db727d4bff3deed6b34cfc05df17794f4a52eeb26cf8928f7c1a0fb85
entrypoints==0.4
et_xmlfile==2.0.0
etils==1.13.0
etuples==0.3.10
evaluate==0.4.6
Farama-Notifications==0.0.4
fastai==2.8.5
fastapi==0.121.1
fastcore==1.8.16
fastdownload==0.0.7
fastjsonschema==2.21.2
fastprogress==1.0.3
fastrlock==0.8.3
fasttransform==0.0.2
ffmpy==1.0.0
filelock==3.20.0
firebase-admin==6.9.0
Flask==3.1.2
flatbuffers==25.9.23
flax==0.10.7
folium==0.20.0
fonttools==4.60.1
fqdn==1.5.1
frozendict==2.4.6
frozenlist==1.8.0
fsspec==2025.3.0
future==1.0.0
gast==0.6.0
gcsfs==2025.3.0
GDAL==3.8.4
gdown==5.2.0
geemap==0.35.3
geocoder==1.38.1
geographiclib==2.1
geopandas==1.1.1
geopy==2.4.1
gin-config==0.5.0
gitdb==4.0.12
GitPython==3.1.45
glob2==0.7
google==2.0.3
google-adk==1.17.0
google-ai-generativelanguage==0.6.15
google-api-core==2.28.1
google-api-python-client==2.187.0
google-auth==2.38.0
google-auth-httplib2==0.2.1
google-auth-oauthlib==1.2.3
google-cloud-aiplatform==1.126.1
google-cloud-appengine-logging==1.7.0
google-cloud-audit-log==0.4.0
google-cloud-bigquery==3.38.0
google-cloud-bigquery-connection==1.19.0
google-cloud-bigquery-storage==2.34.0
google-cloud-bigtable==2.34.0
google-cloud-core==2.5.0
google-cloud-dataproc==5.23.0
google-cloud-datastore==2.21.0
google-cloud-discoveryengine==0.13.12
google-cloud-firestore==2.21.0
google-cloud-functions==1.21.0
google-cloud-language==2.18.0
google-cloud-logging==3.12.1
google-cloud-monitoring==2.28.0
google-cloud-resource-manager==1.15.0
google-cloud-secret-manager==2.25.0
google-cloud-spanner==3.59.0
google-cloud-speech==2.34.0
google-cloud-storage==2.19.0
google-cloud-trace==1.17.0
google-cloud-translate==3.23.0
google-colab @ file:///colabtools/dist/google_colab-1.0.0.tar.gz
google-crc32c==1.7.1
google-genai==1.49.0
google-generativeai==0.8.5
google-pasta==0.2.0
google-resumable-media==2.7.2
googleapis-common-protos==1.72.0
googledrivedownloader==1.1.0
gradio==5.49.1
gradio_client==1.13.3
graphviz==0.21
greenlet==3.2.4
groovy==0.1.2
grpc-google-iam-v1==0.14.3
grpc-interceptor==0.15.4
grpcio==1.76.0
grpcio-status==1.71.2
grpclib==0.4.8
gspread==6.2.1
gspread-dataframe==4.0.0
gym==0.25.2
gym-notices==0.1.0
gymnasium==1.2.2
h11==0.16.0
h2==4.3.0
h5netcdf==1.7.3
h5py==3.15.1
hdbscan==0.8.40
hf-xet==1.2.0
hf_transfer==0.1.9
highspy==1.12.0
holidays==0.84
holoviews==1.22.0
hpack==4.1.0
html5lib==1.1
httpcore==1.0.9
httpimport==1.4.1
httplib2==0.31.0
httpx==0.28.1
httpx-sse==0.4.3
huggingface-hub==0.36.0
humanize==4.14.0
hyperframe==6.1.0
hyperopt==0.2.7
ibis-framework==9.5.0
idna==3.11
ImageIO==2.37.2
imageio-ffmpeg==0.6.0
imagesize==1.4.1
imbalanced-learn==0.14.0
immutabledict==4.2.2
importlib_metadata==8.7.0
importlib_resources==6.5.2
imutils==0.5.4
inflect==7.5.0
iniconfig==2.3.0
intel-cmplr-lib-ur==2025.3.1
intel-openmp==2025.3.1
ipyevents==2.0.4
ipyfilechooser==0.6.0
ipykernel==6.17.1
ipyleaflet==0.20.0
ipyparallel==8.8.0
ipython==7.34.0
ipython-genutils==0.2.0
ipython-sql==0.5.0
ipytree==0.2.2
ipywidgets==7.7.1
isoduration==20.11.0
itsdangerous==2.2.0
jaraco.classes==3.4.0
jaraco.context==6.0.1
jaraco.functools==4.3.0
jax==0.7.2
jax-cuda12-pjrt==0.7.2
jax-cuda12-plugin==0.7.2
jaxlib==0.7.2
jeepney==0.9.0
jieba==0.42.1
Jinja2==3.1.6
jiter==0.12.0
joblib==1.5.2
jsonpatch==1.33
jsonpickle==4.1.1
jsonpointer==3.0.0
jsonschema==4.25.1
jsonschema-specifications==2025.9.1
jupyter-console==6.6.3
jupyter-events==0.12.0
jupyter-leaflet==0.20.0
jupyter_client==7.4.9
jupyter_core==5.9.1
jupyter_kernel_gateway @ git+https://github.com/googlecolab/kernel_gateway@b134e9945df25c2dcb98ade9129399be10788671
jupyter_server==2.14.0
jupyter_server_terminals==0.5.3
jupyterlab_pygments==0.3.0
jupyterlab_widgets==3.0.16
jupytext==1.17.3
kaggle==1.7.4.5
kagglehub==0.3.13
keras==3.10.0
keras-hub==0.21.1
keras-nlp==0.21.1
keyring==25.6.0
keyrings.google-artifactregistry-auth==1.1.2
kiwisolver==1.4.9
langchain==0.3.27
langchain-core==0.3.79
langchain-text-splitters==0.3.11
langsmith==0.4.42
lark==1.3.1
launchpadlib==1.10.16
lazr.restfulclient==0.14.4
lazr.uri==1.0.6
lazy_loader==0.4
libclang==18.1.1
libcudf-cu12 @ https://pypi.nvidia.com/libcudf-cu12/libcudf_cu12-25.10.0-py3-none-manylinux_2_28_x86_64.whl
libcugraph-cu12==25.10.1
libcuml-cu12==25.10.0
libkvikio-cu12==25.10.0
libpysal==4.13.0
libraft-cu12==25.10.0
librmm-cu12==25.10.0
librosa==0.11.0
libucx-cu12==1.19.0
libucxx-cu12==0.46.0
lightgbm==4.6.0
linkify-it-py==2.0.3
llvmlite==0.43.0
locket==1.0.0
logical-unification==0.4.7
lxml==5.4.0
Mako==1.3.10
Markdown==3.10
markdown-it-py==4.0.0
MarkupSafe==3.0.3
matplotlib==3.10.0
matplotlib-inline==0.2.1
matplotlib-venn==1.1.2
mcp==1.21.0
mdit-py-plugins==0.5.0
mdurl==0.1.2
miniKanren==1.0.5
missingno==0.5.2
mistune==3.1.4
mizani==0.13.5
mkl==2025.2.0
ml_dtypes==0.5.3
mlxtend==0.23.4
more-itertools==10.8.0
moviepy==1.0.3
mpmath==1.3.0
msgpack==1.1.2
multidict==6.7.0
multipledispatch==1.0.0
multiprocess==0.70.16
multitasking==0.0.12
murmurhash==1.0.13
music21==9.3.0
namex==0.1.0
narwhals==2.11.0
natsort==8.4.0
nbclassic==1.3.3
nbclient==0.10.2
nbconvert==7.16.6
nbformat==5.10.4
ndindex==1.10.0
nest-asyncio==1.6.0
networkx==3.5
nibabel==5.3.2
nltk==3.9.1
notebook==6.5.7
notebook_shim==0.2.4
numba==0.60.0
numba-cuda==0.19.1
numexpr==2.14.1
numpy==2.0.2
nvidia-cublas-cu12==12.6.4.1
nvidia-cuda-cccl-cu12==12.9.27
nvidia-cuda-cupti-cu12==12.6.80
nvidia-cuda-nvcc-cu12==12.9.86
nvidia-cuda-nvrtc-cu12==12.6.77
nvidia-cuda-runtime-cu12==12.6.77
nvidia-cudnn-cu12==9.10.2.21
nvidia-cufft-cu12==11.3.0.4
nvidia-cufile-cu12==1.11.1.6
nvidia-curand-cu12==10.3.7.77
nvidia-cusolver-cu12==11.7.1.2
nvidia-cusparse-cu12==12.5.4.2
nvidia-cusparselt-cu12==0.7.1
nvidia-ml-py==13.580.82
nvidia-nccl-cu12==2.27.3
nvidia-nvjitlink-cu12==12.6.85
nvidia-nvshmem-cu12==3.4.5
nvidia-nvtx-cu12==12.6.77
nvtx==0.2.13
nx-cugraph-cu12 @ https://pypi.nvidia.com/nx-cugraph-cu12/nx_cugraph_cu12-25.10.0-py3-none-any.whl
oauth2client==4.1.3
oauthlib==3.3.1
omegaconf==2.3.0
openai==1.109.1
opencv-contrib-python==4.12.0.88
opencv-python==4.12.0.88
opencv-python-headless==4.12.0.88
openpyxl==3.1.5
opentelemetry-api==1.37.0
opentelemetry-exporter-gcp-logging==1.11.0a0
opentelemetry-exporter-gcp-monitoring==1.11.0a0
opentelemetry-exporter-gcp-trace==1.11.0
opentelemetry-exporter-otlp-proto-common==1.37.0
opentelemetry-exporter-otlp-proto-http==1.37.0
opentelemetry-proto==1.37.0
opentelemetry-resourcedetector-gcp==1.11.0a0
opentelemetry-sdk==1.37.0
opentelemetry-semantic-conventions==0.58b0
opt_einsum==3.4.0
optax==0.2.6
optree==0.17.0
orbax-checkpoint==0.11.28
orjson==3.11.4
osqp==1.0.5
overrides==7.7.0
packaging==25.0
pandas==2.2.2
pandas-datareader==0.10.0
pandas-gbq==0.30.0
pandas-stubs==2.2.2.240909
pandocfilters==1.5.1
panel==1.8.3
param==2.2.1
parso==0.8.5
parsy==2.2
partd==1.4.2
patsy==1.0.2
peewee==3.18.3
peft==0.17.1
pexpect==4.9.0
pickleshare==0.7.5
pillow==11.3.0
platformdirs==4.5.0
plotly==5.24.1
plotnine==0.14.5
pluggy==1.6.0
plum-dispatch==2.6.0
ply==3.11
polars==1.31.0
pooch==1.8.2
portpicker==1.5.2
preshed==3.0.10
prettytable==3.16.0
proglog==0.1.12
progressbar2==4.5.0
prometheus_client==0.23.1
promise==2.3
prompt_toolkit==3.0.52
propcache==0.4.1
prophet==1.1.7
proto-plus==1.26.1
protobuf==5.29.5
psutil==5.9.5
psycopg2==2.9.11
psygnal==0.15.0
ptyprocess==0.7.0
py-cpuinfo==9.0.0
py4j==0.10.9.7
pyarrow==18.1.0
pyasn1==0.6.1
pyasn1_modules==0.4.2
pycairo==1.29.0
pycocotools==2.0.10
pycparser==2.23
pycryptodomex==3.23.0
pydantic==2.11.10
pydantic-settings==2.12.0
pydantic_core==2.33.2
pydata-google-auth==1.9.1
pydot==3.0.4
pydotplus==2.0.2
PyDrive2==1.21.3
pydub==0.25.1
pyerfa==2.0.1.5
pygame==2.6.1
pygit2==1.19.0
Pygments==2.19.2
PyGObject==3.42.0
PyJWT==2.10.1
pylibcudf-cu12 @ https://pypi.nvidia.com/pylibcudf-cu12/pylibcudf_cu12-25.10.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
pylibcugraph-cu12==25.10.1
pylibraft-cu12==25.10.0
pymc==5.26.1
pynndescent==0.5.13
pyogrio==0.11.1
pyomo==6.9.5
PyOpenGL==3.1.10
pyOpenSSL==24.2.1
pyparsing==3.2.5
pyperclip==1.11.0
pyproj==3.7.2
pyproject_hooks==1.2.0
pyshp==3.0.2.post1
PySocks==1.7.1
pyspark==3.5.1
pytensor==2.35.1
pytest==8.4.2
python-apt==0.0.0
python-box==7.3.2
python-dateutil==2.9.0.post0
python-dotenv==1.2.1
python-json-logger==4.0.0
python-louvain==0.16
python-multipart==0.0.20
python-slugify==8.0.4
python-snappy==0.7.3
python-utils==3.9.1
pytz==2025.2
pyviz_comms==3.0.6
PyWavelets==1.9.0
PyYAML==6.0.3
pyzmq==26.2.1
raft-dask-cu12==25.10.0
rapids-dask-dependency==25.10.0
rapids-logger==0.1.19
ratelim==0.1.6
referencing==0.37.0
regex==2024.11.6
requests==2.32.4
requests-oauthlib==2.0.0
requests-toolbelt==1.0.0
requirements-parser==0.9.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rfc3987-syntax==1.1.0
rich==13.9.4
rmm-cu12==25.10.0
roman-numerals-py==3.1.0
rpds-py==0.28.0
rpy2==3.5.17
rsa==4.9.1
ruff==0.14.4
safehttpx==0.1.7
safetensors==0.6.2
scikit-image==0.25.2
scikit-learn==1.6.1
scipy==1.16.3
scooby==0.11.0
scs==3.2.9
seaborn==0.13.2
SecretStorage==3.4.1
semantic-version==2.10.0
Send2Trash==1.8.3
sentence-transformers==5.1.2
sentencepiece==0.2.1
sentry-sdk==2.44.0
setuptools==75.2.0
shap==0.50.0
shapely==2.1.2
shellingham==1.5.4
simple-parsing==0.1.7
simplejson==3.20.2
simsimd==6.5.3
six==1.17.0
sklearn-pandas==2.2.0
slicer==0.0.8
smart_open==7.5.0
smmap==5.0.2
sniffio==1.3.1
snowballstemmer==3.0.1
sortedcontainers==2.4.0
soundfile==0.13.1
soupsieve==2.8
soxr==1.0.0
spacy==3.8.8
spacy-legacy==3.0.12
spacy-loggers==1.0.5
spanner-graph-notebook==1.1.8
Sphinx==8.2.3
sphinxcontrib-applehelp==2.0.0
sphinxcontrib-devhelp==2.0.0
sphinxcontrib-htmlhelp==2.1.0
sphinxcontrib-jsmath==1.0.1
sphinxcontrib-qthelp==2.0.0
sphinxcontrib-serializinghtml==2.0.0
SQLAlchemy==2.0.44
sqlalchemy-spanner==1.17.1
sqlglot==25.20.2
sqlparse==0.5.3
srsly==2.5.1
sse-starlette==3.0.3
stanio==0.5.1
starlette==0.49.3
statsmodels==0.14.5
stringzilla==4.2.3
stumpy==1.13.0
sympy==1.13.3
tables==3.10.2
tabulate==0.9.0
tbb==2022.3.0
tblib==3.2.1
tcmlib==1.4.1
tenacity==8.5.0
tensorboard==2.19.0
tensorboard-data-server==0.7.2
tensorflow==2.19.0
tensorflow-datasets==4.9.9
tensorflow-hub==0.16.1
tensorflow-metadata==1.17.2
tensorflow-probability==0.25.0
tensorflow-text==2.19.0
tensorflow_decision_forests==1.12.0
tensorstore==0.1.78
termcolor==3.2.0
terminado==0.18.1
text-unidecode==1.3
textblob==0.19.0
tf-slim==1.1.0
tf_keras==2.19.0
thinc==8.3.8
threadpoolctl==3.6.0
tifffile==2025.10.16
tiktoken==0.12.0
timm==1.0.22
tinycss2==1.4.0
tokenizers==0.22.1
toml==0.10.2
tomlkit==0.13.3
toolz==0.12.1
torch==2.8.0+cu126
torchao==0.10.0
torchaudio==2.8.0+cu126
torchdata==0.11.0
torchsummary==1.5.1
torchtune==0.6.1
torchvision==0.23.0+cu126
tornado==6.5.1
tqdm==4.67.1
traitlets==5.7.1
traittypes==0.2.3
transformers==4.57.1
treelite==4.4.1
treescope==0.1.10
triton==3.4.0
tsfresh==0.21.1
tweepy==4.16.0
typeguard==4.4.4
typer==0.20.0
typer-slim==0.20.0
types-pytz==2025.2.0.20251108
types-setuptools==80.9.0.20250822
typing-inspection==0.4.2
typing_extensions==4.15.0
tzdata==2025.2
tzlocal==5.3.1
uc-micro-py==1.0.3
ucxx-cu12==0.46.0
umap-learn==0.5.9.post2
umf==1.0.2
unsloth==2025.11.3
uri-template==1.3.0
uritemplate==4.2.0
urllib3==2.5.0
uvicorn==0.38.0
vega-datasets==0.9.0
wadllib==1.3.6
wandb==0.22.3
wasabi==1.1.3
watchdog==6.0.0
wcwidth==0.2.14
weasel==0.4.2
webcolors==25.10.0
webencodings==0.5.1
websocket-client==1.9.0
websockets==15.0.1
Werkzeug==3.1.3
wheel==0.45.1
widgetsnbextension==3.6.10
wordcloud==1.9.4
wrapt==2.0.1
wurlitzer==3.1.1
xarray==2025.10.1
xarray-einstats==0.9.1
xgboost==3.1.1
xlrd==2.0.2
xxhash==3.6.0
xyzservices==2025.10.0
yarl==1.22.0
ydf==0.13.0
yellowbrick==1.5
yfinance==0.2.66
zict==3.0.0
zipp==3.23.0
zstandard==0.25.0
</pre>
</details>

In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install huggingface_hub

# Do this only in Colab notebooks
!pip install --no-deps bitsandbytes peft
!pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer
!pip install --no-deps unsloth
!pip install evaluate
!pip install autoawq
!pip install torch --index-url https://download.pytorch.org/whl/cu124

In [None]:
if "COLAB_" not in "".join(os.environ.keys()):
  from huggingface_hub import login
  login() # To access gated repo

In [None]:
import torch
from datasets import load_dataset
from peft import PeftMixedModel
import bitsandbytes as bnb
from tqdm.notebook import tqdm
import math
from transformers import AutoModelForCausalLM, AutoTokenizer, AwqConfig

In [None]:
import warnings
import gc
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [None]:
# @title Basic variable creation
model = "Meta-Llama-3.1-8B-Instruct-AWQ-INT4" # @param {"type": "string", "placeholder": "Meta-Llama-3.1-8B-bnb-4bit"}
source_user = "hugging-quants" # @param {"type": "string", "placeholder": "unsloth"}

dtype = torch.float16
load_in_4bit = True

model_name = f"{source_user}/{model}"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device

In [None]:
def load_model(model_name):
  # Need double loading or crashes
  quant_type = 'awq'
  quantization_config = AwqConfig(
      bits=4,
      fuse_max_seq_len=512, # Note: Update this as per your use-case
      do_fuse=True,
  )
  try:
      model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="cuda:0",   # Put optimized layers on GPU
        torch_dtype=torch.float16,
        trust_remote_code=True,
        dtype=torch.float16,
        #quantization_config=quantization_config
    )
  except Exception as e:
    print("Trying to load again")
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="cuda:0",   # Put optimized layers on GPU
        torch_dtype=torch.float16,
        trust_remote_code=True,
        dtype=torch.float16,
        #quantization_config=quantization_config
    )
  model = model.to("cuda")

  return model, quant_type

In [None]:
def unload_model(model):
  """
  Unloads a Hugging Face model from GPU memory to free up resources.

  Args:
    model: The model object (e.g., AutoModelForCausalLM instance).
  """
  if model is not None:
    print("Unloading model from GPU...")
    try:
      # 1. Move model to CPU (optional, but good practice to clear GPU cache)
      model.to("cpu")

      # 2. Delete the model object
      del model

      # 3. Clear CUDA cache
      if torch.cuda.is_available():
        torch.cuda.empty_cache()

      print("Model unloaded and GPU memory cleared.")
      return None # Return None to ensure the calling scope clears the reference

    except Exception as e:
      print(f"Error during model unloading: {e}")
      return model
  else:
    print("No model provided to unload.")
    return None

In [None]:
# @title Apply a bit mask on weight to simulate 3 bits, 2 bits and 1 bit weight

def apply_quantization(model, simulated_quantization, target_modules, quant_type='awq'):
  print(simulated_quantization, target_modules)
  # Simulate the quantization by applying a bitmask (0-4 bits)
  masks = [0b00000000000000000000000000000000, 0b10001000100010001000100010001000,
            0b11001100110011001100110011001100, 0b11101110111011101110111011101110,
            0b11111111111111111111111111111111]
  mask = masks[simulated_quantization]

  quant_module_name = ["awq.modules.linear"]

  for name, module in model.named_modules():
      if any(name in str(type(module)) for name in quant_module_name):
        if any(str(module_to_shift) in str(name) for module_to_shift in target_modules):
          qweight = module.qweight

          if not torch.is_floating_point(qweight):
              # Apply bitmask only to the quantized values
              qweight &= mask
          else:
            print("Floating point")

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{}

### Input:
{}

### Response:
{}"""

EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
def formatting_prompts_func(examples):
    instructions = examples["instruction"]
    inputs       = examples["input"]
    outputs      = examples["output"]
    texts = []
    for instruction, input, output in zip(instructions, inputs, outputs):
        # Must add EOS_TOKEN, otherwise your generation will go on forever!
        text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
        texts.append(text)
    return { "text" : texts, }
pass

dataset = load_dataset("yahma/alpaca-cleaned", split = "train")
dataset = dataset.map(formatting_prompts_func, batched = True,)
# Evaluate data on subset of dataset
eval_dataset = dataset.select(range(50))
input_texts = eval_dataset["text"]

In [None]:
def tokenize_batch(batch_texts, tokenizer, max_length=512):
    return tokenizer(
        batch_texts,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=max_length,
    )

In [None]:
def compute_perplexity(model, tokenizer, texts, batch_size=4, max_length=512):
    model.eval()
    total_loss = 0.0
    total_tokens = 0

    if batch_size > 1:
      print("[WARNING]: Wrong PPL will be returned as token counts between\
      inputs is different and outputs.loss don't take it into account\n Use batch_size=1")
      return -1

    with torch.no_grad():
        for i in tqdm(range(0, len(texts), batch_size), desc="Evaluating"):
            batch_texts = texts[i : i + batch_size]
            batch = tokenize_batch(batch_texts, tokenizer, max_length)
            batch = {k: v.to(device) for k, v in batch.items()}

            # Shift inputs for causal LM loss
            outputs = model(**batch, labels=batch["input_ids"])
            loss = outputs.loss  # Cross-entropy over non-padded tokens

            # Count number of valid tokens
            attention_mask = batch["attention_mask"]
            n_tokens = attention_mask.sum().item()

            total_loss += loss.item() * n_tokens
            total_tokens += n_tokens

    mean_loss = total_loss / total_tokens
    ppl = math.exp(mean_loss)
    return ppl

![Current bit-shifted quantization](https://github.com/Tfloow/peft/blob/main/examples/lora_mixed/fig/image.png?raw=1)

For reference

In [None]:
# Initial Model and Quantization Setup
model,quant_type = load_model(model_name)

# --- First Adapter (2-bits, QKGateUp) ---
simulated_quantization = 2
target_modules = ["q_proj", "k_proj", "gate_proj", "up_proj"]
apply_quantization(model, simulated_quantization, target_modules,quant_type=quant_type)

# --- Second Adapter (3-bits, V) ---
simulated_quantization_2 = 3
target_modules_2 = ['v_proj']
# Apply the specific quantization for the V-proj module
apply_quantization(model, simulated_quantization_2, target_modules_2, quant_type=quant_type)

# Pre-trained LoRA patches with mixed quantization
# --- Loading first adapter ---
model_name_lora = "Tfloow/Meta-Llama-3.1-8B-Instruct-AWQ-INT4_simulated_2-bits_lora_test_QKGateUp"
model_with_adapters = PeftMixedModel.from_pretrained(
    model,
    model_name_lora,
    adapter_name="adapter_qk_up"
)
tokenizer = AutoTokenizer.from_pretrained(model_name_lora)
tokenizer.pad_token = tokenizer.eos_token

# --- Loading second adapter ---
model_name_lora_2 = "Tfloow/Meta-Llama-3.1-8B-Instruct-AWQ-INT4_simulated_3-bits_lora_test_V"
model_with_adapters.load_adapter(
    model_name_lora_2,
    adapter_name="adapter_v_proj"
)

# Make sure they are all loaded
model_with_adapters.set_adapter(["adapter_qk_up", "adapter_v_proj"])
# This represents a ~ 30 % reduction compared to quantize AWQ INT4 only
# This also yields better perplexity than baseline model

tokenizer = AutoTokenizer.from_pretrained(model_name_lora)
tokenizer.pad_token = tokenizer.eos_token

ppl = compute_perplexity(model_with_adapters, tokenizer, input_texts, batch_size=1, max_length=512)
print(f"Perplexity on eval dataset: {ppl}")

![alt text](https://github.com/Tfloow/peft/blob/main/examples/lora_mixed/fig/image-1.png?raw=1)

You can see that a simple Hardware bit-shifting can simply reduce data traffic while gaining in accuracy thanks to the LoRA adapters

In [None]:
prompt = [
  {"role": "system", "content": "You are a helpful assistant, that responds as a pirate."},
  {"role": "user", "content": "What's Deep Learning?"},
]
inputs = tokenizer.apply_chat_template(
  prompt,
  tokenize=True,
  add_generation_prompt=True,
  return_tensors="pt",
  return_dict=True,
).to("cuda")

outputs = model.generate(**inputs, do_sample=True, max_new_tokens=256)
print(tokenizer.batch_decode(outputs[:, inputs['input_ids'].shape[1]:], skip_special_tokens=True)[0])

In [None]:
# @title Unload the model to free up GPU memory
unload_model(model)
# Should remove from memory model
del model
gc.collect()
torch.cuda.empty_cache()
torch.cuda.ipc_collect()



----

This is an experimental notebook made by [@Tfloow](https://github.com/Tfloow) on Github. If any issues or inquiries please reach me out.