In [2]:
import argparse
import logging
import sys
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

# from utils.dice_score import dice_loss
%run DiceLoss.ipynb import DiceLoss
%run Evaluater.ipynb import Evaluater
from monai.networks.nets import BasicUNet as BU
%run DosePredictionDataset.ipynb import DosePredictionDataset

  warn(f"Failed to load image Python extension: {e}")


importing Jupyter notebook from DICOMReader.ipynb


In [2]:
# dir_img = Path('./data/imgs/')
# dir_mask = Path('./data/masks/')
# dir_checkpoint = Path('./checkpoints/')    
        
class Net_Trainer:

    def train_net(net,
                  device,
                  input_dir,
                  dose_dir,
                  dir_checkpoint,
                  epochs: int = 5,
                  batch_size: int = 1,
                  learning_rate: float = 0.001,
                  val_percent: float = 0.1,
                  save_checkpoint: bool = True,
                  amp: bool = False,
                  ):

        # create dataloaders for training and validation
        dataset = DosePredictionDataset(num_patients, input_dir, dose_dir)
        val_num = int(len(dataset) * val_percent)
        train_num = len(dataset) - val_num
        train_set, val_set = random_split(dataset, [train_num, val_num], generator=torch.Generator().manual_seed(42))    
        loader_args = dict(batch_size=batch_size, num_workers=num_workers, pin_memory=True)
        train_loader = DataLoader(train_set, shuffle=True, **loader_args)
        val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

        # Initialize logging
        experiment = wandb.init(project='StandardUNet', resume='allow', anonymous='must')
        experiment.config.update(dict(epochs=epochs, 
                                      batch_size=batch_size, 
                                      learning_rate=learning_rate,
                                      val_percent=val_percent, 
                                      save_checkpoint=save_checkpoint, 
                                      amp=amp))

        logging.info(f'''Starting training:
            Epochs:          {epochs}
            Batch size:      {batch_size}
            Learning rate:   {learning_rate}
            Training size:   {n_train}
            Validation size: {n_val}
            Checkpoints:     {save_checkpoint}
            Device:          {device.type}
            Mixed Precision: {amp}
        ''')

        # Set up optimizer/loss/learning rate/scheduler
        optimizer = optim.RMSprop(net.parameters(), 
                                  lr=learning_rate, 
                                  weight_decay=1e-8, 
                                  momentum=0.9)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)  # goal: maximize Dice score
        criterion = nn.CrossEntropyLoss()
        global_step = 0

        # training begins
        for epoch in range(epochs):
            net.train()
            epoch_loss = 0
            with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as progressBar:
                for images, labels in train_loader:
#                     images = batch['image']
#                     true_masks = batch['mask']

                    images = images.to(device=device, dtype=torch.float32)
                    labels = labels.to(device=device, dtype=torch.float32)

#                     with torch.cuda.amp.autocast(enabled=amp):
                    image_pred = net(images)
                    loss = criterion(image_pred, labels) + DiceLoss().forward(image_pred,
                                                                              labels)
#                                + dice_loss(F.softmax(image_pred, dim=1).float(),
#                                            F.one_hot(labels, net.n_classes).permute(0, 3, 1, 2).float(),
#                                            multiclass=True)

                    optimizer.zero_grad()
#                     grad_scaler.scale(loss).backward()
                    loss.backwards()
#                     grad_scaler.step(optimizer)
#                     grad_scaler.update()
                    optimizer.step()

                    progressBar.update(image.shape[0])
                    global_step += 1
                    epoch_loss += loss.item()
                    experiment.log({
                        'train loss': loss.item(),
                        'step': global_step,
                        'epoch': epoch
                    })
                    progressBar.set_postfix(**{'loss (batch)': loss.item()})
                    
                        

#                     # Evaluation round
#                     division_step = (train_num // (10 * batch_size))
#                     if division_step > 0:
#                         if global_step % division_step == 0:
#                             histograms = {}
#                             for tag, value in net.named_parameters():
#                                 tag = tag.replace('/', '.')
#                                 histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
#                                 histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())

#                             val_score = Evaluater().evaluate(net, val_loader, device)
#                             scheduler.step(val_score)

#                             logging.info('Validation Dice score: {}'.format(val_score))
#                             experiment.log({
#                                 'learning rate': optimizer.param_groups[0]['lr'],
#                                 'validation Dice': val_score,
#                                 'images': wandb.Image(images[0].cpu()),
#                                 'labels': {
#                                     'true': wandb.Image(labels[0].float().cpu()),
#                                     'pred': wandb.Image(torch.softmax(image_pred, dim=1).argmax(dim=1)[0].float().cpu()),
#                                 },
#                                 'step': global_step,
#                                 'epoch': epoch,
#                                 **histograms
#                             })

            if save_checkpoint:
                Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
                torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch + 1)))
                logging.info(f'Checkpoint {epoch + 1} saved!')


    def get_args():
        parser = argparse.ArgumentParser(description='Train the BasicUNet on images and labels')
        parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
        parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
        parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=0.00001,
                            help='Learning rate', dest='lr')
        parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
        parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
        parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
                            help='Percent of the data that is used as validation (0-100)')
        parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')

        return parser.parse_args()


    if __name__ == '__main__':
        args = get_args()

        logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        logging.info(f'Using device {device}')

        # Change here to adapt to your data
        # n_channels=3 for RGB images
        # n_classes is the number of probabilities you want to get per pixel
        net = BU(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128))

        logging.info(f'Network:\n'
                     f'\t{2} input channels\n')
#                      f'\t{net.n_classes} output channels (classes)\n'
#                      f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')

        if args.load:
            net.load_state_dict(torch.load(args.load, map_location=device))
            logging.info(f'Model loaded from {args.load}')

        net.to(device=device)
        try:
            train_net(net=net,
                      epochs=args.epochs,
                      batch_size=args.batch_size,
                      learning_rate=args.lr,
                      device=device,
                      img_scale=args.scale,
                      val_percent=args.val / 100,
                      amp=args.amp)
        except KeyboardInterrupt:
            torch.save(net.state_dict(), 'INTERRUPTED.pth')
            logging.info('Saved interrupt')
            sys.exit(0)

INFO: Using device cpu
INFO: Network:
	2 input channels



BasicUNet features: (64, 128, 256, 512, 1024, 128).


UnpicklingError: invalid load key, '{'.

In [29]:
for epoch in range(100):
    with tqdm(total=200, desc=f'Epoch {epoch + 1}/{100}') as progressBar:
        for j in range(20):
            progressBar.update(j*10)
            progressBar.set_postfix(**{'loss (batch)': epoch})

Epoch 1/100: 1900it [00:00, 145792.75it/s, loss (batch)=0]                      
Epoch 2/100: 1900it [00:00, 221495.25it/s, loss (batch)=1]                      
Epoch 3/100: 1900it [00:00, 244146.25it/s, loss (batch)=2]                      
Epoch 4/100: 1900it [00:00, 285183.85it/s, loss (batch)=3]                      
Epoch 5/100: 1900it [00:00, 341658.20it/s, loss (batch)=4]                      
Epoch 6/100: 1900it [00:00, 375656.53it/s, loss (batch)=5]                      
Epoch 7/100: 1900it [00:00, 411801.24it/s, loss (batch)=6]                      
Epoch 8/100: 1900it [00:00, 385338.12it/s, loss (batch)=7]                      
Epoch 9/100: 1900it [00:00, 452717.01it/s, loss (batch)=8]                      
Epoch 10/100: 1900it [00:00, 490682.69it/s, loss (batch)=9]                     
Epoch 11/100: 1900it [00:00, 465353.44it/s, loss (batch)=10]                    
Epoch 12/100: 1900it [00:00, 517546.28it/s, loss (batch)=11]                    
Epoch 13/100: 1900it [00:00,