Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 21 additions & 36 deletions gemma/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,54 +113,45 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
return x_out


class Linear(nn.Module):
class QuantizedWeight(nn.Module):

def __init__(self, in_features: int, out_features: int, quant: bool):
def __init__(self, num_embeddings: int, embedding_dim: int, quant: bool):
super().__init__()
if quant:
self.weight = nn.Parameter(
torch.empty((out_features, in_features), dtype=torch.int8),
torch.empty((num_embeddings, embedding_dim), dtype=torch.int8),
requires_grad=False,
)
self.weight_scaler = nn.Parameter(torch.Tensor(out_features))
self.weight_scaler = nn.Parameter(torch.Tensor(num_embeddings))
else:
self.weight = nn.Parameter(
torch.empty((out_features, in_features)),
torch.empty((num_embeddings, embedding_dim)),
requires_grad=False,
)
self.quant = quant

def forward(self, x):
weight = self.weight
def get_weight(self):
if self.quant:
weight = weight * self.weight_scaler.unsqueeze(-1)
output = F.linear(x, weight)
return output
return self.weight * self.weight_scaler.unsqueeze(-1)
return self.weight


class Embedding(nn.Module):
class Linear(QuantizedWeight):

def __init__(self, num_embeddings: int, embedding_dim: int, quant: bool):
super().__init__()
if quant:
self.weight = nn.Parameter(
torch.empty((num_embeddings, embedding_dim), dtype=torch.int8),
requires_grad=False,
)
self.weight_scaler = nn.Parameter(torch.Tensor(num_embeddings))
else:
self.weight = nn.Parameter(
torch.empty((num_embeddings, embedding_dim)),
requires_grad=False,
)
self.quant = quant
super().__init__(num_embeddings, embedding_dim, quant)

def forward(self, x):
weight = self.weight
if self.quant:
weight = weight * self.weight_scaler.unsqueeze(-1)
output = F.embedding(x, weight)
return output
return F.linear(x, self.get_weight())


class Embedding(QuantizedWeight):

def __init__(self, num_embeddings: int, embedding_dim: int, quant: bool):
super().__init__(num_embeddings, embedding_dim, quant)

def forward(self, x):
return F.embedding(x, self.get_weight())


class RMSNorm(torch.nn.Module):
Expand Down Expand Up @@ -368,7 +359,6 @@ def forward(
kv_write_indices: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
mask: torch.Tensor,
local_mask: torch.Tensor,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
Expand Down Expand Up @@ -560,7 +550,6 @@ def forward(
self,
input_token_ids: torch.Tensor,
input_positions: torch.Tensor,
kv_write_indices: torch.Tensor,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
mask: torch.Tensor,
output_positions: torch.Tensor,
Expand Down Expand Up @@ -605,10 +594,7 @@ def forward(
mask=mask,
local_mask=local_mask,
)
embedder_weight = self.embedder.weight
if self.config.quant:
embedder_weight = (
embedder_weight * self.embedder.weight_scaler.unsqueeze(-1))
embedder_weight = self.embedder.get_weight()
next_tokens, logits = self.sampler(
embedding=embedder_weight,
hidden_states=hidden_states,
Expand Down Expand Up @@ -691,7 +677,6 @@ def generate(
next_token_ids, _ = self(
input_token_ids=input_token_ids_tensor,
input_positions=input_positions_tensor,
kv_write_indices=None,
kv_caches=kv_caches,
mask=curr_mask_tensor,
output_positions=output_positions_tensor,
Expand Down
5 changes: 2 additions & 3 deletions scripts/run_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def main(_):
for key in image_paths:
try:
image[key] = Image.open(image_paths[key]) # Open local file
image[key].show()
# image[key].show()
except IOError as e:
print(f"Error loading image: {e}")
exit()
Expand All @@ -113,8 +113,7 @@ def main(_):
device = torch.device(_DEVICE.value)
with _set_default_tensor_type(model_config.get_dtype()):
model = gemma3_model.Gemma3ForMultimodalLM(model_config)
model.load_state_dict(torch.load(_CKPT.value)['model_state_dict'])
# model.load_weights(_CKPT.value)
model.load_weights(_CKPT.value)
model = model.to(device).eval()
print('Model loading done')

Expand Down