In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as compare_psnr

class Zero(nn.Module):
    def __init__(self, stride):
        super(Zero, self).__init__()
        self.stride = stride

    def forward(self, x):
        if self.stride == 1:
            return x.mul(0.)
        return x[:, :, ::self.stride, ::self.stride].mul(0.)


class SepConv(nn.Module):
    def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
        super(SepConv, self).__init__()
        self.op = nn.Sequential(
            nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
            nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
            nn.BatchNorm2d(C_in, affine=affine),
            nn.ReLU(inplace=True),
            nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False),
            nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
            nn.BatchNorm2d(C_out, affine=affine),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.op(x)

class FactorizedReduce(nn.Module):
    def __init__(self, C_in, C_out, affine=True):
        super(FactorizedReduce, self).__init__()
        assert C_out % 2 == 0
        self.relu = nn.ReLU(inplace=True)
        self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
        self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
        self.bn = nn.BatchNorm2d(C_out, affine=affine)

    def forward(self, x):
        x = self.relu(x)
        out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)
        out = self.bn(out)
        return out
    
class ChannelAttention(nn.Module):
    def __init__(self, channel, reduction=16, symbol=None, *args, **kwargs):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

PRIMITIVES = [
    'none',
    'avg_pool_3x3',
    'max_pool_3x3',
    'sep_conv_3x3',
    'sep_conv_5x5',
    'skip_connect',
    'channel_attention'
]

OPS = {
    'none': lambda C, stride, affine: Zero(stride),
    'avg_pool_3x3': lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
    'max_pool_3x3': lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1),
    'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
    'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
    'skip_connect': lambda C, stride, affine: nn.Identity() if stride == 1 else FactorizedReduce(C, C, affine),
    'channel_attention': lambda C, stride, affine: ChannelAttention(C)
}

class MixedOp(nn.Module):
    def __init__(self, C, stride):
        super(MixedOp, self).__init__()
        self._ops = nn.ModuleList()
        for primitive in PRIMITIVES:
            op = OPS[primitive](C, stride, False)
            if 'pool' in primitive:
                op = nn.Sequential(op, nn.BatchNorm2d(C, affine=False))
            self._ops.append(op)

    def forward(self, x, weights):
        return sum(op(w * x) for w, op in zip(weights, self._ops))



class Cell(nn.Module):
    multiplier = 1
    def __init__(self, C_prev_prev, C_prev, C, reduction, reduction_prev):
        super(Cell, self).__init__()
        self.reduction = reduction

        # Define operations
        # For simplicity, let's just assume we have 2 nodes for operations.
        self._ops = nn.ModuleList()
        for _ in range(2):
            op = MixedOp(C, stride=2 if reduction else 1)
            self._ops.append(op)

        self._initialize_alphas()

    def _initialize_alphas(self):
        k = sum(1 for i in range(2))
        num_ops = len(PRIMITIVES)
        self.alphas_normal = nn.Parameter(1e-3 * torch.randn(k, num_ops))
        self.alphas_reduce = nn.Parameter(1e-3 * torch.randn(k, num_ops))

    def forward(self, s0, s1, weights):
        s2 = sum(self._ops[i](s1, weights) for i in range(2))
        return s2
    
class DenoisingDARTS(nn.Module):
    def __init__(self, C=16, num_cells=3):
        super(DenoisingDARTS, self).__init__()
        self.C = C
        self.num_cells = num_cells

        # Initial convolution
        self.stem = nn.Sequential(
            nn.Conv2d(1, C, 3, padding=1),
            nn.BatchNorm2d(C)
        )

        C_prev_prev, C_prev, C_curr = C, C, C

        self.alphas_normal = nn.Parameter(torch.randn(1, len(OPS)))
        self.alphas_reduce = nn.Parameter(torch.randn(1, len(OPS)))

        self.cells = nn.ModuleList()
        reduction_prev = False
        for i in range(num_cells):
            cell = Cell(C_prev_prev, C_prev, C_curr, reduction=False, reduction_prev=reduction_prev)
            self.cells += [cell]
            C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
            reduction_prev = False

        self.output_conv = nn.Conv2d(C_prev, 3, 1)

    def forward(self, x):
        s0 = s1 = self.stem(x)
        for i, cell in enumerate(self.cells):
            if cell.reduction:
                weights = F.softmax(self.alphas_reduce, dim=-1)
            else:
                weights = F.softmax(self.alphas_normal, dim=-1)
            s0, s1 = s1, cell(s0, s1, weights)
        out = self.output_conv(s1)

        return torch.sigmoid(out) 


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.MSELoss()

# Create the model
model = DenoisingDARTS()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Assuming noisy_image is your corrupted image you want to restore
noise_type = 'gaussian'
resolution = 64
noise_level = '0.09'
epochs = 1000
noisy_image = torch.tensor(np.load(f'/home/joe/nas-for-dip/phantoms/{noise_type}/res_{resolution}/nl_{noise_level}/p_{45}.npy')).float().to(device)
input_noise = torch.randn(noisy_image.size()).to(noisy_image.device)  # Random noise matching the image size
noisy_image = noisy_image.unsqueeze(1)  # Add channel dimension
input_noise = input_noise.unsqueeze(1)  # Add channel dimension

# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
noisy_image = noisy_image.to(device)
input_noise = input_noise.to(device)


# Training loop
for epoch in range(epochs):
    denoised_image = model(input_noise)  # Pass noise through the model
    loss = criterion(denoised_image, noisy_image)
    # compute PSNR
    psnr = compare_psnr(noisy_image.cpu().detach().numpy(), denoised_image.cpu().detach().numpy())
    print(f'Epoch {epoch} | Loss: {loss.item()} | PSNR: {psnr}')
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


RuntimeError: The size of tensor a (7) must match the size of tensor b (64) at non-singleton dimension 3

: 

In [None]:
from collections import OrderedDict

import torch
import torch.nn as nn


class UNet(nn.Module):

    def __init__(self, in_channels=3, out_channels=1, init_features=32, depth=6):
        super(UNet, self).__init__()

        self.depth = depth

        features = init_features

        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        self.conv = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        return torch.sigmoid(self.conv(dec1))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

In [7]:
from collections import OrderedDict

import torch
import torch.nn as nn


class UNet(nn.Module):

    def __init__(self, in_channels=1, out_channels=1, init_features=32, depth=4):
        super(UNet, self).__init__()

        self.depth = depth

        features = init_features
        self.pools = nn.ModuleList()
        self.encoders = nn.ModuleList()
        self.encoders.append(UNet._block(in_channels, features, name="enc1"))
        self.pools.append(nn.MaxPool2d(kernel_size=2, stride=2))

        for i in range(depth-1):
            self.encoders.append(UNet._block(features, features * 2, name=f"enc{i+2}"))
            self.pools.append(nn.MaxPool2d(kernel_size=2, stride=2))
            features *= 2

        self.bottleneck = UNet._block(features, features * 2, name="bottleneck")

        self.upconvs = nn.ModuleList()
        self.decoders = nn.ModuleList()
        for i in range(depth):
            self.upconvs.append(nn.ConvTranspose2d(
                features * 2, features, kernel_size=2, stride=2
            ))
            self.decoders.append(UNet._block(features * 2, features, name=f"dec{i+1}"))
            features //= 2
        

        self.conv = nn.Conv2d(
            in_channels=features*2, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):
        print(f'in: {x.shape}')
        skips = []
        for i in range(self.depth):
            x = self.encoders[i](x)
            print(f'enc{i+1}: {x.shape}')
            skips.append(x)
            x = self.pools[i](x)
            print(f'pool{i+1}: {x.shape}')
            
        x = self.bottleneck(x)
        print(f'bottleneck: {x.shape}\n\n')

        for i in range(self.depth):
            x = self.upconvs[i](x)
            print(f'upconv{i+1}: {x.shape}')
            x = torch.cat((x, skips[-i-1]), dim=1)
            print(f'cat{i+1}: {x.shape}')
            x = self.decoders[i](x)
            print(f'dec{i+1}: {x.shape}')
        return torch.sigmoid(self.conv(x))

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )
    
    def test(self):
        x = torch.randn(1, 1, 64, 64)
        out = self.forward(x)
        print(out.shape)
        assert out.shape == (1, 1, 64, 64)
        print('Test passed')

in: torch.Size([1, 1, 64, 64])
enc1: torch.Size([1, 32, 64, 64])
pool1: torch.Size([1, 32, 32, 32])
enc2: torch.Size([1, 64, 32, 32])
pool2: torch.Size([1, 64, 16, 16])
enc3: torch.Size([1, 128, 16, 16])
pool3: torch.Size([1, 128, 8, 8])
enc4: torch.Size([1, 256, 8, 8])
pool4: torch.Size([1, 256, 4, 4])
bottleneck: torch.Size([1, 512, 4, 4])


upconv1: torch.Size([1, 256, 8, 8])
cat1: torch.Size([1, 512, 8, 8])
dec1: torch.Size([1, 256, 8, 8])
upconv2: torch.Size([1, 128, 16, 16])
cat2: torch.Size([1, 256, 16, 16])
dec2: torch.Size([1, 128, 16, 16])
upconv3: torch.Size([1, 64, 32, 32])
cat3: torch.Size([1, 128, 32, 32])
dec3: torch.Size([1, 64, 32, 32])
upconv4: torch.Size([1, 32, 64, 64])
cat4: torch.Size([1, 64, 64, 64])
dec4: torch.Size([1, 32, 64, 64])
torch.Size([1, 1, 64, 64])
Test passed
