In [117]:
import torch

inputs = torch.tensor(
  [[0.4, 0.153, 0.845], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.05, 0.80, 0.55]], # step     (x^6)
)

In [118]:
inputs.shape

torch.Size([4, 3])

In [157]:
batch = torch.stack((inputs, inputs), dim=0)
batch.shape

torch.Size([2, 4, 3])

In [210]:
import torch
import torch.nn as nn

import torch

import torch

def generate_unique_row_binary_tensor(n):
    tensor = torch.randint(0, 2, (n, n))

    # Keep checking for duplicates until all rows are unique
    while True:
        # Convert rows to tuples for easy uniqueness check
        rows = [tuple(row.tolist()) for row in tensor]

        # Find duplicates by counting occurrences
        duplicates = [row for row in set(rows) if rows.count(row) > 1]

        if not duplicates:
            break  # all rows unique

        # For each duplicate, randomly pick one occurrence to change
        for dup in duplicates:
            indices = [i for i, row in enumerate(rows) if row == dup]
            # Keep one occurrence intact, change others
            for idx in indices[1:]:
                tensor[idx] = torch.randint(0, 2, (n,))

    return tensor


class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length=4, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.context_length = context_length

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(0.5)
        self.register_buffer('mask', torch.zeros(context_length, context_length)) # New

    def forward(self, x):
        """
        x: Tensor of shape (batch_size, num_tokens, d_in)
        Returns:
            context_vec: Tensor of shape (batch_size, num_tokens, d_out)
        """
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        attn_scores = queries @ keys.transpose(1, 2) # Changed transpose
        attn_scores.masked_fill_(  # New, _ ops are in-place
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights) # New
        context_vec = attn_weights @ values
        return context_vec


class OnCloudMNISTModel(nn.Module):
    def __init__(self, d_in=196, d_out=64, context_length=4):
        super().__init__()
        self.name = "OnCloudMNISTModel"

        # Pass required dimensions
        self.features = SelfAttention(d_in, d_out, context_length)

        # Classifier assumes output shape: (batch_size, context_length, d_out)
        self.classifier = nn.Sequential(
            nn.Linear(context_length * d_out, 128),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Dropout(0.5),
            nn.Linear(128, 10)
        )

    def forward(self, x):  # x: (batch_size, context_length=4, d_in=196)
        b, num_tokens, d_in = x.shape
    
        if self.training:
            p = 0.2 
            mask = (torch.rand(b, num_tokens, device=x.device) > p).float()
            x = x * mask.unsqueeze(-1)  
            
        if num_tokens < self.features.context_length:
            pad_len = self.features.context_length - num_tokens
            pad_tensor = torch.zeros(b, pad_len, d_in, device=x.device, dtype=x.dtype)
            x = torch.cat([x, pad_tensor], dim=1)  # pad on token dim

        x = self.features(x)  # → (batch_size, 4, d_out)
        x = x.view(x.size(0), -1)  # Flatten → (batch_size, 4*d_out)
        x = self.classifier(x)     # → (batch_size, 10)
        return x
    
    

In [211]:
import torch 
d_in = inputs.shape[1] # the input embedding size, d=3
d_out = 2 # the output embedding size, d=2
torch.manual_seed(123)

context_length = batch.shape[1]
ca = SelfAttention(d_in, d_out, context_length, 0.0)

context_vecs = ca(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (192x196 and 3x2)

In [212]:
import torch

# Simulate input: batch_size=64, 4 sources, 196 features each
batch = torch.randn(64, 3, 196)

# Create model
model = OnCloudMNISTModel(d_in=196, d_out=64, context_length=4)

# Forward pass
logits = model(batch)

print("logits.shape:", logits.shape)  # → torch.Size([64, 10])

logits.shape: torch.Size([64, 10])


In [213]:
import os
import json 

from federated_inference.common.environment import  DataMode, TransformType
from federated_inference.simulations.simulation import Simulation
from federated_inference.simulations.utils import *
from federated_inference.configs.model_configs import OnCloudModelConfiguration

class OnCloudVerticalSimulation(Simulation): 
    def __init__(self, seed, version, data_config, transform_config, model, transform_type: TransformType = TransformType.FULL_STRIDE_PARTITION, exist=False):
        self.seed = seed
        self.version = version
        self.data_config = data_config
        self.transform_config = transform_config
        self.server_model_config = OnCloudModelConfiguration
        self.data_mode = DataMode.VERTICAL
        self.transform_type = transform_type
        self.dataset =  self.load_data(data_config)
        self.client_datasets, self.transformation = self.transform_data(self.dataset, data_mode = self.data_mode, transform_config = transform_config, transform_type = self.transform_type)


In [215]:
from federated_inference.common.utils import set_seed
from federated_inference.configs.data_config import DataConfiguration
from federated_inference.configs.transform_config import DataTransformConfiguration
DATASET = 'MNIST'
VERSION = "attention_v1"
seed = 4
data_config = DataConfiguration(DATASET)
transform_config = DataTransformConfiguration()
simulation = OnCloudVerticalSimulation(seed, VERSION, data_config, transform_config, OnCloudMNISTModel, exist=False)


MNIST training data loaded.
MNIST test data loaded.


In [216]:
def combine_batch_columns(batch_list):
    """
    batch_list: list of batches from each dataset loader
    Each batch is usually a tuple (inputs, targets)
    """       
    inputs_list = [b[0] for b in batch_list]  # get inputs from each batch
    flattened = [x.view(x.size(0), -1)for x in inputs_list]
    inputs_tensor = torch.stack(flattened, dim=1)
    targets_list = [b[1] for b in batch_list] 
    combined_targets = targets_list[0] 
    return inputs_tensor, combined_targets


In [217]:
from torch.utils.data import DataLoader, Subset
trainsets =  [client.train_dataset for client in simulation.client_datasets]
testsets = [client.test_dataset for client in simulation.client_datasets]
testloader = [DataLoader(testdata, batch_size=54, shuffle=False) for testdata in testsets]
trainloader = [DataLoader(traindata, batch_size=64, shuffle=False)  for traindata in trainsets]
for batch_idx, batches in enumerate(zip(*trainloader)):
    inputs_tensor, combined_targets = combine_batch_columns(batches)
    print(inputs_tensor.shape)
    break

torch.Size([64, 4, 196])


In [218]:
import torch.optim as optim

model = OnCloudMNISTModel(d_in=196, d_out=64, context_length=4)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 10  # or whatever you want

for epoch in range(num_epochs):
    model.train()
    for batch_idx, batches in enumerate(zip(*trainloader)):
        inputs_tensor, combined_targets = combine_batch_columns(batches)

        outputs = model(inputs_tensor)
        loss = criterion(outputs, combined_targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f"Epoch {epoch} Batch {batch_idx} Loss: {loss.item():.4f}")

    # Validation phase
    model.eval()
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for test_batches in zip(*testloader):
            inputs_tensor, combined_targets = combine_batch_columns(test_batches)
            outputs = model(inputs_tensor)
            preds = outputs.argmax(dim=1)
            total_correct += (preds == combined_targets).sum().item()
            total_samples += combined_targets.size(0)

    accuracy = total_correct / total_samples
    print(f"Epoch {epoch} Validation Accuracy: {accuracy:.4f}")


Epoch 0 Batch 0 Loss: 2.4311
Epoch 0 Batch 100 Loss: 1.0218
Epoch 0 Batch 200 Loss: 0.9316
Epoch 0 Batch 300 Loss: 0.8708
Epoch 0 Batch 400 Loss: 0.6706
Epoch 0 Batch 500 Loss: 0.7131
Epoch 0 Batch 600 Loss: 0.9207
Epoch 0 Batch 700 Loss: 0.7305
Epoch 0 Batch 800 Loss: 0.7004
Epoch 0 Batch 900 Loss: 0.8032
Epoch 0 Validation Accuracy: 0.8770
Epoch 1 Batch 0 Loss: 0.6349
Epoch 1 Batch 100 Loss: 0.7741
Epoch 1 Batch 200 Loss: 0.6285
Epoch 1 Batch 300 Loss: 0.7795
Epoch 1 Batch 400 Loss: 0.5425
Epoch 1 Batch 500 Loss: 0.5838
Epoch 1 Batch 600 Loss: 0.7232
Epoch 1 Batch 700 Loss: 0.7094
Epoch 1 Batch 800 Loss: 0.8272
Epoch 1 Batch 900 Loss: 0.7676
Epoch 1 Validation Accuracy: 0.9074
Epoch 2 Batch 0 Loss: 0.6425
Epoch 2 Batch 100 Loss: 0.5870
Epoch 2 Batch 200 Loss: 0.7176
Epoch 2 Batch 300 Loss: 0.4657
Epoch 2 Batch 400 Loss: 0.6159
Epoch 2 Batch 500 Loss: 0.8573
Epoch 2 Batch 600 Loss: 0.6005
Epoch 2 Batch 700 Loss: 0.6772
Epoch 2 Batch 800 Loss: 0.7446
Epoch 2 Batch 900 Loss: 0.6402
Epoc

KeyboardInterrupt: 

In [None]:
import os

import torch.nn as nn
import torch
from torch.utils.data import DataLoader, Subset
import numpy as np

from federated_inference.common.early_stopper import EarlyStopper
from federated_inference.common.cost_calculator import CostCalculator
from federated_inference.common.environment import Member


class OnCloudVerticalServer():

    def __init__(self, 
            idx, 
            seed,
            model_config,
            data_config,
            log: bool = True, 
            log_interval: int = 100,
            save_interval: int = 20
        ):
        self.idx = idx
        self.seed = seed
        self.model_config = model_config
        self.data_config = data_config
        self.seed = seed
        self.n_epoch = model_config.N_EPOCH
        self.device = model_config.DEVICE
        self.member_type = Member.SERVER
        self.model = model_config.MODEL
        self.optimizer = model_config.OPTIMIZER
        self.criterion = model_config.CRITERION
        self.costs = []

        self.log = log
        self.log_interval = log_interval
        self.save_interval = save_interval

        if self.log: 
            self.train_losses = []
            self.test_losses = []
            self.accuracies = []



    def _to_loader(self, trainsets, testsets, batch_size_train, batch_size_val, batch_size_test, train_shuffle, val_shuffle, test_shuffle, train_ratio):
        # TODO refactoing to use self
        if True:
            print("shuffle training_data")
            # Assuming all trainsets have the same length
            dataset_length = len(trainsets[0])
            assert all(len(trainset) == dataset_length for trainset in trainsets), "All trainsets must be the same length"

            self.train_set_indices = np.arange(dataset_length)
            
            np.random.shuffle(self.train_set_indices)

            train_end = round(train_ratio * dataset_length)
            train_indices = self.train_set_indices[:train_end]
            val_indices = self.train_set_indices[train_end:]

            traindatas = [Subset(trainset, train_indices) for trainset in trainsets]
            valdatas = [Subset(trainset, val_indices) for trainset in trainsets]
        else:
            traindatas = [Subset(trainset, range(round(train_ratio*len(trainset)))) for trainset in trainsets]
            valdatas = [Subset(trainset, range(round(train_ratio*len(trainset)), len(trainset))) for trainset in trainsets]
        self.trainloader = [DataLoader(traindata, batch_size=batch_size_train, shuffle=False)  for traindata in traindatas]
        self.valloader = [DataLoader(valdata, batch_size=batch_size_val, shuffle=False) for valdata in valdatas]
        self.testloader = [DataLoader(testdata, batch_size=batch_size_test, shuffle=False) for testdata in testsets]


    def _pred_loader(self, testsets, batch_size_test, test_shuffle):
        return  [DataLoader(testdata, batch_size=batch_size_test, shuffle=False) for testdata in testsets]

    def train(self, epoch):
        self.model.train()
        for batch_idx, batches in enumerate(zip(*self.trainloader)):
            # batches is tuple of batch from each loader
            data, target = self.combine_batch_columns(batches, expect_target=True)
            data = data.to(self.device).float()
            target = target.to(self.device).long()
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()

            if self.log and batch_idx % self.log_interval == 0:
                print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(self.trainloader[0].dataset)} '
                    f'({100. * batch_idx / len(self.trainloader[0]):.0f}%)]\tLoss: {loss.item():.6f}')

            if batch_idx % self.save_interval == 0:
                self.train_losses.append(loss.item())
        val_loss = 0
        self.model.eval()
        with torch.no_grad():
            for batch_idx, batches in enumerate(zip(*self.valloader)):
                data, target = self.combine_batch_columns(batches, expect_target=True)
                data = data.to(self.device).float()
                target = target.to(self.device).long()
                output = self.model(data)
                val_loss += self.criterion(output, target).item()
        val_loss /= len(self.valloader[0].dataset)
        if self.early_stopper.best_loss is None or val_loss < self.early_stopper.best_loss:
            print("Validation loss improved. Saving model...")
            self.save()
        self.early_stopper(val_loss)

    def test(self):
        self.model.eval()
        test_loss = 0
        correct = 0

        with torch.no_grad():
            for batch_idx, batches in enumerate(zip(*self.testloader)):
                # batches is tuple of batch from each loader
                data, target = self.combine_batch_columns(batches)
                data = data.to(self.device).float()
                target = target.to(self.device).long()
                output = self.model(data)
                test_loss += self.criterion(output, target).item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(self.testloader[0].dataset)
        accuracy = 100. * correct / len(self.testloader[0].dataset)
        if self.log: 
            self.test_losses.append(test_loss)
            self.accuracies.append(accuracy)
        print(f'\nTest set: Average loss per sample: {test_loss:.4f}, Accuracy: {correct}/{len(self.testloader[0].dataset)} '
            f'({accuracy:.0f}%)\n')
            
    def run_training(self, trainset, testset):
        self.early_stopper = EarlyStopper()
        self._to_loader(trainset, testset, 
            self.model_config.BATCH_SIZE_TRAIN,
            self.model_config.BATCH_SIZE_VAL, 
            self.model_config.BATCH_SIZE_TEST, 
            self.model_config.TRAIN_SHUFFLE,
            self.model_config.VAL_SHUFFLE, 
            self.model_config.TEST_SHUFFLE,
            self.model_config.TRAIN_RATIO)
        self.test()
        for epoch in range(1, self.model_config.N_EPOCH + 1):
            self.train(epoch)
            self.test()

    def save(self):
        result_path = f"./results/oncloud/{self.model_config.version}/{self.data_config.DATASET_NAME}/{self.seed}"
        os.makedirs(result_path, exist_ok=True)
        model_path = os.path.join(result_path, f'model_server_{self.idx}.pth').replace("\\", "/")
        optimizer_path = os.path.join(result_path, f'optimizer_server_{self.idx}.pth').replace("\\", "/")
        torch.save(self.model.state_dict(), model_path)
        torch.save(self.optimizer.state_dict(), optimizer_path)


    def load(self):
        result_path = f"./results/oncloud/{self.model_config.version}/{self.data_config.DATASET_NAME}/{self.seed}"
        model_path = os.path.join(result_path, f'model_server_{self.idx}.pth').replace("\\", "/")
        optimizer_path = os.path.join(result_path, f'optimizer_server_{self.idx}.pth').replace("\\", "/")
        network_state_dict = torch.load(model_path)
        self.model.load_state_dict(network_state_dict)
        optimizer_state_dict = torch.load(optimizer_path)
        self.optimizer.load_state_dict(optimizer_state_dict)