In [1]:
import numpy as np
import pandas as pd
from dataclasses import dataclass
import os
import torch
import torchvision.models
from torch import nn
from torch.utils.data import Dataset as TorchDataset, DataLoader
import torch.optim as TorchOptimizers
import torchvision.transforms.v2 as T
from torchvision.models import ResNet50_Weights
from torchvision.models.segmentation import deeplabv3_resnet50
from torchinfo import summary as torch_summary
import matplotlib.pyplot as plt
from PIL import Image, ImageFile
import math
import time
import wandb
from sklearn.model_selection import train_test_split
import json
from typing import Callable

In [2]:
@dataclass
class Config:
    train_csv_filepath: str
    test_csv_filepath: str
    submission_filepath: str
    images_root_folder: str
    image_masks_root_folder: str
    training_output_folder: str
    saved_weights_filepath: str
    device: str

    def __post_init__(self):
        """ For configuration variables that are shared across environments """
        self.num_classes = 12
        self.batch_size = 32
        self.starting_learning_rate = 4e-4
        self.max_epochs = 100
        self.patience = 20
        self.seed = 1234
        self.num_workers = 4
        self.pin_memory = self.num_workers > 0 and self.device == 'cuda'

    # noinspection PyAttributeOutsideInit
    def init(self, training):
        """ Adjust configuration setup for training vs inference """
        self.training = training

        if self.training:
            os.makedirs(self.training_output_folder, exist_ok=True)

        self.imagenet_mean_cpu_tensor = torch.tensor(imagenet_mean_array)
        self.imagenet_std_cpu_tensor = torch.tensor(imagenet_std_array)
        self.channelwise_imagenet_mean_cpu_tensor = self.imagenet_mean_cpu_tensor.view(3, 1, 1)
        self.channelwise_imagenet_std_cpu_tensor = self.imagenet_std_cpu_tensor.view(3, 1, 1)
        self.imagenet_mean_gpu_tensor = gpu_tensor(imagenet_mean_array)
        self.imagenet_std_gpu_tensor = gpu_tensor(imagenet_std_array)
        self.channelwise_imagenet_mean_gpu_tensor = self.imagenet_mean_gpu_tensor.view(3, 1, 1)
        self.channelwise_imagenet_std_gpu_tensor = self.imagenet_std_gpu_tensor.view(3, 1, 1)

        self.image_transforms = T.Compose([
            T.ToImage(),
            T.ToDtype(torch.float32, scale=True),
            T.Normalize(self.imagenet_mean_cpu_tensor, self.imagenet_std_cpu_tensor),
        ])

config: Config = None
""" Set to environment-relevant config before training/inference """;

In [3]:
local_config = Config(
    train_csv_filepath='data/train.csv',
    test_csv_filepath='data/test.csv',
    submission_filepath='data/submission.csv',
    images_root_folder='data/images/',
    image_masks_root_folder='data/masks/',
    training_output_folder='data_gen/training_output/',
    saved_weights_filepath='data_gen/training_output/model_weights.pth',
    device='cpu',
)
kaggle_config = Config(
    train_csv_filepath='/kaggle/input/opencv-pytorch-segmentation-project-round2/train.csv',
    test_csv_filepath='/kaggle/input/opencv-pytorch-segmentation-project-round2/test.csv',
    images_root_folder='/kaggle/input/opencv-pytorch-segmentation-project-round2/imgs/imgs/',
    image_masks_root_folder='/kaggle/input/opencv-pytorch-segmentation-project-round2/masks/masks/',
    submission_filepath='/kaggle/working/submission.csv',
    training_output_folder='/kaggle/working/training_output/',
    saved_weights_filepath='/kaggle/working/training_output/model_weights.pth',
    device='cuda',
)

In [4]:
imagenet_mean_array = np.array([0.485, 0.456, 0.406], dtype=np.float32)
imagenet_std_array = np.array([0.229, 0.224, 0.225], dtype=np.float32)

def gpu_tensor(numpy_array):
    return torch.tensor(numpy_array, device=config.device)

def visualize_image(image_tensor):
    """ Input tensor should be on gpu """
    image = denormalize(image_tensor, config.channelwise_imagenet_mean_gpu_tensor, config.channelwise_imagenet_std_gpu_tensor)
    image = torch.clamp(image, 0, 1)
    image = image.permute(1, 2, 0).cpu().numpy()
    image = (image * 255).astype('uint8')
    plt.imshow(image)
    plt.axis('off')
    plt.show()

def visualize_mask(mask_tensor):
    """ Input tensor should be on gpu """
    mask = mask_tensor.permute(1, 2, 0).cpu().numpy()
    mask = (mask * (255 / config.num_classes)).astype('uint8')
    plt.imshow(mask)
    plt.axis('off')
    plt.show()

def normalize(tensor, mean, std):
    return (tensor - mean) / std

def denormalize(tensor, mean, std):
    return tensor * std + mean

def load_pil_image_from_id(image_id) -> ImageFile.ImageFile:
    return Image.open(config.images_root_folder + image_id + '.jpg')

In [5]:
@dataclass
class ImageSegmentationDataset(TorchDataset):
    image_ids: np.ndarray
    image_transforms: Callable

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image = Image.open(f'{config.images_root_folder}{image_id}.jpg')
        mask = Image.open(f'{config.image_masks_root_folder}{image_id}.png')
        return self.image_transforms(image), self.image_transforms(mask)

In [6]:
# test_model = deeplabv3_resnet50(weights=DeepLabV3_ResNet50_Weights.DEFAULT)
# torch_summary(test_model)

In [7]:
# test_model

In [8]:
# aux_classifier_children = test_model.aux_classifier.children()
# for layer in aux_classifier_children:
#     print(layer)

In [9]:
def create_deeplab_v3_model() -> nn.Module:
    weights_backbone = ResNet50_Weights.DEFAULT if config.training else None
    model = deeplabv3_resnet50(
        weights=None,
        weights_backbone=weights_backbone,
        num_classes=config.num_classes,
        aux_loss=True,
    )

    for param in model.backbone.parameters():
        param.requires_grad = False

    for param in model.classifier.parameters():
        param.requires_grad = True
    for param in model.aux_classifier.parameters():
        param.requires_grad = True

    model.to(config.device)

    return model

In [10]:
def dice_coefficient_score(pred_masks, true_masks):
    intersection = (pred_masks * true_masks).sum(dim=(1, 2))
    union = pred_masks.sum(dim=(1, 2)) + true_masks.sum(dim=(1, 2))
    return (2 * intersection) / union

def train_one_epoch(start_time, model, loader, optimizer, loss_function):
    model.train()
    running_loss = 0.0

    all_pred_masks = []
    all_true_masks = []

    num_batches = math.ceil(len(loader.dataset) / config.batch_size)
    for batch_number, (x, y) in enumerate(loader):
        print(f't={time.time() - start_time:.2f}: Loading training batch {batch_number + 1}/{num_batches}')

        x = x.to(config.device, non_blocking=True)
        y = y.to(config.device, non_blocking=True)

        if batch_number == 0:
            allocated = torch.cuda.memory_allocated(config.device) / 1024**3
            reserved = torch.cuda.memory_reserved(config.device) / 1024**3
            print(f'Memory allocated={allocated:.2f} GiB, reserved={reserved:.2f} GiB')
            print(f'First image:')
            visualize_image(x[0])
            print(f'First mask:')
            visualize_mask(y[0])

        optimizer.zero_grad()
        preds = model(x)
        loss = loss_function(preds, y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * x.size(0)

        all_pred_masks.append(preds.detach().cpu())
        all_true_masks.append(y.detach().cpu())

    epoch_loss = running_loss / len(loader.dataset)

    all_pred_masks = torch.cat(all_pred_masks, dim=0).numpy()
    all_true_masks = torch.cat(all_true_masks, dim=0).numpy()

    epoch_score = dice_coefficient_score(all_pred_masks, all_true_masks)

    return epoch_loss, epoch_score

@torch.no_grad()
def validate_one_epoch(start_time, model, loader, loss_function):
    model.eval()
    running_loss = 0.0

    all_pred_masks = []
    all_true_masks = []
    num_batches = math.ceil(len(loader.dataset) / config.batch_size)
    for batch_number, (x, y) in enumerate(loader):
        print(f't={time.time() - start_time:.2f}: Loading validation batch {batch_number + 1}/{num_batches}')

        x = x.to(config.device, non_blocking=True)
        y = y.to(config.device, non_blocking=True)

        if batch_number == 0:
            print('First image:')
            visualize_image(x[0])
            print('First mask:')
            visualize_mask(y[0])

        preds = model(x)
        loss = loss_function(preds, y)

        running_loss += loss.item() * x.size(0)

        all_pred_masks.append(preds.detach().cpu())
        all_true_masks.append(y.detach().cpu())

    epoch_loss = running_loss / len(loader.dataset)

    all_pred_masks = torch.cat(all_pred_masks, dim=0).numpy()
    all_true_masks = torch.cat(all_true_masks, dim=0).numpy()

    epoch_score = dice_coefficient_score(all_pred_masks, all_true_masks)

    return epoch_loss, epoch_score

In [11]:
def train():
    config.init(training=True)

    start_time = time.time()
    print('t=0: Starting data prep and model loading')

    run = wandb.init(
        project='drone_image_segmentation',
        name=f'run={int(start_time)}',
        config={
            'batch_size': config.batch_size,
            'learning_rate': config.starting_learning_rate,
            'max_epochs': config.max_epochs,
            'seed': config.seed,
            'model': 'deeplabv3_resnet50',
            'optimizer': 'Adam',
        },
    )

    train_df = pd.read_csv(config.train_csv_filepath)
    train_array = train_df['ImageID'].to_numpy()
    train_ids, val_ids = train_test_split(train_array, test_size=0.2, random_state=config.seed)

    model = create_deeplab_v3_model()

    wandb.watch(model, log='gradients', log_freq=100)

    train_dataset = ImageSegmentationDataset(train_ids, config.image_transforms)
    val_dataset = ImageSegmentationDataset(val_ids, config.image_transforms)

    def loader(ds, shuffle):
        return DataLoader(ds, shuffle=shuffle, batch_size=config.batch_size, num_workers=config.num_workers, pin_memory=config.pin_memory)

    train_loader = loader(train_dataset, shuffle=True)
    val_loader = loader(val_dataset, shuffle=False)

    loss_function = nn.BCEWithLogitsLoss()
    optimizer = TorchOptimizers.Adam(model.parameters(), lr=config.starting_learning_rate)

    best_val_loss = float('inf')
    best_val_loss_epoch = -1
    best_val_score = float('-inf')
    best_val_score_epoch = -1

    history = dict(train_loss=[], val_loss=[], train_score=[], val_score=[], best_val_score_epoch=dict(), best_val_loss_epoch=dict())

    training_start_time = time.time()
    print(f't={training_start_time - start_time:.2f}: Starting training')
    torch.manual_seed(config.seed)

    best_state_dict = None
    best_loss_state_dict = None
    epochs_since_best = 0

    for epoch in range(1, config.max_epochs + 1):
        epoch_start_time = time.time()
        print(f't={epoch_start_time - start_time:.2f}: Starting epoch {epoch}')
        train_loss, train_score = train_one_epoch(start_time, model, train_loader, optimizer, loss_function)
        val_loss, val_score = validate_one_epoch(start_time, model, val_loader, loss_function)

        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_score'].append(train_score)
        history['val_score'].append(val_score)

        print(f'================ Epoch {epoch:03d} stats ==================')
        print(f'train_loss: {train_loss:.4f}  val_loss: {val_loss:.4f}')
        print(f'train_score: {train_score:.4f}  val_score: {val_score:.4f}')
        print('===================================================')

        wandb.log(
            {
                'epoch': epoch,
                'train_loss': train_loss,
                'val_loss': val_loss,
                'train_score': train_score,
                'val_score': val_score,
            }
        )

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_val_loss_epoch = epoch
            best_loss_state_dict = {k: v.cpu().clone() for k, v in model.state_dict().items()}

        if val_score > best_val_score:
            best_val_score = val_score
            best_val_score_epoch = epoch
            epochs_since_best = 0
            best_state_dict = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        else:
            epochs_since_best += 1
            if epochs_since_best >= config.patience:
                break

    history['best_val_score_epoch']['epoch'] = best_val_score_epoch
    history['best_val_score_epoch']['val_score'] = best_val_score
    history['best_val_loss_epoch']['epoch'] = best_val_loss_epoch
    history['best_val_loss_epoch']['val_loss'] = best_val_loss

    print()
    print('==================== Results ======================')
    print(f'Best val score epoch: {best_val_score_epoch}')
    print(f'Best val score: {best_val_score:.4f}')
    print(f'Best val loss epoch: {best_val_loss_epoch}')
    print(f'Best val loss: {best_val_loss:.2f}')
    print('===================================================')
    print()

    wandb.run.summary['best_val_score'] = best_val_score
    wandb.run.summary['best_val_score_epoch'] = best_val_score_epoch
    wandb.run.summary['best_val_loss'] = best_val_loss
    wandb.run.summary['best_val_loss_epoch'] = best_val_loss_epoch

    train_score = history['train_score']
    val_score = history['val_score']
    epochs = list(range(1, len(train_score) + 1))

    plt.figure(figsize=(8, 5))

    plt.plot(epochs, train_score, label='train_score', marker='o')
    plt.plot(epochs, val_score, label='val_score', marker='o')

    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Train vs Validation Score per Epoch')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()

    wandb.log({"score_curve": wandb.Image(plt.gcf())})

    plt.show()

    model.load_state_dict(best_state_dict)
    torch.save(model.state_dict(), config.training_output_folder + 'best_model_weights.pth')

    model.load_state_dict(best_loss_state_dict)
    torch.save(model.state_dict(), config.training_output_folder + 'best_loss_model_weights.pth')

    with open(config.training_output_folder + 'history.json', 'w') as json_file:
        json.dump(history, json_file, indent=4)

    wandb.save(config.training_output_folder + 'best_model_weights.pth')
    wandb.save(config.training_output_folder + 'best_loss_model_weights.pth')
    wandb.save(config.training_output_folder + 'history.json')

    wandb.finish()

In [None]:
# from kaggle_secrets import UserSecretsClient
# user_secrets = UserSecretsClient()
# wandb_key = user_secrets.get_secret("wandb_key")
# !wandb login $wandb_key

In [12]:
config = local_config
train()