In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GateFusionModel(nn.Module):
    def __init__(self, m=5, n=5, k=5):
        super(GateFusionModel, self).__init__()
        
        # 假设输入向量的长度
        self.input_dim_1 = 768
        self.input_dim_2 = 67
        self.m=m
        self.n=n
        self.k=k
        # 1D卷积层
        self.conv1d_1 = nn.Conv1d(in_channels=1, out_channels=m, kernel_size=5, stride=1, padding=2)
        self.conv1d_2 = nn.Conv1d(in_channels=1, out_channels=n, kernel_size=5,stride=1, padding=2)
        
        self.mlp = nn.Sequential(
        nn.Linear(768 + 67, 128),
        nn.ReLU(),
        nn.Linear(128, 1),
        nn.Sigmoid()
        )
        
        # 最终卷积层
        self.conv1d_final = nn.Conv1d(in_channels=k, out_channels=2 * k, kernel_size=1)
        self.linear = nn.Linear(768 + 67, 768)

    def forward(self, c1, c2):
        batch_size = c1.size(0)
        # Step 1: Conv1d transformation
        transformed_vector1 = self.conv1d_1(c1)  # (batch_size, m, 768)
        transformed_vector2 = self.conv1d_2(c2)  # (batch_size, n, 67)
        
        # Step 2: Concatenate all m * n combinations
        transformed_vector1 = transformed_vector1.unsqueeze(2).repeat(1, 1, self.n, 1)  # (batch_size, m, n, 768)
        transformed_vector2 = transformed_vector2.unsqueeze(1).repeat(1, self.m, 1, 1)  # (batch_size, m, n, 67)
        combined_features = torch.cat([transformed_vector1, transformed_vector2], dim=-1)  # (batch_size, m, n, 768+67)
        combined_features = combined_features.view(batch_size, -1, 768 + 67)  # (batch_size, m*n, 768+67)

        # Step 3: Compute weights using MLP
        mlp_weights = self.mlp(combined_features)  # (batch_size, m*n, 1)
        avg_weights = combined_features.mean(dim=-1, keepdim=True)  # (batch_size, m*n, 1)
        weights = mlp_weights * avg_weights  # (batch_size, m*n, 1)

        # Step 4: Combine scaled features and original features
        res_features = combined_features + combined_features * weights  # (batch_size, m*n, 768+67)

        # Step 5: Compute gating weights and select top-k
        gating_weights = res_features.mean(dim=-1)  # (batch_size, m*n)
        topk_values, topk_indices = torch.topk(gating_weights, self.k, dim=-1)  # (batch_size, k)

        topk_res_features = torch.gather(
            res_features, dim=1, index=topk_indices.unsqueeze(-1).expand(-1, -1, 768 + 67)
        )  # (batch_size, k, 768+67)


        # Step 6: Final convolution layer
        final_output = self.conv1d_final(topk_res_features)  # (batch_size, 2*k, 768 + 67)
        final_output = final_output.mean(dim=1, keepdim=True)  # (batch_size, 1, 768 + 67)
        
        return self.linear(final_output)

f1=torch.randn(8, 1, 768)
f2=torch.randn(8, 1, 67)
gate_model=GateFusionModel()
out=gate_model(f1,f2)
print(out.shape)  # torch.Size([8, 1, 768])

torch.Size([8, 1, 768])


In [None]:
# dataset
import pandas as pd
import scipy.ndimage
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms
import random
import matplotlib.pyplot as plt
import os

class CustomDataset(Dataset):
    def __init__(self,img_dir='DATA_FOLDER',train_val_test='train'):
        train_val_test_dict = {'train': 'trainset_normalized.csv', 'val': 'valset_normalized.csv', 'test': 'testset_normalized.csv'}
        self.labelfile = os.path.join('files',train_val_test_dict[train_val_test])
        self.label_array=np.array(pd.read_csv(self.labelfile))
        self.label_dict={}
        for i in range(len(self.label_array)):
            self.label_dict[self.label_array[i][0].replace('.nii.gz','.png')] = self.label_array[i][1:]
       
        self.cc_img_root = os.path.join(img_dir,'images',train_val_test)
        self.images=os.listdir(self.cc_img_root)
        self.image_transform = transforms.Compose([
            transforms.Resize((256, 256)),  # 保持 128x128 大小
            transforms.ToTensor(),
            transforms.Normalize((0.5), (0.5)),
        ])
        self.mask_transform = transforms.Compose([
            transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.NEAREST),  
        ])
        self.type=type
        
    def __len__(self):
        return len(self.images)
    
    def get_random_box_mask(self, image):
        y, x = image.shape[-2:]
        mask = torch.zeros((1, y, x))

        top = random.randint(0, y)
        left = random.randint(0, x)
        height = random.randint(0, y - top)
        width = random.randint(0, x - left)

        mask[:, top:top+height, left:left+width] = 1

        return 1-mask
    
    def gaussian_smooth_and_normalize(self, hard_labels, sigma=1.0):
        """
        Apply Gaussian smoothing to a segmentation mask with labels {0, 1, 2},
        and normalize the resulting soft labels to the range [0, 1].
        
        Parameters:
        - hard_labels (np.array): The hard label segmentation map with values {0, 1, 2}.
        - sigma (float): Standard deviation of the Gaussian kernel for smoothing. 
        
        Returns:
        - soft_labels (np.array): Softened label map with smoothed boundaries, scaled to [0, 1].
        """
        # Apply Gaussian filter to the hard labels
        smoothed_labels = scipy.ndimage.gaussian_filter(hard_labels.astype(np.float32), sigma=sigma)
        
        # Normalize the smoothed values to the range [0, 1]
        soft_labels = (smoothed_labels - smoothed_labels.min()) / (smoothed_labels.max() - smoothed_labels.min())
        
        return soft_labels
    
    def __getitem__(self, idx):
        cc_img_path = os.path.join(self.cc_img_root, self.images[idx])
        cc_mask_path = cc_img_path.replace('images','masks')
        cc_image= self.image_transform(Image.open(cc_img_path).convert("RGB"))
        cc_mask = self.mask_transform(Image.open(cc_mask_path).convert("L"))
        hardlabel=np.array(cc_mask)
        softlabel=self.gaussian_smooth_and_normalize(hardlabel, sigma=1)
        
        data = {}
        data["image_target"] = cc_image
        data["image_cond"] = torch.FloatTensor(softlabel).unsqueeze(0).repeat(3,1,1)
        data["mass_cond"]= torch.FloatTensor((hardlabel==2).astype(np.float32)).unsqueeze(0).repeat(3,1,1)
        if self.images[idx] in self.label_dict:
            data["additional_feature"]= torch.FloatTensor((self.label_dict[self.images[idx]]).astype(np.float32)).unsqueeze(0)
        else:
            data["additional_feature"]= torch.FloatTensor((np.array([0]*67)).astype(np.float32)).unsqueeze(0)
        return data
    
dataset=CustomDataset(img_dir='DATA_FOLDER',train_val_test='train')
print(len(dataset))
data=dataset[0]
print(data["image_target"].shape) # torch.Size([3, 256, 256])
print(data["image_cond"].shape) # torch.Size([3, 256, 256])
print(data["mass_cond"].shape) # torch.Size([3, 256, 256])
print(data["additional_feature"].shape) # torch.Size([1, 67])

6704
torch.Size([3, 256, 256])
torch.Size([3, 256, 256])
torch.Size([3, 256, 256])
torch.Size([1, 67])
