This module allows you to train RNet with your custom dataset.
To train RNet with your custom dataset, simply change the images inside the data/train and data/val folders.

In [None]:
# Install required libraries
!pip install torch
!pip install torchvision
!pip install numpy
!pip install tqdm
!pip install albumentations==0.4.6
!pip install natsort
!pip install opencv-python
import platform
if platform.python_version() >= "3.10":
    import glob
else:
    %pip install glob2
    import glob2 as glob

In [2]:
# Importing required libraries

import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from utils.model import RNet
from utils.utils import (
    load_checkpoint,
    save_checkpoint,
    get_loaders,
    get_test_loader,
    check_accuracy,
    save_predictions_as_imgs,
)
from PIL import Image
import numpy as np
import glob
import cv2
import torchvision
import re
from torchvision import transforms
from natsort import natsorted

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Hyperparameters, you can change learning_rate, batch size, num_epochs, num_workers, image height and width, pin_memory as you wish
# If you are going to test RNet with pre-trained model, set LOAD_MODEL = True
# If you are going to train RNet by yourself, set LOAD_MODEL = False
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 256
NUM_EPOCHS = 5
NUM_WORKERS = 4
IMAGE_HEIGHT = 80  # 720 originally
IMAGE_WIDTH = 120  # 1280 originally
PIN_MEMORY = True
LOAD_MODEL = True
TRAIN_IMG_DIR = "data/train"
TRAIN_MASK_DIR = "data/val"
VAL_IMG_DIR = "data/train"
VAL_MASK_DIR = "data/val"

In [None]:
# Train function
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(DEVICE)
        targets = targets.float().unsqueeze(1).to(DEVICE)

        # Forward
        with torch.cuda.amp.autocast():
          predictions = model(data)
          loss = loss_fn(predictions, targets)

        # Backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Update tqdm loop
        loop.set_postfix(loss=loss.item())

# Main function for starting the training
# Training will start only if you set LOAD_MODEL = False
def main():
    train_transform = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Rotate(limit=35, p=1.0),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.1),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    val_transforms = A.Compose(
        [
            A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
            A.Normalize(
                mean=[0.0, 0.0, 0.0],
                std=[1.0, 1.0, 1.0],
                max_pixel_value=255.0,
            ),
            ToTensorV2(),
        ],
    )

    model = RNet(in_channels=3, out_channels=1).to(DEVICE)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        val_transforms,
        NUM_WORKERS,
        PIN_MEMORY,
    )

    if LOAD_MODEL:
        load_checkpoint(torch.load("model/checkpoint.pth"), model)
    else:
      scaler = torch.cuda.amp.GradScaler()

      for epoch in range(NUM_EPOCHS):
          print(f"Epoch: {epoch+1}/{NUM_EPOCHS}")
          train_fn(train_loader, model, optimizer, loss_fn, scaler)

          # Save model
          checkpoint = {
              "state_dict": model.state_dict(),
              "optimizer": optimizer.state_dict()
          }
          save_checkpoint(model)

          # Check accuracy
          if epoch == 0:
            pass
          else:
            check_accuracy(val_loader, model, device=DEVICE)

In [None]:
# Start training
if __name__ == "__main__":
    main()