<a href="https://colab.research.google.com/github/mathunjoroge/icd/blob/master/Kenya_Legal_AI_Full_Pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install required libraries with specific versions
!pip install -q --upgrade pip
!pip install -q --upgrade git+https://github.com/unslothai/unsloth.git@main
!pip install -q --upgrade transformers trl peft accelerate datasets bitsandbytes

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.8/1.8 MB[0m [31m60.2 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m40.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for unsloth (pyproject.toml) ... [?25l[?25hdone
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cudf-cu12 25.6.0 requires pyarrow<20.0.0a0,>=14.0.0; platform_machine == "x86_64", but you have pyarrow 22.0.0 which is incompatible.
pylibcudf-cu12 25.6.0 requires pyarrow<20.0.0a0,>=14.0.0; platform_machine == "x86_64", 

In [None]:
# Install required libraries with specific versions
!pip install -q --upgrade pip
!pip install -q --upgrade git+https://github.com/unslothai/unsloth.git@main
!pip install -q --upgrade transformers trl peft accelerate datasets bitsandbytes

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone


In [20]:
import json
import logging
import os
import re
import sys
import time
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass, field
from datetime import datetime
from typing import List, Set, Dict, Optional
from urllib.parse import urljoin
import io

import requests
from bs4 import BeautifulSoup
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from selenium.common.exceptions import TimeoutException, NoSuchElementException, WebDriverException
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.service import Service
from webdriver_manager.chrome import ChromeDriverManager

# Try to import PDF libraries, but make them optional
try:
    import PyPDF2
    PDF_SUPPORT = True
except ImportError:
    PDF_SUPPORT = False
    print("Warning: PyPDF2 not installed. PDF content extraction will be limited.")

try:
    import pdfplumber
    PDFPLUMBER_SUPPORT = True
except ImportError:
    PDFPLUMBER_SUPPORT = False


# --------------------------------------------------------------------------- #
#                               CONFIGURATION                                 #
# --------------------------------------------------------------------------- #

@dataclass
class Config:
    BASE_PROJECT_DIR: str = os.path.join(os.path.expanduser("~"), "projects", "kenya_law")

    LOG_DIR: str = field(init=False)
    DATA_DIR: str = field(init=False)
    LOG_FILE: str = field(init=False)
    DATA_FILE: str = field(init=False)
    DEBUG_DIR: str = field(init=False)
    CONSTITUTION_FILE: str = field(init=False)
    ACTS_FILE: str = field(init=False)
    SUBSIDIARY_FILE: str = field(init=False)
    COUNTIES_FILE: str = field(init=False)
    PDF_DIR: str = field(init=False)  # NEW: Directory to store downloaded PDFs

    MAX_CASES: int = None
    MAX_PAGES: int = None
    MAX_COUNTY_LAWS: int = 50
    REQUEST_TIMEOUT: int = 30
    SELENIUM_TIMEOUT: int = 45
    YEAR_START: int = 2020

    BASE_URL: str = "https://kenyalaw.org"
    NEW_BASE_URL: str = "https://new.kenyalaw.org"
    SEARCH_URL: str = "https://new.kenyalaw.org/search/"
    JUDGMENTS_URL: str = "https://new.kenyalaw.org/judgments/"
    COUNTIES_URL: str = "https://new.kenyalaw.org/legislation/counties"
    ACTS_TOC_URL: str = "https://new.kenyalaw.org/legislation/"

    LOCAL_CHROMEDRIVER_PATH: Optional[str] = None
    CHROME_HEADLESS: bool = True

    KEYWORDS: List[str] = field(default_factory=lambda: [
        "constitution", "human rights", "land", "election", "criminal", "civil",
        "jurisdiction", "appeal", "judicial review", "injunction", "contract"
    ])

    MAX_SCRAPE_WORKERS: int = 10
    ENABLE_PDF_EXTRACTION: bool = True  # NEW: Control PDF processing

    def __post_init__(self) -> None:
        self.LOG_DIR = os.path.join(self.BASE_PROJECT_DIR, "logs")
        self.DATA_DIR = os.path.join(self.BASE_PROJECT_DIR, "data")
        self.DEBUG_DIR = os.path.join(self.BASE_PROJECT_DIR, "debug")
        self.PDF_DIR = os.path.join(self.DATA_DIR, "pdfs")  # NEW
        self.LOG_FILE = os.path.join(self.LOG_DIR, f"kenyalaw_scraper_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
        self.DATA_FILE = os.path.join(self.DATA_DIR, "kenya_law_training_data.jsonl")
        self.CONSTITUTION_FILE = os.path.join(self.DATA_DIR, "constitution.json")
        self.ACTS_FILE = os.path.join(self.DATA_DIR, "acts_of_kenya.json")
        self.SUBSIDIARY_FILE = os.path.join(self.DATA_DIR, "subsidiary_legislation.json")
        self.COUNTIES_FILE = os.path.join(self.DATA_DIR, "county_legislation.json")

        for d in [self.LOG_DIR, self.DATA_DIR, self.DEBUG_DIR, self.PDF_DIR]:
            os.makedirs(d, exist_ok=True)


# --------------------------------------------------------------------------- #
#                                 LOGGING                                    #
# --------------------------------------------------------------------------- #

def setup_logging(log_file: str) -> logging.Logger:
    os.makedirs(os.path.dirname(log_file), exist_ok=True)
    logger = logging.getLogger("KenyaLaw-Scraper-v6.0-FULL")
    logger.setLevel(logging.INFO)
    logger.handlers.clear()
    fh = logging.FileHandler(log_file, encoding="utf-8")
    ch = logging.StreamHandler()
    formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
    fh.setFormatter(formatter)
    ch.setFormatter(formatter)
    logger.addHandler(fh)
    logger.addHandler(ch)
    return logger


# --------------------------------------------------------------------------- #
#                               DATA HANDLER                                 #
# --------------------------------------------------------------------------- #

class DataHandler:
    def __init__(self, cfg: Config, log: logging.Logger):
        self.cfg = cfg
        self.log = log
        self.lock = threading.Lock()

    def load_existing_case_ids(self) -> Set[str]:
        if not os.path.exists(self.cfg.DATA_FILE):
            return set()
        ids = set()
        try:
            with open(self.cfg.DATA_FILE, "r", encoding="utf-8") as f:
                for line_num, line in enumerate(f, 1):
                    line = line.strip()
                    if not line: continue
                    try:
                        data = json.loads(line)
                        if "case_id" in data:
                            ids.add(data["case_id"])
                    except json.JSONDecodeError:
                        self.log.warning(f"Bad JSON at line {line_num}")
            self.log.info(f"Loaded {len(ids)} existing case IDs")
        except Exception as e:
            self.log.error(f"Failed to load IDs: {e}")
        return ids

    def save_case(self, case: Dict) -> bool:
        try:
            json_line = json.dumps(case, ensure_ascii=False)
            with self.lock:
                with open(self.cfg.DATA_FILE, "a", encoding="utf-8") as f:
                    f.write(json_line + "\n")
            self.log.info(f"Saved case {case['case_id']} ({len(case['text'].split())} words)")
            return True
        except Exception as e:
            self.log.error(f"Save failed: {e}")
            return False


# --------------------------------------------------------------------------- #
#                              PDF HANDLING                                  #
# --------------------------------------------------------------------------- #

class PDFHandler:
    def __init__(self, cfg: Config, log: logging.Logger):
        self.cfg = cfg
        self.log = log

    def extract_text_from_pdf(self, pdf_url: str, pdf_content: bytes) -> Optional[str]:
        """Extract text from PDF content using multiple methods"""
        if not self.cfg.ENABLE_PDF_EXTRACTION:
            return "PDF_CONTENT_AVAILABLE_BUT_EXTRACTION_DISABLED"

        # Save PDF file for reference
        pdf_filename = self._save_pdf_file(pdf_url, pdf_content)

        extracted_text = None

        # Try pdfplumber first (better for scanned PDFs with OCR)
        if PDFPLUMBER_SUPPORT:
            extracted_text = self._extract_with_pdfplumber(pdf_content)

        # Fallback to PyPDF2
        if not extracted_text and PDF_SUPPORT:
            extracted_text = self._extract_with_pypdf2(pdf_content)

        # Final fallback
        if not extracted_text:
            extracted_text = self._extract_fallback(pdf_content, pdf_filename)

        return extracted_text

    def _save_pdf_file(self, pdf_url: str, pdf_content: bytes) -> str:
        """Save PDF file to disk for reference"""
        try:
            # Create filename from URL
            filename = re.sub(r'[^a-zA-Z0-9]', '_', pdf_url) + '.pdf'
            filepath = os.path.join(self.cfg.PDF_DIR, filename)

            with open(filepath, 'wb') as f:
                f.write(pdf_content)

            return filename
        except Exception as e:
            self.log.warning(f"Failed to save PDF file: {e}")
            return "unknown.pdf"

    def _extract_with_pdfplumber(self, pdf_content: bytes) -> Optional[str]:
        """Extract text using pdfplumber (better for scanned PDFs)"""
        try:
            text_parts = []
            with pdfplumber.open(io.BytesIO(pdf_content)) as pdf:
                for page in pdf.pages:
                    page_text = page.extract_text()
                    if page_text:
                        text_parts.append(page_text.strip())

            if text_parts:
                full_text = '\n'.join(text_parts)
                if len(full_text.split()) > 10:  # Ensure we have substantial text
                    return full_text
        except Exception as e:
            self.log.debug(f"pdfplumber extraction failed: {e}")

        return None

    def _extract_with_pypdf2(self, pdf_content: bytes) -> Optional[str]:
        """Extract text using PyPDF2"""
        try:
            text_parts = []
            pdf_file = io.BytesIO(pdf_content)
            reader = PyPDF2.PdfReader(pdf_file)

            for page in reader.pages:
                page_text = page.extract_text()
                if page_text:
                    text_parts.append(page_text.strip())

            if text_parts:
                full_text = '\n'.join(text_parts)
                if len(full_text.split()) > 10:
                    return full_text
        except Exception as e:
            self.log.debug(f"PyPDF2 extraction failed: {e}")

        return None

    def _extract_fallback(self, pdf_content: bytes, pdf_filename: str) -> str:
        """Fallback method when PDF text extraction fails"""
        self.log.warning(f"PDF text extraction failed for {pdf_filename}. Content saved to disk.")
        return f"PDF_CONTENT_UNABLE_TO_EXTRACT_TEXT_SAVED_AS_{pdf_filename}"

    def is_pdf_url(self, url: str) -> bool:
        """Check if URL points to a PDF document"""
        return url.lower().endswith('.pdf') or '/pdf/' in url.lower()

    def download_pdf(self, session: requests.Session, pdf_url: str) -> Optional[bytes]:
        """Download PDF content"""
        try:
            headers = {
                "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36",
                "Accept": "application/pdf, */*",
                "Referer": "https://new.kenyalaw.org/"
            }

            response = session.get(pdf_url, timeout=30, headers=headers)
            response.raise_for_status()

            content_type = response.headers.get('content-type', '').lower()
            if 'pdf' in content_type or response.content[:4] == b'%PDF':
                return response.content
            else:
                self.log.warning(f"URL {pdf_url} doesn't contain PDF data")
                return None

        except Exception as e:
            self.log.error(f"Failed to download PDF from {pdf_url}: {e}")
            return None


# --------------------------------------------------------------------------- #
#                      STATIC CONTENT SCRAPERS (Constitution & Acts)          #
# --------------------------------------------------------------------------- #

def save_constitution_data(cfg: Config, log: logging.Logger, data: Dict[str, str]) -> None:
    try:
        with open(cfg.CONSTITUTION_FILE, "w", encoding="utf-8") as f:
            json.dump(data, f, ensure_ascii=False, indent=2)
        log.info(f"Constitution saved → {cfg.CONSTITUTION_FILE}")
    except Exception as e:
        log.error(f"Failed to save constitution: {e}")

def scrape_constitution(cfg: Config, log: logging.Logger) -> None:
    log.info("Scraping Constitution of Kenya (2010)...")
    if os.path.exists(cfg.CONSTITUTION_FILE):
        log.info(f"Constitution already exists → {cfg.CONSTITUTION_FILE}. Skipping.")
        return

    # Try multiple potential sources
    SOURCES = [
        "https://new.kenyalaw.org/akn/ke/act/2010/constitution/eng@2010-09-03",
        "https://new.kenyalaw.org/akn/ke/act/2010/constitution",
        "https://kenyalaw.org/kl/index.php?id=398"
    ]

    session = requests.Session()
    session.headers.update({
        "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36",
        "Accept": "application/xml, text/xml, text/html, */*",
        "Referer": "https://new.kenyalaw.org/",
        "Accept-Encoding": "gzip, deflate, br",
    })

    for source_url in SOURCES:
        try:
            log.info(f"Trying source: {source_url}")
            resp = session.get(source_url, timeout=60)
            resp.raise_for_status()

            content_type = resp.headers.get('content-type', '').lower()

            if 'xml' in content_type:
                # Parse as XML
                soup = BeautifulSoup(resp.content, "xml")
                body = soup.find("body")
                if not body:
                    log.warning(f"No <body> found in XML from {source_url}")
                    continue

                data = {}
                current_title = "Preamble"
                current_lines = []

                # Extract all relevant elements
                for elem in body.find_all(['heading', 'num', 'p', 'chapter', 'part', 'section', 'article']):
                    if elem.name in ['heading', 'chapter', 'part']:
                        # Save previous section
                        if current_lines:
                            data[current_title] = "\n".join(current_lines).strip()
                            current_lines = []
                        # Start new section
                        current_title = elem.get_text(strip=True)
                        if not current_title:
                            current_title = "Untitled Section"

                    elif elem.name == 'p' and elem.get_text(strip=True):
                        text = elem.get_text(strip=True)
                        if len(text) > 10:  # Filter out very short paragraphs
                            current_lines.append(text)

                # Save the last section
                if current_lines:
                    data[current_title] = "\n".join(current_lines).strip()

            else:
                # Parse as HTML
                soup = BeautifulSoup(resp.content, "html.parser")

                # Try different content selectors
                content_selectors = [
                    "div.act-content",
                    "div.content",
                    "article",
                    "main",
                    "div.container",
                    "#content"
                ]

                content = None
                for selector in content_selectors:
                    content = soup.select_one(selector)
                    if content:
                        break

                if not content:
                    log.warning(f"No content found with selectors in {source_url}")
                    continue

                # Clean up the content
                for element in content.select("script, style, nav, header, footer, .nav, .header, .footer"):
                    element.decompose()

                # Extract text and structure
                data = {}
                current_section = "Constitution of Kenya"
                sections = [current_section]
                text_lines = []

                for element in content.find_all(['h1', 'h2', 'h3', 'h4', 'p', 'div']):
                    text = element.get_text(strip=True)
                    if not text:
                        continue

                    # Detect section headers
                    if element.name in ['h1', 'h2', 'h3', 'h4']:
                        if text_lines:  # Save previous section
                            data[current_section] = "\n".join(text_lines).strip()
                            text_lines = []
                        current_section = text
                        sections.append(current_section)
                    elif len(text) > 20:  # Substantial content
                        text_lines.append(text)

                # Save the last section
                if text_lines:
                    data[current_section] = "\n".join(text_lines).strip()

            # Filter out empty or very short sections
            filtered_data = {}
            for title, content in data.items():
                if content and len(content.split()) >= 10:
                    clean_content = re.sub(r'\s+', ' ', content).strip()
                    filtered_data[title] = clean_content

            # If we have substantial content, save it
            if filtered_data and sum(len(c.split()) for c in filtered_data.values()) > 500:
                save_constitution_data(cfg, log, filtered_data)
                total_words = sum(len(v.split()) for v in filtered_data.values())
                log.info(f"SUCCESS: Constitution scraped → {len(filtered_data)} sections, {total_words:,} words")
                return
            else:
                log.warning(f"Insufficient content from {source_url}")

        except Exception as e:
            log.warning(f"Source {source_url} failed: {e}")
            continue

    # Fallback: manual structure if all sources fail
    log.warning("All automated sources failed. Creating placeholder structure.")
    fallback_data = {
        "Preamble": "We, the people of Kenya—ACKNOWLEDGING the supremacy of the Almighty God of all creation...",
        "Chapter One - Sovereignty of the People": "1. (1) All sovereign power belongs to the people of Kenya...",
        "Chapter Two - The Republic": "4. (1) Kenya is a sovereign Republic. (2) The Republic of Kenya shall be a multi-party democratic state...",
        "Note": "This is a placeholder. The actual constitution text could not be scraped automatically. Consider manual entry."
    }
    save_constitution_data(cfg, log, fallback_data)
    log.info("Created fallback constitution structure")

def save_acts_data(cfg: Config, log: logging.Logger, acts: Dict[str, str], subs: Dict[str, str]) -> None:
    with open(cfg.ACTS_FILE, "w", encoding="utf-8") as f:
        json.dump(acts, f, ensure_ascii=False, indent=2)
    log.info(f"Saved {len(acts)} Acts → {cfg.ACTS_FILE}")

    with open(cfg.SUBSIDIARY_FILE, "w", encoding="utf-8") as f:
        json.dump(subs, f, ensure_ascii=False, indent=2)
    log.info(f"Saved {len(subs)} Subsidiary Laws → {cfg.SUBSIDIARY_FILE}")

def save_counties_data(cfg: Config, log: logging.Logger, counties_data: Dict[str, Dict]) -> None:
    try:
        with open(cfg.COUNTIES_FILE, "w", encoding="utf-8") as f:
            json.dump(counties_data, f, ensure_ascii=False, indent=2)
        log.info(f"County legislation saved → {cfg.COUNTIES_FILE}")
    except Exception as e:
        log.error(f"Failed to save county legislation: {e}")


def scrape_acts_of_kenya(cfg: Config, log: logging.Logger) -> None:
    log.info("Scraping ALL Acts + Subsidiary Legislation (NEW SITE)...")
    if os.path.exists(cfg.ACTS_FILE) and os.path.exists(cfg.SUBSIDIARY_FILE):
        log.info("Acts & Subsidiary files exist. Skipping.")
        return

    session = requests.Session()
    session.headers.update({
        "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36",
        "Accept": "text/html",
        "Referer": cfg.ACTS_TOC_URL
    })

    acts_data = {}
    subsidiary_data = {}
    total_acts = 0
    total_subs = 0

    try:
        resp = session.get(cfg.ACTS_TOC_URL, timeout=cfg.REQUEST_TIMEOUT)
        resp.raise_for_status()
        soup = BeautifulSoup(resp.text, "lxml")

        main_rows = soup.select('tr.has-children')
        log.info(f"Found {len(main_rows)} parent Acts")

        for row in main_rows:
            btn = row.select_one('button[data-bs-toggle="collapse"]')
            if not btn:
                continue
            target = btn.get("data-bs-target", "").lstrip("#")
            if not target:
                continue

            link = row.select_one('td.cell-title a')
            if not link:
                continue
            title = link.get_text(strip=True)
            href = link.get("href")
            url = urljoin(cfg.ACTS_TOC_URL, href)
            citation = row.select_one('td.cell-citation')
            cap = citation.get_text(strip=True) if citation else ""

            key = f"[{cap}] {title}".strip()
            if key in acts_data:
                continue

            # Scrape main Act
            try:
                time.sleep(0.8)
                r = session.get(url, timeout=cfg.REQUEST_TIMEOUT)
                r.raise_for_status()
                s = BeautifulSoup(r.text, "lxml")
                content = s.select_one("div.act-content, article, main")
                if content:
                    for el in content.select("script, style, nav, header, footer, .act-tools"):
                        el.decompose()
                    text = content.get_text(separator="\n", strip=True)
                    text = re.sub(r'\n{3,}', '\n\n', text)
                    text = re.sub(r'\s+', ' ', text).strip()
                    if len(text.split()) > 200:
                        acts_data[key] = text
                        total_acts += 1
                        log.info(f"  Success: Act: {key}")
            except Exception as e:
                log.error(f"Failed Act {key}: {e}")

            # Scrape subsidiary
            tbody = soup.find("tbody", id=target)
            if not tbody:
                continue

            for sub_row in tbody.select('tr'):
                sub_link = sub_row.select_one('td.cell-title a')
                if not sub_link:
                    continue
                sub_title = sub_link.get_text(strip=True)
                sub_href = sub_link.get("href")
                sub_url = urljoin(cfg.ACTS_TOC_URL, sub_href)
                sub_cite = sub_row.select_one('td.cell-citation')
                sub_cite_text = sub_cite.get_text(strip=True) if sub_cite else ""

                sub_key = f"{sub_cite_text} {sub_title}".strip()
                full_key = f"[{cap}] {title} → {sub_key}"

                try:
                    time.sleep(0.8)
                    r = session.get(sub_url, timeout=cfg.REQUEST_TIMEOUT)
                    r.raise_for_status()
                    s = BeautifulSoup(r.text, "lxml")
                    content = s.select_one("div.act-content, article")
                    if content:
                        for el in content.select("script, style, nav, header, footer"):
                            el.decompose()
                        text = content.get_text(separator="\n", strip=True)
                        text = re.sub(r'\s+', ' ', text).strip()
                        if len(text.split()) > 100:
                            subsidiary_data[full_key] = text
                            total_subs += 1
                            log.info(f"    Success: Subsidiary: {sub_key}")
                except Exception as e:
                    log.error(f"Failed subsidiary {sub_key}: {e}")

        save_acts_data(cfg, log, acts_data, subsidiary_data)
        log.info(f"SUCCESS: {total_acts} Acts + {total_subs} Subsidiary Laws saved!")

    except Exception as e:
        log.error(f"Acts scrape failed: {e}", exc_info=True)


# --------------------------------------------------------------------------- #
#                           COUNTY LEGISLATION SCRAPER                        #
# --------------------------------------------------------------------------- #

def scrape_county_legislation(cfg: Config, log: logging.Logger) -> None:
    """Scrape county legislation from all 47 counties"""
    log.info("Scraping County Legislation from all 47 counties...")
    if os.path.exists(cfg.COUNTIES_FILE):
        log.info(f"County legislation already exists → {cfg.COUNTIES_FILE}. Skipping.")
        return

    session = requests.Session()
    pdf_handler = PDFHandler(cfg, log)  # NEW: PDF handler

    session.headers.update({
        "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36",
        "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
        "Accept-Language": "en-US,en;q=0.5",
        "Accept-Encoding": "gzip, deflate, br",
        "Referer": cfg.NEW_BASE_URL
    })

    counties_data = {}
    total_county_laws = 0

    try:
        # Get the main counties page
        log.info(f"Accessing counties page: {cfg.COUNTIES_URL}")
        resp = session.get(cfg.COUNTIES_URL, timeout=cfg.REQUEST_TIMEOUT)
        resp.raise_for_status()
        soup = BeautifulSoup(resp.text, "lxml")

        # Extract all county links
        county_links = []
        flow_columns = soup.select('.flow-columns-group')

        for column in flow_columns:
            links = column.select('a[href^="/legislation/ke-"]')
            for link in links:
                county_name = link.get_text(strip=True)
                county_url = urljoin(cfg.NEW_BASE_URL, link.get('href'))
                county_links.append((county_name, county_url))

        log.info(f"Found {len(county_links)} counties to process")

        # Process each county
        for county_name, county_url in county_links:
            try:
                log.info(f"Processing county: {county_name}")
                time.sleep(1)  # Be respectful

                # Get county page
                county_resp = session.get(county_url, timeout=cfg.REQUEST_TIMEOUT)
                county_resp.raise_for_status()
                county_soup = BeautifulSoup(county_resp.text, "lxml")

                # Extract county laws
                county_laws = {}

                # Look for laws in tables or lists
                law_elements = county_soup.select('tr.has-children, .legislation-item, .law-item')

                if not law_elements:
                    # Try alternative selectors
                    law_elements = county_soup.select('a[href*="/akn/ke/act/"]')

                laws_processed = 0

                for law_element in law_elements:
                    if laws_processed >= cfg.MAX_COUNTY_LAWS:
                        break

                    try:
                        # Extract law link and title
                        if law_element.name == 'tr':
                            link_elem = law_element.select_one('td.cell-title a')
                        else:
                            link_elem = law_element

                        if not link_elem or not link_elem.get('href'):
                            continue

                        law_title = link_elem.get_text(strip=True)
                        law_url = urljoin(cfg.NEW_BASE_URL, link_elem.get('href'))

                        # Skip if it's not a direct law link
                        if '/akn/ke/act/' not in law_url:
                            continue

                        # Scrape the actual law content (now with PDF support)
                        time.sleep(0.5)
                        law_content = scrape_county_law_content(session, pdf_handler, law_url, log)

                        if law_content and len(law_content.split()) > 50:  # Reduced threshold for PDFs
                            county_laws[law_title] = {
                                'url': law_url,
                                'content': law_content,
                                'word_count': len(law_content.split()),
                                'content_type': 'pdf' if pdf_handler.is_pdf_url(law_url) else 'html'
                            }
                            laws_processed += 1
                            total_county_laws += 1
                            log.info(f"    ✓ County law: {law_title} ({len(law_content.split())} words) [{county_laws[law_title]['content_type'].upper()}]")

                    except Exception as e:
                        log.warning(f"Failed to process county law in {county_name}: {e}")
                        continue

                # Add county data
                if county_laws:
                    counties_data[county_name] = {
                        'county_url': county_url,
                        'laws': county_laws,
                        'total_laws': len(county_laws),
                        'scraped_at': datetime.now().isoformat()
                    }
                    log.info(f"  ✓ {county_name}: {len(county_laws)} laws")

            except Exception as e:
                log.error(f"Failed to process county {county_name}: {e}")
                continue

        # Save counties data
        save_counties_data(cfg, log, counties_data)
        log.info(f"SUCCESS: County legislation scraped → {len(counties_data)} counties, {total_county_laws} total laws")

    except Exception as e:
        log.error(f"County legislation scrape failed: {e}", exc_info=True)

def scrape_county_law_content(session: requests.Session, pdf_handler: PDFHandler, law_url: str, log: logging.Logger) -> Optional[str]:
    """Scrape content of an individual county law with PDF support"""

    # Check if it's a PDF URL
    if pdf_handler.is_pdf_url(law_url):
        log.info(f"  Detected PDF document: {law_url}")
        pdf_content = pdf_handler.download_pdf(session, law_url)
        if pdf_content:
            return pdf_handler.extract_text_from_pdf(law_url, pdf_content)
        else:
            return None

    # Handle HTML content
    try:
        resp = session.get(law_url, timeout=30)
        resp.raise_for_status()
        soup = BeautifulSoup(resp.text, "lxml")

        # Check if the page contains a PDF link
        pdf_links = soup.select('a[href$=".pdf"], a[href*="/pdf/"]')
        for pdf_link in pdf_links:
            pdf_url = urljoin(law_url, pdf_link.get('href'))
            log.info(f"  Found embedded PDF link: {pdf_url}")
            pdf_content = pdf_handler.download_pdf(session, pdf_url)
            if pdf_content:
                return pdf_handler.extract_text_from_pdf(pdf_url, pdf_content)

        # Try multiple content selectors for HTML laws
        content_selectors = [
            "div.act-content",
            "div.fr-view",
            "div.content",
            "article",
            "main",
            ".law-content",
            ".document-content"
        ]

        content = None
        for selector in content_selectors:
            content = soup.select_one(selector)
            if content:
                break

        if not content:
            # Fallback: get body content
            content = soup.find('main') or soup.find('article') or soup.find('body')

        if content:
            # Clean up content
            for element in content.select("script, style, nav, header, footer, .nav, .header, .footer, .tools"):
                element.decompose()

            # Extract text
            text = content.get_text(separator="\n", strip=True)
            text = re.sub(r'\n{3,}', '\n\n', text)
            text = re.sub(r'\s+', ' ', text).strip()

            return text if len(text.split()) > 50 else None

    except Exception as e:
        log.warning(f"Failed to scrape county law content from {law_url}: {e}")

    return None


# --------------------------------------------------------------------------- #
#                                 CASE LAW SCRAPER                            #
# --------------------------------------------------------------------------- #

class KenyaLawScraper:
    def __init__(self, cfg: Config, log: logging.Logger):
        self.cfg = cfg
        self.log = log
        self.keywords = {k.lower() for k in cfg.KEYWORDS}
        self.seen_case_ids = set()
        self.driver = None
        self.session = self._create_session()

    def _create_session(self):
        s = requests.Session()
        retry = Retry(total=3, backoff_factor=1, status_forcelist=[429, 500, 502, 503, 504])
        adapter = HTTPAdapter(max_retries=retry)
        s.mount("http://", adapter)
        s.mount("https://", adapter)
        s.headers.update({
            "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36",
            "Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8",
            "Accept-Language": "en-US,en;q=0.9",
            "Accept-Encoding": "gzip, deflate, br"
        })
        return s

    def _setup_driver(self):
        options = Options()
        if self.cfg.CHROME_HEADLESS:
            options.add_argument("--headless")
        options.add_argument("--no-sandbox")
        options.add_argument("--disable-dev-shm-usage")
        options.add_argument("--window-size=1920,1080")
        options.add_argument("--disable-blink-features=AutomationControlled")
        options.add_experimental_option("excludeSwitches", ["enable-automation"])
        options.add_experimental_option('useAutomationExtension', False)

        try:
            service = Service(ChromeDriverManager().install())
            driver = webdriver.Chrome(service=service, options=options)
        except Exception as e:
            self.log.warning(f"ChromeDriverManager failed: {e}, trying direct Chrome")
            driver = webdriver.Chrome(options=options)

        driver.execute_script("Object.defineProperty(navigator, 'webdriver', {get: () => undefined})")
        driver.execute_cdp_cmd('Network.setUserAgentOverride', {
            "userAgent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36"
        })
        return driver

    def _case_id_from_url(self, url: str) -> Optional[str]:
        # Extract case ID from new URL format: /akn/ke/judgment/kehc/2025/15858/eng@2025-11-07
        patterns = [
            r"/akn/ke/judgment/[^/]+/(\d+)/(?:eng@|\d+)",
            r"/judgments/view/(\d+)",
            r"/caselaw/cases/view/(\d+)"
        ]
        for pattern in patterns:
            m = re.search(pattern, url)
            if m:
                return m.group(1)
        return None

    def fetch_case_urls_selenium(self) -> List[str]:
        self.log.info("Collecting case URLs from new Kenya Law site...")
        urls = set()

        try:
            self.driver = self._setup_driver()
            self.driver.get(self.cfg.JUDGMENTS_URL)
            time.sleep(5)

            # Wait for page to load
            WebDriverWait(self.driver, 10).until(
                EC.presence_of_element_located((By.TAG_NAME, "table"))
            )

            # Get initial page URLs
            soup = BeautifulSoup(self.driver.page_source, "lxml")
            self._extract_urls_from_page(soup, urls)

            # Handle pagination
            page_count = 0
            while page_count < (self.cfg.MAX_PAGES or 10):  # Limit pages if specified
                try:
                    # Look for next button
                    next_buttons = self.driver.find_elements(By.XPATH,
                        "//a[contains(text(), 'Next') or contains(@class, 'next') or contains(@aria-label, 'next')]")

                    if not next_buttons:
                        break

                    next_btn = next_buttons[0]
                    if "disabled" in next_btn.get_attribute("class") or not next_btn.is_enabled():
                        break

                    # Click next page
                    self.driver.execute_script("arguments[0].click();", next_btn)
                    time.sleep(4)

                    # Wait for new content to load
                    WebDriverWait(self.driver, 10).until(
                        EC.presence_of_element_located((By.TAG_NAME, "table"))
                    )

                    # Extract URLs from new page
                    soup = BeautifulSoup(self.driver.page_source, "lxml")
                    self._extract_urls_from_page(soup, urls)

                    page_count += 1
                    self.log.info(f"Processed page {page_count}, total URLs: {len(urls)}")

                except (TimeoutException, NoSuchElementException) as e:
                    self.log.info("No more pages or pagination failed")
                    break

        except Exception as e:
            self.log.error(f"Error collecting URLs: {e}")
        finally:
            if self.driver:
                self.driver.quit()

        final = list(urls)
        if self.cfg.MAX_CASES:
            final = final[:self.cfg.MAX_CASES]
        self.log.info(f"Collected {len(final)} case URLs")
        return final

    def _extract_urls_from_page(self, soup: BeautifulSoup, urls: set) -> None:
        """Extract case URLs from a page"""
        # Look for case links in tables
        table_links = soup.find_all("a", href=re.compile(r"/akn/ke/judgment/"))
        for link in table_links:
            href = link.get("href")
            if href:
                full_url = urljoin(self.cfg.NEW_BASE_URL, href)
                urls.add(full_url)

        # Also check for any judgment links
        judgment_links = soup.select('a[href*="/judgment/"]')
        for link in judgment_links:
            href = link.get("href")
            if href and "/akn/ke/judgment/" in href:
                full_url = urljoin(self.cfg.NEW_BASE_URL, href)
                urls.add(full_url)

    def scrape_one_case(self, url: str) -> Optional[Dict]:
        case_id = self._case_id_from_url(url)
        if not case_id or case_id in self.seen_case_ids:
            return None

        try:
            self.log.info(f"Scraping case: {url}")
            resp = self.session.get(url, timeout=self.cfg.REQUEST_TIMEOUT)
            resp.raise_for_status()

            soup = BeautifulSoup(resp.text, "lxml")

            # Extract case title
            title_elem = soup.find("h1") or soup.find("title")
            case_name = title_elem.get_text(strip=True) if title_elem else "Unknown Case"

            # Extract case content - try multiple selectors for new site
            content_selectors = [
                "div.fr-view",  # Rich text content
                "div.content",
                "article",
                "main",
                ".judgment-content",
                ".case-content"
            ]

            content = None
            for selector in content_selectors:
                content = soup.select_one(selector)
                if content:
                    break

            if not content:
                # Fallback: get main content area
                content = soup.find("main") or soup.find("article") or soup.find("div", class_=re.compile("content"))

            if not content:
                self.log.warning(f"No content found for case {case_id}")
                return None

            # Clean up content
            for element in content.select("script, style, nav, header, footer, .nav, .header, .footer, .tools, .act-tools"):
                element.decompose()

            # Extract text
            text = content.get_text(separator="\n", strip=True)
            text = re.sub(r'\n{3,}', '\n\n', text)
            text = re.sub(r'\s+', ' ', text).strip()

            if len(text.split()) < 100:
                self.log.warning(f"Case {case_id} has insufficient text: {len(text.split())} words")
                return None

            # Extract metadata
            metadata = self._extract_case_metadata(soup)

            data = {
                "case_id": case_id,
                "case_name": case_name,
                "url": url,
                "text": text,
                "text_length_words": len(text.split()),
                "scraped_at": datetime.now().isoformat(),
                "metadata": metadata
            }

            self.seen_case_ids.add(case_id)
            return data

        except Exception as e:
            self.log.error(f"Case failed {url}: {e}")
            return None

    def _extract_case_metadata(self, soup: BeautifulSoup) -> Dict[str, str]:
        """Extract case metadata from the page"""
        metadata = {}

        try:
            # Look for common metadata patterns
            meta_selectors = {
                "court": ["span.court", "div.court", "td.cell-court"],
                "date": ["span.date", "div.date", "td.cell-date", "time"],
                "case_number": ["span.case-number", "div.case-number", "td.cell-case-number"],
                "judges": ["span.judges", "div.judges", "p.judges"],
                "citation": ["span.citation", "div.citation", "td.cell-citation"]
            }

            for key, selectors in meta_selectors.items():
                for selector in selectors:
                    element = soup.select_one(selector)
                    if element:
                        metadata[key] = element.get_text(strip=True)
                        break

        except Exception as e:
            self.log.debug(f"Metadata extraction failed: {e}")

        return metadata

    def run_case_scrape(self, urls: List[str], handler: DataHandler) -> int:
        self.log.info(f"Scraping {len(urls)} cases...")
        saved = 0

        with ThreadPoolExecutor(max_workers=self.cfg.MAX_SCRAPE_WORKERS) as executor:
            future_to_url = {executor.submit(self.scrape_one_case, url): url for url in urls}

            for future in as_completed(future_to_url):
                url = future_to_url[future]
                try:
                    result = future.result()
                    if result and handler.save_case(result):
                        saved += 1
                except Exception as e:
                    self.log.error(f"Case scraping failed for {url}: {e}")

        return saved


# --------------------------------------------------------------------------- #
#                                   MAIN                                      #
# --------------------------------------------------------------------------- #

def main() -> None:
    cfg = Config()
    log = setup_logging(cfg.LOG_FILE)
    log.info("=== KenyaLaw Scraper v6.0 FULL (Acts + Subsidiary + Cases + Counties with PDF support) ===")

    # Check PDF support
    if cfg.ENABLE_PDF_EXTRACTION:
        if not PDF_SUPPORT and not PDFPLUMBER_SUPPORT:
            log.warning("PDF extraction enabled but no PDF libraries found. Install: pip install pypdf2 pdfplumber")
        else:
            log.info(f"PDF extraction enabled: PyPDF2={PDF_SUPPORT}, pdfplumber={PDFPLUMBER_SUPPORT}")

    # Scrape static content
    scrape_constitution(cfg, log)
    scrape_acts_of_kenya(cfg, log)
    scrape_county_legislation(cfg, log)  # Now with PDF support

    # Scrape case law
    handler = DataHandler(cfg, log)
    scraper = KenyaLawScraper(cfg, log)
    scraper.seen_case_ids = handler.load_existing_case_ids()

    urls = scraper.fetch_case_urls_selenium()
    new_urls = [u for u in urls if scraper._case_id_from_url(u) not in scraper.seen_case_ids]

    if new_urls:
        log.info(f"Found {len(new_urls)} new cases to scrape")
        saved = scraper.run_case_scrape(new_urls, handler)
        log.info(f"Completed: {saved} new cases saved.")
    else:
        log.info("No new cases found.")

    log.info("=== ALL DONE ===")


if __name__ == "__main__":
    main()

ModuleNotFoundError: No module named 'webdriver_manager'

In [22]:
%run /content/kenyalaw_scraper_full.py

2025-11-08 16:09:45,596 - KenyaLaw-Scraper-v6.0-FULL - INFO - === KenyaLaw Scraper v6.0 FULL (Acts + Subsidiary + Cases + Counties with PDF support) ===
INFO:KenyaLaw-Scraper-v6.0-FULL:=== KenyaLaw Scraper v6.0 FULL (Acts + Subsidiary + Cases + Counties with PDF support) ===
2025-11-08 16:09:45,599 - KenyaLaw-Scraper-v6.0-FULL - INFO - Scraping Constitution of Kenya (2010)...
INFO:KenyaLaw-Scraper-v6.0-FULL:Scraping Constitution of Kenya (2010)...
2025-11-08 16:09:45,602 - KenyaLaw-Scraper-v6.0-FULL - INFO - Trying source: https://new.kenyalaw.org/akn/ke/act/2010/constitution/eng@2010-09-03
INFO:KenyaLaw-Scraper-v6.0-FULL:Trying source: https://new.kenyalaw.org/akn/ke/act/2010/constitution/eng@2010-09-03




2025-11-08 16:09:53,879 - KenyaLaw-Scraper-v6.0-FULL - INFO - Constitution saved → /root/projects/kenya_law/data/constitution.json
INFO:KenyaLaw-Scraper-v6.0-FULL:Constitution saved → /root/projects/kenya_law/data/constitution.json
2025-11-08 16:09:53,940 - KenyaLaw-Scraper-v6.0-FULL - INFO - SUCCESS: Constitution scraped → 11 sections, 344,469 words
INFO:KenyaLaw-Scraper-v6.0-FULL:SUCCESS: Constitution scraped → 11 sections, 344,469 words
2025-11-08 16:09:53,944 - KenyaLaw-Scraper-v6.0-FULL - INFO - Scraping ALL Acts + Subsidiary Legislation (NEW SITE)...
INFO:KenyaLaw-Scraper-v6.0-FULL:Scraping ALL Acts + Subsidiary Legislation (NEW SITE)...
2025-11-08 16:09:55,256 - KenyaLaw-Scraper-v6.0-FULL - INFO - Found 33 parent Acts
INFO:KenyaLaw-Scraper-v6.0-FULL:Found 33 parent Acts
2025-11-08 16:09:56,891 - KenyaLaw-Scraper-v6.0-FULL - INFO -   Success: Act: [Cap. 7M] Access to Information Act
INFO:KenyaLaw-Scraper-v6.0-FULL:  Success: Act: [Cap. 7M] Access to Information Act
2025-11-08 16:

In [18]:
!pip install -q selenium

In [5]:
# train_legal.py
# Full training pipeline for Kenyan Legal AI Model
# Works with your KenyaLaw Scraper v6.0 output

import json
import os
import unsloth
from pathlib import Path
from typing import List, Dict, Any
from datasets import Dataset
from unsloth import FastLanguageModel
from trl import SFTTrainer
from transformers import TrainingArguments
import torch
import gc

# ========================= CONFIG =========================
class TrainConfig:
    BASE_DIR = Path(os.path.expanduser("~/projects/kenya_law/data"))
    OUTPUT_DIR = Path("./kenya-legal-llm")
    MODEL_NAME = "unsloth/Mistral-7B-Instruct-v0.3-bnb-4bit"  # or "unsloth/Qwen2-7B-7B-bnb-4bit" # Corrected typo

    MAX_SEQ_LENGTH = 8192
    BATCH_SIZE = 2
    GRADIENT_ACCUMULATION_STEPS = 8
    LEARNING_RATE = 2e-4
    NUM_EPOCHS = 3
    WARMUP_STEPS = 10
    LOGGING_STEPS = 10
    SAVE_STEPS = 100
    EVAL_STRATEGY = "no" # Set to "no" when no eval_dataset is provided

    # Output formats
    GGUF_OUTPUT = OUTPUT_DIR / "gguf"
    HF_OUTPUT = OUTPUT_DIR / "hf"

cfg = TrainConfig()
cfg.OUTPUT_DIR.mkdir(exist_ok=True)
cfg.GGUF_OUTPUT.mkdir(exist_ok=True)
cfg.HF_OUTPUT.mkdir(exist_ok=True)

# ========================= LOAD DATA =========================
def load_jsonl(file_path: Path) -> List[Dict]:
    data = []
    if not file_path.exists():
        print(f"Warning: {file_path} not found!")
        return data
    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                try:
                    data.append(json.loads(line))
                except:
                    continue
    print(f"Loaded {len(data)} cases from {file_path.name}")
    return data

def load_json(file_path: Path) -> Dict:
    if not file_path.exists():
        print(f"Warning: {file_path} not found!")
        return {}
    with open(file_path, "r", encoding="utf-8") as f:
        return json.load(f)

print("Loading Kenyan legal data...")
cases = load_jsonl(cfg.BASE_DIR / "kenya_law_training_data.jsonl")
constitution = load_json(cfg.BASE_DIR / "constitution.json")
acts = load_json(cfg.BASE_DIR / "acts_of_kenya.json")
subsidiary = load_json(cfg.BASE_DIR / "subsidiary_legislation.json")
counties = load_json(cfg.BASE_DIR / "county_legislation.json")

# ========================= PREPARE DATA =========================
def chunk_text(text: str, chunk_size: int = 1000, overlap: int = 200) -> List[str]:
    words = text.split()
    chunks = []
    i = 0
    while i < len(words):
        chunk = " ".join(words[i:i + chunk_size])
        chunks.append(chunk)
        i += chunk_size - overlap
    return chunks

def create_instruction_samples() -> List[Dict[str, Any]]:
    samples = []

    # 1. Constitution
    for title, content in constitution.items():
        chunks = chunk_text(content, 1200)
        for i, chunk in enumerate(chunks):
            samples.append({
                "instruction": f"You are a Kenyan constitutional law expert. Answer based on the Constitution of Kenya 2010.",
                "input": f"Explain: {title}" + (f" (Part {i+1})" if len(chunks) > 1 else ""),
                "output": chunk[:4000]
            })

    # 2. Acts of Parliament
    for title, content in acts.items():
        if len(content.split()) > 200:
            samples.append({
                "instruction": "You are a Kenyan lawyer. Cite and explain the relevant law.",
                "input": f"What does the law say about: {title}?",
                "output": content[:6000]
            })

    # 3. Case Law (High-value)
    for case in cases[:2000]:  # Use top 2000 cases
        text = case.get("text", "")
        if len(text.split()) < 200:
            continue
        metadata = case.get("metadata", {})
        court = metadata.get("court", "Kenyan Court")
        date = metadata.get("date", "Unknown date")

        samples.append({
            "instruction": "You are a Kenyan judge. Analyze this case and give legal reasoning.",
            "input": f"Case: {case['case_name']}\nCourt: {court}\nDate: {date}\n\n{text[:3000]}...",
            "output": f"**Case Analysis:**\n\n**Citation:** {case['case_name']}\n**Court:** {court}\n**Date:** {date}\n\n**Legal Reasoning:**\n{text[1000:8000]}\n\n**Held:** {text.split('Held:')[-1].split('JUDGMENT')[0] if 'Held:' in text else 'See full judgment.'}"
        })

    # 4. County Laws (PDF + HTML)
    for county, data in counties.items():
        for law_name, law in data.get("laws", {}).items():
            content = law.get("content", "")
            if "UNABLE_TO_EXTRACT" in content or len(content.split()) < 100:
                continue
            samples.append({
                "instruction": f"You are a legal expert in {county} County legislation.",
                "input": f"What does {county} law say about: {law_name}?",
                "output": content[:5000]
            })

    print(f"Created {len(samples)} training samples")
    return samples

# ========================= FORMAT FOR TRAINING =========================
def format_alpaca(sample: Dict) -> str:
    return f"""### Instruction:
{sample['instruction']}

### Input:
{sample['input']}

### Response:
{sample['output']}"""

print("Creating dataset...")
raw_samples = create_instruction_samples()
formatted = [format_alpaca(s) for s in raw_samples]
dataset = Dataset.from_dict({"text": formatted})

# ========================= LOAD MODEL =========================
print(f"Loading {cfg.MODEL_NAME} with Unsloth...")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=cfg.MODEL_NAME,
    max_seq_length=cfg.MAX_SEQ_LENGTH,
    dtype=None,  # Auto detect
    load_in_4bit=True,
    device_map="auto", # Added device_map="auto"
    # llm_int8_enable_fp32_cpu_offload=True # Optional: Enable CPU offloading for fp32 modules
)

model = FastLanguageModel.get_peft_model(
    model,
    r=64,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha=16,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=3407,
    max_seq_length=cfg.MAX_SEQ_LENGTH,
)

# ========================= TRAINER =========================
print("Setting up SFTTrainer...")
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=cfg.MAX_SEQ_LENGTH,
    dataset_num_proc=2,
    packing=True,
    args=TrainingArguments(
        per_device_train_batch_size=cfg.BATCH_SIZE,
        gradient_accumulation_steps=cfg.GRADIENT_ACCUMULATION_STEPS,
        warmup_steps=cfg.WARMUP_STEPS,
        num_train_epochs=cfg.NUM_EPOCHS,
        learning_rate=cfg.LEARNING_RATE,
        fp16=not torch.cuda.is_bf16_supported(),
        bf16=torch.cuda.is_bf16_supported(),
        logging_steps=cfg.LOGGING_STEPS,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=3407,
        output_dir=str(cfg.OUTPUT_DIR),
        report_to="none",
        save_strategy="steps",
        save_steps=cfg.SAVE_STEPS,
        eval_strategy=cfg.EVAL_STRATEGY,
        load_best_model_at_end=False, # Set to False when no eval dataset
        tokenizer=tokenizer, # Added tokenizer here
    ),
)

print("Starting training...")
trainer.train()

# ========================= SAVE MODEL =========================
print("Saving model...")
trainer.save_model(str(cfg.HF_OUTPUT))
tokenizer.save_pretrained(str(cfg.HF_OUTPUT))

# ========================= EXPORT TO GGUF (for Ollama) =========================
print("Exporting to GGUF (Q5_K_M)...")
model.save_pretrained_gguf(
    str(cfg.GGUF_OUTPUT),
    tokenizer,
    quantization_method="q5_k_m",
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
Loading Kenyan legal data...
Creating dataset...
Created 0 training samples
Loading unsloth/Mistral-7B-Instruct-v0.3-bnb-4bit with Unsloth...
==((====))==  Unsloth 2025.11.2: Fast Mistral patching. Transformers: 4.57.1.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.4.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.32.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/4.14G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/157 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/587k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/446 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

Unsloth 2025.11.2 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


Setting up SFTTrainer...


TypeError: TrainingArguments.__init__() got an unexpected keyword argument 'tokenizer'

In [4]:
!pip install -q unsloth

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.8/61.8 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m351.3/351.3 kB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m14.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m564.7/564.7 kB[0m [31m44.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.2/117.2 MB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25h