In [1]:
from lib.training import train
from lib.data import get_data_chest_x_ray_image
from lib.utils import get_device

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
class ExperimentConfig:
    def __init__(self, name, model_fn, loss_fn, optimizer_fn):
        self.name = name
        self.model_fn = model_fn
        self.loss_fn = loss_fn
        self.optimizer_fn = optimizer_fn

In [None]:
import os
from torch.utils.data import Dataset, DataLoader, Subset
from lib.data import TransformDataset

def prepare_dataloaders(train_dataset : Dataset, val_dataset : Dataset):
    num_workers = max(1,os.cpu_count()-1)
    
    train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=num_workers, pin_memory=True)
    val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=True, num_workers=num_workers, pin_memory=True)
    return train_dataloader, val_dataloader

def get_my_metrics(device, n_classes):
    import torchmetrics
    from lib.metrics import MetricCollection
    
    metrics = MetricCollection(device=device)
    metrics.register('accuracy', torchmetrics.Accuracy(task='multiclass', num_classes=n_classes))
    metrics.register('precision', torchmetrics.Precision(task='multiclass', num_classes=n_classes, average='macro'))
    metrics.register('recall', torchmetrics.Recall(task='multiclass', num_classes=n_classes, average='macro'))
    metrics.register('f1_score', torchmetrics.F1Score(task='multiclass', num_classes=n_classes, average='macro'))
    
    return metrics

def experiment(config : ExperimentConfig, data_dict, device, metrics, freeze=False):
    
    save_path = f'{config.name}_folder'
    os.makedirs(save_path, exist_ok=True)   
    
    folds = data_dict['folds']
    base_dataset = data_dict['base_dataset']
    train_transform = data_dict['train_transform']
    val_transform = data_dict['val_transform']
    
    for fold_idx, (train_idx, val_idx) in enumerate(folds):
        print(f"\n--- Fold {fold_idx+1} ---")
        
        model = config.model_fn().to(device=device)
        
        if freeze: model.freeze()
        else : model.unfreeze()

        optimizer = config.optimizer_fn(model.parameters())
        loss_fn = config.loss_fn()
        
        train_subset = Subset(base_dataset, train_idx)
        val_subset = Subset(base_dataset, val_idx)

        train_dataset = TransformDataset(train_subset, train_transform)
        val_dataset = TransformDataset(val_subset, val_transform)

        train_dataloader, val_dataloader = prepare_dataloaders(train_dataset, val_dataset)
        
        save_name = f'{config.name}_fold={fold_idx+1}.pt'
        
        history, model = train(
            model, train_dataloader, val_dataloader,
            loss_fn, optimizer,
            save_path=save_path, save_name=save_name,
            device=device, metrics=metrics, verbose=True
        )

In [3]:
data_dict = get_data_chest_x_ray_image(img_size=(224, 224), kfold=5)    

In [4]:
device = get_device()

In [None]:
from lib.models import MyResnet18
from torch import nn 
import torch.optim as opt

n_classes = len(data_dict['classes'])

metrics = get_my_metrics(device, n_classes)

resnet18_config = ExperimentConfig(
    name="resnet18",
    model_fn=lambda: MyResnet18(n_classes=n_classes),
    loss_fn=lambda: nn.CrossEntropyLoss(),
    optimizer_fn=lambda params: opt.Adam(params, lr=1e-3)
)
            
experiment(resnet18_config, data_dict, device, metrics, freeze=False)

resnet18_freeze_config = ExperimentConfig(
    name="resnet18_freeze",
    model_fn=lambda: MyResnet18(n_classes=n_classes),
    loss_fn=lambda: nn.CrossEntropyLoss(),
    optimizer_fn=lambda params: opt.Adam(params, lr=1e-3)
)

experiment(resnet18_freeze_config, data_dict, metrics, freeze=True)