In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import timm
import matplotlib.pyplot as plt




# Generated Dummy Dataset

In [2]:
# Example dataset
num_samples = 1000
img_shape = (3,224,224)
rna_dim = 20000

images = torch.randn(num_samples, *img_shape)
rna_seq = torch.randn(num_samples, rna_dim)
labels = torch.randint(0,2,(num_samples,))

dataset = TensorDataset(images, rna_seq, labels)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Client 1 Model

In [3]:

class ImageModelViT(nn.Module):
    def __init__(self, emb_dim=128):
        super().__init__()
        self.vit = timm.create_model('vit_tiny_patch16_224', pretrained=True)
        self.vit.head = nn.Linear(self.vit.head.in_features, emb_dim)

    def forward(self, x):
        return self.vit(x)


# Client 2 Model

In [4]:


class RNATransformer(nn.Module):
    def __init__(self, input_dim, emb_dim=128, nhead=8, num_layers=2):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, emb_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=nhead, dim_feedforward=512)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.emb_dim = emb_dim

    def forward(self, x):
        x = self.input_proj(x).unsqueeze(1)
        x = x.transpose(0,1)
        x = self.transformer(x)
        x = x.transpose(0,1)
        x = self.pool(x.transpose(1,2)).squeeze(-1)
        return x


# Server Model

In [5]:

class AttentionServerModel(nn.Module):
    def __init__(self, emb_dim=128):
        super().__init__()
        # Attention score layers for each modality
        self.attn_img = nn.Linear(emb_dim, 1)
        self.attn_rna = nn.Linear(emb_dim, 1)

        # Classifier after attention fusion
        self.classifier = nn.Sequential(
            nn.Linear(emb_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, h_img, h_rna):
        # Compute attention scores
        score_img = self.attn_img(h_img)
        score_rna = self.attn_rna(h_rna)

        # Concatenate scores and normalize with softmax
        scores = torch.cat([score_img, score_rna], dim=1)
        weights = F.softmax(scores, dim=1)

        # Weighted sum of embeddings
        z = weights[:, 0:1] * h_img + weights[:, 1:2] * h_rna

        # Pass through classifier
        out = self.classifier(z)
        return out


In [6]:
img_model = model = ImageModelViT(emb_dim=128)
#img_model = ImageModel(emb_dim=128)
rna_model =  RNATransformer(input_dim=rna_dim)
server_model = AttentionServerModel()

criterion = nn.BCEWithLogitsLoss()

opt_img = optim.Adam(img_model.parameters(), lr=1e-3)
opt_rna = optim.Adam(rna_model.parameters(), lr=1e-3)
opt_server = optim.Adam(server_model.parameters(), lr=1e-3)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/22.9M [00:00<?, ?B/s]



In [7]:
server_Fusion_loss_history = []
server_Fusion_acc_history = []
client_1_loss_history = []
client_2_loss_history = []

# Learning in Vertical Training

In [None]:

num_epochs = 10

# Store metrics
server_loss_history = []
server_acc_history = []
img_loss_history = []
rna_loss_history = []

for epoch in range(num_epochs):
    running_loss_server = 0.0
    running_loss_img = 0.0
    running_loss_rna = 0.0
    correct = 0
    total = 0

    for batch in loader:
        imgs, rna, y = [b for b in batch]

        # Client 1: Image embedding
        h_img = img_model(imgs)
        h_img.requires_grad_()

        # Client 2: RNA embedding
        h_rna = rna_model(rna)
        h_rna.requires_grad_()

        # Server - Attention Gated fusion & classification
        logits = server_model(h_img, h_rna)
        loss = criterion(logits.squeeze(), y.float())

        # Backprop
        opt_server.zero_grad()
        opt_img.zero_grad()
        opt_rna.zero_grad()
        loss.backward()

        opt_server.step()
        opt_img.step()
        opt_rna.step()


        batch_size = imgs.size(0)
        running_loss_server += loss.item() * batch_size

        # For client models, approximate their contribution by partial gradient norm
        running_loss_img += h_img.grad.norm().item() if h_img.grad is not None else 0
        running_loss_rna += h_rna.grad.norm().item() if h_rna.grad is not None else 0

        pred = torch.sigmoid(logits).round()
        correct += (pred.squeeze() == y).sum().item()
        total += y.size(0)

    # Epoch metrics
    epoch_loss_server = running_loss_server / len(dataset)
    epoch_loss_img = running_loss_img / len(loader)
    epoch_loss_rna = running_loss_rna / len(loader)
    epoch_acc_server = correct / total

    server_loss_history.append(epoch_loss_server)
    server_acc_history.append(epoch_acc_server)
    img_loss_history.append(epoch_loss_img)
    rna_loss_history.append(epoch_loss_rna)

    print(f"Epoch {epoch+1}/{num_epochs} - Server Loss: {epoch_loss_server:.4f}, Acc: {epoch_acc_server:.4f}")


In [None]:

plt.figure(figsize=(12,5))

plt.subplot(1,2,1)
plt.plot(server_loss_history, label='Server Loss', marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Server Loss per Epoch')
plt.grid(True)
plt.legend()

plt.subplot(1,2,2)
plt.plot(server_acc_history, label='Server Accuracy', marker='o', color='green')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Server Accuracy per Epoch')
plt.grid(True)
plt.legend()

plt.show()


