In [1]:
import os
import sys
import random
import re
from typing import Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import nltk
from nltk.corpus import stopwords
from transformers import BertTokenizer
from transformers import BertModel

import ray
from ray.data.preprocessor import Preprocessor
from ray.data import Dataset
import ray.train as train
from ray.train.torch import TorchCheckpoint, TorchTrainer, get_device
from ray.train import Checkpoint, session, DataConfig
from ray.air.config import CheckpointConfig, DatasetConfig, RunConfig, ScalingConfig

# Add the parent directory of the script's current directory to the `sys.path`. This means that when importing modules, Python will also look in the parent directory for any modules that are not found in the current directory or the standard library paths.
sys.path.append("..")

# Where to get data
DATASET_LOC = "https://raw.githubusercontent.com/GokuMohandas/Made-With-ML/main/datasets/dataset.csv"
HOLDOUT_LOC = "https://raw.githubusercontent.com/GokuMohandas/Made-With-ML/main/datasets/holdout.csv"

# Stopwords, effectively reducing dataset size by removing words that occur often but does not add value to learning.
nltk.download("stopwords")
SW = stopwords.words("english")

[nltk_data] Downloading package stopwords to /home/tc/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [2]:
# String ops to clean things up during pre-processing
def clean_text(text, stopwords=SW):
    
    # Lower case it
    text = text.lower()

    # Remove stopwords
    pattern = re.compile(r'\b(' + r"|".join(stopwords) + r")\b\s*")
    text = pattern.sub('', text)

    # Spacing and filters
    text = re.sub(r"([!\"'#$%&()*\+,-./:;<=>?@\\\[\]^_`{|}~])", r" \1 ", text)  # add spacing
    text = re.sub("[^A-Za-z0-9]+", " ", text)  # remove non alphanumeric chars
    text = re.sub(" +", " ", text)  # remove multiple spaces
    text = text.strip()  # strip white space at the ends
    text = re.sub(r"http\S+", "", text)  #  remove links

    return text

# Go from words to numericals 
def tokenize(batch) -> dict:
    tok = BertTokenizer.from_pretrained("allenai/scibert_scivocab_uncased", return_dict=False)
    enc_ip = tok(batch["text"].tolist(), return_tensors="np", padding="longest")
    return dict(ids=enc_ip["input_ids"], masks=enc_ip["attention_mask"], targets=np.array(batch["tag"]))

# Pre-process data 
def pp(df, class_to_index):
    # Feature 'engineering' with addition of column title & column description.
    df["text"] = df.title + " " + df.description
    # Clean text: lower case, remove stopwords, use regex deal with spaces & special characters
    df["text"] = df.text.apply(clean_text)
    # Drop unnecessary baggage
    df = df.drop(columns=["id", "created_on", "title", "description"], errors="ignore")
    # Arrange
    df = df[["text", "tag"]]
    # Map class strings to a number, for machine learning
    df["tag"] = df["tag"].map(class_to_index)
    # Go from words to numbers, for machine learning
    op = tokenize(df)
    return op

class CustomPP(Preprocessor): 
    def _fit(self, ds):
        tags = ds.unique(column="tag")
        self.class_to_index = { tag: i for i, tag in enumerate(tags) }
        self.index_to_class = { v: k for k, v in self.class_to_index.items() }
    def _transform_pandas(self, batch):
        return pp(batch, class_to_index=self.class_to_index)

# Split
def stratify_split(
    ds: Dataset,
    stratify: str,
    test_size: float,
    shuffle: bool = True,
    seed: int = 1234,
) -> Tuple[Dataset, Dataset]:
    """Split a dataset into train and test splits with equal
    amounts of data points from each class in the column we
    want to stratify on.

    Args:
        ds (Dataset): Input dataset to split.
        stratify (str): Name of column to split on.
        test_size (float): Proportion of dataset to split for test set.
        shuffle (bool, optional): whether to shuffle the dataset. Defaults to True.
        seed (int, optional): seed for shuffling. Defaults to 1234.

    Returns:
        Tuple[Dataset, Dataset]: the stratified train and test datasets.
    """

    def _add_split(df: pd.DataFrame) -> pd.DataFrame:  # pragma: no cover, used in parent function
        """Naively split a dataframe into train and test splits.
        Add a column specifying whether it's the train or test split."""
        train, test = train_test_split(df, test_size=test_size, shuffle=shuffle, random_state=seed)
        train["_split"] = "train"
        test["_split"] = "test"
        return pd.concat([train, test])

    def _filter_split(df: pd.DataFrame, split: str) -> pd.DataFrame:  # pragma: no cover, used in parent function
        """Filter by data points that match the split column's value
        and return the dataframe with the _split column dropped."""
        return df[df["_split"] == split].drop("_split", axis=1)

    # Train, test split with stratify
    grouped = ds.groupby(stratify).map_groups(_add_split, batch_format="pandas")  # group by each unique value in the column we want to stratify on
    train_ds = grouped.map_batches(_filter_split, fn_kwargs={"split": "train"}, batch_format="pandas")  # combine
    test_ds = grouped.map_batches(_filter_split, fn_kwargs={"split": "test"}, batch_format="pandas")  # combine

    # Shuffle each split (required)
    train_ds = train_ds.random_shuffle(seed=seed)
    test_ds = test_ds.random_shuffle(seed=seed)

    return train_ds, test_ds

# Set seeds to enable determinism or change otherwise.
def set_seeds(seed=42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    eval("setattr(torch.backends.cudnn, 'deterministic', True)")
    eval("setattr(torch.backends.cudnn, 'benchmark', False)")
    os.environ["PYTHONHASHSEED"] = str(seed)

# Load data and create dataset, ray uses dataset as primitive
# Have a way to control how many samples to load, best practice, for testing purposes before pushing on to full data 
def load_data(num_samples=None):
    ds = ray.data.read_csv(DATASET_LOC)
    ds = ds.random_shuffle(seed=1234)
    ds = ray.data.from_items(ds.take(num_samples)) if num_samples else ds
    return ds 

# Make em all same length 
def pad_array(arr, dtype=np.int32):
    max_len = max(len(row) for row in arr)
    padded_arr = np.zeros( (arr.shape[0], max_len), dtype=dtype )
    for i, row in enumerate(arr):
        padded_arr[i][:len(row)] = row
    return padded_arr

# Pad and tensor it
def collate_fn(batch):
    batch["ids"] = pad_array(batch["ids"])
    batch["masks"] = pad_array(batch["masks"])
    dtypes = { "ids" : torch.int32 , "masks" : torch.int32 , "targets" : torch.int64 }
    batch_tensored = {}
    for key, array in batch.items():
        batch_tensored[key] = torch.as_tensor(array, dtype=dtypes[key], device=get_device())
    return batch_tensored

In [3]:
# Pre-trained LLM
llm = BertModel.from_pretrained("allenai/scibert_scivocab_uncased", return_dict=False)
embedding_dim = llm.config.hidden_size

# Quick test to see if model is working
text = "Transfer learning with transformers for text classification"
tokenizer = BertTokenizer.from_pretrained("allenai/scibert_scivocab_uncased", return_dict=False)
batch = tokenizer([text], return_tensors="np", padding="longest")
batch_tensored = { k: torch.tensor(v) for k, v in batch.items() }
seq, pool = llm( input_ids=batch_tensored["input_ids"], attention_mask=batch_tensored["attention_mask"] )
np.shape(seq), np.shape(pool)

  return self.fget.__get__(instance, owner)()


(torch.Size([1, 9, 768]), torch.Size([1, 768]))

In [4]:
# Use LLM as base, then add linear layer for classification

class FinetunedLLM(nn.Module):
    def __init__(self, llm, dropout_p, embedding_dim, num_classes):
        super(FinetunedLLM, self).__init__()
        self.llm = llm
        self.dropout = nn.Dropout(dropout_p)
        self.fc1 = nn.Linear(embedding_dim, num_classes)

    def forward(self, batch):
        ids, masks = batch["ids"], batch["masks"]
        seq, pool = llm( input_ids=ids, attention_mask=masks )
        z = self.dropout(pool)
        z = self.fc1(z)
        return z
    
    @torch.inference_mode()
    def predict(self, batch):
        self.eval()
        z = self(batch)
        y_pred = torch.argmax(z, dim=1).cpu().numpy()
        return y_pred

    @torch.inference_mode()
    def predict_proba(self, batch):
        self.eval()
        z = self(batch)
        y_probs = F.softmax(z).cpu().numpy()
        return y_probs
    
model = FinetunedLLM(
    llm=llm,
    dropout_p = 0.5,
    embedding_dim=embedding_dim,
    num_classes=4,
)
print(model.named_parameters)

<bound method Module.named_parameters of FinetunedLLM(
  (llm): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(31090, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): Layer

In [5]:
# Get dataset ready: Ingest, split, tag mapping, cleaning text, tokenize, pre-processing.

# Ensure determinism
ray.data.DatasetContext.get_current().execution_options.preserve_order = True

# Ingest
ds = ray.data.read_csv("https://raw.githubusercontent.com/GokuMohandas/Made-With-ML/main/datasets/dataset.csv")
ds = ds.random_shuffle(seed=1234)

train_ds, val_ds = stratify_split(ds, stratify="tag", test_size=0.2)

# Mapping
tags = train_ds.unique(column="tag")
class_to_index = { tag: i for i, tag in enumerate(tags) }

# Distributed preprocessing: mapping of tags to data & applying pre-processing
sample_ds = train_ds.map_batches(
    pp,
    fn_kwargs={"class_to_index": class_to_index},
    batch_format="pandas",
)

2024-02-02 21:08:50,246	INFO worker.py:1715 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m
2024-02-02 21:08:52,330	INFO set_read_parallelism.py:115 -- Using autodetected parallelism=24 for stage ReadCSV to satisfy parallelism at least twice the available number of CPUs (12).
2024-02-02 21:08:52,330	INFO set_read_parallelism.py:122 -- To satisfy the requested parallelism of 24, each read task output is split into 24 smaller blocks.
2024-02-02 21:08:52,331	INFO streaming_executor.py:112 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[ReadCSV] -> AllToAllOperator[RandomShuffle] -> LimitOperator[limit=1]
2024-02-02 21:08:52,332	INFO streaming_executor.py:113 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), exclude_resources=ExecutionResources(cpu=0, gpu=0, object_store_memory=0), locality_with_output=False, preserve_order=True, actor_locality_enabled=True, verbos

- RandomShuffle 1:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Map 2:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Reduce 3:   0%|          | 0/1 [00:00<?, ?it/s]

Running 0:   0%|          | 0/1 [00:00<?, ?it/s]

2024-02-02 21:08:54,112	INFO set_read_parallelism.py:115 -- Using autodetected parallelism=24 for stage ReadCSV to satisfy parallelism at least twice the available number of CPUs (12).
2024-02-02 21:08:54,113	INFO set_read_parallelism.py:122 -- To satisfy the requested parallelism of 24, each read task output is split into 24 smaller blocks.
2024-02-02 21:08:54,114	INFO streaming_executor.py:112 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[ReadCSV] -> AllToAllOperator[RandomShuffle] -> AllToAllOperator[Sort] -> AllToAllOperator[MapBatches(group_fn)->MapBatches(_filter_split)->RandomShuffle] -> TaskPoolMapOperator[MapBatches(fn)] -> LimitOperator[limit=1]
2024-02-02 21:08:54,115	INFO streaming_executor.py:113 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), exclude_resources=ExecutionResources(cpu=0, gpu=0, object_store_memory=0), locality_with_output=False, preserve_order=True, actor_locality_enabl

- RandomShuffle 1:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Map 2:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Reduce 3:   0%|          | 0/1 [00:00<?, ?it/s]

- Sort 4:   0%|          | 0/1 [00:00<?, ?it/s]

Sort Sample 5:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Map 6:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Reduce 7:   0%|          | 0/1 [00:00<?, ?it/s]

- MapBatches(group_fn)->MapBatches(_filter_split)->RandomShuffle 8:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Map 9:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Reduce 10:   0%|          | 0/1 [00:00<?, ?it/s]

Running 0:   0%|          | 0/1 [00:00<?, ?it/s]

Sort Sample 0:   0%|          | 0/24 [00:00<?, ?it/s]

2024-02-02 21:08:56,176	INFO set_read_parallelism.py:115 -- Using autodetected parallelism=24 for stage ReadCSV to satisfy parallelism at least twice the available number of CPUs (12).
2024-02-02 21:08:56,176	INFO set_read_parallelism.py:122 -- To satisfy the requested parallelism of 24, each read task output is split into 24 smaller blocks.
2024-02-02 21:08:56,177	INFO streaming_executor.py:112 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[ReadCSV] -> AllToAllOperator[RandomShuffle] -> AllToAllOperator[Sort] -> AllToAllOperator[MapBatches(group_fn)->MapBatches(_filter_split)->RandomShuffle] -> TaskPoolMapOperator[MapBatches(fn)] -> AllToAllOperator[Aggregate]
2024-02-02 21:08:56,179	INFO streaming_executor.py:113 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), exclude_resources=ExecutionResources(cpu=0, gpu=0, object_store_memory=0), locality_with_output=False, preserve_order=True, actor_locality_

- RandomShuffle 1:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Map 2:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Reduce 3:   0%|          | 0/1 [00:00<?, ?it/s]

- Sort 4:   0%|          | 0/1 [00:00<?, ?it/s]

Sort Sample 5:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Map 6:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Reduce 7:   0%|          | 0/1 [00:00<?, ?it/s]

- MapBatches(group_fn)->MapBatches(_filter_split)->RandomShuffle 8:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Map 9:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Reduce 10:   0%|          | 0/1 [00:00<?, ?it/s]

- Aggregate 11:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Map 12:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Reduce 13:   0%|          | 0/1 [00:00<?, ?it/s]

Running 0:   0%|          | 0/1 [00:00<?, ?it/s]

Sort Sample 0:   0%|          | 0/24 [00:00<?, ?it/s]

Sort Sample 0:   0%|          | 0/24 [00:00<?, ?it/s]

[36m(map pid=2136)[0m   if isinstance(items[0], TensorArrayElement):
[36m(map pid=2136)[0m   return items[0]
[36m(map pid=2136)[0m   if isinstance(items[0], TensorArrayElement):
[36m(map pid=2136)[0m   return items[0]
[36m(map pid=2136)[0m   if isinstance(items[0], TensorArrayElement):
[36m(map pid=2136)[0m   return items[0]
[36m(map pid=2136)[0m   if isinstance(items[0], TensorArrayElement):
[36m(map pid=2136)[0m   return items[0]
[36m(map pid=2136)[0m   if isinstance(items[0], TensorArrayElement):
[36m(map pid=2136)[0m   return items[0]
[36m(map pid=2136)[0m   if isinstance(items[0], TensorArrayElement):
[36m(map pid=2136)[0m   return items[0]
[36m(map pid=2136)[0m   if isinstance(items[0], TensorArrayElement):
[36m(map pid=2136)[0m   return items[0]
  if isinstance(items[0], TensorArrayElement):
  return items[0]
  if isinstance(items[0], TensorArrayElement):
  return items[0]
  if isinstance(items[0], TensorArrayElement):
  return items[0]
  if isinstanc

In [6]:
sample_batch = sample_ds.take_batch(batch_size=128)
collate_fn(batch=sample_batch)

2024-02-02 21:08:57,878	INFO set_read_parallelism.py:115 -- Using autodetected parallelism=24 for stage ReadCSV to satisfy parallelism at least twice the available number of CPUs (12).
2024-02-02 21:08:57,879	INFO set_read_parallelism.py:122 -- To satisfy the requested parallelism of 24, each read task output is split into 24 smaller blocks.
2024-02-02 21:08:57,880	INFO streaming_executor.py:112 -- Executing DAG InputDataBuffer[Input] -> TaskPoolMapOperator[ReadCSV] -> AllToAllOperator[RandomShuffle] -> AllToAllOperator[Sort] -> AllToAllOperator[MapBatches(group_fn)->MapBatches(_filter_split)->RandomShuffle] -> TaskPoolMapOperator[MapBatches(pp)] -> LimitOperator[limit=128]
2024-02-02 21:08:57,882	INFO streaming_executor.py:113 -- Execution config: ExecutionOptions(resource_limits=ExecutionResources(cpu=None, gpu=None, object_store_memory=None), exclude_resources=ExecutionResources(cpu=0, gpu=0, object_store_memory=0), locality_with_output=False, preserve_order=True, actor_locality_ena

- RandomShuffle 1:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Map 2:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Reduce 3:   0%|          | 0/1 [00:00<?, ?it/s]

- Sort 4:   0%|          | 0/1 [00:00<?, ?it/s]

Sort Sample 5:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Map 6:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Reduce 7:   0%|          | 0/1 [00:00<?, ?it/s]

- MapBatches(group_fn)->MapBatches(_filter_split)->RandomShuffle 8:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Map 9:   0%|          | 0/1 [00:00<?, ?it/s]

Shuffle Reduce 10:   0%|          | 0/1 [00:00<?, ?it/s]

Running 0:   0%|          | 0/1 [00:00<?, ?it/s]

Sort Sample 0:   0%|          | 0/24 [00:00<?, ?it/s]

  batch_tensored[key] = torch.as_tensor(array, dtype=dtypes[key], device=get_device())


{'ids': tensor([[  102, 15820, 30126,  ...,     0,     0,     0],
         [  102, 18715,  4602,  ...,     0,     0,     0],
         [  102,  6160,  1923,  ...,     0,     0,     0],
         ...,
         [  102, 24895, 30111,  ...,     0,     0,     0],
         [  102,  2322,  2180,  ...,     0,     0,     0],
         [  102,  3267,  4226,  ...,     0,     0,     0]], device='cuda:0',
        dtype=torch.int32),
 'masks': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0', dtype=torch.int32),
 'targets': tensor([2, 0, 0, 0, 3, 0, 2, 0, 3, 0, 0, 0, 2, 3, 1, 2, 2, 1, 0, 2, 3, 2, 2, 2,
         0, 2, 2, 3, 3, 2, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 3, 0, 0, 0,
         2, 1, 2, 2, 0, 3, 3, 0, 0, 2, 2, 2, 1, 0, 2, 3, 2, 1, 2, 0, 2, 2, 0, 2,
         0, 2, 2, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 2, 

In [7]:
# Training pipeline

def train_step(
        ds,
        batch_size,
        model,
        num_classes,
        loss_fn,
        optimizer,
):
    model.train()
    loss = 0.0
    ds_gen = ds.iterorch_batches(batch_size=batch_size, collate_fn=collate_fn)
    for i, batch in enumerate(ds_gen):
        # Reset gradients
        optimizer.zero_grad()
        # Forward pass
        z = model(batch)
        # One-hot vectors, for loss fn
        tgts = F.one_hot(batch[".targets"], num_classes=num_classes).float()
        # Define loss
        J = loss_fn(z, tgts)
        # Backward pass
        J.backward()
        # Update weights
        optimizer.step()
        # cummulative loss
        loss += ( J.detach().item() - loss ) / ( i+1 ) 
    return loss

def eval_step(
        ds,
        batch_size,
        model,
        num_classes,
        loss_fn,
):
    model.eval()
    loss = 0.0
    y_trues, y_preds = [], []
    ds_gen = ds.iterorch_batches(batch_size=batch_size, collate_fn=collate_fn)
    with torch.inference.mode():
        for i, batch in enumerate(ds_gen):
            z = model(batch)
            tgts = F.one_hot(batch[".targets"], num_classes=num_classes).float()
            J = loss_fn(z, tgts).item()
            loss += (J - loss) / (i+1)
            y_trues.extend( batch["targets"].cpu().numpy() )
            y_preds.extend( torch.argmax(z, dim=1).cpu().numpy() )
    return loss, np.vstack(y_trues), np.vstack(y_preds)

def train_loop_per_worker(config):
    # HyperP
    dropout_p = config["dropout_p"]
    lr = config["lr"]
    lr_factor = config["lr_factor"]
    lr_patience = config["lr_patience"]
    num_epochs = config["num_epochs"]
    batch_size = config["batch_size"]
    num_classes = config["num_classes"]

    # Datasets
    set_seeds()
    train_ds = session.get_dataset_shared("train")
    val_ds = session.get_dataset_shared("val")

    # Model
    llm = BertModel.from_pretained("allenai/scibert_scivocab)uncased", return_dict=False)
    model = FinetunedLLM(
        llm=llm,
        dropout_p=dropout_p,
        embedding_dim=llm.config.hidden_size,
        num_classes=num_classes,
    )
    model = train.torch.prepare_model(model)

    # Trg sub-systems
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="min",
        factor=lr_factor,
        patience=lr_patience,
    )

    # Trg Loop
    batch_size_per_worker = batch_size // session.get_world_size()
    for epoch in range(num_epochs):
        # Step
        train_loss = train_step(
            train_ds,
            batch_size_per_worker,
            model,
            num_classes,
            loss_fn,
            optimizer,
        )

        val_loss, _ , _ = eval_step(
            val_ds,
            batch_size_per_worker,
            model,
            num_classes,
            loss_fn,
        )
        scheduler.step(val_loss)

        # Checkpoint it
        metrics = dict(
            epoch=epoch,
            lr=optimizer.param_groups[0]["lr"],
            train_loss = train_loss,
            val_loss = val_loss,
        )
        checkpoint = TorchCheckpoint.from_model(model=model)
        session.report(metrics, checkpoint=checkpoint)
    
# Trg Config
trg_cfg = {
    "dropout_p": 0.5,
    "lr": 1e-4,
    "lr_factor": 0.8,
    "lr_patience": 3,
    "num_epochs": 10,
    "batch_size": 256,
    "num_classes": len(tags),
}


resources_per_worker = {"CPU": 0.1, "GPU": 0.1}
scal_cfg = ScalingConfig(
    num_workers=10,
    use_gpu=bool(resources_per_worker["GPU"]),
    resources_per_worker=resources_per_worker,
)

checkpoint_cfg = CheckpointConfig(
    num_to_keep=1,
    checkpoint_score_attribute="val_loss",
    checkpoint_score_order="min",
)

run_cfg = RunConfig(
    name="llm",
    checkpoint_config=checkpoint_cfg,
    local_dir="~/ray_results"
)

ds_cfg = {
    "train": DataConfig(),
    "val": DataConfig(),
}

In [None]:
ds = load_data()
train_ds, val_ds = stratify_split(ds, stratify="tag", test_size=0.2)
preprocessor = CustomPP()
train_ds = preprocessor.fit_transform(train_ds)
val_ds = preprocessor.transform(val_ds)
train_ds = train_ds.materialize()
val_ds = val_ds.materialize()

trainer = TorchTrainer(
    train_loop_per_worker = train_loop_per_worker,
    train_loop_config = trg_cfg,
    scaling_config=scal_cfg,
    run_config=run_cfg,
    datasets={"train": train_ds, "val": val_ds},
    # dataset_config=ds_cfg,
    # preprocessor=preprocessor,
)
results = trainer.fit()