In [76]:
import pandas as pd
import numpy as np
import os
import json

# local libs
import utils
import engine
import mask_rcnn_dataset

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision.io import read_image
from torchvision.ops.boxes import masks_to_boxes
from torchvision import tv_tensors
from torchvision.transforms.v2 import functional as F
from torchvision.models.detection.mask_rcnn import MaskRCNN
from torchvision.transforms import v2 as T
from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.detection.backbone_utils import LastLevelMaxPool
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork

In [77]:
# TORCH_DEVICE = 'mps' # there is currently a bug: https://github.com/pytorch/pytorch/issues/78915
TORCH_DEVICE = 'cpu'

# if torch.cuda.is_available():
#     TORCH_DEVICE = 'cuda'
#     GPU_COUNT = torch.cuda.device_count()
#     print('Device: {}, Number of GPUs: {}'.format(TORCH_DEVICE, GPU_COUNT))
# else:
#     TORCH_DEVICE = 'cpu'

In [79]:
with open('Zoobot-backbone-transfer_config.json', 'r') as f:
    config = json.load(f)

# some checks, create directories if necessary
# model logs
if os.path.exists(config['log_dir']):
    try:
        os.makedirs(config['log_dir'] + 'logs_eval/')
    except FileExistsError:
        pass
    
    try:
        os.makedirs(config['log_dir'] + 'logs_train/')
    except FileExistsError:
        pass
    
    print('OK - Output directory for logs and checkpoints exists.')
else:
    print('WARNING - Output directory for logs and checkpoints DOES NOT exist.')
    os.mkdir(config['log_dir'])
    os.mkdir(config['log_dir'] + 'logs_eval/')
    os.mkdir(config['log_dir'] + 'logs_train/')

    try:
        os.makedirs(config['log_dir'] + 'logs_eval/')
    except FileExistsError:
        pass

    try:
        os.makedirs(config['log_dir'] + 'logs_train/')
    except FileExistsError:
        pass

    print('OK - Output directory for logs and checkpoints created.')

# Pre-trained model checkpoint
if os.path.exists(config['pretrained_ckpt']):
    print('OK - Pre-trained model checkpoint exists.')
else:
    print('ERROR - Pre-trained model checkpoint is MISSING.')

# Table with training data
if os.path.exists(config['data_table']):
    print('OK - Parquet-file with training data exists.')
else:
    print('ERROR - Parquet-file with training data is MISSING.')

# Directory with image data
if os.path.exists(config['image_dir']):
    print('OK - Directory with image data exists.')
else:
    print('ERROR - Directory with image data is MISSING.')

OK - Output directory for logs and checkpoints exists.
OK - Pre-trained model checkpoint exists.
OK - Parquet-file with training data exists.
OK - Directory with image data exists.


In [80]:
def get_transform(train):
    transforms = []

    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
        transforms.append(T.RandomVerticalFlip(0.5))

    transforms.append(T.PILToTensor())
    transforms.append(T.ConvertImageDtype(torch.float))

    return T.Compose(transforms)

In [81]:
def get_dataloader_dict(train_df, val_df, image_dir, batch_size):
    image_datasets = {}

    image_datasets['train'] = mask_rcnn_dataset.MaskGalaxyDataset(
        dataframe=train_df,
        image_dir=image_dir,
        transforms=get_transform(train=True)
    )
    image_datasets['val'] = mask_rcnn_dataset.MaskGalaxyDataset(
        dataframe=val_df,
        image_dir=image_dir,
        transforms=get_transform(train=False)
    )
    
    return {x: torch.utils.data.DataLoader(
        image_datasets[x], 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=0,
        pin_memory=torch.cuda.is_available(),
        collate_fn=utils.collate_fn
    ) for x in ['train', 'val']}

In [82]:
# Mask R-CNN with 3 channels
class Resnet50WithFPN(torch.nn.Module):
    def __init__(self, ckpt):
        super(Resnet50WithFPN, self).__init__()

        # m = torchvision.models.resnet50()
        self.ckpt = ckpt
        m = torchvision.models.resnet50(num_classes=281)
        checkpoint = torch.load(self.ckpt, map_location=torch.device(TORCH_DEVICE))
        checkpoint = {
            key.replace('encoder.', ''): value for key, value in checkpoint['state_dict'].items()
        }
        checkpoint = {
            key.replace('head.1.0.', 'fc.'): value for key, value in checkpoint.items()
        }
        m.load_state_dict(checkpoint)

        # Extract 4 main layers (note: FRCNN needs this particular name mapping for return nodes)
        self.body = create_feature_extractor(
            m,
            return_nodes={f'layer{k}': str(v) for v, k in enumerate([1, 2, 3, 4])}
        )
        
        # Dry run to get number of channels for FPN
        inp = torch.randn(2, 3, 224, 224)
        # inp = inp.to(TORCH_DEVICE)
        
        with torch.no_grad():
            out = self.body(inp)
        in_channels_list = [o.shape[1] for o in out.values()]
        
        # Build FPN
        self.out_channels = 256
        self.fpn = FeaturePyramidNetwork(
            in_channels_list, out_channels=self.out_channels,
            extra_blocks=LastLevelMaxPool())

    def forward(self, x):
        x = self.body(x)
        x = self.fpn(x)
        return x

In [83]:
def get_model(
        ckpt, 
        num_classes=3, 
        trainable_layers=0, 
        image_mean=[0.485, 0.456, 0.406],
        image_std=[0.229, 0.224, 0.225]
    ):
    """
    Creates the model object for Faster R-CNN

    Args:
      ckpt (str): path to checkpoint for the Zoobot backbone
      num_classes (int): number of classes the detector should output, 
        must include a class for the background
      trainable_layers (int): number of blocks of the classification backbone,
        counted from top, that should be made trainable
        e.g. 0 - all blocks fixed, 5 - all blocks and incl. 'backbone.body.conv1' trainable

    Returns:
      Mask R-CNN model

    """

    # Build the model
    model = MaskRCNN(Resnet50WithFPN(ckpt=ckpt), num_classes=num_classes)

    # change the Transform layer to fit for custom mean and stddev of image values
    grcnn = GeneralizedRCNNTransform(
        min_size=800, 
        max_size=1333, 
        image_mean=image_mean, 
        image_std=image_std
    )
    model.transform = grcnn

    # make sure, backbone layers are freezed after creating the model
    for name, parameter in model.named_parameters():
        if name.startswith('backbone.body.'):
            parameter.requires_grad = False
    
    # unfreeze selected layers
    if trainable_layers < 0 or trainable_layers > 5:
        raise ValueError(f"Trainable layers should be in the range [0, 5], got {trainable_layers}")
    
    layers_to_train = [
        'backbone.body.layer4', 
        'backbone.body.layer3', 
        'backbone.body.layer2', 
        'backbone.body.layer1', 
        'backbone.body.conv1'
    ][:trainable_layers]
    
    if trainable_layers == 5:
        layers_to_train.append('backbone.body.bn1')

    for layer in layers_to_train:
        for name, parameter in model.named_parameters():
            if name.startswith(layer):
                parameter.requires_grad_(True)

    return model

In [84]:
# Galaxy catalogue
df = pd.read_parquet(config['data_table'], engine='pyarrow')

In [69]:
# [START Training]
torch.cuda.empty_cache() # empty GPU cache

# load data
dataloader_dict = get_dataloader_dict(
    train_df=df[(df['train_group']=='training')],
    val_df=df[(df['train_group']=='validation')],
    image_dir=config['image_dir'],
    batch_size=config['batch_size']
)

# initialise Tensorboard writer
tb_log_dir = config['log_dir'] + 'logs_train/'
writer = SummaryWriter(log_dir=tb_log_dir)

# get the model
zoobot = get_model(
    ckpt=config['pretrained_ckpt'],
    num_classes=config['num_classes'],
    trainable_layers=config['trainable_layers'],
    image_mean=config['image_mean'],
    image_std=config['image_std'],
)  

# move model to the right device and using all available GPUs
zoobot = nn.DataParallel(zoobot)
zoobot.to(TORCH_DEVICE)

# construct an optimizer
params = [p for p in zoobot.parameters() if p.requires_grad]
# optimizer = torch.optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005)
optimizer = torch.optim.Adam(params, lr=0.0001, weight_decay=0.00005)

# and a learning rate scheduler, comment for Adam
# lr_scheduler = torch.optim.lr_scheduler.StepLR(
#     optimizer,
#     step_size=3,
#     gamma=0.1
# )

for epoch in range(config['epochs']):
    # train for one epoch, printing every 10 iterations
    engine.train_one_epoch(
        zoobot, 
        optimizer, 
        dataloader_dict['train'], 
        device=TORCH_DEVICE, 
        epoch=epoch, 
        print_freq=10,
        scaler=None,
        tb_writer=writer
        # tb_writer=None
    )
    
    # update the learning rate
    # lr_scheduler.step()

    engine.evaluate_loss(
        zoobot, 
        dataloader_dict['val'], 
        device=TORCH_DEVICE, 
        epoch=epoch, 
        tb_writer=writer
        # tb_writer=None
    )

    # evaluate on the test dataset
    engine.evaluate(
        zoobot, 
        dataloader_dict['val'], 
        device=TORCH_DEVICE,
        epoch=epoch, 
        tb_writer=writer
        # tb_writer=None
    )

    # save model for each epoch
    model_save_path = config['log_dir'] + 'MaskRCNN_Zoobot_epoch_{}.pth'.format(epoch)
    torch.save(zoobot.state_dict(), model_save_path)

# [END Training]

Epoch: [0]  [  0/245]  eta: 0:57:06  lr: 0.000001  loss: 5.8063 (5.8063)  loss_classifier: 1.1606 (1.1606)  loss_box_reg: 0.0134 (0.0134)  loss_mask: 3.9387 (3.9387)  loss_objectness: 0.6850 (0.6850)  loss_rpn_box_reg: 0.0086 (0.0086)  time: 13.9860  data: 0.0467
Epoch: [0]  [ 10/245]  eta: 0:49:16  lr: 0.000005  loss: 5.4770 (5.2499)  loss_classifier: 1.0153 (0.9660)  loss_box_reg: 0.0092 (0.0085)  loss_mask: 3.8018 (3.5862)  loss_objectness: 0.6812 (0.6801)  loss_rpn_box_reg: 0.0086 (0.0090)  time: 12.5800  data: 0.0557
Epoch: [0]  [ 20/245]  eta: 0:43:51  lr: 0.000009  loss: 3.6999 (4.3101)  loss_classifier: 0.5545 (0.6773)  loss_box_reg: 0.0097 (0.0109)  loss_mask: 2.5409 (2.9423)  loss_objectness: 0.6701 (0.6681)  loss_rpn_box_reg: 0.0097 (0.0115)  time: 11.5830  data: 0.0509
Epoch: [0]  [ 30/245]  eta: 0:40:36  lr: 0.000013  loss: 2.5536 (3.6099)  loss_classifier: 0.2433 (0.5192)  loss_box_reg: 0.0189 (0.0173)  loss_mask: 1.7367 (2.4104)  loss_objectness: 0.6311 (0.6489)  loss_rp