## Inference

In [17]:
import torch
from transformers import AutoModel, AutoProcessor
from transformers.image_utils import load_image

### Image Embeddings

In [18]:
IMG_URL = "https://datasets-server.huggingface.co/assets/derek-thomas/ScienceQA/--/default/train/7/image/image.jpg?Expires=1760691296&Signature=jBrBX5II8NFMSI9RkblePB2GBC-2LDVGvOh7wlChebSDCkz9Zy9SXw-JB3dpjaYc4lSrZJ73VkkqxxRZ52Xjrm4cOs4lGKhg0fu7nFbCl~18Tys56S0yJFPFYW0tBJ0fSZi4VzaGxXgjn5J-CwtPU74amaBGfLw6cT3J~ka-oBrH-DLFhvOkGWbPvkIYofAxmR9NBMDCpnzZUVw1oFAIVmf9OAvetf6EuDYNDh4iz2BJMrsIq2u3r1eln5WZz0cdm8ZC7UABDd5V6PYB1CKMgzfJ7dLiq-LhubrKQl~LFceJP7PgNAMcgw8B7crJqAg3mBvM8bFxIsbVfNSgchaBiQ__&Key-Pair-Id=K3EI6M078Z3AC3"

image = load_image(IMG_URL)

# Load model and processor
ckpt = "google/siglip2-base-patch32-256"
model = AutoModel.from_pretrained(ckpt, device_map="auto").eval()
processor = AutoProcessor.from_pretrained(ckpt)

# Preprocess image
inputs = processor(images=image, return_tensors="pt").to(model.device)

# Run inference
with torch.no_grad():
    image_embeddings = model.get_image_features(**inputs)

print(image_embeddings.shape)

torch.Size([1, 768])


### Vision Projector

In [19]:
import torch
import torch.nn as nn

In [20]:
class VisionProjector(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, hidden: int = 1024, n_prefix_tokens: int = 1):
        super().__init__()
        self.n_prefix = n_prefix_tokens
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.Linear(hidden, out_dim * n_prefix_tokens),
        )

    def forward(self, x):
        out = self.net(x)
        out = out.view(x.size(0), self.n_prefix, -1)
        return out

In [35]:
!pip install -Uq bitsandbytes

In [40]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import prepare_model_for_kbit_training, get_peft_model, LoraConfig, TaskType

In [36]:

def build_tokenizer_and_model(model_id: str, LOAD_4BIT,B4_COMPUTE_DTYPE):
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
    if tokenizer.pad_token_id is None:
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    # BitsAndBytesConfig
    if LOAD_4BIT:
        bnb = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=getattr(__import__("torch"), B4_COMPUTE_DTYPE),
        )
    else:
        bnb = None

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=bnb,
        device_map="auto",
        trust_remote_code=True,
    )

    model = prepare_model_for_kbit_training(model)
    return tokenizer, model

In [41]:
LOAD_4BIT = False
B4_COMPUTE_DTYPE = "float16"
tokenizer, model = build_tokenizer_and_model("google/gemma-3-1b-it",LOAD_4BIT,B4_COMPUTE_DTYPE)

## Lora setup

In [47]:
LORA_R = 8
LORA_ALPHA = 32
LORA_TARGET_MODULES = [
    "q_proj", "v_proj", "k_proj", "o_proj", "up_proj", "down_proj"
]  

def apply_peft(model, LORA_R, LORA_ALPHA, LORA_TARGET_MODULES):
    lora_conf = LoraConfig(
        r=LORA_R,
        lora_alpha=LORA_ALPHA,
        target_modules=LORA_TARGET_MODULES,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
    )
    model = get_peft_model(model, lora_conf)
    return model


In [50]:
model = apply_peft(model, LORA_R, LORA_ALPHA, LORA_TARGET_MODULES)

In [53]:
# model

In [52]:
model.print_trainable_parameters()

trainable params: 4,845,568 || all params: 1,004,731,520 || trainable%: 0.4823


In [54]:
V_DIM = image_embeddings.shape[1]
T_DIM = model.config.hidden_size

In [55]:
PREFIX_TOKENS = 1 

In [56]:
vision_proj = VisionProjector(V_DIM, T_DIM, hidden=min(2048, max(512, V_DIM * 2)), n_prefix_tokens=PREFIX_TOKENS)

In [57]:
vision_proj

VisionProjector(
  (net): Sequential(
    (0): Linear(in_features=768, out_features=1536, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=1536, out_features=1152, bias=True)
  )
)

In [62]:
vision_emb = torch.tensor(image_embeddings).unsqueeze(0)

  vision_emb = torch.tensor(image_embeddings).unsqueeze(0)


In [73]:
print("Loading the checkpoints")
vision_proj.load_state_dict(torch.load("/home/nabin/Desktop/3Drecons/results/phyVQA-train/output_gemma_vision_lora/epoch_2/projector.pt", map_location="cpu"))
model.load_adapter("/home/nabin/Desktop/3Drecons/results/phyVQA-train/output_gemma_vision_lora/epoch_2", "peft")
print(f"Loaded checkpoint")

Loading the checkpoints
Loaded checkpoint


In [75]:
vision_proj

VisionProjector(
  (net): Sequential(
    (0): Linear(in_features=768, out_features=1536, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=1536, out_features=1152, bias=True)
  )
)

## Finally inference

In [80]:
prompt = """प्रत्येक नमुनामा रहेका कणहरूको औसत गतिज ऊर्जा तुलना गर्नुहोस्। कुन नमुनाको तापक्रम बढी छ? [
"दुवै होइन; दुवै नमुनाको तापक्रम समान छ",
"नमूना A",
"नमूना B"
]"""

In [81]:
prompt

'प्रत्येक नमुनामा रहेका कणहरूको औसत गतिज ऊर्जा तुलना गर्नुहोस्। कुन नमुनाको तापक्रम बढी छ? [\n"दुवै होइन; दुवै नमुनाको तापक्रम समान छ",\n"नमूना A",\n"नमूना B"\n]'

In [82]:
inputs = tokenizer(prompt, return_tensors="pt")

In [88]:
inputs_embeds = model.get_input_embeddings()(inputs.input_ids)
proj = vision_proj(vision_emb)

In [89]:
inputs_embeds = torch.cat([proj, inputs_embeds], dim=1)

In [90]:
inputs_embeds

tensor([[[-0.0811,  0.1031,  0.1283,  ...,  0.1366,  0.0183, -0.0527],
         [ 0.1388, -0.2165, -0.5386,  ..., -0.2465,  0.0267,  0.1258],
         [-0.1833, -0.6753,  1.4335,  ..., -0.5510, -0.8576, -0.6173],
         ...,
         [ 0.6091,  0.2434,  1.3673,  ...,  1.8644,  1.2844,  0.5013],
         [ 0.9778, -0.5925, -0.5759,  ..., -0.1844, -0.5013,  0.3687],
         [ 0.1020,  1.5910,  0.4495,  ...,  0.3853,  0.0593,  2.1130]]],
       grad_fn=<CatBackward0>)

In [92]:
prefix_mask = torch.ones((1, proj.size(1)), dtype=inputs.attention_mask.dtype)
attn_mask = torch.cat([prefix_mask, inputs.attention_mask], dim=1)

In [98]:
print(f"Prompt: {prompt}")
outputs = model.generate(
        inputs_embeds=inputs_embeds,
        attention_mask=attn_mask,
        max_new_tokens=64,
        do_sample=False,
    )

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Generated text:", generated_text)

The following generation flags are not valid and may be ignored: ['top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Prompt: प्रत्येक नमुनामा रहेका कणहरूको औसत गतिज ऊर्जा तुलना गर्नुहोस्। कुन नमुनाको तापक्रम बढी छ? [
"दुवै होइन; दुवै नमुनाको तापक्रम समान छ",
"नमूना A",
"नमूना B"
]
Generated text: 
सही उत्तर हुन्:
A. नमूना A
B. नमूना B
C. नमूना C
D. न कुनै पनि
सही उत्तर हुन्:
A. नमूना A
B. नमूना B
C. नमूना C
D. न कुनै पनि

उत्तर: A
