<a href="https://colab.research.google.com/github/matbee-eth/LLM-Finetuning-Workspace/blob/main/add_vision_encoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch transformers accelerate vllm xformers pillow huggingface_hub

In [None]:
import torch
from transformers import AutoModelForCausalLM, Mistral3ForConditionalGeneration, AutoTokenizer, AutoProcessor

In [None]:
mistral_small_path = "unsloth/Mistral-Small-3.1-24B-Instruct-2503"
magistral_path = "mistralai/Magistral-Small-2506"

magistral = AutoModelForCausalLM.from_pretrained(
     magistral_path,
     torch_dtype=torch.bfloat16
)

mistral_small = Mistral3ForConditionalGeneration.from_pretrained(
     mistral_small_path,
     torch_dtype=torch.bfloat16
)

magistral_vision = Mistral3ForConditionalGeneration.from_pretrained(
     magistral_path,
     torch_dtype=torch.bfloat16,
     ignore_mismatched_sizes=True
)

In [None]:
state_dict_magistral = magistral.state_dict()
state_dict_small = mistral_small.state_dict()

print("------- Magistral state dict -------")
for name, params in state_dict_magistral.items():
    print(f"{name} | Shape: {params.shape}")
print("------- Magistral state dict -------")

print("------- Small state dict -------")
for name, params in state_dict_small.items():
    print(f"{name} | Shape: {params.shape}")
print("------- Small state dict -------")

In [None]:
new_state_dict = {}

for k, v in state_dict_magistral.items():
    new_key = k
    if "lm_head" not in k:
        new_key = k.replace("model.", "")
        new_key = ("model.language_model." + new_key).strip()
    new_state_dict[new_key] = v
    print(f"Added language layer: {new_key}")

for k, v in state_dict_small.items():
     if "vision_tower" in k or "multi_modal_projector" in k:
          new_state_dict[k] = state_dict_small[k]
          print(f"Added vision layer: {k}")

In [None]:
load_result = magistral_vision.load_state_dict(new_state_dict, strict=False)

print("\n------- Load Result -------")
print(f"Missing keys: {load_result.missing_keys}")
print(f"Unexpected keys: {load_result.unexpected_keys}")

In [None]:
output_path = "/model_weights/magistral_vision"

magistral_vision.save_pretrained(output_path)

processor = AutoProcessor.from_pretrained(mistral_small_path)
processor.save_pretrained(output_path)

tokenizer = AutoTokenizer.from_pretrained(mistral_small_path)
tokenizer.save_pretrained(output_path)

In [None]:
from vllm import LLM
from vllm.sampling_params import SamplingParams

llm = LLM(
     model=output_path,
     max_model_len=8192,
)

In [None]:
sampling_params = SamplingParams(
    max_tokens=4096,
    temperature=0.7,
    top_p=0.95,
)

prompt = "Describe this image in one sentence."
image_url = "https://picsum.photos/id/237/200/300"

messages = [
    {
        "role": "system",
        "content": """A user will ask you to solve a task. You should first draft your thinking process (inner monologue) until you have derived the final answer. Afterwards, write a self-contained summary of your thoughts (i.e. your summary should be succinct but contain all the critical steps you needed to reach the conclusion). You should use Markdown to format your response. Write both your thoughts and summary in the same language as the task posed by the user. NEVER use \\boxed{} in your response.

        Your thinking process must follow the template below:
        <think>
        Your thoughts or/and draft, like working through an exercise on scratch paper. Be as casual and as long as you want until you are confident to generate a correct answer.
        </think>

        Here, provide a concise summary that reflects your reasoning and presents a clear final answer to the user. Don't mention that this is a summary.

        Problem:

        """
    },
    {
        "role": "user",
        "content": [
            {"type": "text", "text": prompt},
            {"type": "image_url", "image_url": {"url": image_url}},
        ],
    },
]

outputs = llm.chat(messages, sampling_params=sampling_params)
print("-------")
print(outputs[0].outputs[0].text)
print("-------")

In [None]:
from huggingface_hub import HfApi

repo_id = "OptimusePrime/Magistral-Small-2506-Vision"

api = HfApi()

api.upload_large_folder(
    repo_id=repo_id,
    repo_type="model",
    folder_path=output_path,
)