In [None]:
import copy
import csv
import os
import time

import numpy as np
import torch
from tqdm import tqdm


# def batch_segmentation_iou(outputs, labels):

#     SMOOTH = 1e-6
#     outputs = outputs.unsqueeze(dim=1)
    
#     intersection = (outputs & labels).float().sum((2, 3))  
#     union = (outputs | labels).float().sum((2, 3))    
#     iou = (intersection + SMOOTH) / (union + SMOOTH)

#     return iou.to("cpu").numpy()

def train_model(model, criterion, dataloaders, optimizer, metrics, bpath,
                num_epochs):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10
    # Use gpu if available
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    # Initialize the log file for training and testing loss and metrics
    fieldnames = ['epoch', 'Train_loss', 'Test_loss'] + \
        [f'Train_{m}' for m in metrics.keys()] + \
        [f'Test_{m}' for m in metrics.keys()]
    with open(os.path.join(bpath, 'log.csv'), 'w', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()

    for epoch in range(1, num_epochs + 1):
        print('Epoch {}/{}'.format(epoch, num_epochs))
        print('-' * 10)
        # Each epoch has a training and validation phase
        # Initialize batch summary
        batchsummary = {a: [0] for a in fieldnames}

        for phase in ['Train', 'Test']:
            if phase == 'Train':
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            ious = 0
            cnt = 0

            # Iterate over data.
            for sample in tqdm(iter(dataloaders[phase])):
                cnt += 1
                inputs = sample['image'].to(device)
                masks = sample['mask'].to(device)
                
                # zero the parameter gradients
                optimizer.zero_grad()

                # track history if only in train
                with torch.set_grad_enabled(phase == 'Train'):
                    outputs = model(inputs)
                    loss = criterion(outputs['out'], masks)

                    y_pred = outputs['out'].data.cpu().numpy().ravel()
                    y_true = masks.data.cpu().numpy().ravel()

                    # y_pred_iou = outputs['out']
                    # y_true_iou = masks
                    
                    # y_pred_iou = torch.ravel(y_pred_iou)
                    # y_true_iou = torch.ravel(y_true_iou)

                    y_pred = np.where(y_pred > 0.1, 1., 0.)
                    y_true = np.where(y_true > 0., 1., 0.)

                    # print(y_pred)
                    # print(y_true)

                    # print('y_pred :', y_pred[y_pred < 0.1])
                    # print('y_true :', len(y_true[y_true > 0]))
                    # print('y_true :', len(y_true))

                    intersection = np.logical_and(y_pred, y_true)
                    union = np.logical_or(y_pred, y_true)

                    intersection = len(intersection[intersection])
                    union = len(union[union])
                    
                    SMOOTH = 1e-6
                    iou = (intersection + SMOOTH) / (union + SMOOTH)
                    
                    batch_size = masks.shape[0]
                    
                    iou += iou / (cnt*batch_size)

                    
                                      
                    for name, metric in metrics.items():
                        if name == 'f1_score':
                            # Use a classification threshold of 0.1
                            batchsummary[f'{phase}_{name}'].append(
                                metric(y_true > 0, y_pred > 0.1))
                        elif name == 'auroc':
                            batchsummary[f'{phase}_{name}'].append(
                                metric(y_true.astype('uint8'), y_pred))
                    
                    # label = sample['mask'].to(device).type(torch.long)
                    # print(label.shape)
                    # out_target = outputs['out'].to("cpu").to(device) 
                    # print(out_target)
                    # ious += batch_segmentation_iou(out_target, label)

                    # backward + optimize only if in training phase
                    if phase == 'Train':
                        loss.backward()
                        optimizer.step()
            batchsummary[f'{phase}_iou'] = ious
            # print(cnt)
            # iou = ious/cnt
            # print('iou:', iou)
            # batchsummary[f'{phase}_iou'].append(iou)

            batchsummary['epoch'] = epoch
            epoch_loss = loss
            batchsummary[f'{phase}_loss'] = epoch_loss.item()
            print('{} Loss: {:.4f}'.format(phase, loss))
        for field in fieldnames[3:]:
            batchsummary[field] = np.mean(batchsummary[field])
        print(batchsummary)
        with open(os.path.join(bpath, 'log.csv'), 'a', newline='') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
            writer.writerow(batchsummary)
            # deep copy the model
            if phase == 'Test' and loss < best_loss:
                best_loss = loss
                best_model_wts = copy.deepcopy(model.state_dict())

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Lowest Loss: {:4f}'.format(best_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model