In [None]:

!pip install transformers datasets torch accelerate
    

In [None]:

class Config:
    # Paths
    dataset_name = "imdb"  # Hugging Face dataset name
    model_name = "distilbert-base-uncased"  # Model to fine-tune
    output_dir = "./fine_tuned_model"  # Directory to save the fine-tuned model

    # Training parameters
    batch_size = 16
    epochs = 3
    learning_rate = 5e-5
    logging_dir = "./logs"
    

In [None]:

from transformers import AutoModelForSequenceClassification, AutoTokenizer, TrainingArguments, Trainer

class Model:
    def __init__(self, model_name, num_labels):
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

    def tokenize_data(self, dataset, text_column="text"):
        # Tokenize the dataset for the model
        return dataset.map(
            lambda x: self.tokenizer(x[text_column], truncation=True, padding="max_length"),
            batched=True
        )

    def get_trainer(self, tokenized_data, output_dir, batch_size, epochs, learning_rate, logging_dir):
        # Define training arguments
        training_args = TrainingArguments(
            output_dir=output_dir,
            evaluation_strategy="epoch",
            learning_rate=learning_rate,
            per_device_train_batch_size=batch_size,
            num_train_epochs=epochs,
            save_strategy="epoch",
            logging_dir=logging_dir,
            push_to_hub=False,  # Avoid pushing to Hugging Face Hub
        )

        # Create the Trainer object
        return Trainer(
            model=self.model,
            args=training_args,
            train_dataset=tokenized_data["train"],
            eval_dataset=tokenized_data["test"],
            tokenizer=self.tokenizer,
        )
    

In [None]:

from datasets import load_dataset

class DataService:
    @staticmethod
    def load_data(dataset_name):
        # Load the dataset
        return load_dataset(dataset_name)

    @staticmethod
    def preprocess_data(dataset, num_samples=None):
        # Optional: Reduce dataset size for faster training
        if num_samples:
            dataset = dataset.shuffle(seed=42).select(range(num_samples))
        return dataset
    

In [None]:

from model import Model
from data_service import DataService

class Controller:
    def __init__(self, config):
        self.config = config
        self.model = Model(config.model_name, num_labels=2)  # Binary classification (positive/negative)

    def fine_tune(self):
        # Step 1: Load and preprocess data
        raw_data = DataService.load_data(self.config.dataset_name)
        tokenized_data = self.model.tokenize_data(raw_data)

        # Step 2: Get the Trainer
        trainer = self.model.get_trainer(
            tokenized_data=tokenized_data,
            output_dir=self.config.output_dir,
            batch_size=self.config.batch_size,
            epochs=self.config.epochs,
            learning_rate=self.config.learning_rate,
            logging_dir=self.config.logging_dir,
        )

        # Step 3: Fine-tune the model
        trainer.train()

        # Step 4: Save the fine-tuned model
        trainer.save_model(self.config.output_dir)
        print("Model fine-tuned and saved to:", self.config.output_dir)
    

In [None]:

class View:
    @staticmethod
    def display_message(message):
        print(message)

    @staticmethod
    def display_error(error):
        print(f"Error: {error}")
    

In [None]:

from config import Config
from controller import Controller
from view import View

def main():
    # Instantiate the configuration and controller
    config = Config()
    controller = Controller(config)
    view = View()

    try:
        # Display starting message
        view.display_message("Starting fine-tuning process...")

        # Start fine-tuning
        controller.fine_tune()

        # Display success message
        view.display_message("Fine-tuning completed successfully!")

    except Exception as e:
        # Display any errors encountered
        view.display_error(str(e))

if __name__ == "__main__":
    main()
    