In [91]:
import torch
import torch.nn as nn
import torch.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import scipy.io as sio
from scipy.io import loadmat
import os
import numpy as np

In [92]:
class SqueezeExcitation(nn.Module):
    def __init__(self, filters, ratio=16):
        super().__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(filters, filters // ratio),
            nn.ReLU(),
            nn.Linear(filters // ratio, filters),
            nn.Sigmoid()
        )

    def forward(self, x):
        se = self.se(x)
        return x * se.unsqueeze(-1).unsqueeze(-1)

class ResidualBlock(nn.Module):
    def __init__(self, filters, kernel_size=3, stride=1):
        super().__init__()
        self.shortcut = nn.Sequential()
        if stride != 1 or filters != filters:
            self.shortcut = nn.Sequential(
                nn.Conv2d(filters, filters, kernel_size=1, stride=stride),
                nn.BatchNorm2d(filters)
            )

        self.conv = nn.Sequential(
            nn.Conv2d(filters, filters, kernel_size, stride=stride, padding=1),
            nn.BatchNorm2d(filters),
            nn.ReLU(),
            nn.Conv2d(filters, filters, kernel_size, padding=1),
            nn.BatchNorm2d(filters),
            SqueezeExcitation(filters)
        )

    def forward(self, x):
        return F.relu(self.conv(x) + self.shortcut(x))

class ImageRestorationModel(nn.Module):
    def __init__(self):
        super().__init__(dim_model)
        self.conv1 = nn.Sequential(
            nn.Conv2d(dim_model, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        self.residual_blocks = nn.Sequential(
            ResidualBlock(32),
            ResidualBlock(32),
            nn.MaxPool2d(2),
            ResidualBlock(64),
            ResidualBlock(64),
            nn.MaxPool2d(2),
            ResidualBlock(128),
            ResidualBlock(128),
            nn.MaxPool2d(2)
        )

        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 2 * 2, 256),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Dropout(0.5),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.5),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.residual_blocks(x)
        x = self.fc(x)
        return x

In [102]:
class DeformableAttention(nn.Module):
    
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.query_linear = nn.Linear(embed_dim, embed_dim)
        self.key_linear = nn.Linear(embed_dim, embed_dim)
        self.value_linear = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, query, key, value):
        batch_size = query.size(0)
        query = self.query_linear(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        key = self.key_linear(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        value = self.value_linear(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)

        attention_scores = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(self.head_dim)
        attention_weights = F.softmax(attention_scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        output = torch.matmul(attention_weights, value).transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)
        return output

class ImageRestorationHead(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.activation = nn.ReLU()

    def forward(self, x):
        return self.activation(self.conv(x))

class DeformableTransformerEncoder(nn.Module):
    def __init__(self, d_model, nhead, num_layers, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        self.deform_attn = DeformableAttention(d_model, nhead)
        self.ffn = nn.Linear(d_model, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.1)
        self.relu = nn.ReLU()

    def forward(self, src):
        src = self.norm1(src + self.dropout(self.self_attn(src, src)))
        src = self.norm2(src + self.dropout(self.deform_attn(src, src)))
        src = self.relu(src + self.dropout(self.ffn(src)))
        return src

class DeformableTransformerDecoder(nn.Module):
    def __init__(self, d_model, nhead, num_layers):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead)
        self.deform_attn = DeformableAttention(d_model, nhead)
        self.ffn = nn.Linear(d_model, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.1)
        self.relu = nn.ReLU()

    def forward(self, tgt, memory):
        tgt = self.norm1(tgt + self.dropout(self.self_attn(tgt, tgt)))
        tgt = self.norm2(tgt + self.dropout(self.deform_attn(tgt, memory)))
        tgt = self.relu(tgt + self.dropout(self.ffn(tgt)))
        tgt = self.norm3(tgt)
        return tgt

In [103]:
class DAT(nn.Module):
    def __init__(self, d_model, nhead, num_encoder_layers, num_decoder_layers):
        super().__init__()
        self.encoder = DeformableTransformerEncoder(d_model, nhead, num_encoder_layers)
        self.decoder = DeformableTransformerDecoder(d_model, nhead, num_decoder_layers)
        self.image_restoration_head = ImageRestorationHead(3, d_model)
        self.norm = nn.LayerNorm(d_model)  
        self.dropout = nn.Dropout(0.1)

    def forward(self, src):
     
        memory = self.encoder(src)
        memory = self.norm(memory)
        memory = self.dropout(memory)

        tgt = torch.zeros_like(src)
        output = self.decoder(tgt, memory, src)
        output = self.norm(output)
        output = self.dropout(output)

        
        restored_image = self.image_restoration_head(output)

        return restored_image

model = DAT(d_model=256, nhead=8, num_encoder_layers=6, num_decoder_layers=6)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001) 
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
criterion = nn.MSELoss()

In [104]:
'''class SSIM(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, input, target):
        return ssim(input, target)

class PSNR(nn.Module):
    #def __init__(self):
        super().__init__()

    def forward(self, input, target):
        mse = nn.functional.mse_loss(input, target)
        psnr = 10 * torch.log10(1 / mse)
        return psnr


psnr_metric = PSNR()
ssim_metric = SSIM()'''

'class SSIM(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, input, target):\n        return ssim(input, target)\n\nclass PSNR(nn.Module):\n    #def __init__(self):\n        super().__init__()\n\n    def forward(self, input, target):\n        mse = nn.functional.mse_loss(input, target)\n        psnr = 10 * torch.log10(1 / mse)\n        return psnr\n\n\npsnr_metric = PSNR()\nssim_metric = SSIM()'

In [105]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

class UHDMDataset(Dataset):
    def __init__(self, root_dir: str, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_labels = []
        self._load_images()

    def _load_images(self):
        for img_name in os.listdir(self.root_dir):
            if 'moire' in img_name and img_name.endswith('.jpg'):
                base_name = img_name.replace('_moire.jpg', '')
                moire_img_path = os.path.join(self.root_dir, img_name)
                label_img_path = os.path.join(self.root_dir, f'{base_name}_gt.jpg')
                
                if os.path.isfile(label_img_path):
                    self.image_labels.append((moire_img_path, label_img_path))

    def __len__(self) -> int:
        return len(self.image_labels)

    def __getitem__(self, idx: int):
        moire_img_path, label_img_path = self.image_labels[idx]
        moire_image = Image.open(moire_img_path).convert('RGB')
        label_image = Image.open(label_img_path).convert('RGB')
        
        if self.transform:
            moire_image = self.transform(moire_image)
            label_image = self.transform(label_image)
        
        return moire_image, label_image

# Initialize dataset and dataloader
train_set = UHDMDataset("UHDM/train", transform)
dataloader = DataLoader(train_set, batch_size=32, shuffle=True)


In [106]:
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(dataloader.dataset)
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {epoch_loss:.4f}')

    # Evaluation phase
    model.eval()
    psnr_values = []
    ssim_values = []
    vif_values = []
    moire_values = []

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            for i in range(images.size(0)):
                img = outputs[i].cpu().numpy().transpose(1, 2, 0) * 255.0
                gt = labels[i].cpu().numpy().transpose(1, 2, 0) * 255.0
                
                img = np.clip(img, 0, 255).astype(np.uint8)
                gt = np.clip(gt, 0, 255).astype(np.uint8)

                psnr_value = psnr(gt, img)
                ssim_value, _ = ssim(gt, img, full=True, multichannel=True)
                vif_value = vif(gt, img)
                moire_value = moire_metric(gt, img)

                psnr_values.append(psnr_value)
                ssim_values.append(ssim_value)
                vif_values.append(vif_value)
                moire_values.append(moire_value)

    avg_psnr = np.mean(psnr_values)
    avg_ssim = np.mean(ssim_values)
    avg_vif = np.mean(vif_values)
    avg_moire = np.mean(moire_values)

    print(f'Epoch [{epoch + 1}/{num_epochs}] - PSNR: {avg_psnr:.4f}, SSIM: {avg_ssim:.4f}, VIF: {avg_vif:.4f}, Moire: {avg_moire:.4f}')

# Save the trained model
torch.save(model.state_dict(), 'model.pth')
print('Model saved to model.pth')

TypeError: MultiheadAttention.forward() missing 1 required positional argument: 'value'