In [None]:
import csv
import os
import subprocess

import pandas as pd
import pyarrow.parquet as pq
from datasets import Dataset, DatasetDict, load_dataset
import pyarrow as pa
import torch
from torch.utils.data import TensorDataset
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          DataCollatorForLanguageModeling, Trainer,
                          TrainingArguments)

# Check and install missing packages
required_packages = [
    "datasets", "pandas", "pyarrow", "torch", "transformers"
]

for package in required_packages:
    try:
        __import__(package)
    except ImportError:
        subprocess.check_call(["pip", "install", package])


class HuggingFaceDataLoader:
    """
    A class for loading datasets from Hugging Face, combining them, and exporting to Parquet or CSV formats.

    This class provides functionality to load datasets from Hugging Face, combine multiple partitions
    if present, and export the resulting dataset to either Parquet or CSV format.

    Attributes:
        output_dir (str): The directory where output files will be saved.
        dataset: The loaded and combined dataset.
    """

    def __init__(self, output_dir: str):
        """
        Initialize the HuggingFaceDataLoader with an output directory.

        Args:
            output_dir (str): The directory where output files will be saved.
        """
        self.output_dir = output_dir

    def load_dataset(self, link: str):
        """
        Load a dataset from Hugging Face using the provided link.

        This method attempts to load the dataset and combines multiple partitions if present.

        Args:
            link (str): The Hugging Face dataset link.

        Raises:
            Exception: If there's an error loading the dataset.
        """
        try:
            self.raw_dataset = load_dataset(link)
            print("Dataset loaded successfully")
            print(f"Dataset structure: {self.raw_dataset}")
            self._combine_dataset()
        except Exception as e:
            print(f"Error loading dataset: {e}")
            raise

    def _combine_dataset(self):
        """
        Combine multiple partitions of a dataset into a single dataset.

        This method is called internally by load_dataset to merge multiple partitions.
        """
        self.dataset = self.raw_dataset[next(iter(self.raw_dataset))]
        for partition in list(self.raw_dataset.keys())[1:]:
            self.dataset = self.dataset.concatenate_datasets(
                [self.raw_dataset[partition]]
            )

    def output_dataset(self, split: str = None) -> DatasetDict:
        """
        Return the raw dataset as a DatasetDict object, or a specific split as a Dataset object.

        Args:
            split (str, optional): The name of the split to return. If None, returns the entire DatasetDict.

        Returns:
            DatasetDict | Dataset: The entire DatasetDict if no split is specified, or a Dataset object for a specific split.

        Raises:
            ValueError: If the dataset has not been loaded yet or if the specified split doesn't exist.
        """
        if self.raw_dataset is None:
            raise ValueError(
                "Dataset has not been loaded. Please load a dataset first."
            )

        if split is None:
            return self.raw_dataset
        elif split in self.raw_dataset:
            return self.raw_dataset[split]
        else:
            available_splits = list(self.raw_dataset.keys())
            raise ValueError(
                f"Split '{split}' not found. Available splits are: {available_splits}"
            )

    def load_local_file(self, file_name: str, chunk_size: int = 100000):
        """
        Load a dataset from a local file.
        This method attempts to load the dataset from a local file and sets it to self.raw_dataset.

        Args:
            file_name (str): The name of the file to load.
            chunk_size (int): The number of rows to load at a time for Parquet files.

        Raises:
            Exception: If there's an error loading the dataset.
        """
        try:
            _, extension = os.path.splitext(file_name)
            extension = extension.lower()
            file_path = os.path.join(self.output_dir, file_name)

            if extension == ".csv":
                df = pd.read_csv(file_path)
                self.raw_dataset = Dataset.from_pandas(df)
            elif extension == ".parquet":
                print("Parquet file detected")
                # Use pyarrow to read Parquet file in chunks
                parquet_file = pq.ParquetFile(file_path)

                # Read and process the file in chunks
                chunks = []
                for batch in parquet_file.iter_batches(batch_size=chunk_size):
                    chunks.append(batch.to_pandas())

                # Combine all chunks
                df = pd.concat(chunks, ignore_index=True)
                self.raw_dataset = Dataset.from_pandas(df)
            else:
                raise ValueError(f"Unsupported file type: {extension}")

            print("Local dataset loaded successfully")
            print(f"Dataset structure: {self.raw_dataset}")
            self._combine_dataset()
        except Exception as e:
            print(f"Error loading local dataset: {e}")
            raise

    def output_parquet(self, file_name: str):
        """
        Export the dataset to a Parquet file.

        Args:
            file_name (str): The name of the output file (without extension).

        Raises:
            Exception: If there's an error saving the Parquet file.
        """
        output_path_parquet = os.path.join(self.output_dir, file_name + ".parquet")

        try:
            self.dataset.to_parquet(output_path_parquet)
            print(f"Attempting to save data to: {output_path_parquet}")

            if os.path.exists(output_path_parquet):
                print(f"Parquet file successfully created at: {output_path_parquet}")
                print(
                    f"Parquet file size: {os.path.getsize(output_path_parquet)} bytes"
                )
            else:
                print(f"Error: Parquet file was not created at {output_path_parquet}")
        except Exception as e:
            print(f"Error saving to Parquet: {e}")
            raise

    def output_csv(self, file_name: str):
        """
        Export the dataset to a CSV file.

        Args:
            file_name (str): The name of the output file (without extension).

        Raises:
            Exception: If there's an error saving the CSV file.
        """
        output_path_csv = os.path.join(self.output_dir, file_name + ".csv")

        try:
            self.dataset.to_csv(output_path_csv)
            print(f"Attempting to save data to: {output_path_csv}")

            if os.path.exists(output_path_csv):
                print(f"CSV file successfully created at: {output_path_csv}")
                print(f"CSV file size: {os.path.getsize(output_path_csv)} bytes")
            else:
                print(f"Error: CSV file was not created at {output_path_csv}")
        except Exception as e:
            print(f"Error saving to CSV: {e}")
            raise

    def print_dataset(self):
        """
        Print the first few rows of the dataset.

        This method is useful for quickly inspecting the loaded dataset.
        """
        print(self.dataset.head())


class Llama7B:
    """Interface for generic Llama7B model with full fine-tuning"""

    def __init__(self, dataset_loader: HuggingFaceDataLoader):
        """
        Initialize the base Llama7B model
        """
        self.base_model = "NousResearch/Llama-2-7b-chat-hf"
        self.dataset = dataset_loader.output_dataset("train")  # pandas DataFrame
        self.tokenizer = AutoTokenizer.from_pretrained(self.base_model)
        self.model = None

    def finetune(self, **kwargs) -> Any:
        """
        Function handling the fine-tuning procedure of the LLM
        """
        training_args = TrainingArguments(
            output_dir="./results",
            num_train_epochs=3,
            per_device_train_batch_size=4,
            gradient_accumulation_steps=4,
            learning_rate=2e-5,
            fp16=True,
            logging_steps=10,
            save_steps=100,
            **kwargs,
        )

        # Load the model
        self.model = AutoModelForCausalLM.from_pretrained(
            self.base_model,
            device_map="auto",
            torch_dtype=torch.float16,
        )

        # Prepare the dataset
        def formatting_prompts_func(row):
            return f"### Human: {row['instruction']}\n### Assistant: {row['response']}"

        self.dataset["text"] = self.dataset.apply(formatting_prompts_func, axis=1)

        # Tokenize the dataset
        tokenized_inputs = self.tokenizer(
            list(self.dataset["text"]),
            truncation=True,
            padding="max_length",
            max_length=512,
            return_tensors="pt",
        )

        # Create labels for causal language modeling
        input_ids = tokenized_inputs["input_ids"]
        attention_mask = tokenized_inputs["attention_mask"]
        labels = input_ids.clone()

        # Create a TensorDataset
        train_dataset = TensorDataset(input_ids, attention_mask, labels)

        # Define data collator
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=self.tokenizer, mlm=False
        )

        # Initialize the Trainer
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            data_collator=data_collator,
        )

        # Train the model
        trainer.train()
        trainer.save_model("./finetuned_model")

    def prompt(self, prompt: str, **kwargs) -> Any:
        """
        Function handling the generation of output
        """
        if self.model is None:
            raise ValueError("Model has not been fine-tuned or loaded yet.")

        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        outputs = self.model.generate(inputs.input_ids, max_new_tokens=100, **kwargs)
        return self.tokenizer.decode(outputs[0], skip_special_tokens=True)


def main() -> None:
    data_loader = HuggingFaceDataLoader("data/")
    data_loader.load_dataset("Nan-Do/code-search-net-python")

    model = Llama7B(data_loader)
    model.finetune()


if __name__ == "__main__":
    main()