In [None]:
class CFG:
    seed = 42
    dataset_path = "cvmistralparis/Q20LLM"
    preset = "/kaggle/input/llama-3/transformers/8b-chat-hf/1"  # name of pretrained Gemma
    sequence_length = 1024 # max size of input sequence for training
    train_batch = 8 # size of the input batch in training
    validation_batch = 16
    lr = 0.001
    epochs = 1 # number of epochs to train
    lora_rank = 4
    lora_alpha = 8
    model_save_name = f'llama-3_lorarank-{lora_rank}_loraalpha{lora_alpha}_epoch{epochs}'

In [None]:
!pip3 install transformers datasets peft -Uq
# !pip install torch~=2.1.0 --index-url https://download.pytorch.org/whl/cpu -q # Updating torch since we need the latest version
# !pip install torch_xla[tpu]~=2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html -q
!pip uninstall tensorflow -y # If we don't do this, TF will take over TPU and cause permission error for PT
!cp /kaggle/input/utils-xla/spmd_util.py . # From this repo: https://github.com/HeegyuKim/torch-xla-SPMD

In [None]:
import os
import pandas as pd
import numpy as np
import datasets
import torch.optim as optim
import torch_xla.debug.profiler as xp
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp # We also import mp modules if we wanna use that for some reason
import torch_xla.distributed.parallel_loader as pl
import torch_xla.test.test_utils as test_utils
import torch
import torch.nn as nn
import re
import torch_xla.experimental.xla_sharding as xs
import torch_xla.core.xla_model as xm
from transformers import (
    AutoTokenizer, AutoModelForCausalLM, DataCollatorWithPadding, AutoConfig
) # You can use any of models with those configs (even flan T5 xxl!). Other models are not supported.

from transformers import logging as hf_logging
import torch.nn.functional as F
import torch_xla.runtime as xr

xr.use_spmd()

import torch_xla.experimental.xla_sharding as xs # "experimental" prefix always means you're gonna have a good time LMAO
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
from torch_xla.experimental.xla_sharding import Mesh

from peft import LoraConfig, TaskType, get_peft_model # If we wanna use peft. Quantazation requiers GPU though. You'll have to download already quantazed models
from spmd_util import partition_module                # You could experiment with using already quantazed models like 4bit/Llama-2-7b-Chat-GPTQ if you're feeling funny
from datasets import Dataset, load_dataset, concatenate_datasets
from dataclasses import dataclass
from tqdm import tqdm

from peft import LoftQConfig, LoraConfig, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig
import datasets
import pandas as pd
import numpy as np
from datasets import Dataset
from sklearn.metrics import roc_auc_score

!export USE_TORCH=True #If we don't do this, transformers will seemingly bork the session upon import. Really weird error.
os.environ["PJRT_DEVICE"] = "TPU"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# os.environ.pop('TPU_PROCESS_ADDRESSES')
# os.environ.pop('CLOUD_TPU_TASK_ID')
hf_logging.set_verbosity_error() # It can still display warnings which is a bit annoying but whatever


# MAX_INPUT=512
# MODEL = "/kaggle/input/gemma/transformers/7b-it/2" #You should be able to use 13B model with no changes! There should be enough HBM
device = xm.xla_device()

In [None]:
tokenizer = AutoTokenizer.from_pretrained(CFG.preset)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
model = AutoModelForCausalLM.from_pretrained(CFG.preset, torch_dtype=torch.bfloat16)

In [None]:
lora_config = LoraConfig(
    r=CFG.lora_rank,
    lora_alpha=CFG.lora_alpha,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
    lora_dropout=0.05,
    bias="none"
#     inference_mode=True
)
peft_model = get_peft_model(model, lora_config)

In [None]:
config = AutoConfig.from_pretrained(CFG.preset)
num_devices = xr.global_runtime_device_count()
print(num_devices)
mesh_shape = (1, num_devices, 1, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('dp', 'fsdp', 'mp', 'sp'))
partition_module(peft_model, mesh) # After this, the model is sharded between cores but still has the same API as if it was on single device. Neat.

In [None]:
peft_model.print_trainable_parameters()

# datasets

In [None]:
import pandas as pd

train_df = pd.read_csv("/kaggle/input/20-questions-dataset/train.csv")
val_df = pd.read_csv("/kaggle/input/20-questions-dataset/validation.csv")

In [None]:
from typing import Iterable
import itertools

class LlamaFormatter:
    _bos_token = '<|begin_of_text|>'
    _start_header_token = '<|start_header_id|>'
    _end_header_token = '<|end_header_id|>'
    _end_token = '<|eot_id|>'

    def __init__(self, system_prompt: str = None, few_shot_examples: Iterable = None):
        self._system_prompt = system_prompt
        self._few_shot_examples = few_shot_examples
        self._turn_system = f"{self._start_header_token}system{self._end_header_token}\n\n{{}}{self._end_token}"
        self._turn_user = f"{self._start_header_token}user{self._end_header_token}\n\n{{}}{self._end_token}"
        self._turn_model = f"{self._start_header_token}assistant{self._end_header_token}\n\n{{}}{self._end_token}"
        self.reset()

    def __repr__(self):
        return self._state
    
    def system(self, prompt):
        self._state += self._turn_system.format(prompt)
        return self
    
    def user(self, prompt):
        self._state += self._turn_user.format(prompt)
        return self

    def model(self, prompt):
        self._state += self._turn_model.format(prompt)
        return self

    def start_user_turn(self):
        self._state += f"{self._start_header_token}user{self._end_header_token}\n\n"
        return self

    def start_model_turn(self):
        self._state += f"{self._start_header_token}assistant{self._end_header_token}\n\n"
        return self

    def end_turn(self):
        self._state += f"{self._end_token}\n"
        return self

    def reset(self):
        self._state = ""
        self._state += self._bos_token
        if self._system_prompt is not None:
            self.system(self._system_prompt)
        if self._few_shot_examples is not None:
            self.apply_turns(self._few_shot_examples, start_agent='user')
        return self

    def apply_turns(self, turns: Iterable, start_agent: str):
        formatters = [self.model, self.user] if start_agent == 'model' else [self.user, self.model]
        formatters = itertools.cycle(formatters)
        for fmt, turn in zip(formatters, turns):
            fmt(turn)
        return self

In [None]:
import random
from tqdm import tqdm

def get_random_word(exclude_word=''):
    w = random.choice(words['train']['keyword'])
    while w == exclude_word:
        w = random.choice(words['train']['keyword'])
    return w

system_prompt = "You are an AI assistant designed to play the 20 Questions game. In this game, the Answerer thinks of a keyword and responds to yes-or-no questions by the Questioner. The keyword is a specific places or things."
prompts = []
rear_keyword = ''
keyword_cnt = 1
for row in tqdm(train_df.itertuples(index=False), total=len(train_df)):
    if rear_keyword == row.label:
        keyword_cnt += 1
    else:
        keyword_cnt = 1
        rear_keyword = row.label
    formatter = LlamaFormatter(system_prompt=system_prompt)
    formatter.user("Let's play 20 Questions. You are playing the role of the Questioner. The keyword is a specific places or things.")
    prompt = eval(row.prompt)
    formatter.apply_turns(turns=prompt, start_agent='model')
    formatter.user('Now guess the keyword.')
    if keyword_cnt % 100 != 0:
        formatter.model('**'+row.label+'**')
        formatter.user('Correct!')
    else:
        formatter.model('**'+get_random_word(row.label)+'**')
        formatter.user('Wrong!')
  
    prompts.append(formatter._state)
    del formatter


In [None]:
from datasets import Dataset

def preprocess_func(example):
    inputs = tokenizer(example['text'], truncation=True, max_length=CFG.sequence_length, padding='max_length')
    return (
    {
        'input_ids': inputs.input_ids,
        'attention_mask': inputs.attention_mask
    })

train_ds = Dataset.from_dict({"text": prompts})
train_ds = train_ds.map(preprocess_func)
train_ds = train_ds.remove_columns(["text"])

In [None]:
import torch
from transformers import DataCollatorWithPadding
traindata_loader = torch.utils.data.DataLoader(train_ds, batch_size=CFG.train_batch, collate_fn=DataCollatorWithPadding(tokenizer=tokenizer),shuffle=True, num_workers=8)

In [None]:
print(prompts[0])

# Train

In [None]:
!export XLA_USE_BF16=1
def train(
    model, train_data, validation_data=None, train_batch=4, validation_batch=8, epochs=10, logging_steps=1,
    lr=1e-5, 
):
#     def train_model(model, input_ids, attention_mask, optimizer):
#         output = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
#         loss = output.loss
#         loss.backward()
#         optimizer.step()
#         return loss
    
#     compiled_step = torch.compile(model, backend="openxla")
    
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=6000/train_batch*epochs)
    for epoch in range(1, epochs + 1):
        model.train()
        xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))
        pbar = tqdm(range(len(train_data)))#, disable=True)
        data_iter = iter(train_data)
        total_loss = 0.
#         with torch.autocast("xla", dtype=torch.bfloat16):
        for step, batch in enumerate(pbar):
            batch = next(data_iter)
            optimizer.zero_grad()
            input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
            xs.mark_sharding(input_ids, mesh, (0, 1)) # Sharding inputs
            xs.mark_sharding(attention_mask, mesh, (0, 1))

#                 loss = compiled_step(model, input_ids, attention_mask, optimizer)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            xm.mark_step()
            total_loss = (total_loss * step + loss.item()) / (step + 1)
            if (step + 1) % logging_steps == 0:
                pbar.set_postfix({'loss': total_loss, 'step': step+1, 'epoch': epoch})
            scheduler.step()
#             if step > 10:
#                 break
        xm.master_print('Epoch {} train end {}, loss={:.3f}'.format(epoch, test_utils.now(), total_loss))
        model.eval()
        total_val_loss = 0.0
        if validation_data is not None:
            pbar = tqdm(range(len(validation_data)))#, disable=True)
            data_iter = iter(validation_data)
            with torch.no_grad():
                for step, batch in enumerate(pbar):
                    batch = next(data_iter)
                    input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
                    xs.mark_sharding(input_ids, mesh, (0, 1))
                    xs.mark_sharding(attention_mask, mesh, (0, 1))
                    outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
                    loss = outputs.loss
                    total_val_loss = (total_val_loss * step + loss.item()) / (step + 1)
                    if (step + 1) % logging_steps == 0:
                        pbar.set_postfix({'val_loss': total_val_loss, 'step': step+1, 'epoch': epoch})

            xm.master_print('Epoch {} test end {}, test val_loss={:.3f}'.format(epoch, test_utils.now(), total_val_loss))
        

        
    model.cpu()
    model.save_pretrained(CFG.model_save_name)

In [None]:
train(peft_model, traindata_loader, train_batch=CFG.train_batch, epochs=CFG.epochs, lr=CFG.lr)