<a href="https://colab.research.google.com/github/kennethmugo/Swahili-SMS-Spam-Detection/blob/main/research/gemma3_swahili_spam_detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install transformers torch accelerate bitsandbytes kagglehub

Collecting bitsandbytes
  Downloading bitsandbytes-0.45.5-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.m

In [2]:
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
import kagglehub
from google.colab import userdata
from huggingface_hub import login

Now, we want to see how Gemma3 does in terms of zero-shot classification of this dataset. Let us create the pipeline to import the model from HuggingFace. Ensure you have `HF_TOKEN` in your secrets to enable you to log into hugging face.

In [3]:
my_secret_key = userdata.get('HF_TOKEN')
login(token=my_secret_key, add_to_git_credential=True)

In [4]:
## Download the dataset from kaggle
path = kagglehub.dataset_download("henrydioniz/swahili-sms-detection-dataset")
full_path = os.path.join(path, "bongo_scam.csv")
df = pd.read_csv(full_path)
df.head()

Downloading from https://www.kaggle.com/api/v1/datasets/download/henrydioniz/swahili-sms-detection-dataset?dataset_version_number=1...


100%|██████████| 26.1k/26.1k [00:00<00:00, 20.9MB/s]

Extracting files...





Unnamed: 0,Category,Sms
0,trust,"Nipigie baada ya saa moja, tafadhali."
1,scam,Naomba unitumie iyo Hela kwenye namba hii ya A...
2,scam,"666,KARIBU FREEMASON UTIMIZE NDOTO KATIKA BIAS..."
3,trust,Watoto wanapenda sana zawadi ulizowaletea.
4,scam,IYO PESA ITUME KWENYE NAMBA HII 0657538690 JIN...


In [5]:
## Let us rename: trust -> spam and scam -> spam.
mapper = {"trust": "ham", "scam": "spam"}
df["Category"] = df["Category"].map(mapper)
df.head()

Unnamed: 0,Category,Sms
0,ham,"Nipigie baada ya saa moja, tafadhali."
1,spam,Naomba unitumie iyo Hela kwenye namba hii ya A...
2,spam,"666,KARIBU FREEMASON UTIMIZE NDOTO KATIKA BIAS..."
3,ham,Watoto wanapenda sana zawadi ulizowaletea.
4,spam,IYO PESA ITUME KWENYE NAMBA HII 0657538690 JIN...


Since the dataset has many rows and I have limited compute, I will take 50 rows from each category. The purpose of this exercise is to see how Gemma3 could compare to supervised classification methods.

In [6]:
# Set the number of samples per class
n_per_class = 50

df = (
    df.groupby("Category", group_keys=False)
      .apply(lambda x: x.sample(n=n_per_class, random_state=42))
      .reset_index(drop=True)
)
print(f"Length of the dataframe now: {len(df)}")

Length of the dataframe now: 100


  .apply(lambda x: x.sample(n=n_per_class, random_state=42))


In [7]:
# Load the model

from transformers import pipeline
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
pipe = pipeline("text-generation", model="google/gemma-3-4b-it", device=device, torch_dtype=torch.bfloat16)

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/215 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

Device set to use cuda


In [27]:
# Test with one message
messages = [
    [
        {
            "role": "system",
            "content": [{"type": "text", "text": "You are a helpful assistant that detects spam messages in both Swahili and English. Classify the provided text as SPAM or HAM (not spam) and provide a brief explanation. Respond with a JSON object containing 'classification' (either 'ham' or 'spam') and 'explanation'."},]
        },
        {
            "role": "user",
            "content": [{"type": "text", "text": "HELLO. Ungana na wakenya wengi wanoSHINDA katika PICK A BOX.2024 END YEAR Bonus NI from 50,000. BONYEZA *201# BILA Credo upick BOX YAKO.STOP *456*9*5#"}]
        },
    ],
]

output = pipe(messages, max_new_tokens=200)
res = output[0][0]["generated_text"][-1]["content"]
print(res)

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


```json
{
  "classification": "spam",
  "explanation": "This message exhibits several characteristics of spam: it uses urgent language ('Bonus'), includes a misleading offer (financial reward), directs the recipient to a specific phone number to initiate an action (*201#), and provides a short code to stop the message (*456*9*5#).  The Shinda mention suggests a potential scam targeting Kenyan users."
}
```



Let us now classify the messages as spam or ham using the LLM. Also extract explanations to see if they could be usable for explainabiluty purposes.

In [47]:
import json
import re
from sklearn.metrics import recall_score, accuracy_score, precision_score, f1_score

def classify_sms_batch(pipe, messages_batch):
    """Classifies a batch of SMS messages using the given pipeline."""
    prompts = []
    for message in messages_batch:
        prompt = [
            {
                "role": "system",
                "content": [{"type": "text", "text": "You are a helpful assistant that detects spam messages in Swahili. Some of the messages may contain English words and some may be mispelled. Classify the provided text as SPAM or HAM (not spam) and provide a brief explanation. Respond with a JSON object containing 'classification' (either 'ham' or 'spam') and 'explanation'."},]
            },
            {
                "role": "user",
                "content": [{"type": "text", "text": message}]
            },
        ]
        prompts.append(prompt)

    outputs = pipe(prompts, max_new_tokens=200)

    classifications = []
    explanations = []
    for output in outputs:
        try:
            # Extract the last generated text content
            generated_text = output[0]["generated_text"][-1]["content"]
            matched = re.search(r'```json\s*(\{.*?\})\s*```', generated_text, re.DOTALL)
            json_str = matched.group(1)
            # Attempt to parse the JSON response
            json_response = json.loads(json_str)
            classifications.append(json_response.get('classification', 'unknown').lower())
            explanations.append(json_response.get('explanation', 'No explanation provided'))
        except (json.JSONDecodeError, IndexError, KeyError) as e:
            print(f"Error processing model output: {e}. Output: {output}")
            classifications.append('unknown') # Handle cases where parsing fails

    return (classifications, explanations)

# Prepare data for batching
sms_messages = df['Sms'].tolist()
ground_truth_labels = df['Category'].tolist()

batch_size = 16 # You can adjust the batch size based on your GPU memory

predicted_labels = []
model_explanations = []

for i in tqdm(range(0, len(sms_messages), batch_size), desc="Classifying SMS batches"):
    batch_messages = sms_messages[i:i + batch_size]
    batch_predictions, batch_explanations = classify_sms_batch(pipe, batch_messages)
    predicted_labels.extend(batch_predictions)
    model_explanations.extend(batch_explanations)

# Ensure predicted_labels and ground_truth_labels have the same length
min_len = min(len(predicted_labels), len(ground_truth_labels))
predicted_labels = predicted_labels[:min_len]
ground_truth_labels = ground_truth_labels[:min_len]
model_explanations = model_explanations[:min_len]

# Filter out 'unknown' predictions if necessary for metrics
valid_indices = [i for i, label in enumerate(predicted_labels) if label in ['ham', 'spam']]
filtered_predicted_labels = [predicted_labels[i] for i in valid_indices]
filtered_ground_truth_labels = [ground_truth_labels[i] for i in valid_indices]
filtered_explanations = [model_explanations[i] for i in valid_indices]

if len(filtered_predicted_labels) > 0:
    # Calculate metrics
    recall = recall_score(filtered_ground_truth_labels, filtered_predicted_labels, pos_label='spam')
    precision = precision_score(filtered_ground_truth_labels, filtered_predicted_labels, pos_label='spam')
    accuracy = accuracy_score(filtered_ground_truth_labels, filtered_predicted_labels)
    f1 = f1_score(filtered_ground_truth_labels, filtered_predicted_labels, pos_label='spam')

    print(f"\n--- Classification Metrics ---")
    print(f"Recall (Spam): {recall:.4f}")
    print(f"Precision (Spam): {precision:.4f}")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"F1 Score (Spam): {f1:.4f}")
else:
    print("\nNo valid predictions were obtained to calculate metrics.")

Classifying SMS batches: 100%|██████████| 7/7 [13:01<00:00, 111.62s/it]


--- Classification Metrics ---
Recall (Spam): 1.0000
Precision (Spam): 0.7042
Accuracy: 0.7900
F1 Score (Spam): 0.8264





In [49]:
# Add the predicted labels to the dataframe for inspection
df['Predicted_Category'] = predicted_labels
df['Explanation'] = model_explanations
print("\nDataFrame with predictions:")
print(df[['Sms', 'Category', 'Explanation', 'Predicted_Category']].head())
print(df['Predicted_Category'].value_counts())
print(df['Category'].value_counts())


DataFrame with predictions:
                                                 Sms Category  \
0  Bro, kuna movie mpya imeachiwa leo. Je, tutaza...      ham   
1                      Tafadhali nipe maelezo zaidi.      ham   
2          Nitaandika ripoti mara tu nitakapomaliza.      ham   
3  Niambie ukweli, unafikiri Ronaldo bado ana kiw...      ham   
4                        Nisaidie na namba ya fundi.      ham   

                                         Explanation Predicted_Category  
0  The message uses informal Swahili ('Bro') and ...               spam  
1  The message 'Tafadhali nipe maelezo zaidi' tra...               spam  
2  The text translates to 'I will write a report ...                ham  
3  The message asks a leading question about Rona...               spam  
4  The message 'Nisaidie na namba ya fundi' trans...               spam  
Predicted_Category
spam    71
ham     29
Name: count, dtype: int64
Category
ham     50
spam    50
Name: count, dtype: int64


Seems like this model doesn't miss spam messages. This is probably very important if we're just using it to explain why a message was marked as spam. Let us sample a random sample of messages marked as spam to see if the explanations make any sense.

In [89]:
spam_messages = df[(df['Category'] == 'spam') & (df['Predicted_Category'] == 'spam')]
spam_messages.head(5)

Unnamed: 0,Category,Sms,Predicted_Category,Explanation
50,spam,"666,KARIBU FREEMASON UTIMIZE NDOTO KATIKA BIAS...",spam,"The message uses phrases like 'FREEMASON,' 'ut..."
51,spam,Au nitumie kwenye M-Pesa Namba.0696530433 jina...,spam,This message uses 'M-Pesa' (a popular mobile m...
52,spam,Mjukuu wangu ndagu niliyokukabizi hiyo uwe mak...,spam,The message contains phrases like'mjukuu wangu...
53,spam,Nitumie tu kwenye hii Tigo 0733822240 jina SAL...,spam,The message explicitly requests contact via a ...
54,spam,Naomba unitumie iyo pesa kwenye namba hii ya A...,spam,The message requests money and provides a mobi...


In [87]:
for row in spam_messages.sample(5).iterrows():
    print("----------------------")
    print(f"SMS: {row[1]['Sms']}")
    print(f"Explanation: {row[1]['Explanation']}")

----------------------
SMS: Iyo pesa itume humu kwenye halotel 0755896103 jina lije PEREGIA FILIPO.
Explanation: The message contains multiple English words ('peregia', 'filipo') which is a common tactic used in spam messages. The reference to a phone number and request for information strongly suggests a scam attempt to lure the recipient into a fraudulent scheme.
----------------------
SMS: MZEE JUMANNE YASINI MASAKA tiba asili biashala kazi masomo utajili kesi kuludisha mke&mume piga (0698018072)(0698018072)
Explanation: The message uses common phrases related to business, education, and family matters ('tiba asili', 'biashala', 'kazi','masomo', 'kesi kuludisha mke&mume'). It also includes phone numbers repeated multiple times, a strong indicator of unsolicited marketing or scam activity. The unusual phrasing 'kesi kuludisha' is also suspicious.
----------------------
SMS: mjukuu wangu utafuta ji wako mgumu Pesa hazikai mkononi pakazinaisha unasota sana mpenzi hamuelewani je utatunz

Seems this model provided reasonable explanations for spam messages. This shows if model explanations are needed, you could use the BERT Embeddings + Logistic Regression classifier to classify 'ham' or 'spam', then use Gemma3 to explain the reason why it is a spam message. I wouldn't trust it's explanations for a 'ham' message.