In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from torch.utils.tensorboard import SummaryWriter

from typing import Tuple, Optional
from abc import ABC, abstractmethod
import io, datetime, hashlib
import gc
import warnings

from luna16_dsets import ClassifierDataset, SegmenterDataset
from luna16_model import Classifier, Segmenter
from luna16_util import *

In [None]:
# Better way to calculate the ground truth masks
# Experiment with different loss
# Segmenter saves the last epoch, instead of best score [bug]
# Memory is overflowed first time the segmenter is kicked off for training [bug]

In [None]:
# Template Pattern Implementation
class BaseTrainer(ABC):
    """Base trainer the LUNA16 lung nodule detection model."""
    def __init__(self, config):
        warnings.filterwarnings("ignore")
        self.config = config
        self._setup()

    def _init_cache(self):
        pass
        
    @abstractmethod
    def _init_criterion(self):
        pass
        
    @abstractmethod
    def _pre_epoch(self, epoch):
        pass
        
    @abstractmethod
    def _post_epoch(self, data, epoch):
        pass
        
    @abstractmethod
    def _process_batch(self):
        pass
        
    def _setup(self):
        self.device = torch.device('cuda' 
                                   if torch.cuda.is_available() 
                                   else 'cpu')
        
        self.batch_size = (torch.cuda.device_count() * self.config.batch_size 
                           if str(self.device) == 'cuda' 
                           else self.config.batch_size)
        
        self._init_cache()

        self.train_loader = DataLoader(
            self.Dataset(mode='train', config=self.config), 
            batch_size=self.batch_size, 
            num_workers=self.config.num_workers, 
            pin_memory=True, 
            shuffle=True
        )
                
        self.val_loader = DataLoader(
            self.Dataset(mode='val', config=self.config), 
            batch_size=self.batch_size, 
            num_workers=self.config.num_workers, 
            pin_memory=True
        )

        model = self.Model(batch_norm=self.config.batch_norm)
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        self.model = model.to(self.device)
        
        self.optimizer = optim.Adam(self.model.parameters())
        self.scaler = GradScaler()
        self.criterion = self._init_criterion()

    def train(self):
        """Main training loop."""
        print("Training started")

        timestamp = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M')
        self.writer = SummaryWriter(f"runs/{self.config.model_type}_{timestamp}")
        
        for epoch in range(1, self.config.epochs + 1):
            # Training phase
            self.model.train()
            loss_acc = 0
            samples_acc = 0
            start_time = datetime.datetime.now().replace(microsecond=0)
            self._pre_epoch(epoch)
            
            for i, batch in enumerate(self.train_loader):
                loss = self._process_batch(batch, i)
                loss_acc += loss.sum().item()
                samples_acc += len(batch[0])
                log_format = (
                    f"Epoch {epoch}, {datetime.datetime.now().replace(microsecond=0) - start_time} "
                    f"{i+1}/{len(self.train_loader)}, Loss={loss_acc/samples_acc:.4f}")
                print(log_format, end='\r')

            print()
            gc.collect()
            
            # Validation phase
            self.model.eval()
            with torch.no_grad():
                for i, batch in enumerate(self.val_loader):
                    _ = self._process_batch(batch, i, mode='val')
                gc.collect()

            self._post_epoch(batch, epoch)
            
        self.writer.close()

class ClassificationTrainer(BaseTrainer):
    def __init__(self, config):
        self.Dataset = ClassifierDataset
        self.Model = Classifier
        super().__init__(config)

    def _init_cache(self):
        if self.config.cache_in:
            print("Caching the dataset")
            for _ in DataLoader(self.Dataset(mode='cache', config=self.config), 
                                batch_size=self.config.batch_size, 
                                num_workers=self.config.num_workers): 
                pass
                
    def _init_criterion(self):
        return nn.CrossEntropyLoss(reduction='none')

    def _pre_epoch(self, epoch):
        if self.config.enable_balanced_decay:
                self.train_loader.dataset.balanced_decay(epoch)

        self.metrics_train = torch.zeros((self.config.n_metrics, len(self.train_loader.dataset)))
        self.metrics_val = torch.zeros((self.config.n_metrics, len(self.val_loader.dataset)))
        self.best_score = 0

    def _post_epoch(self, data, epoch):
        log_classifier_metrics(self.model, self.metrics_train, epoch, 'Train', self.writer)
        metrics_dict = log_classifier_metrics(self.model, self.metrics_val, epoch, 'Val', self.writer)
        featurelogger = FeatureMapLogger(self.model, self.writer, self.config.visualize)
        featurelogger(data[0], epoch)
        featurelogger.close()
        if metrics_dict['metric/recall'] > self.best_score:
            self.best_score = metrics_dict['metric/recall']
            model = self.model.module if isinstance(self.model, nn.DataParallel) else self.model
            self.state = {
                'hyperparameters': self.config.to_dict(),
                'model': model.state_dict(),
                'current_epoch': epoch,
                'metrics_val': metrics_dict,
                'timestamp': datetime.datetime.now().strftime('%Y_%m_%d_%H_%M')
            }
        if self.config.save_state and self.config.epochs == epoch:
            buffer = io.BytesIO()
            torch.save(self.state, buffer)
            sha1 = hashlib.sha1(buffer.getvalue()).hexdigest()
            with open(f'classifier.{sha1}', 'wb') as f:
                f.write(buffer.getvalue())
    
    def _process_batch(self, batch, i, mode='train'):
        """Process a single batch during training or validation."""
        x, y = batch[0].to(self.device), batch[1].to(self.device)
        if hasattr(self.config, 'augment') and mode == 'train':
            x, _ = augment_candidates_3d(x, y, self.config.augment)
            
        with autocast():
            y_pred, prob = self.model(x)
            loss = self.criterion(y_pred, y)
        
        if mode == 'train':
            self.optimizer.zero_grad()
            self.scaler.scale(loss.mean()).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()

        metrics = torch.zeros((self.config.n_metrics, len(x)))
        start, end = i * len(x), i * len(x) + len(x)
        metrics[0] = loss.detach()
        metrics[1] = prob[:, 1]
        metrics[2] = y
        if mode == 'train':
            self.metrics_train[:,start:end] = metrics
        elif mode == 'val':
            self.metrics_val[:,start:end] = metrics
            
        return loss

class SegmentationTrainer(BaseTrainer):
    def __init__(self, config):
        self.Dataset = SegmenterDataset
        self.Model = Segmenter
        super().__init__(config)

    def _init_cache(self):
        pass
        
    def _init_criterion(self):
        return F1MacroLoss()

    def _pre_epoch(self, epoch):
        self.metrics_train = torch.zeros((self.config.n_metrics, len(self.train_loader.dataset)))
        self.metrics_val = torch.zeros((self.config.n_metrics, len(self.val_loader.dataset)))
        self.best_score = 0

    def _post_epoch(self, data, epoch):
        log_segmenter_metrics(self.model, self.metrics_train, epoch, 'Train', self.writer)
        metrics_dict = log_segmenter_metrics(self.model, self.metrics_val, epoch, 'Val', self.writer)
        f1_macro = log_masked_image(data, self.model, self.writer, epoch, self.config.visualize)
        if f1_macro > self.best_score:
            self.best_score = f1_macro
            model = self.model.module if isinstance(self.model, nn.DataParallel) else self.model
            self.state = {
                'hyperparameters': self.config.to_dict(),
                'model': model.state_dict(),
                'current_epoch': epoch,
                'metrics_val': metrics_dict,
                'timestamp': datetime.datetime.now().strftime('%Y_%m_%d_%H_%M')
            }
        if self.config.save_state and self.config.epochs == epoch:
            buffer = io.BytesIO()
            torch.save(self.state, buffer)
            sha1 = hashlib.sha1(buffer.getvalue()).hexdigest()
            print(f'Epoch {self.state["current_epoch"]} is saved!')
            with open(f'segmenter.{sha1}', 'wb') as f:
                f.write(buffer.getvalue())

    def _process_batch(self, batch, i, mode='train'):
        """Process a single batch during training or validation."""
        x, y = batch[0].to(self.device), batch[1].to(self.device)
        if hasattr(self.config, 'augment') and mode == 'train':
            x, y = augment_candidates_2d(x, y, self.config.augment)
            
        with autocast():
            y_pred = self.model(x)
            loss = self.criterion(y_pred, y)
        
        if mode == 'train':
            self.optimizer.zero_grad()
            self.scaler.scale(loss.mean()).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()

        pos_mask = y == 1
        neg_mask = ~pos_mask
        true_pos = (y_pred.detach() * pos_mask).sum(dim=(2,3))
        true_neg = ((1-y_pred.detach()) * neg_mask).sum(dim=(2,3))
        false_pos = neg_mask.sum(dim=(2,3)) - true_neg
        false_neg = pos_mask.sum(dim=(2,3)) - true_pos
        metrics = torch.zeros((self.config.n_metrics,len(x)))
        start, end = i * len(x), i * len(x) + len(x)
        metrics[0] = loss.detach()
        metrics[1] = true_pos.squeeze()
        metrics[2] = true_neg.squeeze()
        metrics[3] = false_pos.squeeze()
        metrics[4] = false_neg.squeeze()
        if mode == 'train':
            self.metrics_train[:,start:end] = metrics
        elif mode == 'val':
            self.metrics_val[:,start:end] = metrics
        
        return loss

#
# Factory Pattern Implementation
class TrainingApp:
    @staticmethod
    def create_trainer(config):
        trainers = {
            'classification': ClassificationTrainer,
            'segmentation': SegmentationTrainer
        }
        
        trainer = trainers.get(config.model_type)
        return trainer(config)

#
def main():
    """Entry point for training."""
    hyper_parameters = {
        'model_type': 'segmentation',  # 'segmentation' | 'classification'
        'window': 'full_range',
        'save_state': True,
        'normalize': True,
        'batch_norm': False,
        'batch_size': 64,
        'num_workers': 4,
        'cache_in': True,
        'visualize': True,
        'epochs': 100,
        'n_metrics': 5,
        'balanced': 1,
        'enable_balanced_decay': True,
        'augment': {
            'flip': True,
            'offset': 0.1,
            'scale': 0.2,
            'rotate': True,
            'noise': 0.1,
            'mixup': 0.4
        }
    }

    config = Config(hyper_parameters)
    trainer = TrainingApp.create_trainer(config)
    trainer.train()

if __name__ == "__main__":
    main()

In [None]:
!zip -r runs.zip runs/*