In [None]:
!pip3 install transformers datasets sentencepiece peft -q
!pip install torch~=2.1.0 --index-url https://download.pytorch.org/whl/cpu -q # Updating torch since we need the latest version
!pip install torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html -q
!pip uninstall tensorflow -y # If we don't do this, TF will take over TPU and cause permission error for PT
!cp /kaggle/input/utils-xla/spmd_util.py . # From this repo: https://github.com/HeegyuKim/torch-xla-SPMD

In [None]:
!pip3 install scikit-learn -q
!pip3 install huggingface_hub -q

In [None]:
import os
import pandas as pd
import numpy as np
import datasets
import torch.optim as optim
import torch_xla.debug.profiler as xp
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp # We also import mp modules if we wanna use that for some reason
import torch_xla.distributed.parallel_loader as pl
import torch_xla.test.test_utils as test_utils
import torch
import torch.nn as nn
import re
import torch_xla.experimental.xla_sharding as xs
import torch_xla.core.xla_model as xm
from transformers import (
    GPTNeoXConfig, T5Config, LlamaConfig, AutoTokenizer, AutoModelForCausalLM, DataCollatorWithPadding, AutoConfig, AutoModelForSequenceClassification
) # You can use any of models with those configs (even flan T5 xxl!). Other models are not supported.

from transformers import logging as hf_logging
import torch.nn.functional as F
import torch_xla.runtime as xr

xr.use_spmd()

import torch_xla.experimental.xla_sharding as xs # "experimental" prefix always means you're gonna have a good time LMAO
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
from torch_xla.experimental.xla_sharding import Mesh

from peft import LoraConfig, TaskType, get_peft_model # If we wanna use peft. Quantazation requiers GPU though. You'll have to download already quantazed models
from spmd_util import partition_module                # You could experiment with using already quantazed models like 4bit/Llama-2-7b-Chat-GPTQ if you're feeling funny
from datasets import Dataset, load_dataset, concatenate_datasets
from dataclasses import dataclass
from tqdm import tqdm

import transformers
import datasets
import pandas as pd
import numpy as np
from datasets import Dataset
from sklearn.metrics import roc_auc_score

!export USE_TORCH=True # If we don't do this, transformers will seemingly bork the session upon import. Really weird error.
!export XLA_USE_BF16=1
!export TOKENIZERS_PARALLELISM=false
os.environ["PJRT_DEVICE"] = "TPU"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["XLA_USE_BF16"] = "1"

hf_logging.set_verbosity_error() # It can still display warnings which is a bit annoying but whatever
MAX_INPUT=512

In [None]:
from huggingface_hub import login

login(token="YOUR_TOKEN")

In [None]:
def test_device_count():
    print(xm.xla_device())
    print(xr.global_runtime_device_count())
test_device_count()

In [None]:
MODEL_NAME = "meta-llama/Llama-2-7b-hf"
MODEL = MODEL_NAME

In [None]:
from transformers import AutoTokenizer
model = AutoModelForSequenceClassification.from_pretrained(MODEL, num_labels=2, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(MODEL)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
FLAGS = {'MAX_INPUT': 512,
         'LOGGING_STEPS': 100,
         'NUM_EPOCHS': 2,
         'BATCH_SIZE': 1, # Making batch_size lower then 8 will result in slower training, but will allow for larger models\context. Fortunately, we have 128GBs. Setting higher batch_size doesn't seem to improve time.
          'NUM_STEPS': 8} 

In [None]:
import torch

# Alpaca prompt template
alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction: {}

### Input: {}

### Response: {}"""

EOS_TOKEN = tokenizer.eos_token

def formatting_prompts_func(examples):
    instructions = examples["instruction"]
    inputs = examples["input"]
    outputs = examples["output"]
    texts = []
    for instruction, input, output in zip(instructions, inputs, outputs):
        text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
        texts.append(text)
    return {"text": texts}

# Load and format the dataset
dataset = load_dataset("yahma/alpaca-cleaned", split="train")
dataset = dataset.map(formatting_prompts_func, batched=True)

# Create labels (assuming all examples are AI-generated)
dataset = dataset.add_column("label", [1] * len(dataset))

# Define preprocessing function
def preprocess_function(examples):
    tokenized = tokenizer(examples['text'], max_length=512, padding='max_length', truncation=True)
    tokenized['label'] = examples['label']
    return tokenized

# Split the dataset
ds = dataset.train_test_split(test_size=0.15)

# Apply preprocessing
ds['train'] = ds['train'].select(range(min(512, len(ds['train'])))).map(preprocess_function, batched=True, remove_columns=dataset.column_names)
ds['test'] = ds['test'].select(range(min(512, len(ds['test'])))).map(preprocess_function, batched=True, remove_columns=dataset.column_names)

print('Dataset prepared and preprocessed')

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
training_loader = torch.utils.data.DataLoader(ds['train'], batch_size=FLAGS['BATCH_SIZE'], collate_fn=data_collator)
testing_loader = torch.utils.data.DataLoader(ds['test'], batch_size=FLAGS['BATCH_SIZE'], collate_fn=data_collator)

In [None]:
# Test the DataLoader
print("Testing DataLoader:")
batch = next(iter(training_loader))
for k, v in batch.items():
    if isinstance(v, torch.Tensor):
        print(f"{k}: shape {v.shape}, dtype {v.dtype}")
    else:
        print(f"{k}: {type(v)}")

In [None]:
cnt = 0
for param in model.parameters():
    cnt += 1
    param.requires_grad = False
    if cnt > 270:
        param.requires_grad = True

config = AutoConfig.from_pretrained(MODEL)
config.pad_token_id = tokenizer.pad_token_id

num_devices = xr.global_runtime_device_count()
mesh_shape = (1, num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('dp', 'fsdp', 'mp'))
partition_module(model, mesh) # After this, the model is sharded between cores but still has the same API as if it was on single device. Neat.

In [None]:
def train(FLAGS):
    num_iterations = int(FLAGS['NUM_STEPS'] / FLAGS['BATCH_SIZE'])
    lr = 1e-5
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=FLAGS['NUM_STEPS'] * FLAGS['BATCH_SIZE'])

    for epoch in range(1, FLAGS['NUM_EPOCHS'] + 1):
        model.train()
        xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
        
        for step, batch in enumerate(training_loader):
            optimizer.zero_grad()
            
            # Process each item in the batch separately
            batch_loss = 0
            for i in range(len(batch.input_ids)):
                input_ids = batch.input_ids[i].unsqueeze(0).to(device)
                attention_mask = batch.attention_mask[i].unsqueeze(0).to(device)
                labels = batch.labels[i].unsqueeze(0).unsqueeze(0).to(device)
                
                xs.mark_sharding(input_ids, mesh, (0, 1))
                xs.mark_sharding(attention_mask, mesh, (0, 1))
                xs.mark_sharding(labels, mesh, (0, 1))
                
                outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
                batch_loss += loss
            
            # Average the loss over the batch
            batch_loss /= len(batch.input_ids)
            batch_loss.backward()
            optimizer.step()
            xm.mark_step()
            
            if (step + 1) % FLAGS['LOGGING_STEPS'] == 0:
                print(f'loss: {batch_loss.item()}, time: {test_utils.now()}, step: {step}')
            
            scheduler.step()
        
        model.eval()
        total_loss = 0.0
        total_steps = 0
        
        with torch.no_grad():
            for step, batch in enumerate(testing_loader):
                batch_loss = 0
                for i in range(len(batch.input_ids)):
                    input_ids = batch.input_ids[i].unsqueeze(0).to(device)
                    attention_mask = batch.attention_mask[i].unsqueeze(0).to(device)
                    labels = batch.labels[i].unsqueeze(0).unsqueeze(0).to(device)
                    
                    xs.mark_sharding(input_ids, mesh, (0, 1))
                    xs.mark_sharding(attention_mask, mesh, (0, 1))
                    xs.mark_sharding(labels, mesh, (0, 1))
                    
                    outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
                    loss = outputs.loss
                    batch_loss += loss.item()
                
                total_loss += batch_loss / len(batch.input_ids)
                total_steps += 1
        
        average_loss = total_loss / total_steps
        xm.master_print('Epoch {} test end {}, test loss={:.2f}'.format(epoch, test_utils.now(), average_loss))
        xm.master_print('Epoch {} train end {}'.format(epoch, test_utils.now()))

In [None]:
train(FLAGS)

In [None]:
model = model.cpu()
print('now saving the model')
model.push_to_hub(
    "felarof01/llama2", 
    tokenizer=tokenizer,
    private=False,
    create_pr=False,
    max_shard_size="2GB", # Sharding isn't as important as before since hardware is better now but who cares anyway
)