In [None]:
import random, math, torch
import torch.nn.functional as F
import torchvision.transforms as T
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torchvision
import datetime
from sklearn.metrics import f1_score, classification_report

class Config:
    def __init__(self, dictionary):
        for key, value in dictionary.items():
            if isinstance(value, dict):
                setattr(self, key, Config(value))
            else:
                setattr(self, key, value)

    def to_dict(self):
        dictionary = {}
        # import pdb; pdb.set_trace()
        for key, value in self.__dict__.items():
            if isinstance(value, Config):
                dictionary[key] = value.to_dict()
            else:
                dictionary[key] = value
        
        return dictionary
                
def log_classifier_metrics(model, metrics, epoch, mode, writer):
    """Log and print training/validation metrics."""
     
    metrics_dict = {
        'loss/all': metrics[0].mean().item(),
        'loss/pos': metrics[0, metrics[2] == 1].mean().item(),
        'loss/neg': metrics[0, metrics[2] == 0].mean().item()
    }
    
    pos_mask = metrics[2] == 1
    neg_mask = ~pos_mask
    
    true_pos = (metrics[1, pos_mask] > 0.5).sum()
    true_neg = (metrics[1, neg_mask] <= 0.5).sum()
    false_pos = (metrics[1, neg_mask] > 0.5).sum()
    false_neg = (metrics[1, pos_mask] <= 0.5).sum()

    recall = true_pos / (true_pos + false_neg)
    precision = true_pos / (true_pos + false_pos)
    f1_score = (1+1**2) * (precision * recall) / (1**2*precision + recall)
    f_5_score = (1+.5**2) * (precision * recall) / (.5**2*precision + recall)
    f2_score = (1+2**2) * (precision * recall) / (2**2*precision + recall)
    
    metrics_dict.update({
        'metric/recall': 100 * recall,
        'metric/precision': 100 * precision,
        'metric/f1_score': 100 * f1_score,
        'metric/f_5_score': 100 * f_5_score,
        'metric/f2_score': 100 * f2_score,
        'metric/accuracy': 100 * (true_pos + true_neg) / (true_pos + true_neg + false_pos + false_neg),
        'metric/accuracy_pos': 100 * true_pos / (true_pos + false_neg),
        'metric/accuracy_neg': 100 * true_neg / (true_neg + false_pos)
    })

    for k,v in metrics_dict.items():
        writer.add_scalar(mode.lower()+'/'+k, v, epoch)
        
    if mode == 'Train':
        for name, param in model.named_parameters():
            writer.add_histogram(f"{name}_weights", param, epoch)
            writer.add_histogram(f"{name}_grad", param.grad, epoch)
    
    log_format = (
        f"Epoch {epoch}, {datetime.datetime.now().strftime('%Y-%m-%d, %H:%M:%S')}: "
        f"{mode} Loss {metrics_dict['loss/all']:.2f}, "
        f"{mode} Loss/Pos {metrics_dict['loss/pos']:.2f}, "
        f"{mode} Loss/Neg {metrics_dict['loss/neg']:.2f}, "
        f"\n{mode} Metric/Recall {metrics_dict['metric/recall']:.2f}%, "
        f"{mode} Metric/Precision {metrics_dict['metric/precision']:.2f}%, "
        f"{mode} Metric/F1 Score {metrics_dict['metric/f1_score']:.2f}%, "
        f"{mode} Metric/F.5 Score {metrics_dict['metric/f_5_score']:.2f}%, "
        f"{mode} Metric/F2 Score {metrics_dict['metric/f2_score']:.2f}%, "
        f"{mode} Metric/Accuracy {metrics_dict['metric/accuracy']:.2f}%, "
        f"{mode} Metric/Accuracy_pos {metrics_dict['metric/accuracy_pos']:.2f}%, "
        f"{mode} Metric/Accuracy_neg {metrics_dict['metric/accuracy_neg']:.2f}%"
    )
    print(log_format)
    print('-' * 50)

    return metrics_dict

def log_segmenter_metrics(model, metrics, epoch, mode, writer):
    """Log and print training/validation metrics."""
     
    metrics_dict = {
        'loss/all': metrics[0].mean().item(),
    }
    
    true_pos = metrics[1].sum()
    true_neg = metrics[2].sum()
    false_pos = metrics[3].sum()
    false_neg = metrics[4].sum()

    recall = true_pos / (true_pos + false_neg)
    precision = true_pos / (true_pos + false_pos)
    f1_score = (1+1**2) * (precision * recall) / (1**2*precision + recall)
    f_5_score = (1+.5**2) * (precision * recall) / (.5**2*precision + recall)
    f2_score = (1+2**2) * (precision * recall) / (2**2*precision + recall)
    
    metrics_dict.update({
        'metric/recall': 100 * recall,
        'metric/precision': 100 * precision,
        'metric/f1_score': 100 * f1_score,
        'metric/f_5_score': 100 * f_5_score,
        'metric/f2_score': 100 * f2_score,
        'metric/accuracy': 100 * (true_pos + true_neg) / (true_pos + true_neg + false_pos + false_neg),
        'metric/accuracy_pos': 100 * true_pos / (true_pos + false_neg),
        'metric/accuracy_neg': 100 * true_neg / (true_neg + false_pos)
    })

    for k,v in metrics_dict.items():
        writer.add_scalar(mode.lower()+'/'+k, v, epoch)
    
    log_format = (
        f"Epoch {epoch}, {datetime.datetime.now().strftime('%Y-%m-%d, %H:%M:%S')}: "
        f"{mode} Loss {metrics_dict['loss/all']:.2f}, "
        f"\n{mode} Metric/Recall {metrics_dict['metric/recall']:.2f}%, "
        f"{mode} Metric/Precision {metrics_dict['metric/precision']:.2f}%, "
        f"{mode} Metric/F1 Score {metrics_dict['metric/f1_score']:.2f}%, "
        f"{mode} Metric/F.5 Score {metrics_dict['metric/f_5_score']:.2f}%, "
        f"{mode} Metric/F2 Score {metrics_dict['metric/f2_score']:.2f}%, "
        f"{mode} Metric/Accuracy {metrics_dict['metric/accuracy']:.2f}%, "
        f"{mode} Metric/Accuracy_pos {metrics_dict['metric/accuracy_pos']:.2f}%, "
        f"{mode} Metric/Accuracy_neg {metrics_dict['metric/accuracy_neg']:.2f}%"
    )
    print(log_format)
    print('-' * 50)

    return metrics_dict

def show_nodule(sample, label, irc_center, cmap='gray', figsize=(10, 4)):
    """
    Visualize nodule slices from different perspectives
    
    Args:
        sample (np.ndarray): Volume data
        irc_center (Tuple[int, int,int]): Center coordinates
        label (int): Nodule classification
        cmap (str): Colormap for visualization
        figsize (Tuple[int, int]): Figure size
    """
    sample = sample.squeeze(0)
    irc_center = irc_center.squeeze()    
    wi, wr, wc = np.array(sample.shape) // 2
    i, r, c = irc_center
    titles = [f'Index {i}', f'Row {r}', f'Column {c}']
    nod = 'Nodule' if label else 'Not Nodule'

    # Normalize and extract slices
    def normalize(arr):
        return (arr - arr.min()) / (arr.max() - arr.min() + 1e-8)

    index = normalize(np.transpose(sample, (2,1,0))[:,:,wi])
    col = normalize(sample[:,:,wc])
    row = normalize(np.transpose(sample, (0,2,1))[:,:,wr])

    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize)
    
    ax1.imshow(index, cmap=cmap)
    ax1.set_title(titles[0])

    ax2.imshow(row, cmap=cmap)
    ax2.set_title(titles[1])
    
    ax3.imshow(col, cmap=cmap)
    ax3.set_title(titles[2])

    fig.suptitle(nod)
    plt.tight_layout()
    plt.show()
    
def augment_candidates_3d(candidates, label, augment):
    # TODO I NEED TO TRANFORM FOR MANY CANDIDATES
    transform_t = torch.eye(4).to(candidates.device)

    for i in range(3):
        if hasattr(augment, 'flip') and augment.flip:
            if random.random() > 0.5:
                transform_t[i,i] *= -1

        if hasattr(augment, 'offset'):
            offset_float = augment.offset
            random_float = (random.random() * 2 - 1)
            transform_t[i,3] = offset_float * random_float

        if hasattr(augment, 'scale'):
            scale_float = augment.scale
            random_float = (random.random() * 2 - 1)
            transform_t[i,i] *= 1.0 + scale_float * random_float


    if hasattr(augment, 'rotate') and augment.rotate:
        angle_rad = random.random() * math.pi * 2
        s = math.sin(angle_rad)
        c = math.cos(angle_rad)

        rotation_t = torch.tensor([
            [c, -s, 0, 0],
            [s, c, 0, 0],
            [0, 0, 1, 0],
            [0, 0, 0, 1],
        ]).to(candidates.device)

        transform_t @= rotation_t

    affine_t = F.affine_grid(
            transform_t[:3].to(torch.float32).expand(candidates.shape[0], -1, -1),
            candidates.size(),
            align_corners=False,
        )

    augmented_chunk = F.grid_sample(
            candidates,
            affine_t,
            padding_mode='border',
            align_corners=False,
        )

    if hasattr(augment, 'noise'):
        noise_t = torch.randn_like(augmented_chunk).to(candidates.device)
        noise_t *= augment.noise

        augmented_chunk += noise_t

    if hasattr(augment, 'mixup'):
        alpha = augment.mixup
        lambda_ = torch.distributions.Beta(alpha, alpha).sample().item()

        batch_size = candidates[label == 1].size(0)
        index = torch.randperm(batch_size).to(candidates.device)

        augmented_chunk[label == 1] = lambda_ * augmented_chunk[label == 1] + \
                (1 - lambda_) * augmented_chunk[label == 1][index]
   
    return augmented_chunk, label

def augment_candidates_2d(candidates, label, augment):
    """
    Augment 2D image candidates using torchvision transforms.
    
    Args:
        candidates: Tensor of shape (batch_size, channels, height, width)
        label: Binary tensor indicating positive/negative samples
        augment: Object containing augmentation parameters
    
    Returns:
        tuple: (augmented_candidates, label)
    """
    batch_size = candidates.shape[0]
    device = candidates.device
    
    # Initialize identity transform matrix
    transform_t = torch.eye(3).to(device)
    
    # Handle flipping
    if hasattr(augment, 'flip') and augment.flip:
        if random.random() > 0.5:
            candidates = T.functional.hflip(candidates)
            label = T.functional.hflip(label)
        if random.random() > 0.5:
            candidates = T.functional.vflip(candidates)
            label = T.functional.vflip(label)

    # Handle scaling
    if hasattr(augment, 'scale'):
        scale_factor = 1.0 + augment.scale * (random.random() * 2 - 1)
        transform_t[0, 0] *= scale_factor
        transform_t[1, 1] *= scale_factor
    
    # Handle rotation
    if hasattr(augment, 'rotate') and augment.rotate:
        angle_rad = random.random() * math.pi * 2
        angle_deg = math.degrees(angle_rad)
        rotation_matrix = torch.tensor([
            [math.cos(angle_rad), -math.sin(angle_rad), 0],
            [math.sin(angle_rad), math.cos(angle_rad), 0],
            [0, 0, 1]
        ]).to(device)
        transform_t = transform_t @ rotation_matrix
    
    # Handle translation/offset
    if hasattr(augment, 'offset'):
        offset_x = augment.offset * (random.random() * 2 - 1)
        offset_y = augment.offset * (random.random() * 2 - 1)
        transform_t[0, 2] = offset_x
        transform_t[1, 2] = offset_y
    
    # Convert 3x3 transformation matrix to 2x3 affine matrix expected by F.affine_grid
    affine_matrix = transform_t[:2].unsqueeze(0).repeat(batch_size, 1, 1)
    
    # Apply affine transformation
    grid = F.affine_grid(
        affine_matrix,
        candidates.size(),
        align_corners=False
    )
    
    augmented_chunk = F.grid_sample(
        candidates,
        grid,
        padding_mode='border',
        align_corners=False
    )

    label = F.grid_sample(
        label,
        grid,
        padding_mode='border',
        align_corners=False
    )
    
    # Add noise if specified
    if hasattr(augment, 'noise'):
        if random.random() > 0.5:
            noise = torch.randn_like(augmented_chunk).to(device) * augment.noise
            augmented_chunk += noise
    
    return augmented_chunk, label

class FeatureMapLogger(nn.Module):
    def __init__(self, model, writer, visualize):
        super().__init__()
        self.model = model
        self.writer = writer
        self.visualize = visualize
        self.hooks = []
        self.feature_maps = {}
        
        # Register hooks for each layer we want to visualize
        for name, layer in model.named_modules():
            if isinstance(layer, nn.Conv3d):
                self.hooks.append(
                    layer.register_forward_hook(
                        self.named_hook(name)
                    )
                )

    def named_hook(self, name): 
        def fun_hook(_, __, output):
            self._hook_fn(name, output)
        return fun_hook
    
    def _hook_fn(self, layer_name, output):
        # Store the feature maps
        self.feature_maps[layer_name] = output
    
    def forward(self, x, step=0):
        # Forward pass through the model
        self.model(x)
        
        # Log feature maps to TensorBoard
        if self.visualize:
            i = 0
            col = 2
            row = len(self.feature_maps) // col
            fig, axs = plt.subplots(row, col, figsize=(10,10))
            axs = axs.flatten()

        for name, feature_map in self.feature_maps.items():
            # Get first volumetric image in batch
            depth_middle = feature_map.shape[-3] // 2
            feature_map = feature_map[0,:,depth_middle]

            # Create grid of feature maps
            grid = torchvision.utils.make_grid(
                feature_map.unsqueeze(1),
                normalize=True,
                nrow=8,
                padding=0
            )
            
            # Log to TensorBoard
            self.writer.add_image(
                f'Feature Maps/{name}/{step}',
                grid,
                step
            )
            
            if self.visualize:
                grid_np = grid.cpu().numpy().transpose(1, 2, 0)
                axs[i].imshow(grid_np)
                axs[i].set_title(f'{name}')
                axs[i].set_axis_off()
                i+=1

        if self.visualize:
            fig.suptitle('Feature Maps: Middle Slice of The Depth')
            plt.tight_layout()
            plt.show()
    
        return x

    def close(self):
        for hook in self.hooks:
            hook.remove()

import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from mpl_toolkits.mplot3d import Axes3D

def create_3d_tomograph(data, slice_spacing=1, colormap='viridis', 
                       transparency=0.15, threshold=0.3):
    """
    Create a 3D visualization of tomographic data with controllable transparency.
    
    Parameters:
    data: 3D numpy array of tomographic data
    slice_spacing: spacing between slices in the visualization
    colormap: matplotlib colormap to use for the visualization
    transparency: base transparency level (0 = fully transparent, 1 = solid)
    threshold: minimum intensity value to display (0-1, filters out noise)
    """
    
    if len(data.shape) > 3:
        data = data.squeeze()

    data = data.permute(1,2,0).numpy()
    
    # Normalize data to 0-1 range
    data_normalized = (data - np.min(data)) / (np.max(data) - np.min(data))

    # Get dimensions of the data
    x_dim, y_dim, z_dim = data.shape
    
    # Create coordinate matrices
    x, y = np.meshgrid(np.arange(x_dim), np.arange(y_dim))
    
    # Create the 3D figure
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    # Plot each slice with adaptive transparency
    for i in range(z_dim):
        # Get the 2D slice
        slice_data = data_normalized[:, :, i]
        
        # Create mask for thresholding
        mask = slice_data > threshold
        
        # Create color array with transparency
        colors = matplotlib.colormaps[colormap](slice_data)
        
        # Adjust alpha channel based on intensity and mask
        colors[:, :, 3] = np.where(mask, 
                                  slice_data * transparency, 
                                  0)  # Make low-intensity voxels fully transparent
        
        # Create the surface plot for this slice
        surf = ax.plot_surface(x, y, i * np.ones_like(x),
                             facecolors=colors,
                             rstride=1, cstride=1)
    
    # Customize the visualization
    ax.set_xlabel('X axis')
    ax.set_ylabel('Y axis')
    ax.set_zlabel('Z axis (Slices)')
    ax.set_title('3D Tomograph Visualization')
        
    # Set the viewing angle to better see internal structures
    ax.view_init(elev=20, azim=45)
    
    plt.show()

def log_masked_image(data, model, writer, epoch, visualize):    
    def create_masked_image(image, mask, alpha=1, mask_color='red'):
        color_rgb = plt.cm.colors.to_rgb(mask_color)
        rgb_image = np.stack([image] * 3, axis=-1)
        colored_mask = np.zeros_like(rgb_image)
        for i in range(3):
            colored_mask[..., i] = mask * color_rgb[i]
        
        blended = rgb_image.copy() 
        mask_3d = np.stack([mask] * 3, axis=-1)
        blended[mask_3d] = (1 - alpha) * rgb_image[mask_3d] + alpha * colored_mask[mask_3d]
        blended = np.clip(blended, 0, 1)
        return blended

    # import pdb; pdb.set_trace()
    mask_p = (model(data[0][0:1])[0] > 0.5).squeeze().cpu()
    mask_gt = (data[1][0] > 0.5).squeeze().cpu()
    image = data[0][0,1].cpu()
    image_gt = create_masked_image(image, mask_gt, mask_color='red')
    image_p = create_masked_image(image, mask_p, mask_color='yellow')

    # import pdb; pdb.set_trace()
    combined = np.concatenate([image_gt, image_p], axis=0)
    writer.add_image(f'Masked Images/{epoch}', combined.transpose(2, 1, 0), epoch)
    # writer.add_image(f'Predicted Mask/{epoch}/2', image_p.transpose(2, 1, 0), epoch)

    if visualize:
        fig, axs = plt.subplots(1, 2)
        axs = axs.flatten()
        axs[0].imshow(image_gt, cmap='gray')
        axs[0].set_title(f'Ground Truth MAsk')
        axs[0].set_axis_off()

        axs[1].imshow(image_p, cmap='gray')
        axs[1].set_title(f'Predicted Mask')
        axs[1].set_axis_off()

        plt.tight_layout()
        plt.show()

    y_pred = mask_p.flatten()
    y_true = mask_gt.flatten()
    f1_macro = f1_score(y_true, y_pred, average='macro')
    print(f'F1 Macro: {f1_macro:.2f}')
    print(classification_report(y_true, y_pred))
    return f1_macro

class F1MacroLoss(nn.Module):
    def __init__(self, epsilon=1e-7):
        super().__init__()
        self.epsilon = epsilon
        
    def forward(self, y_pred, y_true):
        """
        Calculate differentiable F1 macro loss.
        
        Args:
            y_pred: Predicted probabilities of shape (batch_size, num_classes)
            y_true: One-hot encoded target of shape (batch_size, num_classes)
        
        Returns:
            f1_macro_loss: The calculated loss
        """
        # Assert tensors are of the same shape
        assert y_pred.shape == y_true.shape, "Predictions and targets must have the same shape"
        
        # Convert probabilities to binary predictions (0 or 1)
        y_pred = torch.sigmoid(y_pred) if y_pred.shape[1] == 1 else torch.softmax(y_pred, dim=1)
        
        # Calculate true positives, false positives, and false negatives per class
        tp = torch.sum(y_true * y_pred, dim=0)
        fp = torch.sum((1 - y_true) * y_pred, dim=0)
        fn = torch.sum(y_true * (1 - y_pred), dim=0)
        
        # Calculate precision and recall per class
        precision = tp / (tp + fp + self.epsilon)
        recall = tp / (tp + fn + self.epsilon)
        
        # Calculate F1 score per class
        f1_per_class = 2 * (precision * recall) / (precision + recall + self.epsilon)
        
        # Calculate macro average
        f1_macro = torch.mean(f1_per_class)
        
        # Return loss (1 - F1 score)
        return 1 - f1_macro

    def __repr__(self):
        return f"F1MacroLoss(epsilon={self.epsilon})"