# Binom Split Dataset

In [1]:
!pip install numpy torch tifffile



In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

from HistBinomDataset import HistogramBinomDataset

import matplotlib.pyplot as plt

In [3]:
MAX_EPOCHS = 10
BATCH_SIZE = 32
CROP_SIZE = 256
NUM_WORKERS = 4
BINOM_SPLIT = 0.6

# Paths
ROOT_DIR = 'files'

In [4]:
train_dataset = HistogramBinomDataset(
    root_dir=ROOT_DIR,
    crop_size=CROP_SIZE,
    mode='train',
    binomial_split=True,
    binomial_prob=BINOM_SPLIT,          # Optional: adjust probability for binomial split
    apply_augmentations=True,           # Enable D4 augmentations
    virt_size=10000                     # Number of samples per epoch
)

test_dataset = HistogramBinomDataset(
    root_dir=ROOT_DIR,
    crop_size=CROP_SIZE,
    mode='test',
    binomial_split=False,
    binomial_prob=BINOM_SPLIT,         # Optional: adjust probability for binomial split
    apply_augmentations=False,         # Enable D4 augmentations
    virt_size=100                      # Number of samples per epoch
)

In [5]:
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    pin_memory=True,
    num_workers=NUM_WORKERS
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    drop_last=False,
    pin_memory=True,
    num_workers=NUM_WORKERS
)

In [6]:
# Data Size !
img = next(iter(test_loader))
print("Batch shape:", img['histogram'].shape)
print("Noisy shape:", img['noisy'].shape)
print("Clean shape:", img['clean'].shape)
print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")

Batch shape: torch.Size([32, 3, 17, 256, 256])
Noisy shape: torch.Size([32, 3, 256, 256])
Clean shape: torch.Size([32, 3, 256, 256])
Train dataset size: 10000
Test dataset size: 100


# NEURAL NETWORK !!!

In [7]:
# CONVOLUTION
def conv3x3(in_channels, out_channels, stride=1, padding=1, bias=True):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=padding, bias=bias)

In [8]:
class DownConv(nn.Module):
    """
    Residual Down Convolution Block with optional max pooling.
    3 conv layers with residual connection + pooling.
    """
    def __init__(self, in_channels, out_channels, pooling=True):
        super(DownConv, self).__init__()
        self.pooling = pooling
        self.conv1 = conv3x3(in_channels, out_channels)
        self.conv2 = conv3x3(out_channels, out_channels)
        self.conv3 = conv3x3(out_channels, out_channels)
        if self.pooling:
            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x_skip = self.conv1(x)
        x = F.relu(self.conv2(x_skip))
        x = F.relu(self.conv3(x) + x_skip)  # residual add
        before_pool = x
        if self.pooling:
            x = self.pool(x)
        return x, before_pool

In [9]:
class UpConv(nn.Module):
    """
    Residual Up Convolution Block.
    Supports skip connection merge_mode: 'add' or 'concat'.
    Upsamples with ConvTranspose2d.
    """
    def __init__(self, in_channels, out_channels, merge_mode='concat', up_mode='transpose'):
        super(UpConv, self).__init__()
        assert merge_mode in ('add', 'concat'), "merge_mode must be 'add' or 'concat'"
        assert up_mode in ('transpose',), "only 'transpose' up_mode supported for now"

        self.merge_mode = merge_mode

        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

        if self.merge_mode == 'concat':
            self.conv1 = conv3x3(out_channels * 2, out_channels)
        else:  # add
            self.conv1 = conv3x3(out_channels, out_channels)

        self.conv2 = conv3x3(out_channels, out_channels)
        self.conv3 = conv3x3(out_channels, out_channels)

    def forward(self, from_down, from_up):
        from_up = self.upconv(from_up)
        if self.merge_mode == 'concat':
            x = torch.cat((from_up, from_down), dim=1)
        else:  # add
            x = from_up + from_down
        x_skip = self.conv1(x)
        x = F.relu(self.conv2(x_skip))
        x = F.relu(self.conv3(x) + x_skip)  # residual add
        return x

In [10]:
class UNet(nn.Module):
    def __init__(self, in_channels=3, n_bins=33, out_mode='mean', 
                 merge_mode='concat', depth=4, start_filters=64):
        """
        Args:
            in_channels: input channels (usually 3 for RGB)
            n_bins: number of histogram bins
            out_mode: 'mean' or 'distribution'
            merge_mode: 'add' (residual) or 'concat' (default)
            depth: number of downsampling layers (encoder depth)
            start_filters: number of filters in first conv block
        """
        super(UNet, self).__init__()
        assert merge_mode in ('add', 'concat'), "merge_mode must be 'add' or 'concat'"

        self.out_mode = out_mode
        self.n_bins = n_bins
        self.merge_mode = merge_mode
        self.depth = depth
        self.start_filters = start_filters

        # Input channels after flattening histogram dimension
        self.input_channels = in_channels * n_bins

        # Encoder (DownConvs)
        self.down_convs = nn.ModuleList()
        in_ch = self.input_channels
        for i in range(depth):
            out_ch = start_filters * (2 ** i)
            pooling = (i < depth - 1)
            self.down_convs.append(DownConv(in_ch, out_ch, pooling=pooling))
            in_ch = out_ch

        # Decoder (UpConvs)
        self.up_convs = nn.ModuleList()
        for i in reversed(range(depth - 1)):
            in_ch = start_filters * (2 ** (i + 1))
            out_ch = start_filters * (2 ** i)
            self.up_convs.append(UpConv(in_ch, out_ch, merge_mode=merge_mode))

        # Final conv layer
        if out_mode == 'mean':
            self.final = nn.Conv2d(start_filters, 3, kernel_size=1)
        elif out_mode == 'distribution':
            self.final = nn.Conv2d(start_filters, 3 * n_bins, kernel_size=1)
        else:
            raise ValueError("Invalid out_mode. Use 'mean' or 'distribution'.")

    def forward(self, x):
        """
        Input x: (B, 3, bins, H, W)
        """
        B, C, bins, H, W = x.shape
        x = x.view(B, C * bins, H, W)

        encoder_outs = []

        # Encoder path
        for down in self.down_convs:
            x, before_pool = down(x)
            encoder_outs.append(before_pool)

        # Decoder path
        for i, up in enumerate(self.up_convs):
            skip = encoder_outs[-(i + 2)]
            x = up(skip, x)

        # Final conv
        out = self.final(x)

        if self.out_mode == 'distribution':
            out = out.view(B, 3, self.n_bins, H, W)
            out = F.softmax(out, dim=2)  # softmax over bins

        return out

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using... ", device)

Using...  cpu


In [12]:
def compute_mean_from_histogram(hist):
    # hist: (B, 3, bins, H, W)
    device = hist.device
    bins = hist.shape[2]
    bin_centers = torch.linspace(0, 1, bins).to(device)  # assuming normalized bins [0,1]
    # Multiply hist by bin centers and sum over bins axis to get mean
    mean = torch.sum(hist * bin_centers.view(1, 1, bins, 1, 1), dim=2)  # shape (B, 3, H, W)
    return mean

In [13]:
def train_epoch(model, loader, optimizer, criterion, device=device):
    model.train()
    total_loss = 0

    for batch in loader:
        hist = batch['histogram']  # tuple (h1, h2)

        # Use h1 as input, h2 as target
        input_hist = hist[0].to(device)
        target_hist = hist[1].to(device)

        optimizer.zero_grad()

        output_mean = model(input_hist)                # output shape: (B, 3, H, W)
        target_mean = compute_mean_from_histogram(target_hist)  # shape: (B, 3, H, W)

        loss = criterion(output_mean, target_mean)    # criterion = nn.MSELoss()

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)

In [14]:
model = UNet(
    in_channels=3,
    n_bins=17,
    out_mode='mean',     # or 'distribution'
    merge_mode='concat', # or 'add'
    depth=4,
    start_filters=64
)
print(model)

UNet(
  (down_convs): ModuleList(
    (0): DownConv(
      (conv1): Conv2d(51, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): DownConv(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (2): DownConv(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1,

In [15]:
# Optimizer & loss
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

In [2]:
num_epochs = 10
best_val_loss = float('inf')

for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_loader, optimizer, criterion)

    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss:.4f}")

torch.save(model.state_dict(), "best_unet.pth")
print("Saved best model.")

NameError: name 'train_epoch' is not defined

In [None]:
def test_and_plot(model, test_loader, device=device, num_samples=3):
    model.eval()
    
    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            if i >= num_samples:
                break
            
            hist = batch['histogram']
            input_hist = hist[0].to(device)
            target_hist = hist[1].to(device)

            # Model prediction (mean RGB)
            pred_mean = model(input_hist)  # (B, 3, H, W)

            # Compute input and target means from histograms
            input_mean = compute_mean_from_histogram(input_hist)
            target_mean = compute_mean_from_histogram(target_hist)

            # Move to CPU and convert to numpy for plotting
            input_mean_np = input_mean.cpu().numpy()
            pred_mean_np = pred_mean.cpu().numpy()
            target_mean_np = target_mean.cpu().numpy()

            B = input_mean_np.shape[0]
            for b in range(B):
                fig, axs = plt.subplots(1, 3, figsize=(12, 4))
                axs[0].imshow(input_mean_np[b].transpose(1, 2, 0))
                axs[0].set_title('Input Mean')
                axs[0].axis('off')

                axs[1].imshow(pred_mean_np[b].transpose(1, 2, 0))
                axs[1].set_title('Predicted Mean')
                axs[1].axis('off')

                axs[2].imshow(target_mean_np[b].transpose(1, 2, 0))
                axs[2].set_title('Target Mean')
                axs[2].axis('off')

                plt.show()

In [None]:
test_and_plot(model, test_loader, device=device, num_samples=2)