In [13]:
''' Do not change this cell '''
#Functions

def check_gpu():
  if tf.test.gpu_device_name() != '':
    print("GPU sucessfully connected")
  else: print("PLease connect GPU")



# **#Task 1**#


In [14]:
#import all the required packages and  libraries for FineTuning LLM
!pip install transformers datasets accelerate bitsandbytes
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
import torch
import tensorflow as tf



In [15]:
#Setup and check a gpu connection.
if tf.config.list_physical_devices('GPU'):
    print(f"GPU is available: {tf.config.list_physical_devices('GPU')[0]}")
    print(f"TensorFlow version: {tf.__version__}")
else:
    print("No GPU found. Please connect a GPU for efficient training.")

GPU is available: PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')
TensorFlow version: 2.17.1


In [16]:
''' Do not change this cell '''
check_gpu()


GPU sucessfully connected


In [17]:
import os
import logging
from typing import Dict, Any, List, Optional
from dataclasses import dataclass
from datasets import load_dataset, Dataset, DatasetDict
from huggingface_hub import login, HfApi
from transformers import PreTrainedTokenizer, AutoTokenizer
import torch
from tqdm import tqdm

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


In [25]:

@dataclass
class DataConfig:
    source_dataset: str = "jizzu/llama2_indian_law_v2"
    target_repo: str = "rndascode/Llama2_Indian_Law"
    model_name: str = "NousResearch/Llama-2-7b-chat-hf"
    max_length: int = 4096
    system_prompt: str = "You are a helpful assistant that provides accurate information about Indian law."

In [26]:
class DataPreprocessor:
    def __init__(self, config: DataConfig):
        self.config = config
        self.tokenizer = None
        self._initialize_tokenizer()

    def _initialize_tokenizer(self) -> None:
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
            logger.info(f"Initialized tokenizer from {self.config.model_name}")
        except Exception as e:
            logger.error(f"Failed to initialize tokenizer: {str(e)}")
            raise

    def load_dataset(self) -> DatasetDict:
        try:
            dataset = load_dataset(self.config.source_dataset)
            logger.info(f"Successfully loaded dataset: {self.config.source_dataset}")
            return dataset
        except Exception as e:
            logger.error(f"Failed to load dataset: {str(e)}")
            raise

    def _validate_conversation_format(self, text: str) -> bool:
        segments = text.split('###')
        if len(segments) < 3:
            return False
        return all(('Human:' in seg or 'Assistant:' in seg) for seg in segments[1:])

    def _clean_text(self, text: str) -> str:
        return ' '.join(text.split()).strip()

    def _format_conversation(self, text: str) -> str:
        if not self._validate_conversation_format(text):
            logger.warning("Invalid conversation format detected")
            return ""

        segments = text.split('###')
        formatted_texts = []
        current_conversation = []

        # Begin first conversation turn
        if self.config.system_prompt:
            current_conversation.append(f"<s>[INST] <<SYS>>\n{self.config.system_prompt}\n<</SYS>>\n\n")
        else:
            current_conversation.append("<s>")

        for i in range(1, len(segments), 2):
            human_text = self._clean_text(segments[i].replace('Human:', ''))

            # If this is the first message and we have a system prompt,
            # append to the existing [INST] tag
            if i == 1 and self.config.system_prompt:
                current_conversation[-1] += f"{human_text} [/INST]"
            else:
                # Start new conversation turn
                if i > 1:
                    current_conversation.append("<s>")
                current_conversation.append(f"[INST] {human_text} [/INST]")

            # Add assistant response if available
            if i + 1 < len(segments):
                assistant_text = self._clean_text(segments[i + 1].replace('Assistant:', ''))
                current_conversation.append(f"{assistant_text} </s>")
                formatted_texts.append(''.join(current_conversation))
                current_conversation = []
            else:
                # For incomplete conversations (no assistant response)
                current_conversation.append("</s>")
                formatted_texts.append(''.join(current_conversation))

        return ' '.join(formatted_texts)

    def _check_length(self, text: str) -> bool:
        tokens = self.tokenizer.encode(text)
        return len(tokens) <= self.config.max_length

    def process_example(self, example: Dict[str, Any]) -> Dict[str, Any]:
        try:
            formatted_text = self._format_conversation(example['text'])
            if not formatted_text or not self._check_length(formatted_text):
                logger.warning(f"Skipping invalid or too long example")
                return {'text': '', 'valid': False}
            return {'text': formatted_text, 'valid': True}
        except Exception as e:
            logger.error(f"Error processing example: {str(e)}")
            return {'text': '', 'valid': False}

    def process_dataset(self, dataset: Dataset) -> Dataset:
        processed = dataset.map(
            self.process_example,
            desc="Processing examples",
            remove_columns=dataset.column_names
        )
        processed = processed.filter(lambda x: x['valid'])
        processed = processed.remove_columns(['valid'])
        return processed

    def push_to_hub(self, dataset_dict: DatasetDict, token: str) -> None:
        try:
            login(token=token)

            for split_name, dataset in dataset_dict.items():
                processed_dataset = self.process_dataset(dataset)
                processed_dataset.push_to_hub(
                    self.config.target_repo,
                    split=split_name,
                    private=False
                )
                logger.info(f"Successfully pushed {split_name} split to {self.config.target_repo}")
        except Exception as e:
            logger.error(f"Failed to push to hub: {str(e)}")
            raise

In [28]:

def main():
    try:
        config = DataConfig()
        preprocessor = DataPreprocessor(config)
        dataset = preprocessor.load_dataset()
        preprocessor.push_to_hub(
            dataset,
            token="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"
        )
        logger.info("Data preprocessing completed successfully")
    except Exception as e:
        logger.error(f"Pipeline failed: {str(e)}")
        raise
if __name__ == "__main__":
    main()

Processing examples:   0%|          | 0/24607 [00:00<?, ? examples/s]



Filter:   0%|          | 0/24607 [00:00<?, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/25 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/514 [00:00<?, ?B/s]

Processing examples:   0%|          | 0/455 [00:00<?, ? examples/s]

Filter:   0%|          | 0/455 [00:00<?, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/514 [00:00<?, ?B/s]

Processing examples:   0%|          | 0/276 [00:00<?, ? examples/s]

Filter:   0%|          | 0/276 [00:00<?, ? examples/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/514 [00:00<?, ?B/s]

Check the dataset

https://huggingface.co/datasets/jizzu/llama2_indian_law_v2