In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
from pathlib import Path

from tqdm import tqdm
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import torch.nn.functional as F
import torch.ao.quantization as tq

from torchinfo import summary
#print(torch.__version__)

# The device is automatically set to GPU if available, otherwise CPU
# device = torch.device("cpu")

# It is important that your model and all data are on the same device.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### 1- Load data

In [2]:
def get_data(**kwargs):
    """
    Get the training and test data. The data files are assumed to be in the
    same directory as this script.

    Args:
    - kwargs: Additional arguments that you might find useful - not necessary

    Returns:
    - train_data_input: Tensor[N_train_samples, C, H, W]
    - train_data_label: Tensor[N_train_samples, C, H, W]
    - test_data_input: Tensor[N_test_samples, C, H, W]
    where N_train_samples is the number of training samples, N_test_samples is
    the number of test samples, C is the number of channels (1 for grayscale),
    H is the height of the image, and W is the width of the image.
    """
    # Load the training data
    train_data_input = np.load("dataset5/all_train/input/train_input_all.npy")
    print('train input shape: ', train_data_input.shape)
    train_data_label = np.load("dataset5/all_train/output/train_label_all.npy")
    # train_data_input = np.load("dataset5/training/input/train_input_3.npy")
    # print('train input shape: ', train_data_input.shape)
    # train_data_label = np.load("dataset5/training/output/train_label_3.npy")
    
    # Make the training data a tensor
    train_data_input = torch.tensor(train_data_input, dtype=torch.float32)
    train_data_label = torch.tensor(train_data_label, dtype=torch.float32)

    # Load the test data
    test_data_input = np.load("dataset5/all_test/input/test_input_all.npy")
    print('test input shape: ', test_data_input.shape)
    test_data_label = np.load("dataset5/all_test/output/test_label_all.npy")
    # test_data_input = np.load("dataset5/testing/input/test_input_3.npy")
    # print('test input shape: ', test_data_input.shape)
    # test_data_label = np.load("dataset5/testing/output/test_label_3.npy")
    
    # Make the test data a tensor
    test_data_input = torch.tensor(test_data_input, dtype=torch.float32)
    test_data_label = torch.tensor(test_data_label, dtype=torch.float32)

    # train_data = train_data / 255  # Normalize to [0,1]
    # test_data_input = test_data_input / 255  

    
    #print(test_data_input[1,:,:,:])

    # Visualize the training data if needed
    # Set to False if you don't want to save the images
    if True:
        # Create the output directory if it doesn't exist
        if not Path("dataset5/train_image_output").exists():
            Path("dataset5/train_image_output").mkdir()
        for i in tqdm(range(10), desc="Plotting train images"):
            # Show the training and the target image side by side
            plt.subplot(1, 2, 1)
            plt.imshow(train_data_input[i].squeeze(), cmap="gray")
            plt.title("Training Input")
            plt.subplot(1, 2, 2)
            plt.title("Training Label")
            plt.imshow(train_data_label[i].squeeze(), cmap="gray")

            plt.savefig(f"dataset5/train_image_output/image_{i}.png")
            plt.close()

    return train_data_input, train_data_label, test_data_input, test_data_label

# train_data_input, train_data_label, test_data_input, test_data_label = get_data()

### 2- Define model

In [3]:
class DoubleConv(nn.Module):
    """ [Conv2d => ReLU] x2 """
    def __init__(self, in_ch, out_ch, k_size=3):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=k_size, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=k_size, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

# (1, 128, 1024) -> (1, 64, 64)   
class Unet5(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__() # Initialize the parent class
        self.target_height = 128
        self.target_width = 128

        # Encoder
        self.enc1 = DoubleConv(in_channels, 32)
        self.enc2 = DoubleConv(32, 64)
        self.enc3 = DoubleConv(64, 128)
        self.enc4 = DoubleConv(128, 256)
        self.pool = nn.MaxPool2d(2)

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=(1,4), stride=(1,4), padding=0),
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
        )

        # Skip connections
        self.skip2  = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=(1,4), stride=(1,4), padding=0),
        )
        self.skip3  = nn.Sequential(    # because stride 8 not supported by DPU
            nn.Conv2d(128, 128, kernel_size=(1,4), stride=(1,4), padding=0),
        )
        self.skip4  = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=(1,4), stride=(1,4), padding=0),
        )

        # Decoder
        self.up4 = nn.ConvTranspose2d(256, 256, kernel_size=2, stride=2)
        self.dec4 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.up3 = nn.ConvTranspose2d(128, 128, kernel_size=2, stride=2)
        self.dec3 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.up2 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        self.dec1 = nn.Sequential(
            nn.Conv2d(32, 1, kernel_size=3, padding=1),
        )
        # Output layer
        self.final = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e1p = self.pool(e1)
        e2 = self.enc2(e1p) 
        e2p = self.pool(e2) 
        e3 = self.enc3(e2p)
        e3p = self.pool(e3)
        e4 = self.enc4(e3p)
        e4p = self.pool(e4)
        # Bottleneck
        b = self.bottleneck(e4p) # 256, 16, 64 -> 256, 16, 16
        # Skip connections with convolution for resizing
        s4 = self.skip4(e4)  # 256, 32, 128 -> 256, 32, 32 
        s3 = self.skip3(e3)
        s2 = self.skip2(e2)
        # Decoder
        d4 = self.up4(b) # 256, 16, 16 -> 256, 32, 32
        d4c = torch.cat([s4, d4], dim=1)  # Concatenate along the channel dimension
        d4d = self.dec4(d4c)  # 512, 32, 32 -> 128, 32, 32

        d3 = self.up3(d4d)
        d3c = torch.cat([s3, d3], dim=1)
        d3d = self.dec3(d3c)  
        
        d2 = self.up2(d3d)
        d2c = torch.cat([s2, d2], dim=1)  # Concatenate along the channel dimension
        d2d = self.dec2(d2c)  # 128, 128, 128 -> 32, 128, 128

        d1d = self.dec1(d2d)  # 32, 128, 128 -> 1, 128, 128
        
        return d1d

In [None]:
# dummy = Unet5()
# x = torch.randn(1, 1, 256, 1024)  # Example input (batch, channel, H, W)
# y = dummy(x)
# print(y.shape)  # Output segmentation mask shape (should be close to input dimensions)


torch.Size([1, 1, 128, 128])


### 3- Training function

In [6]:
def train_model(train_data_input, train_data_label, **kwargs):
    """
    Train the model. Fill in the details of the data loader, the loss function,
    the optimizer, and the training loop.

    Args:
    - train_data_input: Tensor[N_train_samples, C, H, W]
    - train_data_label: Tensor[N_train_samples, C, H, W]
    - kwargs: Additional arguments that you might find useful - not necessary

    Returns:
    - model: torch.nn.Module
    """
    model = Unet5()
    model.train()
    model.to(device)

    criterion = nn.MSELoss() 
    optimizer = torch.optim.Adam(model.parameters(), lr=10**-3, weight_decay=1e-5)
    batch_size = 32
    dataset = TensorDataset(train_data_input, train_data_label)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # shuffle data since neighbour frames are very similar

    # Training loop
    n_epochs = 25
    best_loss = float('inf')
    for epoch in range(n_epochs):
        for x, y in tqdm(
            data_loader, desc=f"Training Epoch {epoch}", leave=False
        ):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad() # Clears old gradients from the previous step. PyTorch accumulates gradient by default
            output = model(x)
            loss = criterion(output, y)
            if loss < best_loss:
                best_loss = loss
                torch.save(model.state_dict(), "best_unet5.pth")

            loss.backward() # Computes the gradients of the loss with respect to the model parameters.
            optimizer.step() # Update model parameters using computed gradient
        torch.save(model.state_dict(), "unet5.pth")
        print(f"Epoch {epoch} loss: {loss.item()}")

    return model

In [None]:
checkpoint = {
    'epoch': current_epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': current_loss,
    # add more keys if needed
}
torch.save(checkpoint, "training_state.pth")


### 4- Testing function

In [None]:
def test_model(model, test_data_input, test_data_label):
    """
    Uses your model to predict the ouputs for the test data.

    Args:
    - model: torch.nn.Module
    - test_data_input: Tensor
    """
    model.eval()
    model.to(device)
    # model evaluated using MSE error
    mse_criterion = nn.MSELoss()
    total_mse = 0.0
    count = 0

    with torch.no_grad():
        test_data_input = test_data_input.to(device)
        test_data_label = test_data_label.to(device)
        # Predict the output batch-wise to avoid memory issues
        test_data_output = []
        # Can increase or decrease this batch size depending on your
        # memory requirements of your computer / model
        # This will not affect the performance of the model
        batch_size = 64
        for i in tqdm(
            range(0, test_data_input.shape[0], batch_size),
            desc="Predicting test output",
        ):
            output = model(test_data_input[i : i + batch_size])
            # Calculate MSE for the batch
            label = test_data_label[i : i + batch_size]
            mse = mse_criterion(output, label)
            total_mse += mse.item() * output.shape[0]
            count += output.shape[0]

            test_data_output.append(output.cpu())
        test_data_output = torch.cat(test_data_output)

    # Calculate the average MSE
    mean_mse = total_mse / count
    print(f"Mean Test MSE: {mean_mse:.4f}")
    # print(f"Mean Test L1 Loss: {mean_mse:.4f}")

    # Save the output
    test_data_output = test_data_output.numpy()
    # Ensure all values are in the range [0, 255]
    # save_data_clipped = np.clip(test_data_output, 0, 255)

    # Set to False if you don't want to save the images
    if True:
        # Create the output directory if it doesn't exist
        if not Path("dataset5/test_image_output").exists():
            Path("dataset5/test_image_output").mkdir()
        for i in tqdm(range(20), desc="Plotting test images"):
            # Show the training and the target image side by side
            plt.subplot(1, 3, 1)
            plt.title("Test Input")
            plt.imshow(test_data_input[i].squeeze().cpu().numpy(), cmap="gray")
            plt.subplot(1, 3, 2)
            plt.imshow(test_data_output[i].squeeze(), cmap="gray")
            plt.title("Test Output")
            plt.subplot(1, 3, 3)
            plt.imshow(test_data_label[i].squeeze().cpu().numpy(), cmap="gray")
            plt.title("Test Label")

            plt.tight_layout()
            plt.savefig(f"dataset5/test_image_output/image_{i}.png")
            plt.close()

### 5- Main: run training and testing

In [8]:
seed = 1
# Reproducibility
torch.manual_seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True

# Load the data
train_data_input, train_data_label, test_data_input, test_data_label = get_data()
# Train the model
model = train_model(train_data_input, train_data_label)

train input shape:  (5040, 1, 256, 1024)
test input shape:  (556, 1, 256, 1024)


Plotting train images: 100%|██████████| 10/10 [00:00<00:00, 10.43it/s]
                                                                   

Epoch 0 loss: 0.7852686643600464


                                                                   

Epoch 1 loss: 0.8400561213493347


                                                                   

Epoch 2 loss: 0.7969252467155457


                                                                   

Epoch 3 loss: 0.7687708735466003


                                                                   

Epoch 4 loss: 0.7635105848312378


                                                                   

Epoch 5 loss: 0.6967490911483765


                                                                   

Epoch 6 loss: 0.6684781908988953


                                                                   

Epoch 7 loss: 0.747794508934021


                                                                   

Epoch 8 loss: 0.6047333478927612


                                                                   

Epoch 9 loss: 0.6750996112823486


                                                                    

Epoch 10 loss: 0.601094663143158


                                                                    

Epoch 11 loss: 0.6048842668533325


                                                                    

Epoch 12 loss: 0.6302788257598877


                                                                    

Epoch 13 loss: 0.5760076642036438


                                                                    

Epoch 14 loss: 0.5842279195785522


                                                                    

Epoch 15 loss: 0.5436803102493286


                                                                    

Epoch 16 loss: 0.5867789387702942


                                                                    

Epoch 17 loss: 0.5723007321357727


                                                                    

Epoch 18 loss: 0.5131798982620239


                                                                    

Epoch 19 loss: 0.4732908308506012


                                                                    

Epoch 20 loss: 0.5075504183769226


                                                                    

Epoch 21 loss: 0.5429621934890747


                                                                    

Epoch 22 loss: 0.48356738686561584


                                                                    

Epoch 23 loss: 0.4542495906352997


                                                                    

Epoch 24 loss: 0.46527794003486633




### Result

In [11]:
best_model = Unet5()
best_model.load_state_dict(torch.load('best_unet5-MSE-24ep.pth', weights_only=True))
summary(best_model, input_size=(1, 1, 256, 1024))

Layer (type:depth-idx)                   Output Shape              Param #
Unet5                                    [1, 1, 128, 128]          65
├─DoubleConv: 1-1                        [1, 32, 256, 1024]        --
│    └─Sequential: 2-1                   [1, 32, 256, 1024]        --
│    │    └─Conv2d: 3-1                  [1, 32, 256, 1024]        320
│    │    └─BatchNorm2d: 3-2             [1, 32, 256, 1024]        64
│    │    └─ReLU: 3-3                    [1, 32, 256, 1024]        --
│    │    └─Conv2d: 3-4                  [1, 32, 256, 1024]        9,248
│    │    └─BatchNorm2d: 3-5             [1, 32, 256, 1024]        64
│    │    └─ReLU: 3-6                    [1, 32, 256, 1024]        --
├─MaxPool2d: 1-2                         [1, 32, 128, 512]         --
├─DoubleConv: 1-3                        [1, 64, 128, 512]         --
│    └─Sequential: 2-2                   [1, 64, 128, 512]         --
│    │    └─Conv2d: 3-7                  [1, 64, 128, 512]         18,496
│    │ 

In [None]:
test_model(best_model, test_data_input, test_data_label)

Predicting test output: 100%|██████████| 9/9 [01:31<00:00, 10.18s/it]


Mean Test MSE: 0.5673


Plotting test images: 100%|██████████| 20/20 [00:05<00:00,  3.66it/s]


In [15]:
# Calculate the size of the model on disk
size_mb = os.path.getsize("best_unet5-MSE-24ep.pth") / 1024**2
print(f"Model size on disk: {size_mb:.2f} MB")

Model size on disk: 27.57 MB


### Exporting model to ONNX

In [16]:
best_model = Unet5()
best_model.load_state_dict(torch.load('best_unet5-MSE-24ep.pth', weights_only=True))

dummy_input = torch.randn(1, 1, 256, 1024) 

torch.onnx.export(
    best_model,                       # Model to export
    dummy_input,                 # Example input
    "best_unet5-MSE-24ep.onnx",                # File name for output
    export_params=True,          # Store trained weights
    opset_version=11,            # ONNX opset (11 is widely supported)
    do_constant_folding=True,    # Optimize constant expressions
    input_names=["input"],       # Name of model inputs
    output_names=["output"],     # Name of model outputs
    dynamic_axes={               # Allow variable batch size
        "input": {0: "batch_size"},
        "output": {0: "batch_size"}
    }
)
print("Model exported to ONNX format as model.onnx")


Model exported to ONNX format as model.onnx


### Measuring inference time on laptop

In [23]:
import time

# Load your model (make sure it's in eval mode)
model = Unet5()
model.load_state_dict(torch.load('best_unet5-MSE-24ep.pth', weights_only=True))


# Prepare your test input tensor (match your model's expected input)
input_tensor = torch.randn(1, 1, 256, 1024)  

# Warm-up (optional, ensures any lazy initializations are done)
with torch.no_grad():
    _ = model(input_tensor)

# Start timer
start = time.time()

# Inference
with torch.no_grad():
    output = model(input_tensor)

# End timer
end = time.time()

# Print inference time in seconds
inf_time = end - start
print("Inference time: {:.6f} seconds".format(end - start))
print("Frames Per Second: {:.1f} FPS".format(1/inf_time))


Inference time: 0.167052 seconds
Frames Per Second: 6.0 FPS
