In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor
import numpy as np
from PIL import Image
import onnx

# Custom Dataset
class CustomDataset(Dataset):
    def __init__(self, input_dir, target_dir):
        self.input_files = sorted([os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('_object_noise.tif')], key=lambda x: int(os.path.basename(x).split('_')[0]))
        self.target_files = sorted([os.path.join(target_dir, f) for f in os.listdir(target_dir) if f.endswith('_noise.tif')], key=lambda x: int(os.path.basename(x).split('_')[0]))

    def __len__(self):
        return len(self.input_files)

    def __getitem__(self, idx):
        input_image = ToTensor()(np.array(Image.open(self.input_files[idx]), dtype=np.float32) )
        target_image = ToTensor()(np.array(Image.open(self.target_files[idx]), dtype=np.float32) )
        return input_image, target_image

# ResNet-like Neural Network
class SimpleResNet(nn.Module):
    def __init__(self):
        super(SimpleResNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.conv6 = nn.Conv2d(64, 1, kernel_size=3, padding=1)

    def forward(self, x):
        residual = x
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))
        x = self.relu(self.conv5(x))
        x = self.conv6(x)
        return x + residual

# Parameters
input_dir = r"E:\Deeplearning\C11800\High\N2N\object_noise"
target_dir = r"E:\Deeplearning\C11800\High\N2N\noise"
batch_size = 32
epochs = 40
learning_rate = 1e-3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dataset and DataLoader
dataset = CustomDataset(input_dir, target_dir)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

# Model, Loss, Optimizer
model = SimpleResNet().to(device)

criterion = nn.L1Loss()  # MAE
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.1)

# Training Loop
for epoch in range(epochs):
    model.train()
    epoch_loss = 0

    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    scheduler.step()
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss/len(dataloader):.4f}")

# Save Model as ONNX
dummy_input = torch.randn(1, 1, 256, 256).to(device)
torch.onnx.export(model, dummy_input, "resnet_model.onnx", opset_version=11)


Epoch [1/20], Loss: 30.9067
Epoch [2/20], Loss: 27.2627
Epoch [3/20], Loss: 25.9991
Epoch [4/20], Loss: 25.6579
Epoch [5/20], Loss: 26.0223
Epoch [6/20], Loss: 25.0183
Epoch [7/20], Loss: 24.7995
Epoch [8/20], Loss: 24.6844
Epoch [9/20], Loss: 24.6942
Epoch [10/20], Loss: 24.4605
Epoch [11/20], Loss: 24.3531
Epoch [12/20], Loss: 24.1502
Epoch [13/20], Loss: 24.0782
Epoch [14/20], Loss: 24.0226
Epoch [15/20], Loss: 23.9732
Epoch [16/20], Loss: 23.9302
Epoch [17/20], Loss: 23.8866
Epoch [18/20], Loss: 23.8481
Epoch [19/20], Loss: 23.8115
Epoch [20/20], Loss: 23.7755


In [7]:


# Inference
model.eval()
test_image_path = "path_to_test_image.tif"

try:
    test_image = ToTensor()(np.array(Image.open(test_image_path), dtype=np.uint16) / 65535.0).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(test_image)

    output_image = output.squeeze().cpu().numpy()
    output_image = np.clip(output_image * 65535, 0, 65535).astype(np.uint16)
    Image.fromarray(output_image).save("output_image.tif")
    print("Inference complete. Output saved as 'output_image.tif'.")

except FileNotFoundError:
    print(f"Error: File not found at {test_image_path}")
except Exception as e:
    print(f"An error occurred during inference: {e}")

print("Training and inference complete.")


In [4]:
import onnx
from onnxruntime.transformers.onnx_model import OnnxModel

# Load ONNX model
model_path = "resnet_model.onnx"
pruned_model_path = "resnet_model_pruned.onnx"

model = OnnxModel(onnx.load(model_path))

# Prune unnecessary nodes (example: remove identity nodes)
model.prune()

# Save the pruned model
onnx.save(model.model, pruned_model_path)
print(f"Pruned model saved to {pruned_model_path}")


AttributeError: 'OnnxModel' object has no attribute 'prune'

In [5]:
from onnxruntime.quantization import quantize_dynamic, QuantType

# Input and output model paths
pruned_model_path = "resnet_model_pruned.onnx"
quantized_model_path = "resnet_model_quantized.onnx"

# Perform dynamic quantization
quantize_dynamic(
    model_input=pruned_model_path,
    model_output=quantized_model_path,
    weight_type=QuantType.QInt8  # Quantize weights to INT8
)

print(f"Quantized model saved to {quantized_model_path}")


ValidationError: Unable to open proto file: resnet_model_pruned.onnx. Please check if it is a valid proto. 