In [1]:
import ray

In [2]:
if ray.is_initialized():
    ray.shutdown()
ray.init(include_dashboard=True)

2025-09-15 12:30:12,535	INFO worker.py:1942 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8266 [39m[22m


0,1
Python version:,3.10.12
Ray version:,2.49.1
Dashboard:,http://127.0.0.1:8266


[36m(RayTrainWorker pid=5962)[0m Setting up process group for: env:// [rank=0, world_size=1]
[36m(TorchTrainer pid=5811)[0m Started distributed worker processes: 
[36m(TorchTrainer pid=5811)[0m - (node_id=4514140190772872c09e0301ea56f51c2c34c1de0e071a4ab3fbd0ab, ip=127.0.0.1, pid=5962) world_rank=0, local_rank=0, node_rank=0
[36m(RayTrainWorker pid=5962)[0m Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias']
[36m(RayTrainWorker pid=5962)[0m - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a 

In [3]:
ray.cluster_resources()

{'CPU': 8.0,
 'object_store_memory': 2147483648.0,
 'node:127.0.0.1': 1.0,
 'memory': 8502460416.0,
 'node:__internal_head__': 1.0}

## Finetunning our own llm

In [4]:
import os
import random
import torch
import numpy as np
import ray
from ray.data.preprocessor import Preprocessor

In [5]:
def set_seed(seed=42):
    """Set seeds for reproducibility."""
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    #torch.cuda.manual_seed(seed)
    #eval("setattr(torch.backends.cudnn, 'deterministic', True)")
    #eval("setattr(torch.backends.cudnn, 'benchmark', False)")
    os.environ['PYTHONHASHSEED'] = str(seed)

In [6]:
set_seed()

In [7]:
DATASET_LOC = "https://raw.githubusercontent.com/GokuMohandas/Made-With-ML/main/datasets/dataset.csv"

In [8]:
def load_data(num_samples=None, loc=DATASET_LOC):
    ds = ray.data.read_csv(loc)
    ds = ds.random_shuffle(seed=1234)
    ds = ray.data.from_items(ds.take(num_samples)) if num_samples else ds
    return ds

<div style="border-left: 4px solid #00c896; background-color:rgb(23, 22, 22); padding: 1em; border-radius: 8px; margin: 1em 0;">
  <p style="margin: 0; font-weight: bold; color: #059669;">💡 Tip</p>
  <p style="margin: 0;">
    When working with very large datasets, it's a good idea to limit the number of samples in our dataset so that we can execute our code quickly and iterate on bugs, etc.  
    This is why we have a <code>num_samples</code> input argument in our <code>load_data</code> function (<code>None</code> = no limit, all samples).
  </p>
</div>


## Helpful functions

In [9]:
import numpy as np
from transformers import BertTokenizer
import re
import nltk
from nltk.corpus import stopwords

In [10]:
nltk.download("stopwords")
STOPWORDS = stopwords.words("english")

[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/ngkuissi/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [11]:
tokenizer = BertTokenizer.from_pretrained("allenai/scibert_scivocab_uncased", return_dict=False)



In [12]:
def tokenize(batch, tokenizer=tokenizer):
    encoded_inputs = tokenizer(batch['text'].tolist(), return_tensors="np", padding="longest")
    return dict(ids=encoded_inputs['input_ids'], mask=encoded_inputs['attention_mask'],
                target=np.array(batch['tag']))

In [13]:
def clean_text(text, stopwords=STOPWORDS):
    """Clean raw text string."""
    text = text.lower()
    
    # remove stopwords
    pattern = re.compile(r'\b(' + r"|".join(stopwords) + r")\b\s*")
    text = pattern.sub("", text)
    
    text = re.sub(r"([!\"'#$%&()*\+,-./:;<=>?@\\\[\]^_`{|}~])", r" \1 ", text)  # add spacing
    text = re.sub("[^A-Za-z0-9]+", " ", text)  # remove non alphanumeric chars
    text = re.sub(" +", " ", text)  # remove multiple spaces
    text = text.strip()  # strip white space at the ends
    text = re.sub(r"http\S+", "", text)  #  remove links
    
    return text
    

In [14]:
def preprocess(df, class_to_index):
    """Preprocess the data."""
    df["text"] = df.title + " " + df.description  # feature engineering
    df["text"] = df.text.apply(clean_text)  # clean text
    df = df.drop(columns=["id", "created_on", "title", "description"], errors="ignore")  # clean dataframe
    df = df[["text", "tag"]]  # rearrange columns
    df["tag"] = df["tag"].map(class_to_index)  # label encoding
    outputs = tokenize(df)
    return outputs

In [15]:
class CustomPreprocessor(Preprocessor):
    
    def _fit(self, ds):
        tags = ds.unique(column="tag")
        self.class_to_index = {tag: i for i, tag in enumerate(tags)}
        self.index_to_class = {v:k for k, v in self.class_to_index.items()}
    
    def _transform_pandas(self, batch):
        return preprocess(batch, class_to_index=self.class_to_index)

In [16]:
def decode(indices, index_to_class):
    return [index_to_class[idx] for idx in indices]

## Model

In [17]:
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel

In [18]:
llm = BertModel.from_pretrained("allenai/scibert_scivocab_uncased", return_dict=False)
embedding_dim = llm.config.hidden_size

Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [19]:
text = "Transfer learning with transformers for text classification."
batch = tokenizer([text], return_tensors="pt", padding="longest")
seq, pool = llm(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])

In [20]:
seq.shape

torch.Size([1, 10, 768])

In [21]:
pool.shape

torch.Size([1, 768])

In [22]:
class FinetunnedLLM(nn.Module):
    
    def __init__(self, llm, dropout_p, embedding_dim, num_classes):
        super(FinetunnedLLM, self).__init__()
        self.llm = llm
        self.dropout = nn.Dropout(dropout_p)
        self.classifier = nn.Linear(embedding_dim, num_classes)
    
    def forward(self, batch):
        ids, mask = batch['ids'], batch['mask']
        _, pool = self.llm(input_ids=ids, attention_mask = mask)
        
        return self.classifier(self.dropout(pool))

    @torch.inference_mode
    def predict(self, batch):
        self.eval()
        z = self(batch)
        return torch.argmax(z, dim=-1).cpu().numpy()

    @torch.inference_mode
    def predict_prob(self, batch):
        self.eval()
        z = self(batch)
        y_probs = F.softmax(z, dim=-1).cpu().numpy()
        return y_probs

In [23]:
model = FinetunnedLLM(llm, 0.5, embedding_dim, 4)

In [24]:
next(model.parameters()).device == torch.device("cpu")

True

## Data setup

In [25]:
ray.data.DatasetContext.get_current().execution_options.preserve_order = True
ds = ray.data.read_csv(DATASET_LOC)
ds = ds.random_shuffle(seed=1234)

In [26]:
import sys
sys.path.append("..")
from madewithml.data import stratify_split

  import pkg_resources


In [27]:
test_size = 0.2
train_ds, val_ds = stratify_split(ds, stratify="tag", test_size=test_size)

In [28]:
tags = train_ds.unique(column="tag")
class_to_index = {tag: i for i, tag in enumerate(tags)}

2025-09-15 12:30:48,562	INFO dataset.py:3246 -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.
2025-09-15 12:30:48,571	INFO logging.py:295 -- Registered dataset logger for dataset dataset_10_0
2025-09-15 12:30:48,591	INFO streaming_executor.py:159 -- Starting execution of Dataset dataset_10_0. Full logs are in /tmp/ray/session_2025-09-15_12-30-08_614560_3457/logs/ray-data
2025-09-15 12:30:48,591	INFO streaming_executor.py:160 -- Execution plan of Dataset dataset_10_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadCSV] -> AllToAllOperator[RandomShuffle] -> AllToAllOperator[Sort] -> AllToAllOperator[MapBatches(_add_split)->MapBatches(_filter_split)->RandomShuffle] -> AllToAllOperator[Aggregate] -> LimitOperator[limit=1]


Running 0: 0.00 row [00:00, ? row/s]

- ReadCSV->SplitBlocks(16) 1: 0.00 row [00:00, ? row/s]

- RandomShuffle 2: 0.00 row [00:00, ? row/s]

Shuffle Map 3:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 4:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- Sort 5: 0.00 row [00:00, ? row/s]

Sort Sample 6:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Map 7:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 8:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- MapBatches(_add_split)->MapBatches(_filter_split)->RandomShuffle 9: 0.00 row [00:00, ? row/s]

Shuffle Map 10:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 11:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- Aggregate 12: 0.00 row [00:00, ? row/s]

Sort Sample 13:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Map 14:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 15:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- limit=1 16: 0.00 row [00:00, ? row/s]

created_on: timestamp[s]
title: string
description: string
tag: string, new schema: . This may lead to unexpected behavior.
2025-09-15 12:30:52,970	INFO streaming_executor.py:279 -- ✔️  Dataset dataset_10_0 execution finished in 4.38 seconds


In [29]:
sample_ds = train_ds.map_batches(
    preprocess,
    fn_kwargs={"class_to_index": class_to_index},
    batch_format="pandas"
)

## Batching

In [30]:
from ray.train.torch import get_device

In [31]:
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available else "cpu"

In [32]:
device

'mps'

In [33]:
def pad_array(arr, dtype=np.int32):
    max_len = max([len(row) for row in arr])
    padded_arr = np.zeros((arr.shape[0], max_len), dtype=dtype)
    for i, row in enumerate(arr):
        padded_arr[i, :len(row)]= row
    return padded_arr

In [34]:
def collate_fn(batch, device=device):
    batch["ids"] = pad_array(batch["ids"])
    batch["mask"] = pad_array(batch["mask"])
    dtypes = {"ids": torch.int32, "mask": torch.int32, "target": torch.int64}
    tensor_batch = {}
    for key, arr in batch.items():
        tensor_batch[key] = torch.as_tensor(list(arr), dtype=dtypes[key], device=device)
    return tensor_batch
    

In [35]:
sample_batch = sample_ds.take_batch(128)

2025-09-15 12:31:00,244	INFO logging.py:295 -- Registered dataset logger for dataset dataset_12_0
2025-09-15 12:31:00,253	INFO streaming_executor.py:159 -- Starting execution of Dataset dataset_12_0. Full logs are in /tmp/ray/session_2025-09-15_12-30-08_614560_3457/logs/ray-data
2025-09-15 12:31:00,254	INFO streaming_executor.py:160 -- Execution plan of Dataset dataset_12_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadCSV] -> AllToAllOperator[RandomShuffle] -> AllToAllOperator[Sort] -> AllToAllOperator[MapBatches(_add_split)->MapBatches(_filter_split)->RandomShuffle] -> TaskPoolMapOperator[MapBatches(preprocess)] -> LimitOperator[limit=128]


Running 0: 0.00 row [00:00, ? row/s]

- ReadCSV->SplitBlocks(16) 1: 0.00 row [00:00, ? row/s]

- RandomShuffle 2: 0.00 row [00:00, ? row/s]

Shuffle Map 3:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 4:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- Sort 5: 0.00 row [00:00, ? row/s]

Sort Sample 6:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Map 7:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 8:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- MapBatches(_add_split)->MapBatches(_filter_split)->RandomShuffle 9: 0.00 row [00:00, ? row/s]

Shuffle Map 10:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 11:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- MapBatches(preprocess) 12: 0.00 row [00:00, ? row/s]

- limit=128 13: 0.00 row [00:00, ? row/s]

created_on: timestamp[s]
title: string
description: string
tag: string, new schema: . This may lead to unexpected behavior.
mask: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
target: int64, new schema: ids: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
mask: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
target: int64. This may lead to unexpected behavior.
2025-09-15 12:31:02,913	INFO streaming_executor.py:279 -- ✔️  Dataset dataset_12_0 execution finished in 2.66 seconds


In [36]:
collate_fn(sample_batch)

  tensor_batch[key] = torch.as_tensor(list(arr), dtype=dtypes[key], device=device)


{'ids': tensor([[  102, 13568, 11404,  ...,     0,     0,     0],
         [  102,   437, 17574,  ...,     0,     0,     0],
         [  102,  6265, 21930,  ...,     0,     0,     0],
         ...,
         [  102, 17251, 30128,  ...,  4928,   103,     0],
         [  102,  6693,  8215,  ...,     0,     0,     0],
         [  102,  3246,   251,  ...,     0,     0,     0]], device='mps:0',
        dtype=torch.int32),
 'mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 1, 1, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]], device='mps:0', dtype=torch.int32),
 'target': tensor([3, 1, 3, 3, 3, 2, 2, 2, 3, 2, 3, 3, 3, 3, 3, 3, 1, 2, 2, 2, 1, 1, 2, 3,
         2, 1, 2, 2, 1, 3, 2, 3, 2, 2, 2, 3, 2, 3, 0, 2, 3, 3, 2, 1, 3, 3, 2, 3,
         3, 3, 3, 3, 3, 3, 2, 2, 1, 0, 1, 3, 2, 3, 0, 1, 2, 2, 2, 2, 1, 2, 2, 2,
         2, 0, 2, 3, 1, 2, 3, 0, 2, 3, 3, 2, 1, 3, 3, 2, 3

## Utilities

In [46]:
from ray.air import session
from ray.air.config import CheckpointConfig, RunConfig, ScalingConfig
from ray.train.torch import prepare_model
from ray.train.torch import TorchCheckpoint, TorchTrainer
import torch.nn.functional as F

In [95]:
def train_step(ds, model, batch_size, num_classes, loss_fn, optimizer):
    model.train()
    loss = 0.0
    ds_generator = ds.iter_torch_batches(batch_size=batch_size, collate_fn=collate_fn)
    device = next(model.parameters()).device
    for i, batch in enumerate(ds_generator):
        batch['ids'] = batch['ids'].to(device)
        batch['mask'] = batch['mask'].to(device)
        batch['target'] = batch['target'].to(device)
        optimizer.zero_grad()
        z = model(batch)
        targets = batch['target']
        e_loss = loss_fn(z.view(-1, z.shape[-1]), targets)
        e_loss.backward()
        optimizer.step()
        loss += (e_loss.detach().item() - loss) / (i + 1)
    return loss

In [96]:
def eval_step(ds, batch_size, model, num_classes, loss_fn):
    model.eval()
    loss = 0.0
    y_trues, y_preds = [], []
    ds_generator = ds.iter_torch_batches(batch_size=batch_size, collate_fn=collate_fn)
    device = next(model.parameters()).device
    with torch.inference_mode():
        for i, batch in enumerate(ds_generator):
            batch['ids'] = batch['ids'].to(device)
            batch['mask'] = batch['mask'].to(device)
            batch['target'] = batch['target'].to(device)
            z = model(batch)
            targets = batch['target']
            e_loss = loss_fn(z.view(-1, z.shape[-1]), targets).item()
            loss += (e_loss - loss) / (i + 1)
            y_trues.extend(batch['target'].cpu().numpy())
            y_preds.extend(torch.argmax(z, dim=-1).cpu().numpy())
            
    return loss, np.vstack(y_trues), np.vstack(y_preds)
            

In [97]:
def train_loop_per_worker(config):
    dropout = config['dropout_p']
    lr = config['lr']
    lr_factor = config['lr_factor']
    lr_patience = config['lr_patience']
    num_epochs = config['num_epochs']
    batch_size = config['batch_size']
    num_classes = config['num_classes']
    device = config['device']
    
    #set_seed()
    train_ds = session.get_dataset_shard("train")
    val_ds = session.get_dataset_shard("val")
    
    llm = BertModel.from_pretrained("allenai/scibert_scivocab_uncased", return_dict=False).to(device)
    model = FinetunnedLLM(llm, dropout, llm.config.hidden_size, num_classes)
    model = prepare_model(model)
    model = model.to(device)
    
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=lr_factor, patience=lr_patience)
    
    batch_size_per_worker = batch_size // session.get_world_size()
    for epoch in range(num_epochs):
        train_loss = train_step(train_ds, model, batch_size_per_worker, num_classes, loss_fn, optimizer)
        val_loss, _, _ = eval_step(val_ds, batch_size_per_worker, model, num_classes, loss_fn)
        scheduler.step(val_loss)
        
        metrics = dict(epoch=epoch, lr=optimizer.param_groups[0]["lr"], train_loss=train_loss, val_loss=val_loss)
        checkpoint = TorchCheckpoint.from_state_dict(model.state_dict())
        session.report(metrics, checkpoint=checkpoint)
    

In [98]:
train_loop_config = {
    "dropout_p": 0.5,
    "lr": 1e-4,
    "lr_factor": 0.8,
    "lr_patience": 3,
    "num_epochs": 10,
    "batch_size": 128,
    "num_classes": 4,
    "device": device
}

In [99]:
num_workers = 1
if device == "cuda":
    resources_per_worker = {"CPU": 10, "GPU": 1}
else:
    resources_per_worker = {"CPU": 3}

In [100]:
from ray.train import ScalingConfig

scaling_config = ScalingConfig(
    num_workers=num_workers,
    use_gpu=bool(resources_per_worker.get("GPU", 0)),
    resources_per_worker=resources_per_worker,
)

In [101]:
checkpoint_config = CheckpointConfig(num_to_keep=1, checkpoint_score_attribute="val_loss", checkpoint_score_order="min")

In [102]:
run_config = RunConfig(name="llm", storage_path="~/ray_results", checkpoint_config=checkpoint_config)

## Training

In [103]:
ds = load_data()
train_ds, val_ds = stratify_split(ds, stratify="tag", test_size=test_size)

In [104]:
preprocessor = CustomPreprocessor()
train_ds = preprocessor.fit_transform(train_ds)
val_ds = preprocessor.transform(val_ds)

2025-09-15 12:39:12,023	INFO logging.py:295 -- Registered dataset logger for dataset dataset_69_0
2025-09-15 12:39:12,042	INFO streaming_executor.py:159 -- Starting execution of Dataset dataset_69_0. Full logs are in /tmp/ray/session_2025-09-15_12-30-08_614560_3457/logs/ray-data
2025-09-15 12:39:12,042	INFO streaming_executor.py:160 -- Execution plan of Dataset dataset_69_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadCSV] -> AllToAllOperator[RandomShuffle] -> AllToAllOperator[Sort] -> AllToAllOperator[MapBatches(_add_split)->MapBatches(_filter_split)->RandomShuffle] -> AllToAllOperator[Aggregate] -> LimitOperator[limit=1]


Running 0: 0.00 row [00:00, ? row/s]

- ReadCSV->SplitBlocks(16) 1: 0.00 row [00:00, ? row/s]

- RandomShuffle 2: 0.00 row [00:00, ? row/s]

Shuffle Map 3:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 4:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- Sort 5: 0.00 row [00:00, ? row/s]

Sort Sample 6:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Map 7:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 8:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- MapBatches(_add_split)->MapBatches(_filter_split)->RandomShuffle 9: 0.00 row [00:00, ? row/s]

Shuffle Map 10:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 11:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- Aggregate 12: 0.00 row [00:00, ? row/s]

Sort Sample 13:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Map 14:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 15:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- limit=1 16: 0.00 row [00:00, ? row/s]

created_on: timestamp[s]
title: string
description: string
tag: string, new schema: . This may lead to unexpected behavior.
2025-09-15 12:39:14,322	INFO streaming_executor.py:279 -- ✔️  Dataset dataset_69_0 execution finished in 2.28 seconds


In [105]:
train_ds = train_ds.materialize()
val_ds = val_ds.materialize()

2025-09-15 12:39:14,501	INFO logging.py:295 -- Registered dataset logger for dataset dataset_72_0
2025-09-15 12:39:14,515	INFO streaming_executor.py:159 -- Starting execution of Dataset dataset_72_0. Full logs are in /tmp/ray/session_2025-09-15_12-30-08_614560_3457/logs/ray-data
2025-09-15 12:39:14,517	INFO streaming_executor.py:160 -- Execution plan of Dataset dataset_72_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadCSV] -> AllToAllOperator[RandomShuffle] -> AllToAllOperator[Sort] -> AllToAllOperator[MapBatches(_add_split)->MapBatches(_filter_split)->RandomShuffle] -> TaskPoolMapOperator[CustomPreprocessor]


Running 0: 0.00 row [00:00, ? row/s]

- ReadCSV->SplitBlocks(16) 1: 0.00 row [00:00, ? row/s]

- RandomShuffle 2: 0.00 row [00:00, ? row/s]

Shuffle Map 3:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 4:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- Sort 5: 0.00 row [00:00, ? row/s]

Sort Sample 6:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Map 7:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 8:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- MapBatches(_add_split)->MapBatches(_filter_split)->RandomShuffle 9: 0.00 row [00:00, ? row/s]

Shuffle Map 10:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 11:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- CustomPreprocessor 12: 0.00 row [00:00, ? row/s]

created_on: timestamp[s]
title: string
description: string
tag: string, new schema: . This may lead to unexpected behavior.
mask: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
target: int64, new schema: ids: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
mask: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
target: int64. This may lead to unexpected behavior.
2025-09-15 12:39:16,178	INFO streaming_executor.py:279 -- ✔️  Dataset dataset_72_0 execution finished in 1.66 seconds
2025-09-15 12:39:16,226	INFO logging.py:295 -- Registered dataset logger for dataset dataset_74_0
2025-09-15 12:39:16,233	INFO streaming_executor.py:159 -- Starting execution of Dataset dataset_74_0. Full logs are in /tmp/ray/session_2025-09-15_12-30-08_614560_3457/logs/ray-data
2025-09-15 12:39:16,234	INFO streaming_executor.py:160 -- Execution plan of Dataset dataset_74_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadCSV] -> AllToAllOperator[RandomShuffle] -> AllToAllOperator[Sort] ->

Running 0: 0.00 row [00:00, ? row/s]

- ReadCSV->SplitBlocks(16) 1: 0.00 row [00:00, ? row/s]

- RandomShuffle 2: 0.00 row [00:00, ? row/s]

Shuffle Map 3:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 4:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- Sort 5: 0.00 row [00:00, ? row/s]

Sort Sample 6:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Map 7:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 8:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- MapBatches(_add_split)->MapBatches(_filter_split)->RandomShuffle 9: 0.00 row [00:00, ? row/s]

Shuffle Map 10:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 11:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- CustomPreprocessor 12: 0.00 row [00:00, ? row/s]

created_on: timestamp[s]
title: string
description: string
tag: string, new schema: . This may lead to unexpected behavior.
mask: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
target: int64, new schema: ids: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
mask: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
target: int64. This may lead to unexpected behavior.
2025-09-15 12:39:17,655	INFO streaming_executor.py:279 -- ✔️  Dataset dataset_74_0 execution finished in 1.42 seconds


In [106]:
from ray.train import DataConfig

In [107]:
dataset_config = DataConfig(["train", "val"])

In [108]:
trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config=train_loop_config,
    scaling_config=scaling_config,
    run_config= run_config,
    datasets= {"train": train_ds, "val": val_ds},
    dataset_config=dataset_config,
    #preprocessor=preprocessor   
)

In [109]:
results = trainer.fit()

2025-09-15 12:39:17,874	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


== Status ==
Current time: 2025-09-15 12:39:18 (running for 00:00:00.13)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-15 12:39:23 (running for 00:00:05.13)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-15 12:39:28 (running for 00:00:10.22)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=14865) Running 0: 0.00 row [00:00, ? row/s]

(pid=14865) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 12:39:33 (running for 00:00:15.33)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=14864) Running 0: 0.00 row [00:00, ? row/s]

(pid=14864) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 12:39:38 (running for 00:00:20.39)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=14865) Running 0: 0.00 row [00:00, ? row/s]

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.


(pid=14865) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 12:39:43 (running for 00:00:25.41)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-15 12:39:48 (running for 00:00:30.50)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=14864) Running 0: 0.00 row [00:00, ? row/s]

(pid=14864) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.


(pid=14865) Running 0: 0.00 row [00:00, ? row/s]

(pid=14865) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 12:39:53 (running for 00:00:35.52)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-15 12:39:58 (running for 00:00:40.61)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=14864) Running 0: 0.00 row [00:00, ? row/s]

(pid=14864) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 12:40:03 (running for 00:00:45.68)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.


(pid=14865) Running 0: 0.00 row [00:00, ? row/s]

(pid=14865) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 12:40:08 (running for 00:00:50.73)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-15 12:40:13 (running for 00:00:55.81)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=14864) Running 0: 0.00 row [00:00, ? row/s]

(pid=14864) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 12:40:18 (running for 00:01:00.83)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.


(pid=14865) Running 0: 0.00 row [00:00, ? row/s]

(pid=14865) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 12:40:23 (running for 00:01:05.87)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=14864) Running 0: 0.00 row [00:00, ? row/s]

(pid=14864) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 12:40:28 (running for 00:01:10.97)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.


(pid=14865) Running 0: 0.00 row [00:00, ? row/s]

(pid=14865) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 12:40:33 (running for 00:01:16.06)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-15 12:40:39 (running for 00:01:21.13)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=14864) Running 0: 0.00 row [00:00, ? row/s]

(pid=14864) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.


(pid=14865) Running 0: 0.00 row [00:00, ? row/s]

(pid=14865) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 12:40:44 (running for 00:01:26.15)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-15 12:40:49 (running for 00:01:31.16)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=14864) Running 0: 0.00 row [00:00, ? row/s]

(pid=14864) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.


(pid=14865) Running 0: 0.00 row [00:00, ? row/s]

(pid=14865) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 12:40:54 (running for 00:01:36.23)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-15 12:40:59 (running for 00:01:41.29)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=14864) Running 0: 0.00 row [00:00, ? row/s]

(pid=14864) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.


(pid=14865) Running 0: 0.00 row [00:00, ? row/s]

(pid=14865) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 12:41:04 (running for 00:01:46.34)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-15 12:41:09 (running for 00:01:51.42)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=14864) Running 0: 0.00 row [00:00, ? row/s]

(pid=14864) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.


(pid=14865) Running 0: 0.00 row [00:00, ? row/s]

(pid=14865) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 12:41:14 (running for 00:01:56.44)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-15 12:41:19 (running for 00:02:01.51)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=14864) Running 0: 0.00 row [00:00, ? row/s]

(pid=14864) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.


== Status ==
Current time: 2025-09-15 12:41:24 (running for 00:02:06.53)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




2025-09-15 12:41:25,523	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/Users/ngkuissi/ray_results/llm' in 0.0241s.
2025-09-15 12:41:25,534	INFO tune.py:1041 -- Total run time: 127.66 seconds (127.62 seconds for the tuning loop).


== Status ==
Current time: 2025-09-15 12:41:25 (running for 00:02:07.64)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_12-39-17/llm/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)




In [110]:
results.metrics_dataframe.columns

Index(['epoch', 'lr', 'train_loss', 'val_loss', 'timestamp',
       'checkpoint_dir_name', 'should_checkpoint', 'done',
       'training_iteration', 'trial_id', 'date', 'time_this_iter_s',
       'time_total_s', 'pid', 'hostname', 'node_ip', 'time_since_restore',
       'iterations_since_restore', 'config/train_loop_config/dropout_p',
       'config/train_loop_config/lr', 'config/train_loop_config/lr_factor',
       'config/train_loop_config/lr_patience',
       'config/train_loop_config/num_epochs',
       'config/train_loop_config/batch_size',
       'config/train_loop_config/num_classes',
       'config/train_loop_config/device'],
      dtype='object')

In [111]:
results.metrics_dataframe[['epoch', 'train_loss', 'val_loss']]

Unnamed: 0,epoch,train_loss,val_loss
0,0,1.298656,1.348185
1,1,1.279565,1.162291
2,2,1.130297,0.845694
3,3,0.823525,0.659376
4,4,0.520582,0.482941
5,5,0.303644,0.489036
6,6,0.178201,0.433394
7,7,0.081183,0.416853
8,8,0.029168,0.45143
9,9,0.012867,0.435537


In [112]:
results.best_checkpoints

[(Checkpoint(filesystem=local, path=/Users/ngkuissi/ray_results/llm/TorchTrainer_85d4f_00000_0_2025-09-15_12-39-17/checkpoint_000007),
  {'epoch': 7,
   'lr': 0.0001,
   'train_loss': 0.0811827726662159,
   'val_loss': 0.4168533682823181,
   'timestamp': 1757954461,
   'checkpoint_dir_name': 'checkpoint_000007',
   'should_checkpoint': True,
   'done': False,
   'training_iteration': 8,
   'trial_id': '85d4f_00000',
   'date': '2025-09-15_12-41-01',
   'time_this_iter_s': 10.688474893569946,
   'time_total_s': 99.50343370437622,
   'pid': 14752,
   'hostname': 'Nathans-Laptop.local',
   'node_ip': '127.0.0.1',
   'config': {'train_loop_config': {'dropout_p': 0.5,
     'lr': 0.0001,
     'lr_factor': 0.8,
     'lr_patience': 3,
     'num_epochs': 10,
     'batch_size': 128,
     'num_classes': 4,
     'device': 'mps'}},
   'time_since_restore': 99.50343370437622,
   'iterations_since_restore': 8})]

In [113]:
from ray.train.torch import TorchPredictor
from sklearn.metrics import precision_recall_fscore_support

In [114]:
llm = BertModel.from_pretrained("allenai/scibert_scivocab_uncased", return_dict=False).to(device)
model = FinetunnedLLM(llm, 0.5, llm.config.hidden_size, 4)

Some weights of the model checkpoint at allenai/scibert_scivocab_uncased were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [115]:
results.best_checkpoints

[(Checkpoint(filesystem=local, path=/Users/ngkuissi/ray_results/llm/TorchTrainer_85d4f_00000_0_2025-09-15_12-39-17/checkpoint_000007),
  {'epoch': 7,
   'lr': 0.0001,
   'train_loss': 0.0811827726662159,
   'val_loss': 0.4168533682823181,
   'timestamp': 1757954461,
   'checkpoint_dir_name': 'checkpoint_000007',
   'should_checkpoint': True,
   'done': False,
   'training_iteration': 8,
   'trial_id': '85d4f_00000',
   'date': '2025-09-15_12-41-01',
   'time_this_iter_s': 10.688474893569946,
   'time_total_s': 99.50343370437622,
   'pid': 14752,
   'hostname': 'Nathans-Laptop.local',
   'node_ip': '127.0.0.1',
   'config': {'train_loop_config': {'dropout_p': 0.5,
     'lr': 0.0001,
     'lr_factor': 0.8,
     'lr_patience': 3,
     'num_epochs': 10,
     'batch_size': 128,
     'num_classes': 4,
     'device': 'mps'}},
   'time_since_restore': 99.50343370437622,
   'iterations_since_restore': 8})]

In [116]:
best_checkpoint = results.best_checkpoints[0][0].path + "/model.pt"
state_dict = torch.load(best_checkpoint)
model.load_state_dict(state_dict)

<All keys matched successfully>

In [117]:
model

FinetunnedLLM(
  (llm): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(31090, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affi

In [118]:
HOLDOUT_LOC = "https://raw.githubusercontent.com/GokuMohandas/Made-With-ML/main/datasets/holdout.csv"
test_ds = ray.data.read_csv(HOLDOUT_LOC)
preprocessed_ds = preprocessor.transform(test_ds)

In [119]:
preprocessed_ds = preprocessor.transform(test_ds)
values = preprocessed_ds.select_columns(cols=["target"]).take_all()
y_true = np.stack([item["target"] for item in values])

2025-09-15 12:41:57,133	INFO logging.py:295 -- Registered dataset logger for dataset dataset_83_0


2025-09-15 12:41:57,149	INFO streaming_executor.py:159 -- Starting execution of Dataset dataset_83_0. Full logs are in /tmp/ray/session_2025-09-15_12-30-08_614560_3457/logs/ray-data
2025-09-15 12:41:57,150	INFO streaming_executor.py:160 -- Execution plan of Dataset dataset_83_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadCSV] -> TaskPoolMapOperator[CustomPreprocessor->Project]


Running 0: 0.00 row [00:00, ? row/s]

- ReadCSV->SplitBlocks(16) 1: 0.00 row [00:00, ? row/s]

- CustomPreprocessor->Project 2: 0.00 row [00:00, ? row/s]

2025-09-15 12:41:58,191	INFO streaming_executor.py:279 -- ✔️  Dataset dataset_83_0 execution finished in 1.04 seconds


In [120]:
y_true

array([1, 1, 1, 3, 2, 3, 3, 3, 3, 2, 3, 3, 2, 1, 3, 3, 2, 2, 1, 2, 1, 3,
       1, 2, 3, 2, 2, 0, 0, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 0, 0, 2, 3,
       3, 1, 0, 2, 3, 2, 2, 1, 1, 3, 2, 1, 2, 1, 1, 1, 1, 3, 3, 3, 3, 2,
       2, 3, 2, 0, 3, 2, 1, 3, 3, 2, 2, 2, 2, 2, 3, 3, 2, 3, 0, 3, 3, 3,
       3, 1, 3, 3, 2, 3, 2, 2, 1, 2, 3, 2, 3, 2, 3, 1, 3, 3, 3, 3, 3, 2,
       3, 3, 2, 2, 2, 2, 1, 3, 2, 3, 2, 3, 2, 1, 1, 1, 2, 3, 2, 2, 2, 2,
       3, 2, 2, 2, 3, 0, 2, 2, 2, 2, 2, 0, 2, 3, 1, 3, 2, 2, 0, 0, 2, 3,
       3, 3, 3, 3, 3, 2, 2, 2, 3, 2, 0, 0, 2, 3, 3, 0, 2, 1, 2, 2, 2, 3,
       3, 2, 3, 2, 3, 1, 3, 2, 2, 3, 0, 2, 0, 2, 2])

In [121]:
model = model.to(device)

In [122]:
ds_generator = preprocessed_ds.iter_torch_batches(batch_size=128, collate_fn=collate_fn)
y_pred = None
model.eval()
with torch.inference_mode():
    for i, batch in enumerate(ds_generator):
        batch['ids'] = batch['ids'].to(device)
        batch['mask'] = batch['mask'].to(device)
        z = model(batch)
        if y_pred is not None:
            y_pred = torch.concat([y_pred, torch.argmax(z, dim=-1)], dim=0)
        else:
            y_pred = torch.argmax(z, dim=-1)

2025-09-15 12:42:00,499	INFO logging.py:295 -- Registered dataset logger for dataset dataset_82_0
2025-09-15 12:42:00,504	INFO streaming_executor.py:159 -- Starting execution of Dataset dataset_82_0. Full logs are in /tmp/ray/session_2025-09-15_12-30-08_614560_3457/logs/ray-data
2025-09-15 12:42:00,505	INFO streaming_executor.py:160 -- Execution plan of Dataset dataset_82_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadCSV] -> TaskPoolMapOperator[CustomPreprocessor]


Running 0: 0.00 row [00:00, ? row/s]

- ReadCSV->SplitBlocks(16) 1: 0.00 row [00:00, ? row/s]

- CustomPreprocessor 2: 0.00 row [00:00, ? row/s]

mask: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
target: int64, new schema: ids: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
mask: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
target: int64. This may lead to unexpected behavior.
2025-09-15 12:42:01,045	INFO streaming_executor.py:279 -- ✔️  Dataset dataset_82_0 execution finished in 0.54 seconds


In [123]:
y_pred = y_pred.cpu().numpy()

In [124]:
metrics = precision_recall_fscore_support(y_true, y_pred, average="weighted")

In [125]:
{"precision": metrics[0], "recall": metrics[1], "f1": metrics[2]}

{'precision': 0.9369889807859304,
 'recall': 0.93717277486911,
 'f1': 0.9367971189393501}

In [128]:
def evaluate(ds, model, preprocessor=preprocessor, device=device):
    # y_true
    preprocessed_ds = preprocessor.transform(ds)
    values = preprocessed_ds.select_columns(cols=["target"]).take_all()
    y_true = np.stack([item["target"] for item in values])
    
    # y_pred
    model = model.to(device)
    ds_generator = preprocessed_ds.iter_torch_batches(batch_size=128, collate_fn=lambda x: collate_fn(x, device=device))
    y_pred = None
    model.eval()
    with torch.inference_mode():
        for i, batch in enumerate(ds_generator):
            batch['ids'] = batch['ids'].to(device)
            batch['mask'] = batch['mask'].to(device)
            z = model(batch)
            if y_pred is not None:
                y_pred = torch.concat([y_pred, torch.argmax(z, dim=-1)], dim=0)
            else:
                y_pred = torch.argmax(z, dim=-1)
    
    y_pred = y_pred.cpu().numpy()
    metrics = precision_recall_fscore_support(y_true, y_pred, average="weighted")
    return {"precision": metrics[0], "recall": metrics[1], "f1": metrics[2]}

In [129]:
evaluate(test_ds, model)

2025-09-15 12:42:40,588	INFO logging.py:295 -- Registered dataset logger for dataset dataset_87_0
2025-09-15 12:42:40,596	INFO streaming_executor.py:159 -- Starting execution of Dataset dataset_87_0. Full logs are in /tmp/ray/session_2025-09-15_12-30-08_614560_3457/logs/ray-data
2025-09-15 12:42:40,597	INFO streaming_executor.py:160 -- Execution plan of Dataset dataset_87_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadCSV] -> TaskPoolMapOperator[CustomPreprocessor->Project]


Running 0: 0.00 row [00:00, ? row/s]

- ReadCSV->SplitBlocks(16) 1: 0.00 row [00:00, ? row/s]

- CustomPreprocessor->Project 2: 0.00 row [00:00, ? row/s]

2025-09-15 12:42:41,156	INFO streaming_executor.py:279 -- ✔️  Dataset dataset_87_0 execution finished in 0.56 seconds
2025-09-15 12:42:41,180	INFO logging.py:295 -- Registered dataset logger for dataset dataset_86_0
2025-09-15 12:42:41,185	INFO streaming_executor.py:159 -- Starting execution of Dataset dataset_86_0. Full logs are in /tmp/ray/session_2025-09-15_12-30-08_614560_3457/logs/ray-data
2025-09-15 12:42:41,186	INFO streaming_executor.py:160 -- Execution plan of Dataset dataset_86_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadCSV] -> TaskPoolMapOperator[CustomPreprocessor]


Running 0: 0.00 row [00:00, ? row/s]

- ReadCSV->SplitBlocks(16) 1: 0.00 row [00:00, ? row/s]

- CustomPreprocessor 2: 0.00 row [00:00, ? row/s]

mask: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
target: int64, new schema: ids: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
mask: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
target: int64. This may lead to unexpected behavior.
2025-09-15 12:42:41,636	INFO streaming_executor.py:279 -- ✔️  Dataset dataset_86_0 execution finished in 0.45 seconds


{'precision': 0.9369889807859304,
 'recall': 0.93717277486911,
 'f1': 0.9367971189393501}

## Inference

In [130]:
import pandas as pd

In [131]:
preprocessor.index_to_class

{0: 'mlops',
 1: 'other',
 2: 'natural-language-processing',
 3: 'computer-vision'}

In [132]:
def format_prob(prob, index_to_class):
    d = {}
    for i, item in enumerate(prob):
        d[index_to_class[i]] = item
    return d

In [133]:
def predict_prob(df, model, preprocessor=preprocessor, device=device):
    
    processed = preprocessor._transform_pandas(df)
    processed = collate_fn(processed, device)
    model = model.to(device)
    output = model(processed)
    output = output.cpu().detach()
    y_prob = output.softmax(dim=1).numpy()
    
    res = []
    for i, prob in enumerate(y_prob):
        tag = decode([prob.argmax()], preprocessor.index_to_class)[0]
        res.append({"prediction": tag, "probabilities": format_prob(prob, preprocessor.index_to_class)})
    return res
    

In [134]:
title = "Transfer learning with transformers"
description = "Using transformers for transfer learning on text classification tasks."
sample_df = pd.DataFrame([{"title": title, "description": description, "tag": "natural-language-processing"}])

In [135]:
predict_prob(sample_df, model)

[{'prediction': 'natural-language-processing',
  'probabilities': {'mlops': 0.0010815858,
   'other': 0.0015012306,
   'natural-language-processing': 0.99653804,
   'computer-vision': 0.00087906554}}]

# MLFLOW

In [152]:
import mlflow
from pathlib import Path
from ray.air.integrations.mlflow import MLflowLoggerCallback
import time

In [153]:
MODEL_REGISTERY = Path("tmp/mlflow")
MODEL_REGISTERY.mkdir(exist_ok=True, parents=True)

In [172]:
MLFLOW_TRACKING_URI = "file:" + str(MODEL_REGISTERY.absolute())
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)

In [173]:
MLFLOW_TRACKING_URI

'file:/Users/ngkuissi/Dev/learning/Made-With-ML/notebooks/tmp/mlflow'

In [174]:
print(mlflow.get_tracking_uri())

file:/Users/ngkuissi/Dev/learning/Made-With-ML/notebooks/tmp/mlflow


In [175]:
experiment_name = f"llm-{int(time.time())}"
mlflow.set_experiment(experiment_name)
mlflow_callback = MLflowLoggerCallback(
    mlflow.get_tracking_uri(),
    experiment_name=experiment_name,
    save_artifact=True
)

2025/09/15 13:02:03 INFO mlflow.tracking.fluent: Experiment with name 'llm-1757955723' does not exist. Creating a new experiment.


In [176]:
run_config = RunConfig(
    checkpoint_config=checkpoint_config,
    callbacks=[mlflow_callback]
)

In [177]:
ds = load_data()
train_ds, val_ds = stratify_split(ds, stratify="tag", test_size=test_size)

In [178]:
preprocessor = CustomPreprocessor()
train_ds = preprocessor.fit_transform(train_ds)
val_ds = preprocessor.transform(val_ds)

2025-09-15 13:02:13,102	INFO logging.py:295 -- Registered dataset logger for dataset dataset_137_0
2025-09-15 13:02:13,119	INFO streaming_executor.py:159 -- Starting execution of Dataset dataset_137_0. Full logs are in /tmp/ray/session_2025-09-15_12-30-08_614560_3457/logs/ray-data
2025-09-15 13:02:13,119	INFO streaming_executor.py:160 -- Execution plan of Dataset dataset_137_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadCSV] -> AllToAllOperator[RandomShuffle] -> AllToAllOperator[Sort] -> AllToAllOperator[MapBatches(_add_split)->MapBatches(_filter_split)->RandomShuffle] -> AllToAllOperator[Aggregate] -> LimitOperator[limit=1]


Running 0: 0.00 row [00:00, ? row/s]

- ReadCSV->SplitBlocks(16) 1: 0.00 row [00:00, ? row/s]

- RandomShuffle 2: 0.00 row [00:00, ? row/s]

Shuffle Map 3:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 4:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- Sort 5: 0.00 row [00:00, ? row/s]

Sort Sample 6:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Map 7:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 8:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- MapBatches(_add_split)->MapBatches(_filter_split)->RandomShuffle 9: 0.00 row [00:00, ? row/s]

Shuffle Map 10:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 11:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- Aggregate 12: 0.00 row [00:00, ? row/s]

Sort Sample 13:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Map 14:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 15:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- limit=1 16: 0.00 row [00:00, ? row/s]

created_on: timestamp[s]
title: string
description: string
tag: string, new schema: . This may lead to unexpected behavior.
2025-09-15 13:02:15,252	INFO streaming_executor.py:279 -- ✔️  Dataset dataset_137_0 execution finished in 2.13 seconds


In [179]:
train_ds = train_ds.materialize()
val_ds = val_ds.materialize()

2025-09-15 13:02:15,463	INFO logging.py:295 -- Registered dataset logger for dataset dataset_140_0
2025-09-15 13:02:15,484	INFO streaming_executor.py:159 -- Starting execution of Dataset dataset_140_0. Full logs are in /tmp/ray/session_2025-09-15_12-30-08_614560_3457/logs/ray-data
2025-09-15 13:02:15,485	INFO streaming_executor.py:160 -- Execution plan of Dataset dataset_140_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadCSV] -> AllToAllOperator[RandomShuffle] -> AllToAllOperator[Sort] -> AllToAllOperator[MapBatches(_add_split)->MapBatches(_filter_split)->RandomShuffle] -> TaskPoolMapOperator[CustomPreprocessor]


Running 0: 0.00 row [00:00, ? row/s]

- ReadCSV->SplitBlocks(16) 1: 0.00 row [00:00, ? row/s]

- RandomShuffle 2: 0.00 row [00:00, ? row/s]

Shuffle Map 3:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 4:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- Sort 5: 0.00 row [00:00, ? row/s]

Sort Sample 6:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Map 7:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 8:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- MapBatches(_add_split)->MapBatches(_filter_split)->RandomShuffle 9: 0.00 row [00:00, ? row/s]

Shuffle Map 10:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 11:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- CustomPreprocessor 12: 0.00 row [00:00, ? row/s]

created_on: timestamp[s]
title: string
description: string
tag: string, new schema: . This may lead to unexpected behavior.
mask: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
target: int64, new schema: ids: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
mask: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
target: int64. This may lead to unexpected behavior.
2025-09-15 13:02:17,292	INFO streaming_executor.py:279 -- ✔️  Dataset dataset_140_0 execution finished in 1.81 seconds
2025-09-15 13:02:17,371	INFO logging.py:295 -- Registered dataset logger for dataset dataset_142_0
2025-09-15 13:02:17,380	INFO streaming_executor.py:159 -- Starting execution of Dataset dataset_142_0. Full logs are in /tmp/ray/session_2025-09-15_12-30-08_614560_3457/logs/ray-data
2025-09-15 13:02:17,381	INFO streaming_executor.py:160 -- Execution plan of Dataset dataset_142_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadCSV] -> AllToAllOperator[RandomShuffle] -> AllToAllOperator[Sort

Running 0: 0.00 row [00:00, ? row/s]

- ReadCSV->SplitBlocks(16) 1: 0.00 row [00:00, ? row/s]

- RandomShuffle 2: 0.00 row [00:00, ? row/s]

Shuffle Map 3:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 4:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- Sort 5: 0.00 row [00:00, ? row/s]

Sort Sample 6:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Map 7:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 8:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- MapBatches(_add_split)->MapBatches(_filter_split)->RandomShuffle 9: 0.00 row [00:00, ? row/s]

Shuffle Map 10:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 11:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- CustomPreprocessor 12: 0.00 row [00:00, ? row/s]

created_on: timestamp[s]
title: string
description: string
tag: string, new schema: . This may lead to unexpected behavior.
mask: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
target: int64, new schema: ids: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
mask: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
target: int64. This may lead to unexpected behavior.
2025-09-15 13:02:19,393	INFO streaming_executor.py:279 -- ✔️  Dataset dataset_142_0 execution finished in 2.01 seconds


In [180]:
trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config=train_loop_config,
    scaling_config=scaling_config,
    run_config= run_config,
    datasets= {"train": train_ds, "val": val_ds},
    dataset_config=dataset_config,
    #preprocessor=preprocessor   
)



In [181]:
results = trainer.fit()

2025-09-15 13:02:40,834	INFO tune.py:616 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


== Status ==
Current time: 2025-09-15 13:02:40 (running for 00:00:00.12)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 PENDING)


== Status ==
Current time: 2025-09-15 13:02:46 (running for 00:00:05.20)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-15 13:02:51 (running for 00:00:10.28)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=43664) Running 0: 0.00 row [00:00, ? row/s]

(pid=43664) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 13:02:56 (running for 00:00:15.72)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=43665) Running 0: 0.00 row [00:00, ? row/s]

(pid=43665) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 13:03:01 (running for 00:00:20.73)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.


(pid=43664) Running 0: 0.00 row [00:00, ? row/s]

(pid=43664) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 13:03:06 (running for 00:00:25.76)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-15 13:03:11 (running for 00:00:30.85)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=43665) Running 0: 0.00 row [00:00, ? row/s]

(pid=43665) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 13:03:16 (running for 00:00:35.89)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=43664) Running 0: 0.00 row [00:00, ? row/s]

(pid=43664) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 13:03:21 (running for 00:00:40.96)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-15 13:03:26 (running for 00:00:46.05)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=43665) Running 0: 0.00 row [00:00, ? row/s]

(pid=43665) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 13:03:31 (running for 00:00:51.05)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=43664) Running 0: 0.00 row [00:00, ? row/s]

(pid=43664) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 13:03:36 (running for 00:00:56.07)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-15 13:03:41 (running for 00:01:01.09)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=43665) Running 0: 0.00 row [00:00, ? row/s]

(pid=43665) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

(pid=43664) Running 0: 0.00 row [00:00, ? row/s]

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.


(pid=43664) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 13:03:47 (running for 00:01:06.16)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-15 13:03:52 (running for 00:01:11.26)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=43665) Running 0: 0.00 row [00:00, ? row/s]

(pid=43665) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

(pid=43664) Running 0: 0.00 row [00:00, ? row/s]

(pid=43664) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.


== Status ==
Current time: 2025-09-15 13:03:57 (running for 00:01:16.36)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-15 13:04:02 (running for 00:01:21.45)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-15 13:04:07 (running for 00:01:26.48)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Stat

(pid=43665) Running 0: 0.00 row [00:00, ? row/s]

(pid=43665) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

(pid=43664) Running 0: 0.00 row [00:00, ? row/s]

(pid=43664) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.


== Status ==
Current time: 2025-09-15 13:04:17 (running for 00:01:36.61)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-15 13:04:22 (running for 00:01:41.63)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-15 13:04:27 (running for 00:01:46.65)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=43665) Running 0: 0.00 row [00:00, ? row/s]

(pid=43665) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 13:04:32 (running for 00:01:51.70)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=43664) Running 0: 0.00 row [00:00, ? row/s]

(pid=43664) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 13:04:37 (running for 00:01:56.79)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-15 13:04:42 (running for 00:02:01.89)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=43665) Running 0: 0.00 row [00:00, ? row/s]

(pid=43665) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 13:04:47 (running for 00:02:06.93)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.


(pid=43664) Running 0: 0.00 row [00:00, ? row/s]

(pid=43664) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 13:04:52 (running for 00:02:11.93)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-15 13:04:57 (running for 00:02:17.05)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-15 13:05:02 (running for 00:02:22.07)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=43665) Running 0: 0.00 row [00:00, ? row/s]

(pid=43665) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

(pid=43664) Running 0: 0.00 row [00:00, ? row/s]

(pid=43664) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

== Status ==
Current time: 2025-09-15 13:05:07 (running for 00:02:27.14)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-15 13:05:13 (running for 00:02:32.17)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)


== Status ==
Current time: 2025-09-15 13:05:18 (running for 00:02:37.17)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 RUNNING)




(pid=43665) Running 0: 0.00 row [00:00, ? row/s]

(pid=43665) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.


copying /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts/TorchTrainer_ca0f8_00000_0_2025-09-15_13-02-40/result.json -> /Users/ngkuissi/Dev/learning/Made-With-ML/notebooks/tmp/mlflow/924140208947310253/159c5a13ceab43bda0a04a9b91caa14f/artifacts
copying /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts/TorchTrainer_ca0f8_00000_0_2025-09-15_13-02-40/params.pkl -> /Users/ngkuissi/Dev/learning/Made-With-ML/notebooks/tmp/mlflow/924140208947310253/159c5a13ceab43bda0a04a9b91caa14f/artifacts
copying /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts/TorchTrainer_ca0f8_00000_0_2025-09-15_13-02-40/params.json -> /Users/ngkuissi/Dev/learning/Made-With-ML/notebooks/tmp/mlflow/924140208947310253/159c5a13ceab43bda0a04a9b91caa14f/artifacts
copying /tmp/ray/session_2025

2025-09-15 13:05:22,785	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/Users/ngkuissi/ray_results/TorchTrainer_2025-09-15_13-02-40' in 0.0040s.
2025-09-15 13:05:22,821	INFO tune.py:1041 -- Total run time: 161.99 seconds (161.94 seconds for the tuning loop).


== Status ==
Current time: 2025-09-15 13:05:22 (running for 00:02:41.94)
Using FIFO scheduling algorithm.
Logical resource usage: 4.0/8 CPUs, 0/0 GPUs
Result logdir: /tmp/ray/session_2025-09-15_12-30-08_614560_3457/artifacts/2025-09-15_13-02-40/TorchTrainer_2025-09-15_13-02-40/driver_artifacts
Number of trials: 1/1 (1 TERMINATED)




In [182]:
results.metrics_dataframe

Unnamed: 0,epoch,lr,train_loss,val_loss,timestamp,checkpoint_dir_name,should_checkpoint,done,training_iteration,trial_id,...,time_since_restore,iterations_since_restore,config/train_loop_config/dropout_p,config/train_loop_config/lr,config/train_loop_config/lr_factor,config/train_loop_config/lr_patience,config/train_loop_config/num_epochs,config/train_loop_config/batch_size,config/train_loop_config/num_classes,config/train_loop_config/device
0,0,0.0001,1.324052,1.156449,1757955782,checkpoint_000000,True,False,1,ca0f8_00000,...,17.612937,1,0.5,0.0001,0.8,3,10,128,4,mps
1,1,0.0001,1.061896,0.747005,1757955798,checkpoint_000001,True,False,2,ca0f8_00000,...,33.065001,2,0.5,0.0001,0.8,3,10,128,4,mps
2,2,0.0001,0.621647,0.393996,1757955813,checkpoint_000002,True,False,3,ca0f8_00000,...,47.960903,3,0.5,0.0001,0.8,3,10,128,4,mps
3,3,0.0001,0.333326,0.343746,1757955824,checkpoint_000003,True,False,4,ca0f8_00000,...,58.423026,4,0.5,0.0001,0.8,3,10,128,4,mps
4,4,0.0001,0.16367,0.243425,1757955834,checkpoint_000004,True,False,5,ca0f8_00000,...,68.557185,5,0.5,0.0001,0.8,3,10,128,4,mps
5,5,0.0001,0.080955,0.264231,1757955855,checkpoint_000005,True,False,6,ca0f8_00000,...,89.637968,6,0.5,0.0001,0.8,3,10,128,4,mps
6,6,0.0001,0.051695,0.258584,1757955872,checkpoint_000006,True,False,7,ca0f8_00000,...,106.2711,7,0.5,0.0001,0.8,3,10,128,4,mps
7,7,0.0001,0.020969,0.342157,1757955887,checkpoint_000007,True,False,8,ca0f8_00000,...,121.103904,8,0.5,0.0001,0.8,3,10,128,4,mps
8,8,8e-05,0.030471,0.327466,1757955907,checkpoint_000008,True,False,9,ca0f8_00000,...,140.887608,9,0.5,0.0001,0.8,3,10,128,4,mps
9,9,8e-05,0.013599,0.365207,1757955921,checkpoint_000009,True,False,10,ca0f8_00000,...,154.397435,10,0.5,0.0001,0.8,3,10,128,4,mps


In [183]:
sorted_runs = mlflow.search_runs(experiment_names=[experiment_name],
                                        order_by=['metrics.val_loss ASC'])

In [184]:
sorted_runs

Unnamed: 0,run_id,experiment_id,status,artifact_uri,start_time,end_time,metrics.pid,metrics.lr,metrics.epoch,metrics.train_loss,...,params.train_loop_config/lr_factor,params.train_loop_config/lr_patience,params.train_loop_config/dropout_p,params.train_loop_config/num_epochs,params.train_loop_config/device,params.train_loop_config/batch_size,params.train_loop_config/lr,params.train_loop_config/num_classes,tags.trial_name,tags.mlflow.runName
0,159c5a13ceab43bda0a04a9b91caa14f,924140208947310253,FINISHED,file:///Users/ngkuissi/Dev/learning/Made-With-...,2025-09-15 17:02:45.621000+00:00,2025-09-15 17:05:22.773000+00:00,43506.0,8e-05,9.0,0.013599,...,0.8,3,0.5,10,mps,128,0.0001,4,TorchTrainer_ca0f8_00000,TorchTrainer_ca0f8_00000


In [None]:
!mlflow ui --backend-store-uri file:/Users/ngkuissi/Dev/learning/Made-With-ML/notebooks/tmp/mlflow --port 5000

  import pkg_resources
  import pkg_resources
[2025-09-14 20:14:58 -0400] [22221] [INFO] Starting gunicorn 20.1.0
[2025-09-14 20:14:58 -0400] [22221] [INFO] Listening at: http://0.0.0.0:8080 (22221)
[2025-09-14 20:14:58 -0400] [22221] [INFO] Using worker: sync
[2025-09-14 20:14:58 -0400] [22222] [INFO] Booting worker with pid: 22222
[2025-09-14 20:14:58 -0400] [22223] [INFO] Booting worker with pid: 22223
[2025-09-14 20:14:58 -0400] [22224] [INFO] Booting worker with pid: 22224
[2025-09-14 20:14:59 -0400] [22225] [INFO] Booting worker with pid: 22225
^C

Aborted!
[2025-09-14 20:26:17 -0400] [22225] [INFO] Worker exiting (pid: 22225)
[2025-09-14 20:26:17 -0400] [22223] [INFO] Worker exiting (pid: 22223)
[2025-09-14 20:26:17 -0400] [22222] [INFO] Worker exiting (pid: 22222)
[2025-09-14 20:26:17 -0400] [22224] [INFO] Worker exiting (pid: 22224)


In [185]:
from ray.air import Result
from urllib.parse import urlparse

In [186]:
def best_checkpoint(run_id):
    # get path from mlflow
    artifact_dir = urlparse(mlflow.get_run(run_id).info.artifact_url).path
    result = Result.from_path(artifact_dir)
    return result.get_best_checkpoint[0][0]

# Tuning

In [188]:
num_runs = 2

In [199]:
from ray import tune
from ray.tune import Tuner
from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.tune.search import ConcurrencyLimiter
from ray.tune.search.hyperopt import HyperOptSearch

In [190]:
set_seed()

In [191]:
ds = load_data()
train_ds, val_ds = stratify_split(ds, stratify="tag", test_size=test_size)

In [192]:
preprocessor = CustomPreprocessor()
train_ds = preprocessor.fit_transform(train_ds)
val_ds = preprocessor.transform(val_ds)
train_ds = train_ds.materialize()
val_ds = val_ds.materialize()

2025-09-17 14:14:57,541	INFO logging.py:295 -- Registered dataset logger for dataset dataset_158_0
2025-09-17 14:14:57,668	INFO streaming_executor.py:159 -- Starting execution of Dataset dataset_158_0. Full logs are in /tmp/ray/session_2025-09-15_12-30-08_614560_3457/logs/ray-data
2025-09-17 14:14:57,669	INFO streaming_executor.py:160 -- Execution plan of Dataset dataset_158_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadCSV] -> AllToAllOperator[RandomShuffle] -> AllToAllOperator[Sort] -> AllToAllOperator[MapBatches(_add_split)->MapBatches(_filter_split)->RandomShuffle] -> AllToAllOperator[Aggregate] -> LimitOperator[limit=1]


Running 0: 0.00 row [00:00, ? row/s]

- ReadCSV->SplitBlocks(16) 1: 0.00 row [00:00, ? row/s]

- RandomShuffle 2: 0.00 row [00:00, ? row/s]

Shuffle Map 3:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 4:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- Sort 5: 0.00 row [00:00, ? row/s]

Sort Sample 6:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Map 7:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 8:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- MapBatches(_add_split)->MapBatches(_filter_split)->RandomShuffle 9: 0.00 row [00:00, ? row/s]

Shuffle Map 10:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 11:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- Aggregate 12: 0.00 row [00:00, ? row/s]

Sort Sample 13:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Map 14:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 15:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- limit=1 16: 0.00 row [00:00, ? row/s]

created_on: timestamp[s]
title: string
description: string
tag: string, new schema: . This may lead to unexpected behavior.
2025-09-17 14:15:00,290	INFO streaming_executor.py:279 -- ✔️  Dataset dataset_158_0 execution finished in 2.61 seconds
2025-09-17 14:15:00,454	INFO logging.py:295 -- Registered dataset logger for dataset dataset_161_0
2025-09-17 14:15:00,468	INFO streaming_executor.py:159 -- Starting execution of Dataset dataset_161_0. Full logs are in /tmp/ray/session_2025-09-15_12-30-08_614560_3457/logs/ray-data
2025-09-17 14:15:00,470	INFO streaming_executor.py:160 -- Execution plan of Dataset dataset_161_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadCSV] -> AllToAllOperator[RandomShuffle] -> AllToAllOperator[Sort] -> AllToAllOperator[MapBatches(_add_split)->MapBatches(_filter_split)->RandomShuffle] -> TaskPoolMapOperator[CustomPreprocessor]


Running 0: 0.00 row [00:00, ? row/s]

- ReadCSV->SplitBlocks(16) 1: 0.00 row [00:00, ? row/s]

- RandomShuffle 2: 0.00 row [00:00, ? row/s]

Shuffle Map 3:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 4:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- Sort 5: 0.00 row [00:00, ? row/s]

Sort Sample 6:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Map 7:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 8:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- MapBatches(_add_split)->MapBatches(_filter_split)->RandomShuffle 9: 0.00 row [00:00, ? row/s]

Shuffle Map 10:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 11:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- CustomPreprocessor 12: 0.00 row [00:00, ? row/s]

created_on: timestamp[s]
title: string
description: string
tag: string, new schema: . This may lead to unexpected behavior.
mask: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
target: int64, new schema: ids: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
mask: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
target: int64. This may lead to unexpected behavior.
2025-09-17 14:15:03,720	INFO streaming_executor.py:279 -- ✔️  Dataset dataset_161_0 execution finished in 3.25 seconds
2025-09-17 14:15:03,940	INFO logging.py:295 -- Registered dataset logger for dataset dataset_163_0
2025-09-17 14:15:03,950	INFO streaming_executor.py:159 -- Starting execution of Dataset dataset_163_0. Full logs are in /tmp/ray/session_2025-09-15_12-30-08_614560_3457/logs/ray-data
2025-09-17 14:15:03,952	INFO streaming_executor.py:160 -- Execution plan of Dataset dataset_163_0: InputDataBuffer[Input] -> TaskPoolMapOperator[ReadCSV] -> AllToAllOperator[RandomShuffle] -> AllToAllOperator[Sort

Running 0: 0.00 row [00:00, ? row/s]

- ReadCSV->SplitBlocks(16) 1: 0.00 row [00:00, ? row/s]

- RandomShuffle 2: 0.00 row [00:00, ? row/s]

Shuffle Map 3:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 4:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- Sort 5: 0.00 row [00:00, ? row/s]

Sort Sample 6:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Map 7:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 8:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- MapBatches(_add_split)->MapBatches(_filter_split)->RandomShuffle 9: 0.00 row [00:00, ? row/s]

Shuffle Map 10:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

Shuffle Reduce 11:   0%|          | 0.00/1.00 [00:00<?, ? row/s]

- CustomPreprocessor 12: 0.00 row [00:00, ? row/s]

created_on: timestamp[s]
title: string
description: string
tag: string, new schema: . This may lead to unexpected behavior.
mask: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
target: int64, new schema: ids: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
mask: extension<ray.data.arrow_tensor_v2<ArrowTensorTypeV2>>
target: int64. This may lead to unexpected behavior.
2025-09-17 14:15:06,038	INFO streaming_executor.py:279 -- ✔️  Dataset dataset_163_0 execution finished in 2.09 seconds


In [193]:
trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config=train_loop_config,
    scaling_config=scaling_config,
    run_config= run_config,
    datasets= {"train": train_ds, "val": val_ds},
    dataset_config=dataset_config,
    #preprocessor=preprocessor   
)



In [194]:
mlflow_callback = MLflowLoggerCallback(
    mlflow.get_tracking_uri(),
    experiment_name=experiment_name,
    save_artifact=True
)

In [195]:
checkpoint_config = CheckpointConfig(num_to_keep=1, checkpoint_score_attribute="val_loss", checkpoint_score_order="min")

In [196]:
run_config = RunConfig(
    checkpoint_config=checkpoint_config,
    callbacks=[mlflow_callback]
)

## search algorithm

In [200]:
initial_param = [{"train_loop_config": {"dropout_p": 0.5, "lr": 1e-4, "lr_factor": 0.8, "lr_patience": 3}}]
search_algorithm = HyperOptSearch(points_to_evaluate=initial_param)
search_algorithm = ConcurrencyLimiter(search_algorithm, max_concurrent=2)


In [204]:
param_space = {
    "train_loop_config": {
        "dropout_p": tune.uniform(0.3, 0.9),
        "lr": tune.loguniform(1e-5, 5e-4),
        "lr_factor": tune.uniform(0.1, 0.9),
        "lr_patience": tune.uniform(1, 10),
    }
}

In [201]:
scheduler = AsyncHyperBandScheduler(
    max_t = train_loop_config['num_epochs'],
    grace_period=5
)

In [202]:
tune_config = tune.TuneConfig(
    mode="min",
    metric="val_loss",
    search_alg=search_algorithm,
    scheduler=scheduler,
    num_samples=num_runs
)

In [205]:
tuner = Tuner(
    trainable=trainer,
    param_space=param_space,
    tune_config=tune_config,
    run_config = run_config
)

2025-09-17 14:25:41,281	INFO tuner_internal.py:427 -- A `RunConfig` was passed to both the `Tuner` and the `TorchTrainer`. The run config passed to the `Tuner` is the one that will be used.


In [206]:
results = tuner.fit()

0,1
Current time:,2025-09-17 15:22:30
Running for:,00:56:38.81
Memory:,13.5/16.0 GiB

Trial name,status,loc,train_loop_config/dr opout_p,train_loop_config/lr,train_loop_config/lr _factor,train_loop_config/lr _patience,iter,total time (s),epoch,lr,train_loss
TorchTrainer_041ee198,RUNNING,127.0.0.1:64048,0.5,0.0001,0.8,3.0,4,3219.05,3,0.0001,0.380443
TorchTrainer_38be6225,RUNNING,127.0.0.1:64195,0.415523,6.30269e-05,0.397444,2.50983,2,2554.54,1,6.30269e-05,1.05583


(pid=64273) Running 0: 0.00 row [00:00, ? row/s]

(pid=64273) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

(pid=64389) Running 0: 0.00 row [00:00, ? row/s]

(pid=64389) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

(pid=64274) Running 0: 0.00 row [00:00, ? row/s]

(pid=64274) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

(pid=64273) Running 0: 0.00 row [00:00, ? row/s]

(pid=64273) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.


(pid=64390) Running 0: 0.00 row [00:00, ? row/s]

(pid=64390) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

(pid=64389) Running 0: 0.00 row [00:00, ? row/s]

(pid=64389) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

(pid=64274) Running 0: 0.00 row [00:00, ? row/s]

(pid=64274) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

(pid=64273) Running 0: 0.00 row [00:00, ? row/s]

(pid=64273) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.


(pid=64274) Running 0: 0.00 row [00:00, ? row/s]

(pid=64274) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

(pid=64273) Running 0: 0.00 row [00:00, ? row/s]

(pid=64273) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.


(pid=64390) Running 0: 0.00 row [00:00, ? row/s]

(pid=64390) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

(pid=64389) Running 0: 0.00 row [00:00, ? row/s]

(pid=64389) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

(pid=64274) Running 0: 0.00 row [00:00, ? row/s]

(pid=64274) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

(pid=64273) Running 0: 0.00 row [00:00, ? row/s]

(pid=64273) - split(1, equal=True) 1: 0.00 row [00:00, ? row/s]

You may want to consider increasing the `CheckpointConfig(num_to_keep)` or decreasing the frequency of saving checkpoints.
2025-09-17 15:22:30,641	INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/Users/ngkuissi/ray_results/TorchTrainer_2025-09-17_14-25-41' in 0.8032s.
2025-09-17 15:22:41,030	INFO tune.py:1041 -- Total run time: 3409.25 seconds (3398.00 seconds for the tuning loop).
Resume experiment with: Tuner.restore(path="/Users/ngkuissi/ray_results/TorchTrainer_2025-09-17_14-25-41", trainable=...)
