# Extracting Counts from news summaries using Gemma 2b model

This notebook contains the code for extracting number of Deaths, Injuries, Surrenders, Arrests, and Hostages from news summaries using Gemma 2b model.

Before running the notebook, follow these steps:






*   Create an account on HuggingFace.
*   Open the link https://huggingface.co/google/gemma-2-2b-it
*   Request the access for the gemma model.
*   After getting the access, create an access token in huggingface at https://huggingface.co/settings/tokens

NOTE: Import the notebook on Colab and turn on the T4 GPU.

TODO:

*   This is just inference with the gemma model with a prompt engineering touch. But there is a lot of room for improvement. We can also fine-tune it on our data to get better results.
*   Since this is computationally expensive than other approaches, it would be better to come up with a single LLM model that can be used for all of the tasks we are working on.



In [1]:
pip -q install -U bitsandbytes

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m122.4/122.4 MB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25h

# Import Libraries

In [2]:
import os
import gc
import time
import torch
import numpy as np
import transformers
import pandas as pd
from tqdm import tqdm
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# Get Data

In [3]:
# Change the path to the cleaned file.

root_dir = "https://raw.githubusercontent.com/eteitelbaum/code-satp/refs/heads/main/data"
data = pd.read_csv(root_dir + "/satp_clean.csv")
data.head(3)

Unnamed: 0,incident_number,state,district,block,village_name,other_areas,constituency,longitude,latitude,year,...,commander_arrests,cadre_arrests,sympathizer_arrests,unknown_arrests,total_surrenders,commander_surrenders,cadre_surrenders,sympathizer_surrenders,unknown_surrenders,incident_summary
0,101010701.0,Andhra Pradesh,Hyderabad,Gachibowli (Rangareddy),,Cyberabad,Serilingampally,17.4325,78.371806,2007,...,0,0.0,1,0.0,0.0,0,0.0,0,0,An alleged arms supplier to the Communist Part...
1,101010901.0,Andhra Pradesh,Nizamabad,,Kamareddy,,Kamareddy,18.320889,78.337139,2009,...,0,0.0,0,0.0,1.0,0,1.0,0,0,A Kamareddy dalam (squad) member belonging to ...
2,101030601.0,Andhra Pradesh,Khammam,,Bhadrachalam,,Bhadrachalam,17.668056,80.896861,2006,...,1,0.0,0,0.0,0.0,0,0.0,0,0,Senior CPI-Maoist 'Polit Bureau' and 'central ...


In [4]:
counts_columns = ['total_fatalities', 'total_injuries', 'total_surrenders', 'total_arrests', "total_abducted", "incident_summary"]
df = data[counts_columns].copy()
df.head()

Unnamed: 0,total_fatalities,total_injuries,total_surrenders,total_arrests,total_abducted,incident_summary
0,0.0,0.0,0.0,1.0,0,An alleged arms supplier to the Communist Part...
1,0.0,0.0,1.0,0.0,0,A Kamareddy dalam (squad) member belonging to ...
2,0.0,0.0,0.0,1.0,0,Senior CPI-Maoist 'Polit Bureau' and 'central ...
3,1.0,0.0,0.0,0.0,0,A TDP leader and former Sarpanch of Jerrela Gr...
4,0.0,0.0,0.0,0.0,0,The CPI-Maoist cadres blasted coffee pulping u...


# Load Model and Tokenizer

Using quantization to fit more samples into the batch for faster inference.

In [5]:
model_name = "google/gemma-2-2b-it"
access_token = "Add your acces token"

quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16
            )

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it", token = access_token)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    quantization_config=quantization_config,
    token = access_token
)

tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

In [6]:
# Function to clear unused CPU and GPU memory.

def clear_memory():
    torch.cuda.empty_cache()
    gc.collect()

clear_memory()

# Inference on few samples (Performance Check).

In [7]:
# Simple few shot prompt to extract the counts.

def extract_features(summary):
    prompt = f"""
Given a news summary related to terrorism, extract the below features from news summary and don't give anything else like explainations:

Provide count for below actions if present else 0
Number of Deaths:
Number of Injuries:
Number of Surrenders:
Number of Arrests:
Number of Abductions:

Example Input-Output pair:

Input:
2 CPI-Maoist women were arrested by the police in the Mahbubnagar District, and the police reported that they seized all the weapons that were being used by maoists. Later based on information provided by those two women they also arrested one more women.

Desired Answer:

Number of Deaths: 0
Number of Injuries: 0
Number of Surrenders: 0
Number of Arrests: 3
Number of Abductions: 0

End of Example

Note: These numbers can appear at any part of the summary, don't forget to count all of them.

News Summary: {summary}

Your Answer:
"""

    inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

    with torch.no_grad():
        outputs = model.generate(**inputs, max_length=len(inputs[0])+50, num_return_sequences=1)

    response = tokenizer.decode(outputs[0], skip_special_tokens=True)

    return response.split("Your Answer:")[-1].strip()

# Example usage
summary = """
An alleged arms supplier to the Communist Party of India-Maoist (CPI-Maoist), identified as Ravi Kumar Chevori, was arrested from Cyberabad near Hyderabad. He had entered into a deal with the Maoists to supply arms and ammunition worth INR 40 lakh, which the city Police seized on December 28, 2006, and arrested three persons.
"""

extracted_features = extract_features(summary)
print(extracted_features)

Number of Deaths: 0
Number of Injuries: 0
Number of Surrenders: 0
Number of Arrests: 3
Number of Abductions: 0


In [8]:
# Inference on a random sample

for summary in data.sample(n=5)["incident_summary"]:
    print(summary)
    print(extract_features(summary))

In the Chando Police station area of Balrampur Police District, on information of the movement of a large group of armed Maoists, Police cordoned off an area near Jalbotha village and asked them to surrender. The subsequent encounter led to the killing of nine Maoists.
Number of Deaths: 9
Number of Injuries: 0
Number of Surrenders: 0
Number of Arrests: 0
Number of Abductions: 0
A CPI-Maoist cadre, Kandara Darai, surrendered before the Superintendent of Police Sanjaya Kumar Kausal and District Collector Jamil Ahmed Khan at Kamakhyanagar in the Dhenkanal District.
Number of Deaths: 0
Number of Injuries: 0
Number of Surrenders: 1
Number of Arrests: 0
Number of Abductions: 0
A hardcore Maoist was arrested from Kaimur hills under Chutia Police Station area in Rohtas District. The Security Forces arrested Maoist cadre Suresh Paswan during a combing operation. A printer and scanner were seized from the possession of the Maoist, who was said to be the 'sub-zonal commander' of Rohtas and Sonebh

In [9]:
prompt = f"""
        Given a news summary related to terrorism, extract the below features from news summary and don't give anything else like explainations:

        Provide count for below actions if present else 0
        Number of Deaths:
        Number of Injuries:
        Number of Surrenders:
        Number of Arrests:
        Number of Abductions:

        Example Input-Output pair:

        Input:
        2 CPI-Maoist women were arrested by the police in the Mahbubnagar District, and the police reported that they seized all the weapons that were being used by maoists. Later based on information provided by those two women they also arrested one more women.

        Desired Answer:

        Number of Deaths: 0
        Number of Injuries: 0
        Number of Surrenders: 0
        Number of Arrests: 3
        Number of Abductions: 0

        End of Example

        Note: These numbers can appear at any part of the summary, don't forget to count all of them.

        News Summary: {summary}

        Your Answer:
        """

# Inference on Whole Dataset

In [10]:
# To run the inference on the whole data, it might take upto 4 hours on Coab with T4 or 3 hourse on Kaggle with P100.


def extract_features_batch(summaries, batch_size):
    all_outputs = []
    for i in tqdm(range(0, len(summaries), batch_size), desc="Processing batch"):
        clear_memory()
        batch = summaries[i:i+batch_size]
        prompts = [prompt.format(summary) for summary in batch]
        inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True, max_length=512).to("cuda")

        with torch.no_grad():
            outputs = model.generate(**inputs, max_length = len(inputs[0])+50, num_return_sequences=1)

        decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        all_outputs.extend([output.split("Your Answer:")[-1].strip() for output in decoded_outputs])

    return all_outputs


batch_size = 16 # Lower this, if there are any memory issues.

# Uncomment these lines if you want to run the whole inefrence.
#df['extracted_summary'] = extract_features_batch(df['incident_summary'].tolist(), batch_size)

# Save the results
#output_dir = 'extracted_output-gemma.csv'
#df.to_csv(output_dir, index=False)
#print(f"Extraction complete. Results saved to {output_dir}.")

# Extract summaries from the output summaries

The extracted data can be accessed from the data folder on github.

In [11]:
df = pd.read_csv(root_dir + "/gemma_counts_output.csv")

In [12]:
# Extract counts from the generated outputs, and store it in output list.
# Store the samples that were not generated properly in another list to filter them out.

output = []
bad_samples = []

for i, extract in tqdm(enumerate(df["extracted_summary"])):
    output.append([])
    try:
        split = extract.split('\n')[:5]
    except:
        bad_samples.append(i)
        output[-1] = [-100]*5
        continue

    for x in split:
        try:
            output[-1].append(int(x.split(':')[1]))
        except:
            bad_samples.append(i)
            output[-1].append(-100)

    if len(output[-1]) != 5:
        bad_samples.append(i)
        output[-1] = [-100] * 5

len(bad_samples)

9921it [00:00, 292278.50it/s]


185

In [13]:
preds = pd.DataFrame(output, columns = df.columns[:5], dtype="float")
preds.head()

Unnamed: 0,total_fatalities,total_injuries,total_surrenders,total_arrests,total_abducted
0,0.0,0.0,0.0,3.0,0.0
1,0.0,0.0,1.0,0.0,0.0
2,0.0,0.0,0.0,1.0,0.0
3,0.0,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,0.0,0.0


In [14]:
preds = preds.drop(bad_samples, axis = 0).reset_index(drop=True)
df = df.drop(bad_samples, axis = 0).reset_index(drop=True)

display(preds.head(5), df.head(5), preds.shape, df.shape)

Unnamed: 0,total_fatalities,total_injuries,total_surrenders,total_arrests,total_abducted
0,0.0,0.0,0.0,3.0,0.0
1,0.0,0.0,1.0,0.0,0.0
2,0.0,0.0,0.0,1.0,0.0
3,0.0,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,0.0,0.0


Unnamed: 0,total_fatalities,total_injuries,total_surrenders,total_arrests,total_abducted,incident_summary,extracted_summary
0,0.0,0.0,0.0,1.0,0,An alleged arms supplier to the Communist Part...,Number of Deaths: 0\n Number of Injuries: 0...
1,0.0,0.0,1.0,0.0,0,A Kamareddy dalam (squad) member belonging to ...,Number of Deaths: 0\n Number of Injuries: 0...
2,0.0,0.0,0.0,1.0,0,Senior CPI-Maoist 'Polit Bureau' and 'central ...,Number of Deaths: 0\n Number of Injuries: 0...
3,1.0,0.0,0.0,0.0,0,A TDP leader and former Sarpanch of Jerrela Gr...,Number of Deaths: 0\n Number of Injuries: 0...
4,0.0,0.0,0.0,0.0,0,The CPI-Maoist cadres blasted coffee pulping u...,Number of Deaths: 0\n Number of Injuries: 0...


(9851, 5)

(9851, 7)

# Evaluation

In [19]:
# Predicting all the numbers as correct.
print(f"Overall score: {((preds == df.iloc[:, :5]).all(axis = 1)).sum() / preds.shape[0]}")

Overall score: 0.7833722464724393


In [20]:
# Predicting atleast one of the counts as correct.
columns = df.columns[:5]
for i in range(5):
    print(f"Label {columns[i]} Score: {(preds.iloc[:, i] == df.iloc[:, i]).sum()/preds.shape[0]}")

Label total_fatalities Score: 0.9509694447264238
Label total_injuries Score: 0.9378743274794437
Label total_surrenders Score: 0.9245761851588671
Label total_arrests Score: 0.956248096639935
Label total_abducted Score: 0.9692417013501168
