In [None]:
%env CUDA_LAUNCH_BLOCKING=1
training_path = 'training_set/'
testing_path = 'testing_set/'
csv_name = 'parameters.csv'
im_suffix = 'HC.png'
n_folds = 5
val_split = 0.1
train_size = 8
test_size = 8
epochs = 50
patience = 50
subjects = sorted([
    f for f in os.listdir(training_path)
    if f.endswith(im_suffix)
])

im_size = (540, 800)
half_x = im_size[1] / 2
half_y = im_size[0] / 2

n_points = 1000
angles = (np.arange(n_points) / n_points) * 2 * np.pi

fig = plt.figure(figsize=(32, 42))

for i in range(n_folds):
    fold_ini = len(subjects) * i // n_folds
    fold_end = len(subjects) * (i + 1) // n_folds
    
    training_set = subjects[fold_end:] + subjects[:fold_ini]

    # We account for a validation set or the lack of it. The reason for
    # this is that we want to measure forgetting and that is easier to
    # measure if we only focus on the training set and leave the testing
    # set as an independent generalisation test.
    if val_split > 0:
        n_training = int(len(training_set))
        training_set = training_set[int(n_training * val_split):]
        validation_set = training_set[:int(n_training * val_split)]
    else:
        validation_set = training_set
        
    # Testing set for the current fold
    testing_set = subjects[fold_ini:fold_end]
    
    # net = ConvNeXtTiny()
    net = SimpleUNet()
    
    print('< Training dataset >')
    train_dataset = SkullUSDataset(training_path, csv_name, sub_list=training_set)
    train_loader = DataLoader(
        train_dataset, train_size, True, num_workers=1
    )
    
    print('< Validation dataset >')
    val_dataset = SkullUSDataset(training_path, csv_name, sub_list=validation_set)
    val_loader = DataLoader(
        val_dataset, test_size, num_workers=1
    )
    
    print('< Testing dataset >')
    test_dataset = SkullUSDataset(training_path, csv_name, sub_list=testing_set)
    test_loader = DataLoader(
        test_dataset, test_size, num_workers=1
    )
    
    print(
        'Training / validation / test samples samples = '
        '{:03d} / {:03d} / {:03d} ({:d} parameters)'.format(
            len(train_dataset), len(val_dataset), len(test_dataset),
            sum(p.numel() for p in net.parameters() if p.requires_grad)
        )
    )

    net.fit(
        train_loader, val_loader, epochs=epochs, patience=patience
    )
    
    net.eval()
    batch_j = 0
    for im, (true, _, params) in test_loader:
        with torch.no_grad():
            pred_batch = torch.sigmoid(net(im.to(net.device))).detach().cpu()
        for norm_p, gt, seg, im_j in zip(params.numpy(), true, pred_batch, im):
#             brain_mask = seg.numpy() > 0
#             inner_brain = binary_erosion(brain_mask)
#             outer_brain = torch.from_numpy(
#                 np.logical_and(
#                     brain_mask, np.logical_not(inner_brain)
#                 )
#             )
#             pred_a, pred_b, pred_x0, pred_y0, pred_theta = fit_ellipse(outer_brain)
            pred_a, pred_b, pred_x0, pred_y0, pred_theta = fit_ellipse(seg.squeeze().numpy() > 0.5)
#             prednorm_a, prednorm_b, prednorm_x0, prednorm_y0, pred_theta = pred
            norm_a, norm_b, norm_x0, norm_y0, norm_theta = norm_p
            true_a = norm_a * half_x + half_x
            true_b = norm_b * half_y + half_y
            true_x0 = norm_x0 * half_x + half_x
            true_y0 = norm_y0 * half_y + half_y
            true_theta = norm_theta * np.pi
#             pred_a = prednorm_a * half_x + half_x
#             pred_b = prednorm_b * half_y + half_y
#             pred_x0 = prednorm_x0 * half_x + half_x
#             pred_y0 = prednorm_y0 * half_y + half_y
            print(true_x0, pred_x0, true_y0, pred_y0)
            fig.clear()
#             plt.subplot(2, 1, 1)
            norm_im = (im_j[:1, ...] - torch.min(im_j)) / (torch.max(im_j) - torch.min(im_j))
            segmentation = np.moveaxis(
                torch.cat([seg > 0.5, gt, norm_im]).detach().cpu().numpy(), 0, -1
            )
            plt.imshow(segmentation)
#             plt.imshow(im_j[0, ...], cmap='gray')
            ideal_x = true_a * np.cos(angles)
            ideal_y = true_b * np.sin(angles)
            new_x = ideal_x * np.cos(true_theta) - ideal_y * np.sin(true_theta) + true_x0
            new_y = ideal_y * np.cos(true_theta) + ideal_x * np.sin(true_theta) + true_y0
            plt.scatter(new_x, new_y, c='g')
            plt.scatter(true_x0, true_y0, c='g')
            
            ideal_x = pred_a * np.cos(angles)
            ideal_y = pred_b * np.sin(angles)

            new_x = ideal_x * np.cos(pred_theta) - ideal_y * np.sin(pred_theta) + pred_x0
            new_y = ideal_y * np.cos(pred_theta) + ideal_x * np.sin(pred_theta) + pred_y0

#             plt.scatter(new_x, new_y, s=1, c='r')
            plt.scatter(pred_x0, pred_y0, s=1, c='r')
#             plt.subplot(2, 1, 2)
#             norm_im = (im_j[:1, ...] - torch.min(im_j)) / (torch.max(im_j) - torch.min(im_j))
#             segmentation = np.moveaxis(
#                 torch.cat([outer_brain, gt, norm_im]).detach().cpu().numpy(), 0, -1
#             )
#             plt.imshow(segmentation)
            
            plt.scatter(new_x, new_y, c='r')
            plt.scatter(pred_x0, pred_y0, c='r')
            plt.savefig(
                os.path.join(
                    testing_path,
                    'unet_f{:02d}_b{:03d}.png'.format(i, batch_j)
                )
            )

            batch_j += 1