In [13]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from skimage import io
import numpy as np
from timm.layers import BlurPool2d
from torch.nn.utils import weight_norm

###############################################
# 1. Define the SEM Dataset
###############################################
class SEMDataset(Dataset):
    def __init__(self, images_dir, labels_csv, transform=None):
        """
        Args:
            images_dir (str): Directory containing the .pt image files.
            labels_csv (str): Path to CSV file with columns "filename" and "snr_db".
            transform (callable, optional): Optional transform to be applied on an image tensor.
        """
        self.images_dir = images_dir
        self.labels_df = pd.read_csv(labels_csv)
        self.transform = transform

    def __len__(self):
        return len(self.labels_df)

    def __getitem__(self, idx):
        row = self.labels_df.iloc[idx]
        filename = row['filename']      # Name of the .pt file
        label_db = row['snr_db']          # SNR value in dB
        image_path = os.path.join(self.images_dir, filename)
        # Use weights_only=True for safe loading.
        image_tensor = torch.load(image_path, weights_only=True)  # Expected shape: (1, H, W)
        if self.transform:
            image_tensor = self.transform(image_tensor)
        return image_tensor, label_db

class EvoNormB0(nn.Module):
    def __init__(self, num_features, eps=1e-5):
        super(EvoNormB0, self).__init__()
        # Learnable scaling and shifting parameters
        self.gamma = nn.Parameter(torch.ones(1, num_features, 1, 1))
        self.beta  = nn.Parameter(torch.zeros(1, num_features, 1, 1))
        # Learnable parameter v that modulates the input (for nonlinearity)
        self.v     = nn.Parameter(torch.ones(1, num_features, 1, 1))
        self.eps   = eps

    def forward(self, x):
        # x shape: (B, C, H, W)
        # Compute instance variance (over spatial dimensions)
        inst_var = x.var(dim=[2, 3], keepdim=True, unbiased=False)
        # Compute batch variance (over batch and spatial dimensions)
        batch_var = x.var(dim=[0, 2, 3], keepdim=True, unbiased=False)
        # Compute the denominator using a max between the batch std and a modulated version of x plus instance std
        denom = torch.max(torch.sqrt(batch_var + self.eps), self.v * x + torch.sqrt(inst_var + self.eps))
        return (x / denom) * self.gamma + self.beta

class ECALayer(nn.Module):
    def __init__(self, channels, gamma=2, b=1):
        super().__init__()
        t = int(abs((np.log2(channels) + b) / gamma))
        k = t if t % 2 else t + 1
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=k//2, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv(y.squeeze(-1).transpose(-1, -2))
        y = y.transpose(-1, -2).unsqueeze(-1)
        y = self.sigmoid(y)


###############################################
# 2. Define the CNN Model that Fuses Statistical Features
###############################################
class CNNCombinedFeatureModel(nn.Module):
    def __init__(self):
        super(CNNCombinedFeatureModel, self).__init__()
        # Define a simple CNN architecture.
        # The network will receive a 3-channel input:
        #   Channel 1: Raw image (normalized to [0,1])
        #   Channel 2: Local variance map
        #   Channel 3: Log-PSD map
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)  # Halve spatial dimensions.
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        # Assuming input images are 256x256, after 3 pooling layers: 256/8 = 32.
        self.fc = nn.Sequential(
            nn.Linear(64 * 32 * 32, 128),
            nn.ReLU(),
            nn.Linear(128, 1)  # Predict SNR (in dB)
        )
    
    def forward(self, x):
        # x: (B, 1, H, W) -> Compute statistical maps and stack channels.
        B, C, H, W = x.size()
        # Compute Local Variance Map
        local_var = self.compute_local_variance(x)  # (B, 1, H, W)
        # Compute PSD Map (in log scale)
        psd_map = self.compute_psd_map(x)           # (B, 1, H, W)
        # Stack the raw image, local variance, and PSD map along the channel dimension.
        x_combined = torch.cat([x, local_var, psd_map], dim=1)  # (B, 3, H, W)
        
        # Pass the combined 3-channel input through the CNN
        out = self.conv1(x_combined)
        out = self.conv2(out)
        out = self.conv3(out)
        out = out.view(B, -1)
        out = self.fc(out)
        return out
    
    def compute_local_variance(self, x):
        # Compute the local variance using a sliding window (kernel_size=7)
        kernel_size = 7
        pad = kernel_size // 2
        # Compute local mean with average pooling
        mean = F.avg_pool2d(x, kernel_size=kernel_size, stride=1, padding=pad)
        mean_sq = F.avg_pool2d(x * x, kernel_size=kernel_size, stride=1, padding=pad)
        variance = mean_sq - mean * mean
        return variance
    
    def compute_psd_map(self, x, eps=1e-8, fft_shift=True, normalize=True):
        """
        Compute the log power spectral density (PSD) map from an image.
    
        Args:
            x (Tensor): Input tensor with shape (B, C, H, W).
            eps (float): Small constant for numerical stability in the log transform.
            fft_shift (bool): If True, applies FFT shift to center the zero-frequency component.
            normalize (bool): If True, normalizes the log PSD map per image to zero mean and unit variance.
    
        Returns:
            Tensor: Log PSD map with shape (B, 1, H, W).
        """
        B, C, H, W = x.shape
    
        # Convert to grayscale if input has multiple channels (assumes RGB)
        if C > 1:
            x = 0.299 * x[:, 0:1] + 0.587 * x[:, 1:2] + 0.114 * x[:, 2:3]
    
        # Compute the 2D FFT on the last two dimensions
        fft = torch.fft.fft2(x)
        if fft_shift:
            fft = torch.fft.fftshift(fft, dim=(-2, -1))
    
        # Compute the Power Spectral Density (PSD)
        psd = torch.abs(fft) ** 2
    
        # Apply logarithmic transformation for numerical stability
        psd_log = torch.log(psd + eps)
    
        # Optionally, normalize the PSD map per image
        if normalize:
            mean = psd_log.mean(dim=(-2, -1), keepdim=True)
            std = psd_log.std(dim=(-2, -1), keepdim=True) + eps
            psd_log = (psd_log - mean) / std
    
        # Ensure the output tensor has a channel dimension: (B, 1, H, W)
        if psd_log.dim() == 3:
            psd_log = psd_log.unsqueeze(1)
    
        return psd_log


###############################################
# 3. Create Dataset and DataLoader
###############################################
images_dir = r'C:\Users\lewka\deep_learning\SEM Deep Learning Multiclass Noise Level Classification with Data Augmentation (IEEE)\SEM_images\Cropped\Biofilm SEM Dataset\Noisy'
labels_csv = r"C:\Users\lewka\deep_learning\SEM Deep Learning Multiclass Noise Level Classification with Data Augmentation (IEEE)\SEM_images\Cropped\Biofilm SEM Dataset\Label\labels.csv"

dataset = SEMDataset(images_dir, labels_csv)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

###############################################
# 4. Instantiate Model, Loss, and Optimizer
###############################################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNNCombinedFeatureModel().to(device)
criterion = nn.MSELoss()  # Mean Squared Error Loss for regression
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

###############################################
# 5. Training Loop
###############################################
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    for images, labels in dataloader:
        images = images.to(device)
        # Ensure images have shape (B, 1, H, W); if not, add a channel dimension.
        if images.ndim == 3:
            images = images.unsqueeze(1)
        # Convert labels to tensor.
        labels = torch.as_tensor(labels, dtype=torch.float32, device=device).unsqueeze(1)
        
        optimizer.zero_grad()
        predicted_snr = model(images)
        loss = criterion(predicted_snr, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * images.size(0)
    
    epoch_loss /= len(dataset)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.6f}")

###############################################
# 6. Evaluate on a Batch After Training
###############################################
model.eval()
with torch.no_grad():
    for images, labels in dataloader:
        images = images.to(device)
        if images.ndim == 3:
            images = images.unsqueeze(1)
        predicted_snr = model(images)
        print("Predicted SNR values (dB):", predicted_snr.squeeze().detach().cpu().numpy())
        print("Ground truth SNR (dB):", labels)
        break


Epoch 1/100, Loss: 113.748747
Epoch 2/100, Loss: 27.313970
Epoch 3/100, Loss: 17.735036
Epoch 4/100, Loss: 15.739919
Epoch 5/100, Loss: 15.790888
Epoch 6/100, Loss: 13.105866
Epoch 7/100, Loss: 15.149063
Epoch 8/100, Loss: 14.091651



KeyboardInterrupt



In [2]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from skimage import io
import numpy as np
from timm.layers import BlurPool2d
from torch.nn.utils import weight_norm
from tqdm import tqdm
###############################################
# 1. Define the SEM Dataset
###############################################
class SEMDataset(Dataset):
    def __init__(self, images_dir, labels_csv, transform=None):
        """
        Args:
            images_dir (str): Directory containing the .pt image files.
            labels_csv (str): Path to CSV file with columns "filename" and "snr_db".
            transform (callable, optional): Optional transform to be applied on an image tensor.
        """
        self.images_dir = images_dir
        self.labels_df = pd.read_csv(labels_csv)
        self.transform = transform

    def __len__(self):
        return len(self.labels_df)

    def __getitem__(self, idx):
        row = self.labels_df.iloc[idx]
        filename = row['filename']      # Name of the .pt file
        label_db = row['snr_db']          # SNR value in dB
        image_path = os.path.join(self.images_dir, filename)
        # Use weights_only=True for safe loading.
        image_tensor = torch.load(image_path, weights_only=True)  # Expected shape: (1, H, W)
        if self.transform:
            image_tensor = self.transform(image_tensor)
        return image_tensor, label_db

class EvoNormB0(nn.Module):
    def __init__(self, num_features, eps=1e-5):
        super(EvoNormB0, self).__init__()
        # Learnable scaling and shifting parameters
        self.gamma = nn.Parameter(torch.ones(1, num_features, 1, 1))
        self.beta  = nn.Parameter(torch.zeros(1, num_features, 1, 1))
        # Learnable parameter v that modulates the input (for nonlinearity)
        self.v     = nn.Parameter(torch.ones(1, num_features, 1, 1))
        self.eps   = eps

    def forward(self, x):
        # x shape: (B, C, H, W)
        # Compute instance variance (over spatial dimensions)
        inst_var = x.var(dim=[2, 3], keepdim=True, unbiased=False)
        # Compute batch variance (over batch and spatial dimensions)
        batch_var = x.var(dim=[0, 2, 3], keepdim=True, unbiased=False)
        # Compute the denominator using a max between the batch std and a modulated version of x plus instance std
        denom = torch.max(torch.sqrt(batch_var + self.eps), self.v * x + torch.sqrt(inst_var + self.eps))
        return (x / denom) * self.gamma + self.beta

class ECALayer(nn.Module):
    def __init__(self, channels, gamma=2, b=1):
        super().__init__()
        t = int(abs((np.log2(channels) + b) / gamma))
        k = t if t % 2 else t + 1
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=k//2, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv(y.squeeze(-1).transpose(-1, -2))
        y = y.transpose(-1, -2).unsqueeze(-1)
        y = self.sigmoid(y)


class PSDGuidedFusion(nn.Module):
    def __init__(self, channels=16):
        super().__init__()
        self.psd_proj = nn.Conv2d(channels, channels, kernel_size=1)
        self.var_proj = nn.Conv2d(channels, channels, kernel_size=1)
        self.raw_proj = nn.Conv2d(channels, channels, kernel_size=1)
        
        # Learn frequency-dependent fusion weights
        self.freq_att = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // 4, 1),
            nn.GELU(),
            nn.Conv2d(channels // 4, 3, 1),
            nn.Softmax(dim=1)
        )

    def forward(self, raw, var, psd):
        # Project to common space
        p_psd = self.psd_proj(psd)
        p_var = self.var_proj(var)
        p_raw = self.raw_proj(raw)
        
        # Frequency-adaptive weighting
        att_weights = self.freq_att(p_psd)  # Expected shape: [B, 3, 1, 1]
        att0 = att_weights[:, 0:1, :, :]  # Shape: [B, 1, 1, 1]
        att1 = att_weights[:, 1:2, :, :]
        att2 = att_weights[:, 2:3, :, :]
        
        fused = att0 * p_raw + att1 * p_var + att2 * p_psd
        
        return fused


###############################################
# 2. Define the CNN Model that Fuses Statistical Features
###############################################
class CNNFusionModel(nn.Module):
    def __init__(self):
        super(CNNFusionModel, self).__init__()
        fusion_channels = 16

        # Projection layers to map 1-channel maps to fusion_channels
        self.raw_proj_conv = nn.Conv2d(1, fusion_channels, kernel_size=3, padding=1)
        self.var_proj_conv = nn.Conv2d(1, fusion_channels, kernel_size=3, padding=1)
        self.psd_proj_conv = nn.Conv2d(1, fusion_channels, kernel_size=3, padding=1)

        # PSD-guided fusion module
        self.psd_guided_fusion = PSDGuidedFusion(channels=fusion_channels)

        # CNN backbone: Note that the input now has fusion_channels instead of 3
        self.conv1 = nn.Sequential(
            nn.Conv2d(fusion_channels, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)  # Halve spatial dimensions.
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        # Assuming input images are 256x256, after 3 pooling layers: 256/8 = 32.
        self.fc = nn.Sequential(
            nn.Linear(64 * 32 * 32, 128),
            nn.ReLU(),
            nn.Linear(128, 1)  # Predict SNR (in dB)
        )

    def forward(self, x):
        # x: (B, 1, H, W) – raw image
        B, C, H, W = x.size()

        # Compute Local Variance Map (still single channel)
        local_var = self.compute_local_variance(x)  # (B, 1, H, W)

        # Compute PSD Map (log scale, single channel)
        psd_map = self.compute_psd_map(x)           # (B, 1, H, W)

        # Project each map to fusion_channels
        raw_16 = self.raw_proj_conv(x)         # (B, fusion_channels, H, W)
        var_16 = self.var_proj_conv(local_var)   # (B, fusion_channels, H, W)
        psd_16 = self.psd_proj_conv(psd_map)     # (B, fusion_channels, H, W)

        # Fuse the features using PSDGuidedFusion
        fused = self.psd_guided_fusion(raw_16, var_16, psd_16)  # (B, fusion_channels, H, W)

        # Pass the fused representation through the CNN backbone
        out = self.conv1(fused)
        out = self.conv2(out)
        out = self.conv3(out)
        out = out.view(B, -1)
        out = self.fc(out)
        return out

    def compute_local_variance(self, x):
        # Compute the local variance using a sliding window (kernel_size=7)
        kernel_size = 7
        pad = kernel_size // 2
        mean = F.avg_pool2d(x, kernel_size=kernel_size, stride=1, padding=pad)
        mean_sq = F.avg_pool2d(x * x, kernel_size=kernel_size, stride=1, padding=pad)
        variance = mean_sq - mean * mean
        return variance

    def compute_psd_map(self, x, eps=1e-8, fft_shift=True, normalize=True):
        """
        Compute the log power spectral density (PSD) map from an image.
    
        Args:
            x (Tensor): Input tensor with shape (B, C, H, W).
            eps (float): Small constant for numerical stability in the log transform.
            fft_shift (bool): If True, applies FFT shift to center the zero-frequency component.
            normalize (bool): If True, normalizes the log PSD map per image to zero mean and unit variance.
    
        Returns:
            Tensor: Log PSD map with shape (B, 1, H, W).
        """
        B, C, H, W = x.shape

        # Convert to grayscale if input has multiple channels (assumes RGB)
        if C > 1:
            x = 0.299 * x[:, 0:1] + 0.587 * x[:, 1:2] + 0.114 * x[:, 2:3]

        fft = torch.fft.fft2(x)
        if fft_shift:
            fft = torch.fft.fftshift(fft, dim=(-2, -1))
        psd = torch.abs(fft) ** 2
        psd_log = torch.log(psd + eps)

        if normalize:
            mean = psd_log.mean(dim=(-2, -1), keepdim=True)
            std = psd_log.std(dim=(-2, -1), keepdim=True) + eps
            psd_log = (psd_log - mean) / std

        if psd_log.dim() == 3:
            psd_log = psd_log.unsqueeze(1)

        return psd_log



###############################################
# 3. Create Dataset and DataLoader
###############################################
images_dir = r'C:\Users\lewka\deep_learning\SEM Deep Learning Multiclass Noise Level Classification with Data Augmentation (IEEE)\SEM_images\Cropped\Biofilm SEM Dataset\Noisy'
labels_csv = r"C:\Users\lewka\deep_learning\SEM Deep Learning Multiclass Noise Level Classification with Data Augmentation (IEEE)\SEM_images\Cropped\Biofilm SEM Dataset\Label\labels.csv"

dataset = SEMDataset(images_dir, labels_csv)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

###############################################
# 4. Instantiate Model, Loss, and Optimizer
###############################################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNNFusionModel().to(device)
criterion = nn.MSELoss()  # Mean Squared Error Loss for regression
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

###############################################
# 5. Training Loop
###############################################


num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    # Wrap the dataloader with tqdm for a progress bar per epoch.
    for images, labels in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        images = images.to(device)
        # Ensure images have shape (B, 1, H, W); if not, add a channel dimension.
        if images.ndim == 3:
            images = images.unsqueeze(1)
        # Convert labels to tensor.
        labels = torch.as_tensor(labels, dtype=torch.float32, device=device).unsqueeze(1)
        
        optimizer.zero_grad()
        predicted_snr = model(images)
        loss = criterion(predicted_snr, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item() * images.size(0)
    
    epoch_loss /= len(dataset)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.6f}")


###############################################
# 6. Evaluate on a Batch After Training
###############################################
model.eval()
with torch.no_grad():
    for images, labels in dataloader:
        images = images.to(device)
        if images.ndim == 3:
            images = images.unsqueeze(1)
        predicted_snr = model(images)
        print("Predicted SNR values (dB):", predicted_snr.squeeze().detach().cpu().numpy())
        print("Ground truth SNR (dB):", labels)
        break


Epoch 1/100: 100%|█████████████████████████████████████████████████████████████████████| 26/26 [00:02<00:00,  9.99it/s]


Epoch 1/100, Loss: 157.287478


Epoch 2/100: 100%|█████████████████████████████████████████████████████████████████████| 26/26 [00:02<00:00, 10.44it/s]


Epoch 2/100, Loss: 80.545043


Epoch 3/100: 100%|█████████████████████████████████████████████████████████████████████| 26/26 [00:02<00:00, 10.01it/s]


Epoch 3/100, Loss: 70.174612


Epoch 4/100: 100%|█████████████████████████████████████████████████████████████████████| 26/26 [00:03<00:00,  8.24it/s]


Epoch 4/100, Loss: 32.578082


Epoch 5/100: 100%|█████████████████████████████████████████████████████████████████████| 26/26 [00:03<00:00,  8.11it/s]


Epoch 5/100, Loss: 29.684199


Epoch 6/100: 100%|█████████████████████████████████████████████████████████████████████| 26/26 [00:03<00:00,  7.58it/s]


Epoch 6/100, Loss: 14.425571


Epoch 7/100: 100%|█████████████████████████████████████████████████████████████████████| 26/26 [00:03<00:00,  7.88it/s]


Epoch 7/100, Loss: 14.810117


Epoch 8/100: 100%|█████████████████████████████████████████████████████████████████████| 26/26 [00:03<00:00,  7.77it/s]


Epoch 8/100, Loss: 14.322982


Epoch 9/100: 100%|█████████████████████████████████████████████████████████████████████| 26/26 [00:03<00:00,  6.67it/s]


Epoch 9/100, Loss: 13.167660


Epoch 10/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:06<00:00,  4.13it/s]


Epoch 10/100, Loss: 17.869365


Epoch 11/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:08<00:00,  2.90it/s]


Epoch 11/100, Loss: 13.019029


Epoch 12/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:08<00:00,  3.01it/s]


Epoch 12/100, Loss: 12.971770


Epoch 13/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.43it/s]


Epoch 13/100, Loss: 14.345778


Epoch 14/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:08<00:00,  3.09it/s]


Epoch 14/100, Loss: 12.197217


Epoch 15/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.30it/s]


Epoch 15/100, Loss: 12.775674


Epoch 16/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.35it/s]


Epoch 16/100, Loss: 14.685479


Epoch 17/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.29it/s]


Epoch 17/100, Loss: 12.450492


Epoch 18/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.28it/s]


Epoch 18/100, Loss: 14.224662


Epoch 19/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:08<00:00,  3.07it/s]


Epoch 19/100, Loss: 14.037584


Epoch 20/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:08<00:00,  3.02it/s]


Epoch 20/100, Loss: 12.704831


Epoch 21/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:08<00:00,  3.09it/s]


Epoch 21/100, Loss: 12.204770


Epoch 22/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.34it/s]


Epoch 22/100, Loss: 10.537557


Epoch 23/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:08<00:00,  3.23it/s]


Epoch 23/100, Loss: 10.299395


Epoch 24/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.42it/s]


Epoch 24/100, Loss: 9.911898


Epoch 25/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.44it/s]


Epoch 25/100, Loss: 13.385846


Epoch 26/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.30it/s]


Epoch 26/100, Loss: 13.157095


Epoch 27/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:08<00:00,  3.10it/s]


Epoch 27/100, Loss: 9.422782


Epoch 28/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.38it/s]


Epoch 28/100, Loss: 9.131043


Epoch 29/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.39it/s]


Epoch 29/100, Loss: 8.718039


Epoch 30/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:08<00:00,  3.25it/s]


Epoch 30/100, Loss: 6.702715


Epoch 31/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.37it/s]


Epoch 31/100, Loss: 7.309931


Epoch 32/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:09<00:00,  2.66it/s]


Epoch 32/100, Loss: 10.393392


Epoch 33/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:08<00:00,  3.01it/s]


Epoch 33/100, Loss: 6.544455


Epoch 34/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.47it/s]


Epoch 34/100, Loss: 4.919628


Epoch 35/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:08<00:00,  3.16it/s]


Epoch 35/100, Loss: 3.706364


Epoch 36/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.35it/s]


Epoch 36/100, Loss: 3.308496


Epoch 37/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.35it/s]


Epoch 37/100, Loss: 3.664414


Epoch 38/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.30it/s]


Epoch 38/100, Loss: 2.286100


Epoch 39/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:08<00:00,  3.12it/s]


Epoch 39/100, Loss: 2.024236


Epoch 40/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:08<00:00,  3.17it/s]


Epoch 40/100, Loss: 1.897711


Epoch 41/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:09<00:00,  2.74it/s]


Epoch 41/100, Loss: 1.527948


Epoch 42/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:08<00:00,  2.94it/s]


Epoch 42/100, Loss: 1.123184


Epoch 43/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.35it/s]


Epoch 43/100, Loss: 1.187589


Epoch 44/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.29it/s]


Epoch 44/100, Loss: 0.820733


Epoch 45/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.45it/s]


Epoch 45/100, Loss: 0.447936


Epoch 46/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.44it/s]


Epoch 46/100, Loss: 0.409652


Epoch 47/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.42it/s]


Epoch 47/100, Loss: 0.305962


Epoch 48/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.45it/s]


Epoch 48/100, Loss: 0.585472


Epoch 49/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.42it/s]


Epoch 49/100, Loss: 0.417594


Epoch 50/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.32it/s]


Epoch 50/100, Loss: 0.294468


Epoch 51/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.40it/s]


Epoch 51/100, Loss: 0.172899


Epoch 52/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.44it/s]


Epoch 52/100, Loss: 0.122978


Epoch 53/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.43it/s]


Epoch 53/100, Loss: 0.105996


Epoch 54/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.42it/s]


Epoch 54/100, Loss: 0.059112


Epoch 55/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:08<00:00,  3.15it/s]


Epoch 55/100, Loss: 0.042146


Epoch 56/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.41it/s]


Epoch 56/100, Loss: 0.100746


Epoch 57/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.43it/s]


Epoch 57/100, Loss: 0.074728


Epoch 58/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.44it/s]


Epoch 58/100, Loss: 0.078287


Epoch 59/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.41it/s]


Epoch 59/100, Loss: 0.034688


Epoch 60/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.44it/s]


Epoch 60/100, Loss: 0.020885


Epoch 61/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.46it/s]


Epoch 61/100, Loss: 0.015694


Epoch 62/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.39it/s]


Epoch 62/100, Loss: 0.014340


Epoch 63/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.45it/s]


Epoch 63/100, Loss: 0.015657


Epoch 64/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.45it/s]


Epoch 64/100, Loss: 0.015600


Epoch 65/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.47it/s]


Epoch 65/100, Loss: 0.014171


Epoch 66/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.45it/s]


Epoch 66/100, Loss: 0.017727


Epoch 67/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.39it/s]


Epoch 67/100, Loss: 0.024092


Epoch 68/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.45it/s]


Epoch 68/100, Loss: 0.021865


Epoch 69/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.46it/s]


Epoch 69/100, Loss: 0.016207


Epoch 70/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.46it/s]


Epoch 70/100, Loss: 0.027326


Epoch 71/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.42it/s]


Epoch 71/100, Loss: 0.026592


Epoch 72/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.39it/s]


Epoch 72/100, Loss: 0.033046


Epoch 73/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.45it/s]


Epoch 73/100, Loss: 0.025943


Epoch 74/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.46it/s]


Epoch 74/100, Loss: 0.020393


Epoch 75/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.46it/s]


Epoch 75/100, Loss: 0.017667


Epoch 76/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.43it/s]


Epoch 76/100, Loss: 0.008457


Epoch 77/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.45it/s]


Epoch 77/100, Loss: 0.035549


Epoch 78/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.46it/s]


Epoch 78/100, Loss: 0.028921


Epoch 79/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.46it/s]


Epoch 79/100, Loss: 0.016426


Epoch 80/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.38it/s]


Epoch 80/100, Loss: 0.010184


Epoch 81/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.37it/s]


Epoch 81/100, Loss: 0.005167


Epoch 82/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.42it/s]


Epoch 82/100, Loss: 0.007553


Epoch 83/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.42it/s]


Epoch 83/100, Loss: 0.009807


Epoch 84/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.43it/s]


Epoch 84/100, Loss: 0.038307


Epoch 85/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.33it/s]


Epoch 85/100, Loss: 0.024023


Epoch 86/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.38it/s]


Epoch 86/100, Loss: 0.017488


Epoch 87/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.43it/s]


Epoch 87/100, Loss: 0.011615


Epoch 88/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.43it/s]


Epoch 88/100, Loss: 0.022332


Epoch 89/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.42it/s]


Epoch 89/100, Loss: 0.029619


Epoch 90/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.26it/s]


Epoch 90/100, Loss: 0.057113


Epoch 91/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:08<00:00,  3.11it/s]


Epoch 91/100, Loss: 0.045445


Epoch 92/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:08<00:00,  2.99it/s]


Epoch 92/100, Loss: 0.020144


Epoch 93/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:08<00:00,  3.10it/s]


Epoch 93/100, Loss: 0.010668


Epoch 94/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:08<00:00,  3.23it/s]


Epoch 94/100, Loss: 0.013918


Epoch 95/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:08<00:00,  3.17it/s]


Epoch 95/100, Loss: 0.012354


Epoch 96/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:08<00:00,  3.07it/s]


Epoch 96/100, Loss: 0.022717


Epoch 97/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:08<00:00,  2.91it/s]


Epoch 97/100, Loss: 0.016008


Epoch 98/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:08<00:00,  3.14it/s]


Epoch 98/100, Loss: 0.089459


Epoch 99/100: 100%|████████████████████████████████████████████████████████████████████| 26/26 [00:08<00:00,  2.91it/s]


Epoch 99/100, Loss: 0.107685


Epoch 100/100: 100%|███████████████████████████████████████████████████████████████████| 26/26 [00:07<00:00,  3.43it/s]

Epoch 100/100, Loss: 0.261852
Predicted SNR values (dB): [ 9.209839  9.483712 29.783842 19.707016 14.166658 14.600001 29.96935
  9.494891 29.532959 25.529493 14.627373 24.756805 24.851286 19.67107
 19.087925 25.086676]
Ground truth SNR (dB): tensor([10, 10, 30, 20, 15, 15, 30, 10, 30, 25, 15, 25, 25, 20, 20, 25])



