In [1]:
import torch
import torch.backends.cudnn as cudnn
import argparse
import logging
import pdb
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from torch.optim import Adam, SGD
from dataprepaug import *
from utils import *
from models import *
from functools import partial
from models.unet_model import UNet
from models.UnetPlusPlus import NestedUNet
#from .pspnet import *
#from .deeplab import *
import torch.optim as optim

In [7]:
def trainmodel(config,checkpoint_dir=None):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    print(device)
    cudnn.benchmark = True
    get_logger()
    
    file_path_train = "/home/jay/Documents/courses/Aicourse/Brats/HYPER/train/"
    
    
    file_path_validation = "/home/jay/Documents/courses/Aicourse/Brats/HYPER/validation/"
    print('loading Data')
    net=NestedUNet(4,4, layer_multipler= config["layer_multipler"])   #load_model(args, class_num=4, mode='train')
    optimizer = SGD(net.parameters(), lr=config["lr"], momentum=0.9, weight_decay=1e-8)
        
    net=net.to(device)
    
    if checkpoint_dir:
        checkpoint = os.path.join(checkpoint_dir, "checkpoint")
        model_state, optimizer_state = torch.load(checkpoint)
        net.load_state_dict(model_state)
        optimizer.load_state_dict(optimizer_state)
    
    
    traindataload = data_load(file_path_train, batch_size = config["batch_size"],datatype='Train')
    validationload = data_load(file_path_validation,batch_size = config["batch_size"],datatype = 'Validation')
    
    grad_sc = torch.cuda.amp.GradScaler(enabled=True)
    for epoch in range(10):
        # Train Model
        print(f'Epoch: {epoch}')
        
        
        print('Training Data Load')
        loss = 0
        epoch_loss = 0
        size=0
        acc=0
        nbatch = len(traindataload)
        for idx, (inputs, targets) in enumerate(traindataload):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net(inputs)
                size += outputs.shape[0]*outputs.shape[2]*outputs.shape[3]
                batch_loss = dice_loss(outputs.softmax(1),targets, multiclass=True)
                optimizer.zero_grad()
                grad_sc.scale(batch_loss).backward()
                grad_sc.step(optimizer)
                grad_sc.update()
                loss = batch_loss.item()
                epoch_loss += loss
                acc += (outputs.argmax(1) == targets.argmax(1)).type(torch.float).sum().item()
                progress_bar(idx, len(traindataload), 'Loss: %.5f, Dice-Coef: %.5f'
                         %(loss, 1-loss))#(loss/(idx+1)), (1-(loss/(idx+1)))))
                #log_msg = '\n'.join(['Training (%d/%d): Epoch: %d, Loss: %.5f,  Dice-Coef:  %.5f' %(idx,nbatch, epoch,loss, 1-loss)])
                #logging.info(log_msg)
        epoch_loss /= nbatch
        acc /= size
        print(f"Training epoch error: \n Acc: {(100*acc):>0.1f}%, Avg loss: {epoch_loss:>8f}, Dice: {(1-epoch_loss):>8f} \n")
        
        
        
    loss = 0
    acc = 0
    size=0
    nbatch = len(validationload)
    with torch.no_grad():
            for idx, (inputs, targets) in enumerate(validationload):
                    inputs, targets = inputs.to(device), targets.to(device)
                    outputs = net(inputs)
                    loss += dice_loss(outputs.softmax(1),targets, multiclass=True).item()
                    acc += (outputs.argmax(1) == targets.argmax(1)).type(torch.float).sum().item()
                    size += outputs.shape[0]*outputs.shape[2]*outputs.shape[3]
            
                        
                        #progress_bar(idx, len(validationload), 'Loss: %.5f, Dice-Coef: %.5f' %(loss/(idx+1), 1-(loss/(idx+1))))
                    log_msg = '\n'.join(['Validation (%d/%d): Epoch: %d  Loss: %.5f,  Dice-Coef:  %.5f'
                %(idx, nbatch, epoch, loss/(idx+1), 1-(loss/(idx+1)))])
                    logging.info(log_msg)
        
    loss /= nbatch
    acc /= size
                        
    print(f"Validation error: \n Acc: {(100*acc):>0.1f}%, Avg loss: {loss:>8f}, Dice: {(1-loss):>8f} \n")
        
        
        
    ckpt_root = "/home/jay/Documents/courses/Aicourse/Brats/checkpointUnetPlusPlusHYPER/"
        #train(traindataload,net,optimizer, grad_sc,epoch,device)
        #print('\n\n<Validation>')
        #[validloss, acc] = validation(validationload,net,epoch,device)
       
        # Save Model
        #loss /= (idx+1)
        #score = 1 - validloss
        #if score > best_score:
        #checkpoint = Checkpoint(net, optimizer, epoch, score)
        #checkpoint.save(os.path.join(args.ckpt_root, args.model+'.tar'))
        #best_score = score
        #print("Saving...")
        
    with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
        path = os.path.join(checkpoint_dir, "checkpoint")
        torch.save(
            (net.state_dict(), optimizer.state_dict()), path)
        
#        with tune.checkpoint_dir(epoch) as ckpt_root:
                
#                path = os.path.join(ckpt_root, "checkpoint")
#                torch.save((net.state_dict(), optimizer.state_dict()), path)
    tune.report(loss=loss, accuracy=acc)
        
#        print("Saving...")

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_samples=1
max_num_epochs=10
gpus_per_trial=1
print(device)
cudnn.benchmark = True
config = {
        "layer_multipler": tune.choice([0.5,1,2,3,4,5]),
        "lr": tune.loguniform(1e-4, 1e-1),
        "batch_size": tune.choice([2, 4, 8, 16])
    }

        

cuda:0


In [3]:
scheduler = ASHAScheduler(
        metric="loss",
        mode="min",
        max_t=max_num_epochs,
        grace_period=1,
        reduction_factor=2)


In [None]:
print('Here')
reporter = CLIReporter(
        parameter_columns=["layer_multipler", "lr", "batch_size"],
        metric_columns=["loss", "accuracy", "training_iteration"])
print('Run tune run line')
result = tune.run(
        trainmodel,
        config=config,
        resources_per_trial={"gpu": 1},
        
        num_samples=num_samples,
        scheduler=scheduler,
        progress_reporter=reporter,
checkpoint_at_end=True)

Here
Run tune run line


TypeError: Remote functions cannot be called directly. Instead of running '__main__.my_function()', try '__main__.my_function.remote()'.