In [22]:
import os 
import gc
import timm
import copy
import random as rd 
import numpy as np
import pandas as pd
from PIL import Image
import seaborn as sns
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import Subset
from torch.optim import Adam, SGD 
from torch.utils.data import Subset, Dataset, DataLoader, ConcatDataset

from albumentations.pytorch import ToTensorV2
import albumentations as A

from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score, accuracy_score, precision_score, recall_score 
from sklearn.model_selection import train_test_split
from collections import defaultdict, Counter

In [23]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
device = 'cpu'
print("Using device: ", device)

Using device:  cpu


In [24]:
# Set seeds for reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
rd.seed(seed)
torch.cuda.manual_seed_all(seed)

# Enable deterministic behavior in PyTorch
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [25]:
transform_input = A.Compose([
    A.Resize(40,45),
    ToTensorV2()
])

transform_target = A.Compose([
    A.Resize(160,180),
    ToTensorV2()
])

In [26]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x
        out = self.relu1(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += identity
        out = self.relu2(out)
        return out


In [27]:
class Expander(nn.Module):
    def __init__(self):
        super(Expander, self).__init__()

        # Initial feature extraction block
        self.conv1 = nn.Conv2d(3, 64, kernel_size=5, padding=2)
        self.relu1 = nn.ReLU(inplace=True)

        # Residual blocks to deepen the network
        self.resblock1 = ResidualBlock(64)
        self.resblock2 = ResidualBlock(64)
        self.resblock3 = ResidualBlock(64)

        # Upsample the feature maps
        self.upsample = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)

        # Another block after upsampling
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU(inplace=True)

        self.resblock4 = ResidualBlock(64)
        self.resblock5 = ResidualBlock(64)

        # Final layer to output the 3-channel image
        self.conv3 = nn.Conv2d(64, 3, kernel_size=3, padding=1)
        self.sigmoid = nn.Sigmoid()  # Ensure output is between 0 and 1 for valid image pixel values

    def forward(self, x):
        # Initial convolution
        x = self.relu1(self.conv1(x))

        # Pass through residual blocks
        x = self.resblock1(x)
        x = self.resblock2(x)
        x = self.resblock3(x)

        # Upsample the feature maps
        x = self.upsample(x)

        # Continue feature extraction and residual blocks after upsampling
        x = self.relu2(self.conv2(x))
        x = self.resblock4(x)
        x = self.resblock5(x)

        # Output layer
        x = self.conv3(x)
        x = self.sigmoid(x)  # Apply sigmoid to constrain output between 0 and 1

        return x

In [28]:
from torchviz import make_dot
# Define a dummy input for visualization
dummy_input = torch.randn(1, 3, 40, 45)  # Example input size: (batch_size, channels, height, width)

# Initialize the model
model = Expander()
output = model(dummy_input)

# Visualize the model
graph = make_dot(output, params=dict(list(model.named_parameters())))
graph.render("model_structure", format="png")

'model_structure.png'