In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import pandas as pd
import scipy.ndimage as ndimage
from pathlib import Path
from dataclasses import dataclass
from typing import List, Optional
import datetime
import warnings
import gc
from glob import glob

from luna16_dsets import ClassifierDataset, SegmenterDataset
from luna16_model import Classifier, Segmenter
from luna16_util import show_nodule, Config

In [None]:
class InferenceApp:
    """LUNA16 lung nodule detection inference pipeline."""
    
    def __init__(self, config: Config):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self._setup_models()
        
    def _setup_models(self) -> None:
        """Initialize and load pre-trained models."""
        self.batch_size = (
            torch.cuda.device_count() * self.config.batch_size 
            if torch.cuda.is_available() 
            else self.config.batch_size
        )
        
        mhd_files = Path('/kaggle/input/luna16').glob(f'subset{self.config.subset}/subset*/*.mhd')
        self.seriesuids = [p.stem for p in mhd_files]
        
        self.classifier = self._load_model(
            Classifier(batch_norm=self.config.batch_norm),
            '/kaggle/input/luna16-models/classifier.1946e06a83592641ed139455bd1a77154906bf0c'
        )
        
        self.segmenter = self._load_model(
            Segmenter(batch_norm=self.config.batch_norm),
            '/kaggle/input/luna16-models/segmenter.ace47f0d1a9f7e595a0541132d50b1deac7f40f9'
        )
    
    def _load_model(self, model: nn.Module, checkpoint_path: str) -> nn.Module:
        """Load model from checkpoint and move to device."""
        state = torch.load(checkpoint_path)
        model.load_state_dict(state['model'])
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        return model.to(self.device)
    
    def get_data_loader(self, seriesuid: Optional[str] = None, 
                       classify: bool = False, 
                       data: Optional[pd.DataFrame] = None) -> DataLoader:
        """Create appropriate DataLoader based on task."""
        if classify:
            dataset = ClassifierDataset(
                data=data, 
                mode='val', 
                config=self.config, 
                subset=self.config.subset
            )
        else:
            dataset = SegmenterDataset(
                seriesuid=seriesuid, 
                mode=self.config.mode, 
                config=self.config, 
                subset=self.config.subset
            )
            
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            num_workers=self.config.num_workers,
            pin_memory=True
        )
    
    def process_series(self, seriesuid: str) -> pd.DataFrame:
        """Process a single CT series."""
        # Segmentation
        ds = self.get_data_loader(seriesuid)
        x_full, y_pred = [], []
        
        for x, y, _, _ in ds:
            x, y = x.to(self.device), y.to(self.device)
            y_pred.append(self.segmenter(x).cpu().squeeze())
            x_full.append(x[:,1].cpu())
        
        y_pred = torch.cat(y_pred, dim=0) > 0.5
        x_full = torch.cat(x_full, dim=0)
        y_pred = ndimage.binary_erosion(y_pred, iterations=1)
        
        # Find nodule candidates
        label_array, count = ndimage.label(y_pred)
        centers = ndimage.center_of_mass(
            x_full,
            labels=label_array,
            index=np.arange(1, count + 1)
        )
        
        # Classification
        ct = SegmenterDataset._get_ct(
            seriesuid, 
            self.config.window, 
            self.config.normalize, 
            self.config.subset
        )
        
        candidates = pd.DataFrame(columns=['seriesuid', 'coordX', 'coordY', 'coordZ', 'class', 'diameter_mm'])
        for irc in centers:
            xyz = ct.irc2xyz(irc)
            candidates.loc[len(candidates)] = [seriesuid, xyz[0], xyz[1], xyz[2], -1, -1]
            
        if not candidates.empty:
            ds = self.get_data_loader(classify=True, data=candidates)
            predictions = []
            
            for batch in ds:
                x, y = batch[0].to(self.device), batch[1].to(self.device)
                pred = self.classifier(x)[1][:,1].cpu()
                predictions.append(pred)
                
                if self.config.visualize:
                    show_nodule(batch[0][0].squeeze(1).cpu(), pred[0], batch[3][0].cpu())
            
            candidates['class_p'] = torch.cat(predictions, dim=0).numpy() > 0.5
            
        return candidates
    
    def infer(self) -> pd.DataFrame:
        """Run inference pipeline on all series."""
        warnings.filterwarnings("ignore")
        print("Starting inference...")
        
        timestamp = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M')
        writer = SummaryWriter(f"runs/{self.config.mode}_{timestamp}")
        
        results = []
        self.segmenter.eval()
        self.classifier.eval()
        
        with torch.no_grad():
            for seriesuid in self.seriesuids:
                results.append(self.process_series(seriesuid))
                gc.collect()
                
        writer.close()
        return pd.concat(results, ignore_index=True)

def main():
    """Entry point for inference."""
    hyper_parameters = {
        'mode': 'full_val',  # 'full_val' | 'inference'
        'window': 'full_range',
        'subset': 0,
        'normalize': True,
        'batch_norm': False,
        'batch_size': 8,
        'balanced': 0,
        'num_workers': 4,
        'cache_in': False,
        'visualize': True,
        'n_metrics': 5
    }

    config = Config(hyper_parameters)
    
    app = InferenceApp(config)
    results = app.infer()
    print(results)

if __name__ == "__main__":
    main()