In [1]:
import torch
import torch.nn as nn

In [2]:
class Generator(nn.Module):
    
    def __init__(self):
        super(Generator, self).__init__()
        
        self.encoder1 = nn.Sequential(
            nn.Conv2d(1, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
        )
        
        self.self_attention1 = SelfAttention(256)

        self.encoder2 = nn.Sequential(
            nn.Conv2d(256, 512, 3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.Conv2d(512, 1024, 3, stride=2, padding=1),
            nn.BatchNorm2d(1024),
            nn.ReLU(True)
        )
        
        self.self_attention2 = SelfAttention(1024)
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, 4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 3, stride=1, padding=1),
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.encoder1(x)
        x = self.self_attention1(x)
        x = self.encoder2(x)
        x = self.self_attention2(x)
        x = self.decoder(x)   # After the decoder, the size has been upsampled back to torch.Size([8, 3, 240, 320])
        return x


In [3]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),
            nn.BatchNorm2d(32),  # Originally, this wasn't included?
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True), 
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(128, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(256, 16, 4, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        self.self_attention1 = SelfAttention(64)
        self.self_attention2 = SelfAttention(128)
        self.self_attention3 = SelfAttention(16)
        
        self.liner = nn.Sequential(
            nn.Linear(16*2*4, 1),   #1*12*17  1*4*7
            #nn.ReLU(True),
            
            #nn.Linear(16, 1),
            nn.Sigmoid()
        )
           
    def forward(self, x):
        x = self.conv1(x)
        x = self.self_attention1(x)
        x = self.conv2(x)
        x = self.self_attention2(x)
        x = self.conv3(x)
        #x = self.self_attention3(x)
        x = x.view(-1,16*2*4)
        x = self.liner(x)
        return x


In [4]:
class FeatureMatchingLoss(nn.Module):
    def __init__(self):
        super(FeatureMatchingLoss, self).__init__()

    def forward(self, real_features, fake_features):
        loss = 0.0
        for real, fake in zip(real_features, fake_features):
            loss += torch.mean(torch.abs(real - fake))
        return loss
    
   

In [5]:
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)   # Floor division//
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))  # Weight scaling of the output of the self-attention module

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        
        proj_query = self.query_conv(x).view(batch_size, -1, height * width).permute(0, 2, 1)  # for点积，交换张量的第二、第三个维度
        proj_key = self.key_conv(x).view(batch_size, -1, height * width)
        
        energy = torch.bmm(proj_query, proj_key)  # dot product
        attention = torch.softmax(energy, dim=-1)  # Mapping to obtain weights between 0 and 1
        
        proj_value = self.value_conv(x).view(batch_size, -1, height * width)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        
        out = out.view(batch_size, channels, height, width)  # Back to the input size

        out = self.gamma * out + x

        return out

In [6]:
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image
import os


# Dataset
class FusionDataset(Dataset):   # It includes an infrared image (inf) and a visible light image (vis)
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.inf_dir = os.path.join(self.root_dir, 'infrared')  # infrared
        self.vis_dir = os.path.join(self.root_dir, 'visible')  # visible
        self.inf_files = sorted(os.listdir(self.inf_dir))
        self.vis_files = sorted(os.listdir(self.vis_dir))
        self.transform = transforms.Compose([      # Transformation function
            # transforms.Resize([240,320]),
            transforms.CenterCrop([128,256]),  # Crop from the center, make it smaller
            transforms.ToTensor(),      # Convert image data from integer type with pixel values (0-255) to floating-point type tensor (0.0-1.0), and reshape the channel from the format HWC to CHW. Normalize the image
            transforms.Normalize((0.5,), (0.5,))  # Normalize to the range of [-1,1], image = (image - mean) / std
        ])
        
    def __len__(self):   # Return the dataset size (number of images)
        return len(self.inf_files)
        
    def __getitem__(self, idx):
        inf_path = os.path.join(self.inf_dir, self.inf_files[idx])
        vis_path = os.path.join(self.vis_dir, self.vis_files[idx])
        inf_image = Image.open(inf_path).convert('L')  # Convert to single-channel grayscale image (optional)
        vis_image = Image.open(vis_path)
        inf_tensor = self.transform(inf_image)
        vis_tensor = self.transform(vis_image)
        return inf_tensor, vis_tensor