In [1]:
import torch
import torch.nn as nn
import os
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from PIL import Image

# import albumentations as A
# from albumentations.pytorch import ToTensorV2

#Trainer Imports
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim


from matplotlib import pyplot as plt #for visualizing 

#From the local python files
from datasets import DenoisingPairedDataset
from trainer import DenoisingTrainer #What is going awn
# from trainer import perceptual_loss, combined_loss, DenoisingTrainer
import torch.nn.functional as F

# Root directories

In [2]:
root_dir = 'datasets/aquarium-data-cots/aquarium_pretrain'

splits = ['train', 'test','valid']

model_path = '/home/ubuntu/cs230_VIVEKA/saved_models'

# Denoising Model Class

### Architecture 1: basic model.
consider adding: dropout, batch normalization layers, skip connections

In [79]:
import torch
import torch.nn as nn

class DenoisingCNN(nn.Module):
    def __init__(self):
        super(DenoisingCNN, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # Downsample once
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)  # Downsample again
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 3, kernel_size=3, padding=1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded


### Architecture 1.5 (WIP)

In [34]:
class DenoisingCNN_v2(nn.Module):
    ''' Simple CNN architecture for the first iteration of Denoising model training.
        simple encoder-decoder structure with dropout of
        0.2 initially and 0.4 at the last layer of encoder (and no dropout for the final layer of the decoder).
    '''
    def __init__(self):
        super(DenoisingCNN_v2, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            #nn.Dropout(0.2),

            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            #nn.Dropout(0.2),

            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            #nn.Dropout(0.4)
            
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.MaxPool2d(2),

        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            #nn.Dropout(0.2),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            #nn.Dropout(0.2),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            nn.Conv2d(32, 3, kernel_size=3, padding=1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded


### Architecture 2
find pre existing model that targets issues of the images, 

In [35]:
class DenoisingUNet(nn.Module):
    ''' Denoising CNN class with a U-Net architecture.
        3 contractive layers (encoder) and 3 expansive layers (decode).
        Final layer is Sigmoid to map pixel values to (0,1).
    '''
    def __init__(self):
        super(DenoisingUNet, self).__init__()
        # Encoder
        self.enc_conv1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.enc_conv1a = nn.Conv2d(32, 32, kernel_size=3, padding=1)  # Additional layer
        self.pool1 = nn.MaxPool2d(2)
        self.dropout1 = nn.Dropout(0.2)

        self.enc_conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.enc_conv2a = nn.Conv2d(64, 64, kernel_size=3, padding=1)  # Additional layer
        self.pool2 = nn.MaxPool2d(2)
        self.dropout2 = nn.Dropout(0.2)

        self.enc_conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.enc_conv3a = nn.Conv2d(128, 128, kernel_size=3, padding=1)  # Additional layer
        self.dropout3 = nn.Dropout(0.4)
        
        # Decoder
        self.dec_conv1 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )
        self.dropout4 = nn.Dropout(0.2)

        self.dec_conv2 = nn.Sequential(
            nn.ConvTranspose2d(128, 32, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )
        self.dropout5 = nn.Dropout(0.2)

        self.final_conv = nn.Sequential(
            nn.Conv2d(64, 3, kernel_size=3, padding=1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        # Encoder
        x1 = self.enc_conv1(x)
        x1 = self.enc_conv1a(x1)  # Pass through additional layer
        x1_pooled = self.pool1(x1)
        x1_pooled = self.dropout1(x1_pooled)

        x2 = self.enc_conv2(x1_pooled)
        x2 = self.enc_conv2a(x2)  # Pass through additional layer
        x2_pooled = self.pool2(x2)
        x2_pooled = self.dropout2(x2_pooled)

        x3 = self.enc_conv3(x2_pooled)
        x3 = self.enc_conv3a(x3)  # Pass through additional layer
        x3 = self.dropout3(x3)

        # Decoder with Skip Connections
        x4 = self.dec_conv1(x3)
        x4 = self.dropout4(x4)
        x4 = torch.cat((x4, x2), dim=1)  # Skip connection from encoder layer 2

        x5 = self.dec_conv2(x4)
        x5 = self.dropout5(x5)
        x5 = torch.cat((x5, x1), dim=1)  # Skip connection from encoder layer 1

        out = self.final_conv(x5)

        return out

### Architecture 3.1

In [36]:
class ImprovedDenoisingUNet(nn.Module):
    def __init__(self):
        super(ImprovedDenoisingUNet, self).__init__()
        # Encoder
        self.enc_conv1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool1 = nn.MaxPool2d(2)

        self.enc_conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool2 = nn.MaxPool2d(2)

        self.enc_conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        # Decoder
        self.dec_conv1 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )

        self.dec_conv2 = nn.Sequential(
            nn.ConvTranspose2d(64 + 64, 32, kernel_size=2, stride=2),  # Skip connection
            nn.ReLU(inplace=True)
        )

        self.final_conv = nn.Conv2d(32 + 32, 3, kernel_size=3, padding=1)  # Skip connection

    def forward(self, x):
        # Encoder
        x1 = self.enc_conv1(x)
        x1_pooled = self.pool1(x1)

        x2 = self.enc_conv2(x1_pooled)
        x2_pooled = self.pool2(x2)

        x3 = self.enc_conv3(x2_pooled)

        # Decoder with Skip Connections
        x4 = self.dec_conv1(x3)
        x4 = torch.cat((x4, x2), dim=1)  # Skip connection

        x5 = self.dec_conv2(x4)
        x5 = torch.cat((x5, x1), dim=1)  # Skip connection

        out = self.final_conv(x5)

        return torch.sigmoid(out)  # Map to [0, 1]


### Architecture 3.2

In [37]:
class ResidualUNet(nn.Module):
    def __init__(self):
        super(ResidualUNet, self).__init__()
        # Encoder
        self.enc_conv1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.enc_conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool = nn.MaxPool2d(2)

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        # Decoder
        self.dec_conv1 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )
        self.dec_conv2 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
            nn.ReLU(inplace=True)
        )
        self.final_conv = nn.Conv2d(32, 3, kernel_size=3, padding=1)

    def forward(self, x):
        # Encoder
        x1 = self.enc_conv1(x)
        x1_pooled = self.pool(x1)

        x2 = self.enc_conv2(x1_pooled)
        x2_pooled = self.pool(x2)

        # Bottleneck
        bottleneck = self.bottleneck(x2_pooled)

        # Decoder
        x3 = self.dec_conv1(bottleneck) + x2  # Residual connection
        x4 = self.dec_conv2(x3) + x1  # Residual connection

        out = torch.sigmoid(self.final_conv(x4))
        return out

# Denoising Trainer

In [38]:
# from skimage.metrics import peak_signal_noise_ratio as psnr
# from skimage.metrics import structural_similarity as ssim
# import numpy as np

# class DenoisingTrainer:
#     def __init__(self, model, device, criterion, optimizer, save_path='best_model.pth'):
#         self.model = model
#         self.device = device
#         self.criterion = criterion
#         self.optimizer = optimizer
#         self.save_path = save_path

#         self.train_losses = []
#         self.val_losses = []
#         self.best_val_loss = float('inf')

#     def train(self, train_loader, valid_loader, num_epochs=10):
#         for epoch in range(num_epochs):
#             self.model.train()
#             running_train_loss = 0.0
#             total_train_samples = 0

#             for noisy_images, clean_images in train_loader:
#                 noisy_images = noisy_images.to(self.device)
#                 clean_images = clean_images.to(self.device)

#                 self.optimizer.zero_grad()
#                 outputs = self.model(noisy_images)
#                 loss = self.criterion(outputs, clean_images)
#                 loss.backward()
#                 self.optimizer.step()

#                 running_train_loss += loss.item() * noisy_images.size(0)
#                 total_train_samples += noisy_images.size(0)

#             epoch_train_loss = running_train_loss / total_train_samples
#             self.train_losses.append(epoch_train_loss)
#             print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_train_loss:.4f}')

#             self.model.eval()
#             running_val_loss = 0.0
#             total_val_samples = 0

#             with torch.no_grad():
#                 for noisy_images, clean_images in valid_loader:
#                     noisy_images = noisy_images.to(self.device)
#                     clean_images = clean_images.to(self.device)

#                     outputs = self.model(noisy_images)
#                     loss = self.criterion(outputs, clean_images)

#                     running_val_loss += loss.item() * noisy_images.size(0)
#                     total_val_samples += noisy_images.size(0)

#             epoch_val_loss = running_val_loss / total_val_samples
#             self.val_losses.append(epoch_val_loss)
#             print(f'Epoch {epoch+1}/{num_epochs}, Validation Loss: {epoch_val_loss:.4f}')

#             if epoch_val_loss < self.best_val_loss:
#                 self.best_val_loss = epoch_val_loss
#                 torch.save(self.model.state_dict(), self.save_path)
#                 print(f"Model saved with validation loss: {epoch_val_loss:.4f}")

#         #Plotting train and validation loss per epoch at the end of train() function
#         plt.figure(figsize=(10, 6))
#         plt.plot(range(1, len(self.train_losses) + 1), self.train_losses, label='Training Loss')
#         plt.plot(range(1, len(self.val_losses) + 1), self.val_losses, label='Validation Loss')
#         plt.xlabel('Epochs')
#         plt.ylabel('Loss')
#         plt.title('Training and Validation Loss per Epoch')
#         plt.legend()
#         plt.grid(True)
#         plt.show()

#     def evaluate(self, data_loader):
#         ''' Evaluates MSE loss, pSNR, and SSIM of the model output to the ground truth labels
#             using skimage.metrics for the latter two metrics.
#         '''
#         self.model.eval()
#         running_loss = 0.0
#         total_samples = 0
#         total_psnr = 0.0
#         total_ssim = 0.0

#         with torch.no_grad():
#             for noisy_images, clean_images in data_loader:
#                 noisy_images = noisy_images.to(self.device)
#                 clean_images = clean_images.to(self.device)

#                 outputs = self.model(noisy_images)
#                 loss = self.criterion(outputs, clean_images)

#                 running_loss += loss.item() * noisy_images.size(0)
#                 total_samples += noisy_images.size(0)

#                 # METRICS pSNR, SSIM ADDED HERE
#                 outputs_np = outputs.cpu().numpy().transpose(0, 2, 3, 1)
#                 clean_images_np = clean_images.cpu().numpy().transpose(0, 2, 3, 1)
                
#                 # debugging ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
#                 # print("Output shape:", outputs_np.shape)
#                 # print("Clean image shape:", clean_images_np.shape)
#                 # end debugging ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||

#                 for o, c in zip(outputs_np, clean_images_np):
#                     total_psnr += psnr(c, o, data_range=1.0)
#                     total_ssim += ssim(c, o, data_range=1.0, win_size=3, channel_axis=-1)

#         avg_loss = running_loss / total_samples
#         avg_psnr = total_psnr / total_samples
#         avg_ssim = total_ssim / total_samples

#         print(f"Avg Loss: {avg_loss:.4f}, Avg PSNR: {avg_psnr:.4f}, Avg SSIM: {avg_ssim:.4f}")
#         return avg_loss, avg_psnr, avg_ssim


# Main model training pipeline

##### create datasets, dataloaders, model, trainer

In [56]:
hyperparams = ['batch_size', 'lr', 'num_epochs', 'dropout_rate']

In [57]:
#Transforms
normalization_mean = normalization_std = [0.5, 0.5, 0.5]

transform_normalize = T.Compose([T.Resize((224, 224)),  #resize
                        T.ToTensor(),                   # to tensor
                        T.Normalize(mean=normalization_mean, std=normalization_std)]) #normalize

transform_regular = T.ToTensor()

transform_resize = T.Compose( [ T.Resize((224,224)) , T.ToTensor() ] ) #IMPORTANT: size is taken from datasets.py transform in order to match 

create dataset

In [58]:
train_data = DenoisingPairedDataset(root_dir=root_dir, split='train', transform=transform_resize)
valid_data = DenoisingPairedDataset(root_dir=root_dir, split='valid', transform=transform_resize)
test_data = DenoisingPairedDataset(root_dir=root_dir, split='test', transform=transform_resize)

create DataLoaders

In [78]:
batch_size = 32

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

create Model, device

In [70]:
#define models
model_simple = DenoisingCNN()
# model_unet = DenoisingUNet()

# model_simple_v2 = DenoisingCNN_v2()

In [71]:
class PerceptualLoss(nn.Module):
    def __init__(self, vgg, layers=None):
        super(PerceptualLoss, self).__init__()
        
        # Use specific layers from the VGG model (e.g., up to 'conv_4')
        if layers is None:
            layers = ['0', '5', '10', '19']  # You can choose different layers
        self.layers = layers
        
        # Extract the layers from VGG model
        self.vgg_layers = nn.ModuleList([vgg[int(i)] for i in layers])
    
    def forward(self, x, y):
        # Extract features at selected layers from both the input images
        x_features = self.extract_features(x)
        y_features = self.extract_features(y)
        
        # Compute the perceptual loss (MSE loss between feature maps)
        loss = 0.0
        for x_feat, y_feat in zip(x_features, y_features):
            loss += F.mse_loss(x_feat, y_feat)
        
        return loss

    def extract_features(self, x):
        features = []
        for i, layer in enumerate(self.vgg_layers):
            x = layer(x)
            if str(i) in self.layers:
                features.append(x)
        return features

In [72]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = model_simple

model.to(device)

Using device: cuda


DenoisingCNN(
  (encoder): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
  )
  (decoder): Sequential(
    (0): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
    (1): ReLU(inplace=True)
    (2): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))
    (3): ReLU(inplace=True)
    (4): ConvTranspose2d(128, 64, kernel_size=(

In [73]:
import torchvision.models as models

# Load pre-trained VGG-16 model
vgg = models.vgg16(pretrained=True).features

# Set to evaluation mode (important)
vgg.eval()

# Move the model to the same device as your neural network
vgg = vgg.to(device)

# Freeze VGG parameters so they aren't updated during training
for param in vgg.parameters():
    param.requires_grad = False


In [74]:
num_epochs = 10

# Assuming 'model' is your denoising model, and 'train_loader' is your DataLoader
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Perceptual Loss function
perceptual_criterion = PerceptualLoss(vgg)


In [75]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = model_simple

model.to(device)

# criterion = nn.MSELoss()
perceptual_criterion = PerceptualLoss(vgg)

optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)

Using device: cuda


In [76]:
trainer = DenoisingTrainer(model, device, perceptual_criterion, optimizer, save_path="saved_models/best_denoising_model_unet.pth")

##### Train model

In [80]:
num_epochs = 10

trainer.train(train_loader, valid_loader, num_epochs=num_epochs)

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 21.95 GiB of which 17.88 MiB is free. Process 5819 has 8.65 GiB memory in use. Process 6592 has 186.00 MiB memory in use. Including non-PyTorch memory, this process has 13.08 GiB memory in use. Of the allocated memory 12.86 GiB is allocated by PyTorch, and 9.05 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

##### Evaluate model

In [None]:
val_loss, val_psnr, val_ssim = trainer.evaluate(valid_loader)
#print(f"Validation - Loss: {val_loss:.4f}, PSNR: {val_psnr:.4f}, SSIM: {val_ssim:.4f}")

test_loss, test_psnr, test_ssim = trainer.evaluate(test_loader)
#print(f"Test - Loss: {test_loss:.4f}, PSNR: {test_psnr:.4f}, SSIM: {test_ssim:.4f}")


# visualize dataset, model outputs

##### visualize X,Y pairs

In [16]:
def unnormalize(img_tensor, mean, std):
    """
    Unnormalize a tensor image given the original mean and standard deviation.
    """
    img_tensor = img_tensor.clone().detach().cpu()
    for t, m, s in zip(img_tensor, mean, std):
        t.mul_(s).add_(m)
    return img_tensor

##### Visualize Y_Pred, Y pairs

In [27]:
def visualize_output(model, data, device, idx, normalize=True):
    ''' Plots the input, ground truth, and model output of a given ID in the dataset. 
        ID must be within (0, len(dataset)). 
        Unnormalizes the image if
    '''
    if idx >= len(data): raise IndexError("idx out of bounds of dataset length")

    noisy_image, clean_image = data[idx] 
    noisy_image = noisy_image.unsqueeze(0).to(device)
    
    with torch.no_grad():
        predicted_image = model(noisy_image).cpu().squeeze(0) #create predicted image from model 
    if normalize:
        noisy_image = unnormalize(noisy_image.cpu().squeeze(0), normalization_mean, normalization_std)
        clean_image = unnormalize(clean_image, normalization_mean, normalization_std)
        #unnormalize predicted image? yes/no?
    else:
        noisy_image = noisy_image.cpu().squeeze(0) #if no normalization for noisy_image
    
    predicted_image = predicted_image.permute(1, 2, 0).clip(0, 1)  # Transpose for plotting

    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.imshow(noisy_image.permute(1, 2, 0))
    plt.title("Noisy Input")
    plt.subplot(1, 3, 2)
    plt.imshow(predicted_image)
    plt.title("Model Output")
    plt.subplot(1, 3, 3)
    plt.imshow(clean_image.permute(1, 2, 0))
    plt.title("Ground Truth")
    plt.show()

In [None]:
model_path = "saved_models/best_denoising_model.pth"
model_vis = DenoisingCNN()
model_vis.load_state_dict(torch.load(model_path))
model_vis.eval()
model_vis.to(device)

In [None]:
visualize_output(model=model_vis, data=train_data, device=device, idx=26, normalize=False)

In [None]:
model_simple

In [None]:
model_unet