<a href="https://colab.research.google.com/github/wandb/edu/blob/main/model-registry-201/Logging_Models_PyTorch_Lightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
<!--- @wandbcode{modelreg201-pytorch-lightning-colab} -->

<img src="https://wandb.me/logo-im-png" width="400" alt="Weights & Biases" />

<!--- @wandbcode{wandb201-pytorch-lightning-colab} -->

## Model Registry Tutorial
The model registry is a central place to house and organize all the model tasks and their associated artifacts being worked on across an org:
- Model checkpoint management
- Document your models with rich model cards
- Maintain a history of all the models being used/deployed
- Facilitate clean hand-offs and stage management of models
- Tag and organize various model tasks
- Set up automatic notifications when models progress

This tutorial will walkthrough how to track the model development lifecycle for a simple image classification task.


### 🛠️ Install `wandb`


In [None]:
!pip install -q wandb onnx pytorch-lightning

## Login to W&B
- You can explicitly login using `wandb login` or `wandb.login()` (See below)
- Alternatively you can set environment variables. There are several env variables which you can set to change the behavior of W&B logging. The most important are:
    - `WANDB_API_KEY` - find this in your "Settings" section under your profile
    - `WANDB_BASE_URL` - this is the url of the W&B server
- Find your API Token in "Profile" -> "Setttings" in the W&B App

![api_token](https://drive.google.com/uc?export=view&id=1Xn7hnn0rfPu_EW0A_-32oCXqDmpA0-kx)

In [None]:
import wandb
wandb.login()

## Log Data and Model Checkpoints as Artifacts  
W&B Artifacts allows you to track and version arbitrary serialized data (e.g. datasets, model checkpoints, evaluation results). When you create an artifact, you give it a name and a type, and that artifact is forever linked to the experimental system of record. If the underlying data changes, and you log that data asset again, W&B will automatically create new versions through checksummming its contents. W&B Artifacts can be thought of as a lightweight abstraction layer on top of shared unstructured file systems.

### Anatomy of an artifact

The `Artifact` class will correspond to an entry in the W&B Artifact registry.  The artifact has
* a name
* a type
* metadata
* description
* files, directory of files, or references

Example usage:
```
run = wandb.init(project = "my-project")
artifact = wandb.Artifact(name = "my_artifact", type = "data")
artifact.add_file("/path/to/my/file.txt")
run.log_artifact(artifact)
run.finish()
```

In this tutorial, the first thing we will do is download a training dataset and log it as an artifact to be used downstream in the training job.

In [None]:
#@title Enter your W&B project and entity

# FORM VARIABLES
PROJECT_NAME = "model-registry-201" #@param {type:"string"}
ENTITY = "wandb"#@param {type:"string"}

src_url = "https://storage.googleapis.com/wandb_datasets/nature_100.zip"
src_zip = "nature_100.zip"
DATA_SRC = "nature_100"
IMAGES_PER_LABEL = 10
BALANCED_SPLITS = {"train" : 8, "val" : 1, "test": 1}

In [None]:
%%capture
!curl -SL $src_url > $src_zip
!unzip $src_zip

In [None]:
import wandb
import pandas as pd
import os

with wandb.init(project=PROJECT_NAME, entity=ENTITY, job_type='log_datasets') as run:
  img_paths = []
  for root, dirs, files in os.walk('nature_100', topdown=False):
    for name in files:
        img_path = os.path.join(root, name)
        label = img_path.split('/')[1]
        img_paths.append([img_path, label])

  index_df = pd.DataFrame(columns=['image_path', 'label'], data=img_paths)
  index_df.to_csv('index.csv', index=False)

  train_art = wandb.Artifact(name='Nature_100', type='raw_images', description='nature image dataset with 10 classes, 10 images per class')
  train_art.add_dir('nature_100')

  # Also adding a csv indicating the labels of each image
  train_art.add_file('index.csv')
  wandb.log_artifact(train_art)

### Using Artifact names and aliases to easily hand-off and abstract data assets
- By simply referring to the `name:alias` combination of a dataset or model, we can better standardize components of a workflow
- For instance, you can build PyTorch `Dataset`'s or `DataModule`'s which take as arguments W&B Artifact names and aliases to load appropriately

You can now see all the metadata associated with this dataset, the W&B runs consuming it, and the whole lineage of upstream and downstream artifacts!

![api_token](https://drive.google.com/uc?export=view&id=1fEEddXMkabgcgusja0g8zMz8whlP2Y5P)

In [None]:
from torchvision import transforms
import pytorch_lightning as pl
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from skimage import io, transform
from torchvision import transforms, utils, models
import math

class NatureDataset(Dataset):
    def __init__(self,
                 wandb_run,
                 artifact_name_alias="Nature_100:latest",
                 local_target_dir="Nature_100:latest",
                 transform=None):
        self.local_target_dir = local_target_dir
        self.transform = transform

        # Pull down the artifact locally to load it into memory
        art = wandb_run.use_artifact(artifact_name_alias)
        path_at = art.download(root=self.local_target_dir)

        self.ref_df = pd.read_csv(os.path.join(self.local_target_dir, 'index.csv'))
        self.class_names = self.ref_df.iloc[:, 1].unique().tolist()
        self.idx_to_class = {k: v for k, v in enumerate(self.class_names)}
        self.class_to_idx = {v: k for k, v in enumerate(self.class_names)}

    def __len__(self):
        return len(self.ref_df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_path = self.ref_df.iloc[idx, 0]

        image = io.imread(img_path)
        label = self.ref_df.iloc[idx, 1]
        label = torch.tensor(self.class_to_idx[label], dtype=torch.long)

        if self.transform:
            image = self.transform(image)

        return image, label


class NatureDatasetModule(pl.LightningDataModule):
    def __init__(self,
                 wandb_run,
                 artifact_name_alias: str = "Nature_100:latest",
                 local_target_dir: str = "Nature_100:latest",
                 batch_size: int = 16,
                 input_size: int = 224,
                 seed: int = 42):
        super().__init__()
        self.wandb_run = wandb_run
        self.artifact_name_alias = artifact_name_alias
        self.local_target_dir = local_target_dir
        self.batch_size = batch_size
        self.input_size = input_size
        self.seed = seed

    def setup(self, stage=None):
        self.nature_dataset = NatureDataset(wandb_run=self.wandb_run,
                                            artifact_name_alias=self.artifact_name_alias,
                                            local_target_dir=self.local_target_dir,
                                            transform=transforms.Compose([transforms.ToTensor(),
                                                                          transforms.CenterCrop(self.input_size),
                                                                          transforms.Normalize((0.485, 0.456, 0.406),
                                                                                               (0.229, 0.224, 0.225))]))

        nature_length = len(self.nature_dataset)
        train_size = math.floor(0.8 * nature_length)
        val_size = math.floor(0.2 * nature_length)
        self.nature_train, self.nature_val = random_split(self.nature_dataset,
                                                          [train_size, val_size],
                                                          generator=torch.Generator().manual_seed(self.seed))
        return self

    def train_dataloader(self):
        return DataLoader(self.nature_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.nature_val, batch_size=self.batch_size)

    def predict_dataloader(self):
        pass

    def teardown(self, stage: str):
        pass

## Model Training

### Writing the Model Class and Validation Function

In [None]:
import torch
import wandb
from torch.nn import Linear, CrossEntropyLoss, functional as F
from torch.optim import Adam
from torchmetrics.functional import accuracy
from pytorch_lightning import LightningModule
from torchvision import models
import pandas as pd
import os
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

class LogPredictionsCallback(Callback):
    def __init__(self):
      super().__init__()

    def on_validation_epoch_start(self, trainer, pl_module):
      self.batch_dfs = []
      self.image_list = []
      self.val_table = wandb.Table(columns=['image', 'ground_truth', 'prediction'])


    def on_validation_batch_end(
      self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
      """Called when the validation batch ends."""

      # Append validation predictions and ground truth to log in confusion matrix
      x, y = batch
      preds, y = outputs
      self.batch_dfs.append(pd.DataFrame({"Ground Truth": y.numpy(), "Predictions": preds.numpy()}))

      # Add wandb.Image to a table to log at the end of validation
      x = x.numpy().transpose(0, 2, 3, 1)
      for x_i, y_i, y_pred in list(zip(x, y, preds)):
        self.image_list.append(wandb.Image(x_i, caption=f'Ground Truth: {y_i} - Prediction: {y_pred}'))
        self.val_table.add_data(wandb.Image(x_i), y_i, y_pred)


    def on_validation_epoch_end(self, trainer, pl_module):
      # Collect statistics for whole validation set and log
      class_names = trainer.datamodule.nature_dataset.class_names
      val_df = pd.concat(self.batch_dfs)
      wandb.log({"eval/val_table": self.val_table,
                 "eval/images_over_time": self.image_list,
                 "eval/conf_matrix": wandb.plot.confusion_matrix(y_true = val_df["Ground Truth"].tolist(),
                                                                       preds=val_df["Predictions"].tolist(),
                                                                       class_names=class_names)}, step=trainer.global_step)

      del self.batch_dfs
      del self.val_table

def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False


def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0

    if model_name == "squeezenet":
        """ Squeezenet
        """
        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        model_ft.classifier[1] = torch.nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
        model_ft.num_classes = num_classes
        input_size = 224

    else:
        print("Invalid model name, exiting...")
        exit()

    return model_ft, input_size

class NatureLitModule(LightningModule):
    def __init__(self,
                 model_name,
                 num_classes=10,
                 feature_extract=True,
                 lr=0.01):
        '''method used to define our model parameters'''
        super().__init__()

        self.model_name = model_name
        self.num_classes = num_classes
        self.feature_extract = feature_extract
        self.model, self.input_size = initialize_model(model_name=self.model_name,
                                                       num_classes=self.num_classes,
                                                       feature_extract=True)

        # loss
        self.loss = CrossEntropyLoss()

        # optimizer parameters
        self.lr = lr

        # save hyper-parameters to self.hparams (auto-logged by W&B)
        self.save_hyperparameters()

        # Record the gradients of all the layers
        wandb.watch(self.model)

    def forward(self, x):
        '''method used for inference input -> output'''
        x = self.model(x)

        return x

    def training_step(self, batch, batch_idx):
        '''needs to return a loss from a single batch'''
        preds, y, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('train/loss', loss)
        self.log('train/accuracy', acc)

        return loss

    def validation_step(self, batch, batch_idx):
        '''used for logging metrics'''
        preds, y, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('eval/loss', loss)
        self.log('eval/accuracy', acc)

        # Let's return preds to use it in a custom callback
        return preds, y

    def test_step(self, batch, batch_idx):
        '''used for logging metrics'''
        preds, y, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('test/loss', loss)
        self.log('test/accuracy', acc)

    def configure_optimizers(self):
        '''defines model optimizer'''
        return Adam(self.parameters(), lr=self.lr)


    def _get_preds_loss_accuracy(self, batch):
        '''convenience function since train/valid/test steps are similar'''
        x, y = batch
        logits = self(x)
        preds = torch.argmax(logits, dim=1)
        loss = self.loss(logits, y)

        acc = accuracy(preds, y, task="multiclass", num_classes=self.num_classes)
        return preds, y, loss, acc

# Log Checkpoints Automatically with PyTorch Lightning

Use PyTorch Lightning's `ModelCheckpoint` callback to automatically log checkpoints to W&B

```
wandb_logger = WandbLogger(log_model='all', checkpoint_name=f'image-segmentation-{wandb.run.id}')
```

- Log all models if `log_model="all"` or at end of training if `log_model=True`.
- Optionally use the wandb artifacts api to implement your own checkpointing logic using PyTorch Lightning Callbacks

See more details on our PyTorch Lightning integration [here](https://docs.wandb.ai/guides/integrations/lightning).

In [None]:
wandb.init(project=PROJECT_NAME,
            entity=ENTITY,
            job_type='training',
            config={'train_dataset': "Nature_100:latest",
                    "local_dataset_dir": "nature_100",
                    'model_name': 'squeezenet',
                    'lr': 1.0,
                    'gamma': 0.75,
                    'batch_size': 16,
                    'epochs': 5})

wandb.config['input_size'] = 224

wandb_logger = WandbLogger(log_model='all', checkpoint_name=f'image-segmentation-{wandb.run.id}')

checkpoint_callback = ModelCheckpoint(every_n_epochs=1, save_top_k=5, monitor="eval/accuracy")

model = NatureLitModule(model_name=wandb.config['model_name']) # Access hyperparameters downstream to instantiate models/datasets

nature_module = NatureDatasetModule(wandb_run = wandb_logger.experiment,
                                    artifact_name_alias = wandb.config["train_dataset"],
                                    local_target_dir = wandb.config["local_dataset_dir"],
                                    batch_size=wandb.config['batch_size'],
                                    input_size=model.input_size)
nature_module.setup()

trainer = Trainer(logger=wandb_logger,  # W&B integration
                  accelerator="auto",
                  callbacks=[checkpoint_callback],
                  max_epochs=5,
                  log_every_n_steps=5)
trainer.fit(model, datamodule=nature_module)

wandb.finish()

### Manage all your model checkpoints for a project under one roof.

![api_token](https://drive.google.com/uc?export=view&id=1z7nXRgqHTPYjfR1SoP-CkezyxklbAZlM)

### Note: Syncing with W&B Offline
If for some reason, network communication is lost during the course of training, you can always sync progress with `wandb sync`

The W&B sdk caches all logged data in a local directory `wandb` and when you call `wandb sync`, this syncs the your local state with the web app.

## Model Registry
After logging a bunch of checkpoints across multiple runs during experimentation, now comes time to hand-off the best checkpoint to the next stage of the workflow (e.g. testing, deployment).

The Model Registry is a central page that lives above individual W&B projects. It houses **Registered Models**, portfolios that store "links" to the valuable checkpoints living in individual W&B Projects.

The model registry offers a centralized place to house the best checkpoints for all your model tasks. Any `model` artifact you log can be "linked" to a Registered Model.

### Creating **Registered Models** and Linking through the UI
#### 1. Access your team's model registry by going the team page and selecting `Model Registry`

![model registry](https://drive.google.com/uc?export=view&id=1ZtJwBsFWPTm4Sg5w8vHhRpvDSeQPwsKw)

#### 2. Create a new Registered Model.

![model registry](https://drive.google.com/uc?export=view&id=1RuayTZHNE0LJCxt1t0l6-2zjwiV4aDXe)

#### 3. Go to the artifacts tab of the project that holds all your model checkpoints

![model registry](https://drive.google.com/uc?export=view&id=1LfTLrRNpBBPaUb_RmBIE7fWFMG0h3e0E)

#### 4. Click "Link to Registry" for the model artifact version you want.

### Creating Registered Models and Linking through the **API**
You can [link a model via api](https://docs.wandb.ai/guides/models) with `wandb.run.link_artifact` passing in the artifact object, and the name of the **Registered Model**, along with aliases you want to append to it. **Registered Models** are entity (team) scoped in W&B so only members of a team can see and access the **Registered Models** there. You indicate a registered model name via api with `<entity>/model-registry/<registered-model-name>`. If a Registered Model doesn't exist, one will be created automatically.

In [None]:
last_run_id = "rmlp8vlj" #@param
wandb.init(project=PROJECT_NAME, entity=ENTITY, job_type="registering_best_model")
best_model = wandb.use_artifact(f'{ENTITY}/{PROJECT_NAME}/image-segmentation-{last_run_id}:latest')
registered_model_name = "YOLOv8 Image Segmentation" #@param {type: "string"}
wandb.run.link_artifact(best_model, f'{ENTITY}/model-registry/{registered_model_name}', aliases=['staging'])
wandb.finish()

### What is "Linking"?
When you link to the registry, this creates a new version of that Registered Model, which is just a pointer to the artifact version living in that project. There's a reason W&B segregates the versioning of artifacts in a project from the versioning of a Registered Model. The process of linking a model artifact version is equivalent to "bookmarking" that artifact version under a Registered Model task.

Typically during R&D/experimentation, researchers generate 100s, if not 1000s of model checkpoint artifacts, but only one or two of them actually "see the light of day." This process of linking those checkpoints to a separate, versioned registry helps delineate the model development side from the model deployment/consumption side of the workflow. The globally understood version/alias of a model should be unpolluted from all the experimental versions being generated in R&D and thus the versioning of a Registered Model increments according to new "bookmarked" models as opposed to model checkpoint logging.

## Create a Centralized Hub for all your models
- Add a model card, tags, slack notifactions to your Registered Model
- Change aliases to reflect when models move through different phases
- Embed the model registry in reports for model documentation and regression reports. See this report as an [example](https://api.wandb.ai/links/wandb-smle/r82bj9at)
![model registry](https://drive.google.com/uc?export=view&id=1lKPgaw-Ak4WK_91aBMcLvUMJL6pDQpgO)


### Set up Slack Notifications when new models get linked to the registry

![model registry](https://drive.google.com/uc?export=view&id=1RsWCa6maJYD5y34gQ0nwWiKSWUCqcjT9)

## Consuming a Registered Model
You now can consume any registered model via API by referring the corresponding `name:alias`. Model consumers, whether they are engineers, researchers, or CI/CD processes, can go to the model registry as the central hub for all models that should "see the light of day": those that need to go through testing or move to production.

In [None]:
%%wandb -h 600

run = wandb.init(project=PROJECT_NAME, entity=ENTITY, job_type='inference')
artifact = run.use_artifact(f'{ENTITY}/model-registry/Model Registry Tutorial:staging', type='model')
artifact_dir = artifact.download()
wandb.finish()