In [None]:
# from datasets import load_dataset
import json
from sklearn.model_selection import train_test_split
from huggingface_hub import login
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments
import torch
from torch.utils.data import Dataset, DataLoader

## Data Preprocessing

In [None]:
def dataset_gen(text_file):
    with open(text_file, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    
    texts = [line.strip() for line in lines if len(line.strip()) > 0]
    train_texts, val_texts = train_test_split(texts, test_size=0.1)
    return train_texts, val_texts

In [None]:
train_texts, val_texts = dataset_gen('data/mahabharat.txt')

In [None]:
print(f"Number of training samples: {len(train_texts)}")
print(f"Number of validation samples: {len(val_texts)}")
print(f"Example training sample:\n{train_texts[0]}")

## Fine Tuning

### Authorize huggingface

In [None]:
with open('secrets.json', 'r') as f:
    secrets = json.load(f)
auth_token = secrets['huggingface_token']
login(token=auth_token)

### Tokenize Data

In [None]:
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
def tokenize_data(text):
    return tokenizer(text, padding="max_length", truncation=True, max_length=512, return_tensors="pt")

In [None]:
train_encodings = tokenize_data(train_texts)
val_encodings = tokenize_data(val_texts)

sample_length = 5
print(f"Example training input dimensions: {train_encodings['input_ids'][0].shape}")
print(f"Example word embedding:\nText: {train_texts[0].split(' ')[:sample_length]}\nEmbedding: {train_encodings['input_ids'][0][:sample_length+2]}")

### Create Datasets

In [None]:
class TextDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __len__(self):
        return len(self.encodings["input_ids"])

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        return item

In [None]:
train_dataset = TextDataset(train_encodings)
val_dataset = TextDataset(val_encodings)

### Fine tuning the model

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_id)

In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    weight_decay=0.01,
    warmup_steps=500,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

In [None]:
trainer.train()

In [None]:
model.save_pretrained("./fine-tuned-llama-3-8b")
tokenizer.save_pretrained("./fine-tuned-llama-3-8b")

## Inference

In [None]:
model_path = './fine-tuned-llama-3-8b'
model = AutoModelForCausalLM.from_pretrained("./fine-tuned-llama-3-8b")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path)

In [None]:
def generate_story(prompt_text, max_length=250):
    input_ids = tokenizer(prompt_text, return_tensors="pt").input_ids
    output = model.generate(input_ids, max_length=max_length, pad_token_id=tokenizer.eos_token_id)
    return tokenizer.decode(output[0], skip_special_tokens=True)

In [None]:
def modernize_story(story):
    modern_prompt = f"Rewrite the following story in a modern context:\n\n{story}\n\nModernized story:"
    input_ids = tokenizer(modern_prompt, return_tensors="pt").input_ids
    output = model.generate(input_ids, max_length=500, num_return_sequences=1, no_repeat_ngram_size=2, do_sample=True, top_k=50, top_p=0.95)
    return tokenizer.decode(output[0], skip_special_tokens=True)

In [None]:
prompt = "Give a story from the Mahabharata desribing the bravery of Karna."
epic_story = generate_story(prompt)
modern_story = modernize_story(epic_story)

print("Epic Story:\n", epic_story)
print("\nModernized Story:\n", modern_story)