<a href="https://colab.research.google.com/github/frank-morales2020/MLxDL/blob/main/MISTRAL_FINETUNE_TPU_COLAB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Google Cloud TPU V6e Trillium - Economical Option for AI Models But Locked-In: https://www.youtube.com/watch?v=20Ysq8w_0g4

Based on the provided Jupyter notebook, the maximum number of steps for the training process is 5,834.

This value is derived from the size of the training dataset and the specified batch size.

* The training dataset, train_data, contains 46,670 samples after null values are dropped.

* The BATCH_SIZE is set to 8.

* The code performs training for a single epoch (epoch 2), with each step processing a batch of data.

The total number of steps is calculated by dividing the total number of data samples by the batch size:

Total Steps=⌈
46,670/8⌉=⌈5833.75⌉=5834

In [None]:
# Install packages for LLM fine-tuning
!pip install einops -q
!pip install langsmith -q
!pip install bitsandbytes -q
!pip install peft --upgrade -q
!pip install trl -q

In [2]:
# Removed problematic import
import torch_xla
from torch_xla.distributed.spmd import XLAShardedTensor, XLAShard

In [3]:
import torch_xla.runtime as xr

In [6]:
import torch
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.debug.profiler as xp
import torch_xla.test.test_utils as test_utils
import torch.nn.functional as F
import torch_xla.runtime as xr


xr.use_spmd()

In [7]:
!pip install -q datasets

In [8]:
!pip install peft -q

In [None]:
import os
import gc
import re
from tqdm import tqdm
import pandas as pd
import numpy as np
import datasets
from datasets import Dataset
from dataclasses import dataclass

import transformers
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    AutoTokenizer,
    TrainingArguments,
    DataCollatorWithPadding,
    AutoConfig,
    GPTNeoXConfig,
    T5Config,
    LlamaConfig,
    MistralConfig,
)
from transformers.tokenization_utils_base import (
    PreTrainedTokenizerBase,
    PaddingStrategy,
)
from transformers import logging as hf_logging

from peft import (
    # prepare_model_for_int8_training, # Removed this import
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    PeftModel,
)

In [3]:
from transformers import (
    GPTNeoXConfig,
    T5Config,
    LlamaConfig,
    MistralConfig,
)
import torch.nn as nn
# import torch_xla.experimental.xla_sharding as xs
import torch_xla.core.xla_model as xm
import re
import torch
import warnings
warnings.filterwarnings("ignore")


# ends with $ to prevent sharding lora parameters
GPTNEOX_RULES = (
    # embeddings
    ("gpt_neox\\.embed_in", ("mp", "fsdp")),
    # atention
    ("attention\\.query_key_value$", ("fsdp", "mp")),
    ("attention\\.dense$", ("mp", "fsdp")),
    # mlp
    ("mlp\\.dense_h_to_4h$", ("fsdp", "mp")),
    ("mlp\\.dense_4h_to_h$", ("mp", "fsdp")),
    # output
    ("embed_out", ("fsdp", "mp")),
)

T5_RULES = (
    # embeddings
    ("shared$", ("mp", "fsdp")),
    ("embed_tokens$", ("mp", "fsdp")),

    # attention
    ("q$", ("fsdp", "mp")),
    ("k$", ("fsdp", "mp")),
    ("v$", ("fsdp", "mp")),
    ("o$", ("mp", "fsdp")),

    # mlp
    ("w$", ("fsdp", "mp")),
    ("wi_0$", ("fsdp", "mp")),
    ("wi_1$", ("fsdp", "mp")),
    ("wo$", ("mp", "fsdp")),

    # seq2seq lm head
    ("lm_head", ("fsdp", "mp")),
)

LLAMA_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.gate_proj", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("mlp\\.up_proj", ("fsdp", "fsdp")),
    ("lm_head", ("fsdp", "mp")),
    )

MISTRAL_RULES = (
    ("model\\.embed_tokens", ("mp", "fsdp")),
    ("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
    ("self_attn\\.o_proj", ("mp", "fsdp")),
    ("mlp\\.gate_proj", ("fsdp", "mp")),
    ("mlp\\.down_proj", ("mp", "fsdp")),
    ("mlp\\.up_proj", ("fsdp", "mp")),
    ("lm_head", ("fsdp", "mp")),
    )

ALL_RULES = [
    (GPTNeoXConfig, GPTNEOX_RULES),
    (T5Config, T5_RULES),
    (LlamaConfig, LLAMA_RULES),
    (MistralConfig, MISTRAL_RULES)
]

strkey2id = {
    "dp": 0,
    "fsdp": 1,
    "mp": 2
}

def find_rule(model):
    for config, rule in ALL_RULES:
        if model.config.__class__ == config:
            return rule
    raise Exception("unsupported model to partitioning")

def partition_module(model, mesh, device=xm.xla_device(), verbose=False):
    partition_specs = find_rule(model)
    rule = [(k, tuple([strkey2id[x] for x in v])) for k, v in partition_specs]

    for name, module in model.named_modules():
        module.to(device)
        # print(name, module.__class__.__name__)
        if isinstance(module, (torch.nn.Embedding, torch.nn.Linear)):
            for rule_pattern, spec in rule:
                if re.findall(rule_pattern, name):
                    if verbose:
                        print("match", rule_pattern, name)

                    # xs.mark_sharding(module.weight, mesh, spec) # Commented out due to missing xs
                    break

def partition_module_dp(model, mesh, device=xm.xla_device(), verbose=False):
    spec = (1, 2)

    for name, module in model.named_modules():
        module.to(device)
        if isinstance(module, (nn.Embedding, nn.nn.Linear)):
            # xs.mark_sharding(module.weight, mesh, spec) # Commented out due to missing xs
            pass # Added a pass statement as a placeholder

In [4]:
import numpy as np
def apk(actual, predicted, label_weights=(), k=10):
    """
    Computes the average precision at k.
    This function computes the average prescision at k between two lists of
    items.
    Parameters
    ----------
    actual : list
        A list of elements that are to be predicted (order doesn't matter)
    predicted : list
        A list of predicted elements (order does matter)
    label_weights : list
        A list of weights corresponding to each actual item
    k : int, optional
        The maximum number of predicted elements
    Returns
    -------
    score : double
            The average precision at k over the input lists
    """

    if not isinstance(actual, (list, pd.core.series.Series, np.ndarray)):
        raise Exception(
            "actual should be either list,pd.core.series.Series,np.ndarray"
        )

    if len(actual) < 1:
        return 0.0

    # Normalize the weights in order not to get apk above 1
    label_weights_count = len(label_weights)
    label_weights_sum = sum(label_weights)

    if len(predicted) > k:
        predicted = predicted[:k]

    score = 0.0
    num_hits = 0.0

    for i, p in enumerate(predicted):
        if p in actual and p not in predicted[:i]:
            num_hits += 1.0
            score += num_hits / (i + 1.0)
    return score / min(len(actual), k)


def recall_at_k(actual, predicted, label_weights=(), k=10):
    """
    Computes the percentage of actual items found in the top k predictions over
    all actual items.

    Parameters
    ----------
    actual : list
        A list of the actually clicked items.
    predicted : list
        A list of the predicted items, ranked.
    label_weights : list
        A list of weights corresponding to each actual item
    k : int
        The number of the top predictions that will be taken into account for the
        computation.

    Returns
    -------
    (double): recall@k
    """

    if not isinstance(actual, (list, pd.core.series.Series, np.ndarray)):
        raise Exception(
            "actual should be either list,pd.core.series.Series,np.ndarray"
        )

    if len(actual) < 1:
        return 0.0

    if len(predicted) > k:
        predicted = predicted[:k]

    total_pred_weighted = sum(label_weights)
    success_weighted = 0.0
    for i, item in enumerate(actual):
        if item in predicted:
            success_weighted += label_weights[i]

    if success_weighted == 0.0 and total_pred_weighted == 0.0:
        return 0.0

    return success_weighted / total_pred_weighted


def mean_metric(actual, predicted, metric_name, k=10, weights=()):
    """
    Computes the mean of the given metric after it calculates it for each sample.

    Possible values for metric_name:
        - map@k
        - mr@k
        - mean_rank_clicked

    Parameters
    ----------
    actual : list
             A list of lists of elements that are to be predicted
             (order doesn't matter in the lists)
    predicted : list
                A list of lists of predicted elements
                (order matters in the lists)
    k : int, optional
        The maximum number of predicted elements
    weights : list
              A list of lists of weights, each one characterizing a pair of items.
    Returns
    -------
    score : double
            The mean of the metric with metric_name.
    """

    if metric_name == "map@k":
        return np.mean([apk(a, p, w, k) for a, p, w in zip(actual, predicted, weights)])
    if metric_name == "mr@k":
        return np.mean(
            [recall_at_k(a, p, w, k) for a, p, w in zip(actual, predicted, weights)]
        )
    raise Exception("metric_name should one of the following: 'map@k', 'mr@k'")

In [5]:
def generate_prompt_training(context, prompt, a, b, c, d, e, answer):
    return f"""<human>: Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction: You will be given question with 5 possible answers. Answer the following multiple choice question by giving the most appropriate response. Answer should be one among [A, B, C, D, E]

### Context: {context}\n

### Question: {prompt}\n
A) {a}\n
B) {b}\n
C) {c}\n
D) {d}\n
E) {e}\n

<assistant>: The correct answer is: {answer}"""


def generate_prompt_inference(context, prompt, a, b, c, d, e):
    return f"""<human>: Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction: You will be given question with 5 possible answers. Answer the following multiple choice question by giving the most appropriate response. Answer should be one among [A, B, C, D, E]

### Context: {context}\n

### Question: {prompt}\n
A) {a}\n
B) {b}\n
C) {c}\n
D) {d}\n
E) {e}\n

<assistant>: The correct answer is:  """


def generate_prompt_inference(context, prompt, a, b, c, d, e):
    return f"""<human>: Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction: You will be given question with 5 possible answers. Answer the following multiple choice question by giving the most appropriate response. Answer should be one among [A, B, C, D, E]

### Context: {context}

"""


def generate_question_inference(prompt, a, b, c, d, e):
    return f"""
### Question: {prompt}\n
A) {a}\n
B) {b}\n
C) {c}\n
D) {d}\n
E) {e}\n

<assistant>: The correct answer is:  """


def generate_answer_inference(prompt, a, b, c, d, e, answer):
    return f"""
### Question: {prompt}\n
A) {a}\n
B) {b}\n
C) {c}\n
D) {d}\n
E) {e}\n

<assistant>: The correct answer is: {answer}"""

In [6]:
from datasets import Dataset
import torch

class TorchDataset(Dataset):
    def __init__(self, df, inference_only=False):
        super().__init__()

        self.df = df
        self.inference_only = inference_only
        self.prompt = df.prompt.tolist()
        self.input_ids = df.input_ids.tolist()
        self.attention_mask = df.attention_mask.tolist()

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        input_ids = torch.tensor(self.input_ids[index])
        attention_mask = torch.tensor(self.attention_mask[index])

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
        }

In [7]:
def get_prompt_text(data):
    prompt_texts = []

    for index in tqdm(range(len(data))):
        sample_raw_prompt = generate_prompt_inference(
            context=data.iloc[index]["more_context"],
            prompt=data.iloc[index]["prompt"],
            a=data.iloc[index]["A"],
            b=data.iloc[index]["B"],
            c=data.iloc[index]["C"],
            d=data.iloc[index]["D"],
            e=data.iloc[index]["E"],
        )

        prompt_texts.append(sample_raw_prompt)

    data["prompt_text"] = prompt_texts

    return data

def get_answer_text(data):

    prompt_texts = []

    for index in tqdm(range(len(data))):
        sample_raw_prompt = generate_answer_inference(
            prompt=data.iloc[index]["prompt"],
            a=data.iloc[index]["A"],
            b=data.iloc[index]["B"],
            c=data.iloc[index]["C"],
            d=data.iloc[index]["D"],
            e=data.iloc[index]["E"],
            answer=data.iloc[index]["answer"],
        )

        prompt_texts.append(sample_raw_prompt)

    data["answer_text"] = prompt_texts

    return data

def get_question_text(data):
    prompt_texts = []

    for index in tqdm(range(len(data))):
        sample_raw_prompt = generate_question_inference(
            prompt=data.iloc[index]["prompt"],
            a=data.iloc[index]["A"],
            b=data.iloc[index]["B"],
            c=data.iloc[index]["C"],
            d=data.iloc[index]["D"],
            e=data.iloc[index]["E"],
        )

        prompt_texts.append(sample_raw_prompt)

    data["question_text"] = prompt_texts

    return data

In [8]:
BATCH_SIZE = 8
EPOCHS = 2
LEARNING_RATE = 1e-5
MAX_LENGTH = 1280
LOGGING_STEPS = 100
NUM_REPLICAS = 1

MODEL_NAME= "mistralai/Mistral-7B-Instruct-v0.1"


TRAIN = True
SAVE_MODEL = True
LOAD_MODEL = True

https://github.com/pytorch/xla

In [None]:
!pip install torch==2.8.0 'torch_xla[tpu]==2.8.0' -f https://storage.googleapis.com/libtpu-releases/index.html -q

# Optional: if you're using custom kernels, install pallas dependencies
!pip install 'torch_xla[pallas]' -f https://storage.googleapis.com/libtpu-releases/index.html -q

In [1]:
import torch
import torch_xla.core.xla_model as xm

In [9]:
import torch_xla
device = torch_xla.device()
device

device(type='xla', index=0)

In [11]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("cdeotte/60k-data-with-context-v2")

print("Path to dataset files:", path)

Path to dataset files: /kaggle/input/60k-data-with-context-v2


In [12]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
).to(device)

# Freezing most of the model's layers
cnt = 0
for param in model.parameters():
    cnt += 1
    param.requires_grad = True
    if cnt < 285:
        param.requires_grad = False

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    use_auth_token=True
)

tokenizer.pad_token = tokenizer.eos_token

# tokenizer.truncation_side = "left"
tokenizer.padding_side = "left"

In [15]:
import pandas as pd
train_data = pd.read_csv(
    "/kaggle/input/60k-data-with-context-v2/all_12_with_context2.csv"
)

train_data = train_data.dropna()
train_data = train_data.sample(frac=1.0, random_state=42)
print(train_data.shape)

train_data["more_context"] = train_data["context"].copy()

train_data.reset_index(drop=True, inplace=True)
train_data.head()

(46670, 9)


Unnamed: 0,prompt,context,A,B,C,D,E,answer,source,more_context
0,How did the AS-15TT missile compare to the Bri...,The AS-15TT missile was relatively similar to ...,"The AS-15TT missile was red in color, unlike t...",The AS-15TT missile was of the same size as th...,"The AS-15TT missile was identical in size, wei...","The AS-15TT missile was smaller, slimmer, ligh...","The AS-15TT missile was larger, wider, and hea...",D,3,The AS-15TT missile was relatively similar to ...
1,What were the main objectives for the formatio...,The 1st Colorado Infantry Regiment (officially...,The 1st Colorado Cavalry Regiment was formed i...,The 1st Colorado Cavalry Regiment was formed i...,The 1st Colorado Cavalry Regiment was formed i...,The 1st Colorado Cavalry Regiment was formed i...,The 1st Colorado Cavalry Regiment was formed i...,B,4,The 1st Colorado Infantry Regiment (officially...
2,"What is the significance of the song ""Oh No! O...",Oh My!' is the debut album of indie rock band ...,"The song ""Oh No! Oh My!"" was a bonus track add...","The song ""Oh No! Oh My!"" was originally releas...","The song ""Oh No! Oh My!"" was written by Ryland...","The song ""Oh No! Oh My!"" was the lead single f...","The song ""Oh No! Oh My!"" was the band Oh No! O...",B,4,Oh My!' is the debut album of indie rock band ...
3,In which event did Carol Lindroos compete at t...,Carol Lindroos (29 May 1930 - 9 December 2001)...,Men's discus throw,100-meter sprint,Shot put,Long jump,High jump,A,4,Carol Lindroos (29 May 1930 - 9 December 2001)...
4,What is the capital of Tarata District and Tar...,Tarata is a city in the Tacna Region in southe...,Lima,Peru City,Puno,Tarata,Tacna,D,2,Tarata is a city in the Tacna Region in southe...


In [16]:
from tqdm import tqdm

In [17]:
train_data = get_prompt_text(train_data)
train_data = get_answer_text(train_data)
train_data = get_question_text(train_data)

100%|██████████| 46670/46670 [00:05<00:00, 8522.16it/s]
100%|██████████| 46670/46670 [00:05<00:00, 8511.71it/s]
100%|██████████| 46670/46670 [00:04<00:00, 9534.56it/s]


In [18]:
def preprocess_function(example):
    text_tokens = tokenizer(
        example["prompt_text"],
        " \n " + example["question_text"],
        truncation='only_first',
        max_length=MAX_LENGTH,
        padding="max_length",
        add_special_tokens=False
    )["input_ids"]

    answer_tokens = tokenizer(
        example["prompt_text"],
        " \n " + example["answer_text"],
          truncation='only_first',
        max_length=MAX_LENGTH,
        padding="max_length",
        add_special_tokens=False
    )["input_ids"]

    answer_tokens = [-100 for i in range(len(answer_tokens) - 1)] + [answer_tokens[-1]]

    return {
        "input_ids": text_tokens,
        "label": answer_tokens,
    }


def preprocess_function_inference(example):
    text_tokens = tokenizer(
        example["prompt_text"],
        " \n " + example["question_text"],
        truncation='only_first',
        max_length=MAX_LENGTH,
        padding="max_length",
        add_special_tokens=False
    )["input_ids"]

    answer_tokens = tokenizer(
        example["prompt_text"],
        " \n " + example["answer_text"],
        truncation='only_first',
        max_length=MAX_LENGTH,
        padding="max_length",
        add_special_tokens=False
    )["input_ids"]

    answer_tokens = [-100 for i in range(len(answer_tokens) - 1)] + [answer_tokens[-1]]

    return {
        "input_ids": text_tokens,
        "label": answer_tokens,
    }

In [None]:
data_train = Dataset.from_pandas(train_data)

data_train = data_train.map(
    preprocess_function,
    batched=False,
    num_proc=56
).remove_columns(list(train_data.columns))

data_train

In [None]:
print(tokenizer.decode(data_train["input_ids"][44]))

In [21]:
data_train["label"][44][-10:]

[-100, -100, -100, -100, -100, -100, -100, -100, -100, 384]

In [22]:
FLAGS = {
    'MAX_INPUT': MAX_LENGTH,
    'LOGGING_STEPS': LOGGING_STEPS,
    'NUM_EPOCHS': EPOCHS,
    'BATCH_SIZE': BATCH_SIZE,
    'NUM_STEPS': len(data_train)
}

In [23]:
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.distributed.parallel_loader import ParallelLoader
from transformers import DataCollatorWithPadding

#this guy is responsible for distributing data across 8 cores
# train_sampler = torch.utils.data.distributed.DistributedSampler(
#     data_train, num_replicas=NUM_REPLICAS, rank=xmp.get_ordinal(), shuffle=True)

training_loader = torch.utils.data.DataLoader(
    data_train, batch_size=FLAGS['BATCH_SIZE'],
    collate_fn=DataCollatorWithPadding(tokenizer=tokenizer),
    shuffle=True # Use default RandomSampler with shuffling
    )

# xla_train_loader = pl.MpDeviceLoader(training_loader, device)

In [None]:
import pandas as pd
from datasets import Dataset
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.distributed.parallel_loader import ParallelLoader as pl, MpDeviceLoader
from transformers import DataCollatorWithPadding


test_data = pd.read_csv(
    "/kaggle/input/60k-data-with-context-v2/train_with_context2.csv"
)

# Remove or comment out the line filtering by 'dataset' column
# test_data = test_data.loc[(test_data["dataset"]=="kaggle200")]

test_data.reset_index(drop=True, inplace=True)

# Add the 'more_context' column to test_data
test_data["more_context"] = test_data["context"].copy()

test_data = get_prompt_text(test_data)
test_data = get_answer_text(test_data)
test_data = get_question_text(test_data)

data_test = Dataset.from_pandas(test_data)

#remove everything except for input_ids and labels
data_test = data_test.map(
    preprocess_function_inference,
    batched=False,
    num_proc=56
).remove_columns(list(test_data.columns))

#this guy is responsible for distributing data across 8 cores
# test_sampler = torch.utils.data.distributed.DistributedSampler(
#     data_test, num_replicas=1, rank=xmp.get_ordinal(), shuffle=False)

test_loader = torch.utils.data.DataLoader(
    data_test, batch_size=1,
    collate_fn=DataCollatorWithPadding(tokenizer=tokenizer),
    shuffle=False # No need to shuffle test data
    )

xla_test_loader = MpDeviceLoader(test_loader, device)

In [25]:
import os

# List files in the downloaded dataset directory
downloaded_dataset_path = "/kaggle/input/60k-data-with-context-v2"
if os.path.exists(downloaded_dataset_path):
    print(f"Files in {downloaded_dataset_path}:")
    for root, dirs, files in os.walk(downloaded_dataset_path):
        for file in files:
            print(os.path.join(root, file))
else:
    print(f"Directory not found: {downloaded_dataset_path}")

Files in /kaggle/input/60k-data-with-context-v2:
/kaggle/input/60k-data-with-context-v2/all_12_with_context2.csv
/kaggle/input/60k-data-with-context-v2/train_with_context2.csv
/kaggle/input/60k-data-with-context-v2/sources.txt


In [26]:
!export XLA_USE_BF16=1

In [27]:
import numpy as np
import torch_xla.test.test_utils as test_utils
def train(FLAGS):
    num_replicas = NUM_REPLICAS

    num_iterations = int(FLAGS['NUM_STEPS'] / FLAGS['BATCH_SIZE'] / num_replicas)
    print(f"num_iterations: {num_iterations}")

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=LEARNING_RATE,
        betas=(0.9, 0.999),
        eps=1e-7,
        weight_decay=0.01,
    )

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=num_iterations*FLAGS["NUM_EPOCHS"],
        eta_min=1e-7,
        last_epoch=-1,
        # verbose=False # Removed verbose argument
    )

    for epoch in range(1, FLAGS['NUM_EPOCHS'] + 1):
        model.train()
        falcon_7b_responses = []

        if epoch > 1:
            for step, batch in enumerate(training_loader):
                if step % 100 == 0:
                    xm.master_print('Epoch {} step {} train begin {}'.format(
                        epoch, step, test_utils.now()))

                optimizer.zero_grad()
                input_ids, attention_mask, labels = batch.input_ids.to(device), batch.attention_mask.to(device), batch.labels.to(device)

                attention_mask = torch.where(input_ids==2, 0, 1).to(device)

                # xs.mark_sharding(input_ids, mesh, (0, 1)) # Commented out due to missing xs
                # xs.mark_sharding(attention_mask, mesh, (0, 1)) # Commented out due to missing xs
                # xs.mark_sharding(labels, mesh, (0, 1)) # Commented out due to missing xs
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )

                logits = outputs.logits[:, -1, [330, 365, 334, 384, 413]]
                loss = outputs.loss
                loss.backward()
                optimizer.step()
                xm.mark_step()

                if (step + 1) % FLAGS['LOGGING_STEPS'] == 0:
                    print(f'loss: {loss.item()}, time: {test_utils.now()}, step: {step}')

                scheduler.step()

        model.eval()
        total_loss = 0.0
        total_steps = 0

        with torch.no_grad():
            for step, batch in enumerate(xla_test_loader):
                input_ids, attention_mask, labels = batch.input_ids.to(device), batch.attention_mask.to(device), batch.labels.to(device)

                attention_mask = torch.where(input_ids==2, 0, 1).to(device)

                # xs.mark_sharding(input_ids, mesh, (0, 1)) # Commented out due to missing xs
                # xs.mark_sharding(attention_mask, mesh, (0, 1)) # Commented out due to missing xs
                # xs.mark_sharding(labels, mesh, (0, 1)) # Commented out due to missing xs
                sample_prediction = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )

                logits = sample_prediction.logits[:, -1, [330, 365, 334, 384, 413]]

                if step in [0, 100]:
                    print(f"Step: {step}")
                    print(f"Logits: {logits}")

                sorted_logits = torch.argsort(-logits[0])

                falcon_7b_responses.append(["ABCDE"[x] for x in sorted_logits][:3])

                loss = sample_prediction.loss
                total_loss += loss.item()
                total_steps += 1

        test_data["clean_answer"] = falcon_7b_responses

        apks = [apk([actual], predicted, k=3) for actual, predicted in zip(
            test_data["answer"].values, test_data["clean_answer"].values
        )]

        average_loss = total_loss / total_steps
        xm.master_print('Epoch {} test end {}, test loss={:.6f}'.format(
            epoch, test_utils.now(), average_loss))
        xm.master_print('Epoch {} test end {}, test MAP@3={:.6f}'.format(
            epoch, test_utils.now(), np.mean(apks)))
        xm.master_print('Epoch {} train end {}'.format(
            epoch, test_utils.now()))

    xm.master_print("Saving the model")
    xm.save(model.state_dict(), "tpu-llama.bin")

# xmp.spawn(train, args=(FLAGS,))

In [28]:
if TRAIN:
    train(FLAGS)

num_iterations: 5833
Step: 0
Logits: tensor([[11.1875, 11.0000,  9.8125, 12.6875, 12.5625]], device='xla:0',
       dtype=torch.bfloat16)
Step: 100
Logits: tensor([[10.0000,  9.8750,  8.5000, 14.3750, 11.7500]], device='xla:0',
       dtype=torch.bfloat16)
Epoch 1 test end 09:46:56, test loss=0.978216
Epoch 1 test end 09:46:56, test MAP@3=0.818333
Epoch 1 train end 09:46:56
Epoch 2 step 0 train begin 09:46:56
loss: 0.7403604388237, time: 09:48:43, step: 99
Epoch 2 step 100 train begin 09:48:43
loss: 1.446129560470581, time: 09:49:54, step: 199
Epoch 2 step 200 train begin 09:49:54
loss: 1.5683820247650146, time: 09:51:05, step: 299
Epoch 2 step 300 train begin 09:51:05
loss: 1.0947434902191162, time: 09:52:16, step: 399
Epoch 2 step 400 train begin 09:52:16
loss: 1.2663393020629883, time: 09:53:27, step: 499
Epoch 2 step 500 train begin 09:53:27
loss: 0.3549329340457916, time: 09:54:38, step: 599
Epoch 2 step 600 train begin 09:54:38
loss: 0.9130598306655884, time: 09:55:49, step: 699


## MODEL EVAL

In [29]:
def get_prompt_text(data):
    prompt_texts = []

    for index in tqdm(range(len(data))):
        sample_raw_prompt = generate_prompt_inference(
            context=data.iloc[index]["more_context"],
            prompt=data.iloc[index]["prompt"],
            a=data.iloc[index]["A"],
            b=data.iloc[index]["B"],
            c=data.iloc[index]["C"],
            d=data.iloc[index]["D"],
            e=data.iloc[index]["E"],
        )

        prompt_texts.append(sample_raw_prompt)

    data["prompt_text"] = prompt_texts

    return data

def get_answer_text(data):

    prompt_texts = []

    for index in tqdm(range(len(data))):
        sample_raw_prompt = generate_answer_inference(
            prompt=data.iloc[index]["prompt"],
            a=data.iloc[index]["A"],
            b=data.iloc[index]["B"],
            c=data.iloc[index]["C"],
            d=data.iloc[index]["D"],
            e=data.iloc[index]["E"],
            answer=data.iloc[index]["answer"],
        )

        prompt_texts.append(sample_raw_prompt)

    data["answer_text"] = prompt_texts

    return data

def get_question_text(data):
    prompt_texts = []

    for index in tqdm(range(len(data))):
        sample_raw_prompt = generate_question_inference(
            prompt=data.iloc[index]["prompt"],
            a=data.iloc[index]["A"],
            b=data.iloc[index]["B"],
            c=data.iloc[index]["C"],
            d=data.iloc[index]["D"],
            e=data.iloc[index]["E"],
        )

        prompt_texts.append(sample_raw_prompt)

    data["question_text"] = prompt_texts

    return data

In [None]:
import pandas as pd
from datasets import Dataset
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.distributed.parallel_loader import ParallelLoader as pl, MpDeviceLoader
from transformers import DataCollatorWithPadding
from tqdm import tqdm


test_data = pd.read_csv(
    "/kaggle/input/60k-data-with-context-v2/train_with_context2.csv"
)

# Remove or comment out the line filtering by 'dataset' column
# test_data = test_data.loc[(test_data["dataset"]=="kaggle200")]

test_data.reset_index(drop=True, inplace=True)

# Add the 'more_context' column to test_data
test_data["more_context"] = test_data["context"].copy()

test_data = get_prompt_text(test_data)
test_data = get_answer_text(test_data)
test_data = get_question_text(test_data)

data_test = Dataset.from_pandas(test_data)

#remove everything except for input_ids and labels
data_test = data_test.map(
    preprocess_function_inference,
    batched=False,
    num_proc=56
).remove_columns(list(test_data.columns))

#this guy is responsible for distributing data across 8 cores
# test_sampler = torch.utils.data.distributed.DistributedSampler(
#     data_test, num_replicas=1, rank=xmp.get_ordinal(), shuffle=False)

test_loader = torch.utils.data.DataLoader(
    data_test, batch_size=1,
    collate_fn=DataCollatorWithPadding(tokenizer=tokenizer),
    shuffle=False # No need to shuffle test data
    )

xla_test_loader = MpDeviceLoader(test_loader, device)

In [31]:
%%time
import gc

mistral_7b_responses = []

model.to(device)

for step, data in enumerate(xla_test_loader):
    with torch.no_grad():
        if (step + 1) % 20 == 0:
            print(step + 1)

        # data["attention_mask"] = torch.where(data["input_ids"]==2, 0, 1).to(device) # Data is already on device with xla_test_loader

        sample_prediction = model(**data)
        sorted_logits = torch.argsort(
            -sample_prediction.logits[0][-1, [330, 365, 334, 384, 413]])

        mistral_7b_responses.append(["ABCDE"[x] for x in sorted_logits][:3])

        del sample_prediction
        del sorted_logits
        gc.collect()


print(len(mistral_7b_responses), mistral_7b_responses[:10])

test_data["clean_answer"] = mistral_7b_responses

apks = [apk([actual], predicted, k=3) for actual, predicted in zip(
    test_data["answer"].values,
    test_data["clean_answer"].values
)]

test_data["apk"] = apks
print(test_data["apk"].value_counts())
print("\n")
map3_score = np.mean(apks)
print(f"MAP@3: {map3_score}")

20
40
60
80
100
120
140
160
180
200
200 [['D', 'A', 'E'], ['A', 'B', 'D'], ['A', 'C', 'B'], ['B', 'A', 'C'], ['B', 'A', 'D'], ['B', 'A', 'C'], ['A', 'C', 'B'], ['B', 'D', 'A'], ['B', 'C', 'A'], ['A', 'C', 'B']]
apk
1.000000    130
0.500000     32
0.000000     23
0.333333     15
Name: count, dtype: int64


MAP@3: 0.755
CPU times: user 3min 22s, sys: 1.6 s, total: 3min 24s
Wall time: 2min 24s


## Push model to HuggingFace Hub

In [32]:
# Load the model architecture (assuming it's already defined or loaded in a previous cell like EDTWr7gyxB7f)
# model = AutoModelForCausalLM.from_pretrained(...) # Keep the model instance from EDTWr7gyxB7f

# Load the state dictionary from the local file
model.load_state_dict(torch.load("./tpu-llama.bin"))

# Move the model back to the device if necessary
model.to(device)

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): MistralMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): MistralRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): MistralRMSNorm((4096,), eps=1e-0

OPTION1

frankmorales2020/bert-base-cased_fine_tuned_glue_cola

In [None]:
from google.colab import userdata
token = userdata.get('HF_TOKEN')
hf_auth_token = token
model_id = "frankmorales2020/mistral-7b-alpha-finetuned-llm-science-exam-tpu-colab-v6e-1"

# Move the model to CPU before pushing
model.cpu()

model.push_to_hub(model_id, use_auth_token=hf_auth_token)
tokenizer.push_to_hub(model_id, use_auth_token=hf_auth_token)

OPTION2

In [None]:
if SAVE_MODEL:
    from google.colab import userdata
    token = userdata.get('HF_TOKEN')
    hf_auth_token = token

    model_id = "frankmorales2020/mistral-7b-alpha-finetuned-llm-science-exam-tpu-colab-v6e-1"

    if map3_score >= 0.934:
        print("Merging directly to Main...")

        model = model.cpu()
        model.push_to_hub(
            model_id,
            tokenizer=tokenizer,
            private=False,
            create_pr=False,
            commit_message=f"Merging directly to main. Map@3 = {map3_score}",
            max_shard_size="2GB",
            use_auth_token=hf_auth_token,
        )

    else:
        print("Creating PR...")

        model = model.cpu()
        model.push_to_hub(
            model_id,
            tokenizer=tokenizer,
            private=False,
            create_pr=1,
            commit_message=f"Creating PR. Map@3 = {map3_score}",
            max_shard_size="2GB",
            use_auth_token=hf_auth_token,
        )

## Load Model from HuggingFace and Evaluate

In [None]:
from transformers import AutoModelForCausalLM
import torch

# Define the repository ID
repo_id = "frankmorales2020/mistral-7b-alpha-finetuned-llm-science-exam-tpu-colab-v6e-1"

# Load the model from Hugging Face Hub
model_from_hf = AutoModelForCausalLM.from_pretrained(
    repo_id,
    torch_dtype=torch.bfloat16,
    device_map="auto", # Automatically maps the model to available devices (like TPU)
#     use_auth_token=True, # Uncomment if authentication is required
)

print(f"Model loaded successfully from {repo_id}")

In [37]:
%%time

mistral_7b_responses = []

## HF MODEL
model=model_from_hf

model.to(device)

for step, data in enumerate(xla_test_loader): # Use xla_test_loader instead of test_loader
    with torch.no_grad():
        if (step + 1) % 20 == 0:
            print(step + 1)

        # data["attention_mask"] = torch.where(data["input_ids"]==2, 0, 1).to(device) # Data is already on device with xla_test_loader

        sample_prediction = model(**data)
        sorted_logits = torch.argsort(
            -sample_prediction.logits[0][-1, [330, 365, 334, 384, 413]])

        mistral_7b_responses.append(["ABCDE"[x] for x in sorted_logits][:3])

        del sample_prediction
        del sorted_logits
        gc.collect()


print(len(mistral_7b_responses), mistral_7b_responses[:10])

test_data["clean_answer"] = mistral_7b_responses

apks = [apk([actual], predicted, k=3) for actual, predicted in zip(
    test_data["answer"].values,
    test_data["clean_answer"].values
)]

test_data["apk"] = apks
print(test_data["apk"].value_counts())
print("\n")
print(f"MAP@3: {np.mean(apks)}")

20
40
60
80
100
120
140
160
180
200
200 [['D', 'A', 'E'], ['A', 'B', 'D'], ['A', 'C', 'B'], ['B', 'A', 'C'], ['B', 'A', 'D'], ['B', 'A', 'C'], ['A', 'C', 'B'], ['B', 'D', 'A'], ['B', 'C', 'A'], ['A', 'C', 'B']]
apk
1.000000    130
0.500000     32
0.000000     23
0.333333     15
Name: count, dtype: int64


MAP@3: 0.755
CPU times: user 1min 18s, sys: 10.4 s, total: 1min 28s
Wall time: 1min 52s
