# Pytorch Lightning:aux targets+weighted loss+thresholds

In this notebook, I try to convert notebook from [vslaykovsky](https://www.kaggle.com/code/vslaykovsky/train-pytorch-aux-targets-weighted-loss-thres) to PytorchLightning, which allow us easily switch between single and multiple GPUs training.

## 1. Imports, Constants, Dependencies

In [1]:
! pip install pytorch-lightning --quiet

[0m

In [2]:
import gc
import os

# import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn
import torch
from PIL import Image
from sklearn.preprocessing import LabelEncoder
import sys
sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master')
from timm import create_model, list_models
from timm.data import create_transform
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

from tqdm import tqdm

import wandb

from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wb_key = user_secrets.get_secret("WANDB_API_KEY")

wandb.login(key=wb_key)

[34m[1mwandb[0m: Currently logged in as: [33mdoanthinhvo[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [3]:
# DEBUG MODE
DEBUG=False


# GLOBAL VARIABLES (Unchanged variables)
TRAIN_IMAGES_PATH = f'/kaggle/input/rsna-cut-off-empty-space-from-images'
TARGET = 'cancer'
CATEGORY_AUX_TARGETS = ['site_id', 'laterality', 'view', 'implant', 'biopsy', 'invasive', 'BIRADS', 'density', 'difficult_negative_case', 'machine_id', 'age']
RSNA_2022_PATH = '../input/rsna-breast-cancer-detection'
TRAIN_IMAGES_PATH = f'/kaggle/input/rsna-cut-off-empty-space-from-images'
MAX_TRAIN_BATCHES = 40000
MAX_EVAL_BATCHES = 400
MODELS_PATH = '/kaggle/input/wandb-models/models'
NUM_WORKERS = 2
PREDICT_MAX_BATCHES = 1e9
N_FOLDS = 5
FOLDS = np.array(os.environ.get('FOLDS', '0,1,2,3,4').split(',')).astype(int)
WANDB_SWEEP_PROJECT = 'rsna-breast-cancer-sweeps'

class CFG:
    ONE_CYCLE = True
    ONE_CYCLE_PCT_START = 0.1
    ADAMW = False
    ADAMW_DECAY = 0.024
    # ONE_CYCLE_MAX_LR = float(os.environ.get('LR', '0.0008'))
    ONE_CYCLE_MAX_LR = float(os.environ.get('LR', '0.0004'))
    EPOCHS = int(os.environ.get('EPOCHS', 3))
    MODEL_TYPE = os.environ.get('MODEL', 'seresnext50_32x4d')
    DROPOUT = float(os.environ.get('DROPOUT', 0.0))
    AUG = os.environ.get('AUG', 'true').lower() == 'true'
    AUX_LOSS_WEIGHT = 94
    POSITIVE_TARGET_WEIGHT=20
#     BATCH_SIZE = 32
    BATCH_SIZE = 16
    AUTO_AUG_M = 10
    AUTO_AUG_N = 2
    TTA = False
    CHECKPOINT_PATH="./checkpoints"
    
WANDB_RUN_NAME = f'plot_lr_{CFG.MODEL_TYPE}_lr{CFG.ONE_CYCLE_MAX_LR}_ep{CFG.EPOCHS}_bs{CFG.BATCH_SIZE}_pw{CFG.POSITIVE_TARGET_WEIGHT}_aux{CFG.AUX_LOSS_WEIGHT}_{"adamw" if CFG.ADAMW else "adam"}_{"aug" if CFG.AUG else "noaug"}_drop{CFG.DROPOUT}'
WANDB_PROJECT = 'RSNA-breast-cancer-v1'
print('run', WANDB_RUN_NAME, 'folds', FOLDS)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
pl.seed_everything(seed=42)


# DEBUG=True
# if DEBUG:
#     FOLDS = np.array(os.environ.get('FOLDS', '0,1,2,3,4').split(',')).astype(int)


run plot_lr_seresnext50_32x4d_lr0.0004_ep3_bs16_pw20_aux94_adam_aug_drop0.0 folds [0 1 2 3 4]


42

In [4]:
# # W&B Logger
# wandb_logger = WandbLogger(
#     project='RSNA-Lightning-torch-converted', 
#     job_type='train', 
# #     config=(CFG.__dict__),
# )
    

## 2.   Loading train/eval/test DF

In [5]:
original_df = pd.read_csv("/kaggle/input/rsna-breast-cancer-detection/train.csv")
df = original_df
# df = original_df.head(10000)
# ========== stratifiedGroupKFold ===========
from sklearn.model_selection import StratifiedGroupKFold

split = StratifiedGroupKFold(N_FOLDS)
for k, (_, test_idx) in enumerate(split.split(df, df.cancer, groups=df.patient_id)):
    df.loc[test_idx, 'split'] = k
df.split = df.split.astype(int)
df.groupby('split').cancer.mean()

split
0    0.021127
1    0.021096
2    0.021205
3    0.021209
4    0.021203
Name: cancer, dtype: float64

In [6]:
df.age.fillna(df.age.mean(), inplace=True)
df['age'] = pd.qcut(df.age, 10, labels=range(10), retbins=False).astype(int)
df

Unnamed: 0,site_id,patient_id,image_id,laterality,view,age,cancer,biopsy,invasive,BIRADS,implant,density,machine_id,difficult_negative_case,split
0,2,10006,462822612,L,CC,5,0,0,0,,0,,29,False,3
1,2,10006,1459541791,L,MLO,5,0,0,0,,0,,29,False,3
2,2,10006,1864590858,R,MLO,5,0,0,0,,0,,29,False,3
3,2,10006,1874946579,R,CC,5,0,0,0,,0,,29,False,3
4,2,10011,220375232,L,CC,3,0,0,0,0.0,0,,21,True,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
54701,1,9973,1729524723,R,MLO,0,0,0,0,1.0,0,C,49,False,2
54702,1,9989,63473691,L,MLO,5,0,0,0,,0,C,216,False,0
54703,1,9989,1078943060,L,CC,5,0,0,0,,0,C,216,False,0
54704,1,9989,398038886,R,MLO,5,0,0,0,0.0,0,C,216,True,0


In [7]:
df[CATEGORY_AUX_TARGETS] = df[CATEGORY_AUX_TARGETS].apply(LabelEncoder().fit_transform)

## 3. Dataset class & transform function

In [8]:
import torchvision

def get_transforms(aug=False):
    """
    # old transforms
    create_transform(
        (1024, 512), 
        mean=0.53, #(0.53, 0.53, 0.53),
        std=0.23, #(0.23, 0.23, 0.23),
        is_training=is_training, 
        auto_augment=f'rand-m{config.AUTO_AUG_M}-n{config.AUTO_AUG_N}'
    )
    """
    def transforms(img):
        img = img.convert('RGB')#.resize((512, 512))
        if aug:
            tfm = [
                torchvision.transforms.RandomHorizontalFlip(0.5),
                torchvision.transforms.RandomRotation(degrees=(-5, 5)), 
                torchvision.transforms.RandomResizedCrop((1024, 512), scale=(0.8, 1), ratio=(0.45, 0.55)) 
            ]
        else:
            tfm = [
                torchvision.transforms.RandomHorizontalFlip(0.5),
                torchvision.transforms.Resize((1024, 512))
            ]
        img = torchvision.transforms.Compose(tfm + [            
            torchvision.transforms.ToTensor(), # chanel, height, width.
            torchvision.transforms.Normalize(mean=0.2179, std=0.0529),
            
        ])(img)
        return img

    return lambda img: transforms(img)

# if DEBUG:
#     tfm = get_transforms(aug=True)
#     img = Image.open(f"{TRAIN_IMAGES_PATH}/10006/1459541791.png")
#     print(img.size)
#     plt.imshow(np.array(img), cmap='gray')
#     plt.show()

#     plt.figure(figsize=(20, 20))
#     for i in range(8):
#         v = tfm(img).permute(1, 2, 0)
#         v -= v.min()
#         v /= v.max()
#         # plt.imshow(v)
#         # break
#         plt.subplot(2, 4, i + 1).imshow(v)
#     plt.tight_layout()

In [9]:
class BreastCancerDataSet(torch.utils.data.Dataset):
    def __init__(self, df, path, transforms=None):
        super().__init__()
        self.df = df
        self.path = path
        self.transforms = transforms

    def __getitem__(self, i):

        path = f'{self.path}/{self.df.iloc[i].patient_id}/{self.df.iloc[i].image_id}.png'
        try:
            img = Image.open(path).convert('RGB')
        except Exception as ex:
            print(path, ex)
            return None

        if self.transforms is not None:
            img = self.transforms(img)

        # If not in test.
        if TARGET in self.df.columns:
            cancer_target = torch.as_tensor(self.df.iloc[i].cancer)
            cat_aux_targets = torch.as_tensor(self.df.iloc[i][CATEGORY_AUX_TARGETS])
            return img, cancer_target, cat_aux_targets

        return img

    def __len__(self):
        return len(self.df)

In [10]:
class BreastCancerDataModule(pl.LightningDataModule):
    def __init__(self, df_train, df_valid):
        super().__init__()
        self.df_train = df_train
        self.df_valid = df_valid

    def setup(self, stage=None):
        self.train_dataset = BreastCancerDataSet(
            self.df_train,
            path=TRAIN_IMAGES_PATH,
            transforms= get_transforms(aug=True),
        )

        self.valid_dataset = BreastCancerDataSet(
            self.df_valid,
            path=TRAIN_IMAGES_PATH,
            transforms= get_transforms(aug=True),
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=CFG.BATCH_SIZE,
            num_workers=NUM_WORKERS,
            shuffle=True,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.valid_dataset,
            batch_size=CFG.BATCH_SIZE,
            num_workers=NUM_WORKERS,
            shuffle=False,
            pin_memory=True,
        )
if DEBUG:
    df_train_demo = df.iloc[:5]
    df_valid_demo = df.iloc[5:10]
    demo_data_module = BreastCancerDataModule(df_train_demo, df_valid_demo)
    demo_data_module.setup()

## 4. Model

In [11]:
class BSDRSNAModule(pl.LightningModule):
    def __init__(self, aux_classes,len_dl_train ,model_type=CFG.MODEL_TYPE, dropout=0.):
        super().__init__()
        self.len_dl_train = len_dl_train

        # back bone
        self.model = create_model(model_type, pretrained=True, num_classes=0, drop_rate=dropout)
        self.backbone_dim = self.model(torch.randn(1, 3, 512, 512)).shape[-1]

        # head for cancer:
        self.nn_cancer = torch.nn.Sequential(
            torch.nn.Linear(self.backbone_dim, 1),
        )

        # heads for aux categories:
        self.nn_aux = torch.nn.ModuleList([
            torch.nn.Linear(self.backbone_dim, n) for n in aux_classes
        ])

        self.automatic_optimization = False

    def forward(self, x):
        x = self.model(x)
        cancer = self.nn_cancer(x).squeeze() # [batch_size]
        aux = []
        for nn in self.nn_aux:
            aux.append(nn(x).squeeze()) # aux is list of [batch_size, n_classes] for n_classes is number of classes for each category.
        return cancer, aux

    def training_step(self, batch, batch_idx):
        y_cancer_pred, aux_pred = self(batch[0])
        cancer_loss = torch.nn.functional.binary_cross_entropy_with_logits(
                        y_cancer_pred,
                        batch[1].to(float),# check
                        pos_weight=torch.tensor([CFG.POSITIVE_TARGET_WEIGHT]).to(DEVICE)
                    ).item()
        aux_loss = torch.mean(torch.stack([torch.nn.functional.cross_entropy(aux_pred[i], batch[2][:, i]) for i in range(batch[2].shape[-1])]))
        loss = cancer_loss + CFG.AUX_LOSS_WEIGHT * aux_loss
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('cancer_loss', cancer_loss, on_epoch=True)
        self.log('aux_loss', aux_loss, on_epoch=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        y_cancer, y_aux = batch[1], batch[2]
        y_cancer_pred, aux_pred = self(batch[0])

        cancer_loss = torch.nn.functional.binary_cross_entropy_with_logits(
                        y_cancer_pred, 
                        y_cancer.to(float),
                        # pos_weight=torch.tensor([CFG.POSITIVE_TARGET_WEIGHT])
                        pos_weight=torch.tensor([CFG.POSITIVE_TARGET_WEIGHT]).to(DEVICE)
                    ).item()
        aux_loss = torch.mean(torch.stack([torch.nn.functional.cross_entropy(aux_pred[i], y_aux[:, i]) for i in range(y_aux.shape[-1])])).item()
        val_loss = cancer_loss + CFG.AUX_LOSS_WEIGHT * aux_loss
        self.log('val_loss', val_loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('val_cancer_loss', cancer_loss, on_epoch=True)
        self.log('val_aux_loss', aux_loss, on_epoch=True)
        return val_loss

    def predict_step(self, batch, batch_idx):
        cancer, aux = self.forward(batch[0])
        sigaux = []
        for a in aux:
            sigaux.append(torch.softmax(a, dim=-1))
        return torch.sigmoid(cancer), sigaux # check
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters())
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=CFG.ONE_CYCLE_MAX_LR, epochs=CFG.EPOCHS,
                                                        steps_per_epoch=self.len_dl_train,
                                                        pct_start=CFG.ONE_CYCLE_PCT_START)
#         lr = scheduler.get_last_lr()[0] if scheduler else CFG.ONE_CYCLE_MAX_LR
#         self.log('lr', lr)
        return [optimizer], [scheduler]
#         return optimizer

#     def on_after_backward(self):
#         global_step = self.global_step
#         if (global_step % 25 == 0) and (global_step > 0):
#             for name, param in self.named_parameters():
#                 self.logger.experiment.add_histogram(name, param, global_step)
#         wandb.watch(self.model, log = 'gradients')
#         lr = scheduler.get_last_lr()[0] if scheduler else config.ONE_CYCLE_MAX_LR
#         self.log('lr', lr)

AUX_TARGET_NCLASSES = df[CATEGORY_AUX_TARGETS].max() + 1
DEBUG = False
if DEBUG:
    modelModule = BSDRSNAModule(model_type=CFG.MODEL_TYPE, aux_classes=AUX_TARGET_NCLASSES, dropout=0., len_dl_train=1)
    df_train_demo = df.iloc[:5]
    df_valid_demo = df.iloc[5:10]
    train_dataset = BreastCancerDataSet(df_train_demo, path=TRAIN_IMAGES_PATH, transforms= get_transforms(aug=True))
    dataloader = DataLoader(train_dataset, batch_size=CFG.BATCH_SIZE, num_workers=2, shuffle=True, pin_memory=True)
    trainer = pl.Trainer()
    predictions = trainer.predict(modelModule, dataloader)


In [12]:
# predictions

## 5. Train

In [None]:
TRAIN = True
if TRAIN:
    for fold in FOLDS:
        name = f"{WANDB_RUN_NAME}-{fold}"
        
        # each name for each fold
        with wandb.init(project=WANDB_PROJECT, name=name, group=WANDB_RUN_NAME) as run:
            gc.collect()
            wandb_logger = WandbLogger()
            train_df = df.query('split != @fold')
            valid_df = df.query('split == @fold')
            

            data_module = BreastCancerDataModule(train_df, valid_df)

            len_dl_train = len(train_df) // CFG.BATCH_SIZE + (1 if len(train_df) % CFG.BATCH_SIZE != 0 else 0)
            # early_stopping_callback = EarlyStopping(monitor='val_loss', patience=2)
            checkpoint_callback = ModelCheckpoint(
                dirpath=CFG.CHECKPOINT_PATH,
                filename= f"fold_{fold}_{CFG.MODEL_TYPE}",
                save_top_k=1,
                verbose=True,
                monitor="val_loss",
                mode="min"
            )
            lr_monitor = LearningRateMonitor(logging_interval='step')
            
            modelModule = BSDRSNAModule(model_type=CFG.MODEL_TYPE, aux_classes=AUX_TARGET_NCLASSES, dropout=0., len_dl_train=len_dl_train)
#             wandb_logger.watch(modelModule, log='gradients')
            trainer = pl.Trainer(
                logger=wandb_logger,
                callbacks=[checkpoint_callback, lr_monitor],
                # callbacks=[checkpoint_callback, early_stopping_callback],
                max_epochs=CFG.EPOCHS,
                accelerator="gpu", 
                devices=1, 
#                 strategy='dp',
                log_every_n_steps=1,            
                precision=16,
                )

            trainer.fit(modelModule, data_module)


  "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse"


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

## 6. WandB sweep

## 7. Cross Validation