### Imports

In [0]:
from petastorm.spark import SparkDatasetConverter, make_spark_converter
import torch, torch.nn
import lightning as L

from petastorm import TransformSpec
from PIL import Image
import numpy as np
import io

import pyspark.sql.functions as F
from pyspark.sql.functions import col, pandas_udf, PandasUDFType

NUM_WORKERS = 2
# NOTE: This assumes the driver node and worker nodes have the same instance type.
NUM_GPUS_PER_WORKER = torch.cuda.device_count() # CHANGE AS NEEDED
USE_GPU = NUM_GPUS_PER_WORKER > 0
print(f"NUM_GPUS_PER_WORKER: {NUM_GPUS_PER_WORKER}")
print(f"Using GPU: {USE_GPU}")

log_path = f"./logger" # change this to location on DBFS


  original_result = python_builtin_import(name, globals, locals, fromlist, level)
  from pyarrow import LocalFileSystem


NUM_GPUS_PER_WORKER: 0
Using GPU: False


### Load Data

In [0]:
# import configs.two_tower_config as cfg

In [0]:

# train_positive_samples_df = spark.read.parquet(cfg.TRAIN_PATH)
# # TESTING
# # train_positive_samples_df = train_positive_samples_df.limit(10000)
# train_positive_samples_df = train_positive_samples_df.select(["site_guid", "title_key"]).dropDuplicates()

# test_positive_samples_df = spark.read.parquet(cfg.TEST_PATH)
# # TESTING
# # test_positive_samples_df = test_positive_samples_df.limit(10000)
# test_positive_samples_df = test_positive_samples_df.select(["site_guid", "title_key"]).dropDuplicates()

# positive_samples_df = train_positive_samples_df.union(test_positive_samples_df)



import numpy as np

# names = [str(x) for x in np.random.choice(["Alex", "James", "Michael", "Peter", "Harry"], size=3)]
# ids = [int(x) for x in np.random.randint(1, 10, 3)]

# positive_samples_df = spark.createDataFrame(list(zip(names, ids)), ["site_guid", "title_key"])
# define the length of the DataFrame
length = 100_000_000

# create a DataFrame with two columns, site_guid and title_key
df = spark.range(length) \
  .withColumn("site_guid", (F.rand() * 10000).cast("integer").cast("string")) \
  .withColumn("title_key", (F.rand() * 10000).cast("integer").cast("string"))


LABEL_COL = 'label'

# add a new column with random values between 0 and 1
df = df.withColumn("rand_col", F.rand())

df = df.withColumn(LABEL_COL, F.lit(1))
# assign 0 or 1 based on the random value
df = df.withColumn(LABEL_COL, F.when(F.col("rand_col") > 0.5, F.lit(1)).otherwise(F.lit(0)))


In [0]:
from pyspark.ml.feature import StringIndexer

site_stringIndexer = StringIndexer(inputCol="site_guid", outputCol="site_idx")
site_stringIndexer.setHandleInvalid("error")
model = site_stringIndexer.fit(df)
df = model.transform(df)

In [0]:
title_stringIndexer = StringIndexer(inputCol="title_key", outputCol="title_idx")
title_stringIndexer.setHandleInvalid("error")
model = title_stringIndexer.fit(df)
df = model.transform(df)

In [0]:
df = df.select('site_idx', 'title_idx', 'label')

In [0]:

display(df.limit(2))

site_idx,title_idx,label
2665.0,135.0,0
2933.0,3310.0,0


In [0]:
cfg = {
  'lr': 1e-5
}

# NUMBER OF DISTINCT USERS AND ITEMS
n_users = df.select('site_idx').distinct().count()
n_items = df.select('title_idx').distinct().count()

print(n_users, n_items)

train_df, test_df = df.randomSplit([0.7, 0.3], seed=42)
train_df, val_df = train_df.randomSplit([0.7, 0.3], seed=42)

##### TEST
test_df = test_df.limit(100000)

print(f"train_df: {train_df.count()}\ntest_df: {test_df.count()}")

10000 10000
train_df: 49010252
test_df: 100000


### Build Lightning Model
This is a fairly standard LightningModule, nothing crazy going on.

In [0]:
from torch import nn
class TwoTower(L.LightningModule):
    def __init__(self, n_users, n_items, cfg, embedding_size=32):
        super().__init__()
        
        self.cfg = cfg
        
        # We add +1 additional embedding to account for unknown tokens.
        self.user_emb = nn.Embedding(num_embeddings=n_users + 1, embedding_dim=embedding_size)
        self.item_emb = nn.Embedding(num_embeddings=n_items, embedding_dim=embedding_size)  # self.ln[0]

        self.item_layers = [] #nn.ModuleList()
        self.user_layers = [] #nn.ModuleList()

        # for i, n in enumerate(ln[0:-1]):
        #     m = int(ln[i+1])
        self.item_layers.append(nn.Linear(embedding_size, embedding_size, bias=True))  # n, m
        self.item_layers.append(nn.ReLU())
        self.user_layers.append(nn.Linear(embedding_size, embedding_size, bias=True))
        self.user_layers.append(nn.ReLU())   # is this ReLU needed???


        self.item_layers = nn.Sequential(*self.item_layers)
        self.user_layers = nn.Sequential(*self.user_layers)
        self.dot = torch.matmul
        self.sigmoid = nn.Sigmoid()

        self.criterion = nn.BCELoss()

        # save hyper-parameters to self.hparamsm auto-logged by wandb
        self.save_hyperparameters()


    def get_logits(self, users, items):

        item_emb = self.item_emb(items) # [B, embed_size]
        user_emb = self.user_emb(users) # [B, embed_size]
        
        item_emb = self.item_layers(item_emb) # [B, embed_size]
        user_emb = self.user_layers(user_emb) # [B, embed_size]

        dp = self.dot(user_emb, item_emb.T) # [B, B]
        # print(f"\nAfter dot: {dp}")
        dp = dp.sum(dim=1).squeeze() # [B]
        # print(f"\nAfter sum: {dp}")
        dp = self.sigmoid(dp)

        return dp

      
    def extract_from_batch(self, batch):
      
        users, items, labels = batch.values()
        return users.long(), items.long(), labels.float()


    def forward(self, batch):
        # in lightning, forward defines the prediction/inference actions

        # Make prediction
        users, items, labels = self.extract_from_batch(batch)
        logits = self.get_logits(user, item)
        preds = torch.argmax(logits, dim=1).flatten().tolist()
        return preds

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop. It is independent of forward

        # batch arrives as a dictionary with keys [site_idx, title_idx, label]
        users, items, labels = self.extract_from_batch(batch)
        dp = self.get_logits(users, items)
        loss = self.criterion(dp, labels)
        self.log("train_loss", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        users, items, labels = self.extract_from_batch(batch)
        dp = self.get_logits(users, items)
        loss = self.criterion(dp, labels)
        self.log("val_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        users, items, labels = self.extract_from_batch(batch)
        dp = self.get_logits(users, items)
        loss = self.criterion(dp, labels)
        self.log("test_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.cfg['lr'])
        return optimizer


### Build Lightning DataModule

This class holds all the logic for processing and loading the dataset.

This code comes from [Building the PyTorch Lightning Modules Notebook](https://community.cloud.databricks.com/?o=4375727927284138#notebook/4209209923978200/command/4209209923978208)

We pass this Module the Petastorm Spark DF converters for train/test/val splits.

These converters are obtained by passing our spark DF to:

```petastorm.spark.make_spark_converter(processed_spark_df)```



**[NOTE]**

The value of parameter `num_epochs` used in `make_torch_dataloader` function is deliberately set it to `None` (it is also a default value) in order to generate an infinite number of data batches to avoid handling the last, likely incomplete, batch. This is especially important for distributed training where we need to guarantee that the numbers of data records seen on all workers are identical per step. Given that the length of each data shard may not be identical, setting `num_epochs` to any specific number would fail to meet the guarantee and can result in an error. Even though this may not be really important for training on a single device, it determines the way we control epochs (training will run forever on infinite dataset which means there would be only 1 epoch if other means of controlling the epoch duration are not used), so we decided to introduce it here from the beginning.

Setting the value of `num_epochs=None` is also important for the validation process. At the time this notebook was developed, Pytorch Lightning Trainer will run a sanity validation check prior to any training, unless instructed otherwise (i.e. `num_sanity_val_steps` is set to `0`). That sanity check will initialise the validation data loader and will read the `num_sanity_val_steps` batches from it before the first training epoch. Training will not reload the validation dataset for the actual validation phase of the first epoch which will result in error (an attempt to read a second time from data loader which was not completed in the previous attempt). Possible workarounds to avoid this issue is using a finite amount of epochs in `num_epochs` (e.g. `num_epochs=1` as there is no point in evaluating on repeated dataset), which is not ideal as it will likely result in a last batch being smaller than other batches and at the time when this notebook was developed there was no way of setting an equivalent of `drop_last` for the Data Loader created by `make_torch_dataloader`. The only way we found to work around this was to avoid doing any sanity checks (i.e. setting `num_sanity_val_steps=0`, setting it to anything else doesn't work) and using `limit_val_batches` parameter of the Trainer class to avoid the infinitely running validation.

In [0]:

class PySparkDataModule(L.LightningDataModule):
    def __init__(self, train_converter, val_converter, test_converter, batch_size, device_id:int=0, device_count:int=1):
        super().__init__()
    
        self.train_converter = train_converter
        self.val_converter = val_converter
        self.test_converter = test_converter
        self.train_dataloader_context = None
        self.val_dataloader_context = None
        self.test_dataloader_context = None
        self.prepare_data_per_node = False
        self._log_hyperparams = False

        self.device_id = device_id
        self.device_count = device_count
        
        self.batch_size = batch_size
    
    def train_dataloader(self):
        if self.train_dataloader_context:
            self.train_dataloader_context.__exit__(None, None, None)
        self.train_dataloader_context = self.train_converter.make_torch_dataloader(
                                                                                  # transform_spec=self._get_transform_spec(), 
                                                                                  num_epochs=None,
                                                                                  cur_shard=self.device_id, 
                                                                                  shard_count=self.device_count, 
                                                                                  batch_size=self.batch_size*self.device_count)
        return self.train_dataloader_context.__enter__()

    def val_dataloader(self):
        if self.val_dataloader_context:
            self.val_dataloader_context.__exit__(None, None, None)
        self.val_dataloader_context = self.val_converter.make_torch_dataloader(
                                                                              # transform_spec=self._get_transform_spec(), 
                                                                              num_epochs=None, 
                                                                              cur_shard=self.device_id, 
                                                                              shard_count=self.device_count,
                                                                              batch_size=self.batch_size*self.device_count)
        return self.val_dataloader_context.__enter__()
      
      
    def test_dataloader(self):
        if self.test_dataloader_context:
            self.test_dataloader_context.__exit__(None, None, None)
        self.test_dataloader_context = self.test_converter.make_torch_dataloader(
                                                                              # transform_spec=self._get_transform_spec(), 
                                                                              num_epochs=1, 
                                                                              cur_shard=self.device_id, 
                                                                              shard_count=self.device_count,
                                                                              batch_size=self.batch_size*self.device_count)
        return self.test_dataloader_context.__enter__()
    
    
    def teardown(self, stage=None):
        # Close all readers (especially important for distributed training to prevent errors)
        if self.train_dataloader_context is not None:
            self.train_dataloader_context.__exit__(None, None, None)
        if self.val_dataloader_context is not None:
            self.val_dataloader_context.__exit__(None, None, None)
        if self.test_dataloader_context is not None:
            self.test_dataloader_context.__exit__(None, None, None)
    
    # def preprocess(self, img):
    
    #     image = Image.open(io.BytesIO(img))
    #     transform = transforms.Compose([
    #       transforms.Resize(256),
    #       transforms.CenterCrop(224),
    #       transforms.ToTensor(),
    #       transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    #     ])
    
    # return transform(image)
  
    
    # def _transform_rows(self, batch):
    #     # To keep things simple, use the same transformation both for training and validation
    #     batch["features"] = batch["content"].map(lambda x: self.preprocess(x).numpy())
    #     batch = batch.drop(labels=["content"], axis=1)
    #     return batch

    
    # def _get_transform_spec(self):
    #     return TransformSpec(self._transform_rows, 
    #                          edit_fields=[("features", np.float32, (3, 224, 224), False)], 
    #                          selected_fields=["features", "label"])



### Training Function
Function to pass to TorchDistributor

In [0]:
def main_training_loop(num_tasks, num_proc_per_task):
    """

    Main train and test loop

    """
    # add imports inside pl_train for pickling to work
    from torch import optim, nn, utils, Tensor
    import lightning as L
    import os
    import torch
    from petastorm.spark import SparkDatasetConverter, make_spark_converter
    ## Petastorm requires an intermediate cache directory in order to store processed results
    # Set a cache directory on DBFS FUSE for intermediate data.
    CACHE_DIR = "file:///dbfs/tmp/petastorm/cache"
    spark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, CACHE_DIR)
#     import mlflow

    ############################
    ##### Setting up MLflow ####
    # We need to do this so that different processes that will be able to find mlflow
#     os.environ['DATABRICKS_HOST'] = db_host
#     os.environ['DATABRICKS_TOKEN'] = db_token

    # NCCL P2P can cause issues with incorrect peer settings, so let's turn this off to scale for now
    os.environ['NCCL_P2P_DISABLE'] = '1'

    MAX_EPOCH_COUNT = 2
    batch_size = 32
    STEPS_PER_EPOCH = 200

#     mlf_logger = pl.loggers.MLFlowLogger(experiment_name=experiment_path)

    # init the Lightning Model
    m = TwoTower(n_users, n_items, cfg, embedding_size=32)

    # init the Lightning DataModule
    train_converter = make_spark_converter(train_df)
    val_converter = make_spark_converter(val_df)
    test_converter = make_spark_converter(test_df)
    datamodule = PySparkDataModule(train_converter, val_converter, test_converter, batch_size=batch_size)
    print(f"Data Modules created...")

    # train the model
    if num_tasks == 1 and num_proc_per_task == 1:
        kwargs = {}
    else:
        kwargs = {"strategy" : "ddp"}
    
    '''
        couple of peculiar things need to be done - explained here:
        https://community.cloud.databricks.com/?o=4375727927284138#notebook/4209209923978200/command/4209209923978205
    '''
    from lightning.pytorch.loggers import CSVLogger
    import time
    version = time.strftime("%Y-%m-%d__%H-%M")
    csv_logger = CSVLogger(
      '/dbfs/mnt/my-destiny-ebook-uploads-duq/recommendations/20230411000000/two-tower/logs', 
      name='lightning_logs', version=version
      )


    trainer = L.Trainer(
      accelerator='auto',
      devices=num_proc_per_task,
      num_nodes=num_tasks,
#       gpus=gpus,
#       logger=mlf_logger,
      logger=csv_logger,
      max_epochs=MAX_EPOCH_COUNT,
      limit_train_batches=STEPS_PER_EPOCH,  # this is the way to end the epoch
      log_every_n_steps=1,
      val_check_interval=STEPS_PER_EPOCH,  # this value must be the same as `limit_train_batches`
      num_sanity_val_steps=0,  # this must be zero to prevent a Petastorm error about Data Loader not being read completely
      limit_val_batches=1,  # any value would work here but there is point in validating on repeated set of data
      reload_dataloaders_every_n_epochs=1,  # need to set this to 1
#       callbacks=callbacks,
      default_root_dir=log_path,
      **kwargs
    )

    print(f"Fitting...")
    trainer.fit(model=m, datamodule=datamodule)

    print(f"Testing...")
    trainer.test(
      model=m, 
      dataloaders=datamodule.test_dataloader(),
      # datamodule=datamodule
      )

    return {
      'model': m, 
      'best_model_checkpoint': trainer.checkpoint_callback.best_model_path,
      'trainer': trainer,
      'datamodule': datamodule,
      }
    

### Train the model locally
Note that nnodes = 1 and nproc_per_node = 1.

In [0]:
NUM_TASKS = 1
NUM_PROC_PER_TASK = 1
 
results_dict = main_training_loop(NUM_TASKS, NUM_PROC_PER_TASK)

  self._filesystem = pyarrow.localfs
Converting floating-point columns to float32
The median size 11642874 B (< 50 MB) of the parquet files is too small. Total size: 93149286 B. Increase the median file size by calling df.repartition(n) or df.coalesce(n), which might help improve the performance. Parquet files: file:/dbfs/tmp/petastorm/cache/20230503160456-appid-app-20230503140425-0000-54265627-61a8-4e9a-9ba2-1c50561d225e/part-00003-tid-8165197909285845295-02d73c27-decf-4f39-86b7-405d54d9865e-756-1.c000.parquet, ...
Converting floating-point columns to float32
The median size 5058187 B (< 50 MB) of the parquet files is too small. Total size: 40462545 B. Increase the median file size by calling df.repartition(n) or df.coalesce(n), which might help improve the performance. Parquet files: file:/dbfs/tmp/petastorm/cache/20230503160555-appid-app-20230503140425-0000-151b5e19-4fe2-4c92-8401-f68ed10d1d3d/part-00003-tid-7153029547148232378-167dbac3-e30f-4650-ba70-bb215d3b15aa-774-1.c000.parquet

Data Modules created...


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_val_batches=1)` was configured so 1 batch will be used.


Fitting...



  | Name        | Type       | Params
-------------------------------------------
0 | user_emb    | Embedding  | 320 K 
1 | item_emb    | Embedding  | 320 K 
2 | item_layers | Sequential | 1.1 K 
3 | user_layers | Sequential | 1.1 K 
4 | sigmoid     | Sigmoid    | 0     
5 | criterion   | BCELoss    | 0     
-------------------------------------------
642 K     Trainable params
0         Non-trainable params
642 K     Total params
2.569     Total estimated model params size (MB)
  self._filesystem = pyarrow.localfs


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


Testing...


Testing: 0it [00:00, ?it/s]

### Single node multi-GPU setup
For the distributor API, you want to set num_processes to the total amount of GPUs that you plan on using. For single node multi-gpu, this is limited by the number of GPUs available on the driver node.

As mentioned before, single node multi-gpu (with NUM_PROC GPUs) setup involves setting trainer = pl.Trainer(accelerator='gpu', devices=NUM_PROC, num_nodes=1, **kwargs)

In [0]:
from pyspark.ml.torch.distributor import TorchDistributor

In [0]:
NUM_TASKS = 1
NUM_PROC_PER_TASK = NUM_GPUS_PER_WORKER
NUM_PROC = NUM_TASKS * NUM_PROC_PER_TASK
 
results_dict = TorchDistributor(num_processes=NUM_PROC, local_mode=True, use_gpu=USE_GPU).run(main_training_loop, NUM_TASKS, NUM_PROC_PER_TASK) 

[0;31m---------------------------------------------------------------------------[0m
[0;31mNameError[0m                                 Traceback (most recent call last)
File [0;32m<command-3500542499642573>:5[0m
[1;32m      2[0m NUM_PROC_PER_TASK [38;5;241m=[39m NUM_GPUS_PER_WORKER
[1;32m      3[0m NUM_PROC [38;5;241m=[39m NUM_TASKS [38;5;241m*[39m NUM_PROC_PER_TASK
[0;32m----> 5[0m (model, ckpt_path) [38;5;241m=[39m TorchDistributor(num_processes[38;5;241m=[39mNUM_PROC, local_mode[38;5;241m=[39m[38;5;28;01mTrue[39;00m, use_gpu[38;5;241m=[39mUSE_GPU)[38;5;241m.[39mrun(main_training_loop, NUM_TASKS, NUM_PROC_PER_TASK)

[0;31mNameError[0m: name 'TorchDistributor' is not defined