diff --git a/gemma/model.py b/gemma/model.py index 689143f..cdbadc8 100644 --- a/gemma/model.py +++ b/gemma/model.py @@ -113,54 +113,45 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: return x_out -class Linear(nn.Module): +class QuantizedWeight(nn.Module): - def __init__(self, in_features: int, out_features: int, quant: bool): + def __init__(self, num_embeddings: int, embedding_dim: int, quant: bool): super().__init__() if quant: self.weight = nn.Parameter( - torch.empty((out_features, in_features), dtype=torch.int8), + torch.empty((num_embeddings, embedding_dim), dtype=torch.int8), requires_grad=False, ) - self.weight_scaler = nn.Parameter(torch.Tensor(out_features)) + self.weight_scaler = nn.Parameter(torch.Tensor(num_embeddings)) else: self.weight = nn.Parameter( - torch.empty((out_features, in_features)), + torch.empty((num_embeddings, embedding_dim)), requires_grad=False, ) self.quant = quant - def forward(self, x): - weight = self.weight + def get_weight(self): if self.quant: - weight = weight * self.weight_scaler.unsqueeze(-1) - output = F.linear(x, weight) - return output + return self.weight * self.weight_scaler.unsqueeze(-1) + return self.weight -class Embedding(nn.Module): +class Linear(QuantizedWeight): def __init__(self, num_embeddings: int, embedding_dim: int, quant: bool): - super().__init__() - if quant: - self.weight = nn.Parameter( - torch.empty((num_embeddings, embedding_dim), dtype=torch.int8), - requires_grad=False, - ) - self.weight_scaler = nn.Parameter(torch.Tensor(num_embeddings)) - else: - self.weight = nn.Parameter( - torch.empty((num_embeddings, embedding_dim)), - requires_grad=False, - ) - self.quant = quant + super().__init__(num_embeddings, embedding_dim, quant) def forward(self, x): - weight = self.weight - if self.quant: - weight = weight * self.weight_scaler.unsqueeze(-1) - output = F.embedding(x, weight) - return output + return F.linear(x, self.get_weight()) + + +class Embedding(QuantizedWeight): + + def __init__(self, num_embeddings: int, embedding_dim: int, quant: bool): + super().__init__(num_embeddings, embedding_dim, quant) + + def forward(self, x): + return F.embedding(x, self.get_weight()) class RMSNorm(torch.nn.Module): @@ -368,7 +359,6 @@ def forward( kv_write_indices: torch.Tensor, kv_cache: Tuple[torch.Tensor, torch.Tensor], mask: torch.Tensor, - local_mask: torch.Tensor, ) -> torch.Tensor: # Self Attention residual = hidden_states @@ -560,7 +550,6 @@ def forward( self, input_token_ids: torch.Tensor, input_positions: torch.Tensor, - kv_write_indices: torch.Tensor, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], mask: torch.Tensor, output_positions: torch.Tensor, @@ -605,10 +594,7 @@ def forward( mask=mask, local_mask=local_mask, ) - embedder_weight = self.embedder.weight - if self.config.quant: - embedder_weight = ( - embedder_weight * self.embedder.weight_scaler.unsqueeze(-1)) + embedder_weight = self.embedder.get_weight() next_tokens, logits = self.sampler( embedding=embedder_weight, hidden_states=hidden_states, @@ -691,7 +677,6 @@ def generate( next_token_ids, _ = self( input_token_ids=input_token_ids_tensor, input_positions=input_positions_tensor, - kv_write_indices=None, kv_caches=kv_caches, mask=curr_mask_tensor, output_positions=output_positions_tensor, diff --git a/scripts/run_multimodal.py b/scripts/run_multimodal.py index 231e340..d040dac 100644 --- a/scripts/run_multimodal.py +++ b/scripts/run_multimodal.py @@ -99,7 +99,7 @@ def main(_): for key in image_paths: try: image[key] = Image.open(image_paths[key]) # Open local file - image[key].show() + # image[key].show() except IOError as e: print(f"Error loading image: {e}") exit() @@ -113,8 +113,7 @@ def main(_): device = torch.device(_DEVICE.value) with _set_default_tensor_type(model_config.get_dtype()): model = gemma3_model.Gemma3ForMultimodalLM(model_config) - model.load_state_dict(torch.load(_CKPT.value)['model_state_dict']) - # model.load_weights(_CKPT.value) + model.load_weights(_CKPT.value) model = model.to(device).eval() print('Model loading done')