In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset


# Sample dataset generated

In [None]:
# Generated Dataset
num_samples = 1000
img_shape = (3,28,28)
rna_dim = 20000

images = torch.randn(num_samples, *img_shape)
rna_seq = torch.randn(num_samples, rna_dim)
labels = torch.randint(0,2,(num_samples,))

dataset = TensorDataset(images, rna_seq, labels)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Client_1:



In [None]:
class ImageModel(nn.Module):
    def __init__(self, emb_dim=128):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), # Changed in_channels from 1 to 3 to match img_shape
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1,1))
        )
        self.fc = nn.Linear(64, emb_dim)

    def forward(self, x):
        x = self.cnn(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# Client_2:

In [None]:
class RNAModel(nn.Module):
    def __init__(self, input_dim, emb_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Linear(512, emb_dim)
        )

    def forward(self, x):
        return self.net(x)


# Server_Gated_Fusion

In [None]:
class ServerModel(nn.Module):
    def __init__(self, emb_dim=128):
        super().__init__()
        self.gate = nn.Sequential(
            nn.Linear(emb_dim*2, emb_dim),
            nn.Sigmoid()
        )
        self.classifier = nn.Sequential(
            nn.Linear(emb_dim, 64),
            nn.ReLU(),
            nn.Linear(64,1)
        )

    def forward(self, h_img, h_rna):
        h_cat = torch.cat([h_img, h_rna], dim=1)
        g = self.gate(h_cat)
        z = g * h_img + (1 - g) * h_rna
        return self.classifier(z)



# Server (Attention + Gated) Fusion

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionServerModel(nn.Module):
    def __init__(self, emb_dim=128):
        super().__init__()
        # Attention score layers for each modality
        self.attn_img = nn.Linear(emb_dim, 1)
        self.attn_rna = nn.Linear(emb_dim, 1)

        # Classifier after attention fusion
        self.classifier = nn.Sequential(
            nn.Linear(emb_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, h_img, h_rna):
        # Compute attention scores
        score_img = self.attn_img(h_img)  # [batch, 1]
        score_rna = self.attn_rna(h_rna)  # [batch, 1]

        # Concatenate scores and normalize with softmax
        scores = torch.cat([score_img, score_rna], dim=1)  # [batch, 2]
        weights = F.softmax(scores, dim=1)  # [batch, 2]

        # Weighted sum of embeddings
        z = weights[:, 0:1] * h_img + weights[:, 1:2] * h_rna  # [batch, emb_dim]

        # Pass through classifier
        out = self.classifier(z)
        return out


# Model optimizer in server for Clients and Server fusion model

In [None]:
img_model = ImageModel()
rna_model = RNAModel(input_dim=rna_dim)
server_model = AttentionServerModel()

criterion = nn.BCEWithLogitsLoss()

opt_img = optim.Adam(img_model.parameters(), lr=1e-3)
opt_rna = optim.Adam(rna_model.parameters(), lr=1e-3)
opt_server = optim.Adam(server_model.parameters(), lr=1e-3)

In [None]:
num_epochs = 10

for epoch in range(num_epochs):
    running_loss = 0.0
    for batch in loader:
        imgs, rna, y = [b for b in batch]

        # Training Client 1 with Modality one
        h_img = img_model(imgs)
        h_img.requires_grad_()

        # Training Client 2 with Modality 2
        h_rna = rna_model(rna)
        h_rna.requires_grad_()

        # Server - Attention Gated fusion & classification
        logits = server_model(h_img, h_rna)
        loss = criterion(logits.squeeze(), y.float())

        # Backprop-optm on client_1, client_2 & Server
        opt_server.zero_grad()
        opt_img.zero_grad()
        opt_rna.zero_grad()

        loss.backward()

        opt_server.step()
        opt_img.step()
        opt_rna.step()

        running_loss += loss.item() * imgs.size(0)

    epoch_loss = running_loss / len(dataset)
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {epoch:.4f}")

Epoch 1/10 - Loss: 0.0000
Epoch 2/10 - Loss: 1.0000
Epoch 3/10 - Loss: 2.0000
Epoch 4/10 - Loss: 3.0000
Epoch 5/10 - Loss: 4.0000
Epoch 6/10 - Loss: 5.0000
Epoch 7/10 - Loss: 6.0000
Epoch 8/10 - Loss: 7.0000
Epoch 9/10 - Loss: 8.0000
Epoch 10/10 - Loss: 9.0000
