<a href="https://colab.research.google.com/github/hassanSattariNia/FederatedLearning/blob/main/spliteTo8Client.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from transformers import AlbertModel, AlbertConfig
import numpy as np

def analyze_albert_structure():
    # Load ALBERT model
    model = AlbertModel.from_pretrained('albert-base-v2')

    # Get total number of parameters
    total_params = sum(p.numel() for p in model.parameters())

    # Analyze memory requirements
    param_size = 0
    buffer_size = 0

    for param in model.parameters():
        param_size += param.nelement() * param.element_size()

    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    size_in_mb = (param_size + buffer_size) / 1024**2

    # Analyze layer structure
    layer_info = {}
    for name, module in model.named_modules():
        if len(list(module.children())) == 0:  # If it's a leaf module
            num_params = sum(p.numel() for p in module.parameters())
            layer_info[name] = {
                'parameters': num_params,
                'memory_mb': (num_params * 4) / (1024**2),  # Assuming float32
                'type': module.__class__.__name__
            }

    return {
        'total_parameters': total_params,
        'total_size_mb': size_in_mb,
        'layer_info': layer_info,
        'config': model.config,
        'model': model
    }

def print_model_analysis(analysis):
    print("\n=== ALBERT Model Analysis ===")
    print(f"Total Parameters: {analysis['total_parameters']:,}")
    print(f"Total Size in MB: {analysis['total_size_mb']:.2f}")

    print("\n=== Layer-wise Analysis ===")
    for name, info in analysis['layer_info'].items():
        if info['parameters'] > 0:  # Only show layers with parameters
            print(f"\nLayer: {name}")
            print(f"Type: {info['type']}")
            print(f"Parameters: {info['parameters']:,}")
            print(f"Memory (MB): {info['memory_mb']:.2f}")

    print("\n=== Model Configuration ===")
    config = analysis['config']
    print(f"Hidden Size: {config.hidden_size}")
    print(f"Intermediate Size: {config.intermediate_size}")
    print(f"Number of Hidden Layers: {config.num_hidden_layers}")
    print(f"Number of Attention Heads: {config.num_attention_heads}")

# Function to identify potential partition points
def suggest_partition_points(analysis, num_devices):
    total_params = analysis['total_parameters']
    target_size = total_params / num_devices

    current_size = 0
    partition_suggestions = []

    for name, info in analysis['layer_info'].items():
        current_size += info['parameters']
        if current_size >= target_size:
            partition_suggestions.append(name)
            current_size = 0

    return partition_suggestions

if __name__ == "__main__":
    # Analyze model
    analysis = analyze_albert_structure()
    print_model_analysis(analysis)

    # Example: Suggest partition points for 4 devices
    print("\n=== Suggested Partition Points (4 devices) ===")
    partition_points = suggest_partition_points(analysis, 4)
    for i, point in enumerate(partition_points, 1):
        print(f"Partition {i}: Cut after {point}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/684 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/47.4M [00:00<?, ?B/s]


=== ALBERT Model Analysis ===
Total Parameters: 11,683,584
Total Size in MB: 44.58

=== Layer-wise Analysis ===

Layer: embeddings.word_embeddings
Type: Embedding
Parameters: 3,840,000
Memory (MB): 14.65

Layer: embeddings.position_embeddings
Type: Embedding
Parameters: 65,536
Memory (MB): 0.25

Layer: embeddings.token_type_embeddings
Type: Embedding
Parameters: 256
Memory (MB): 0.00

Layer: embeddings.LayerNorm
Type: LayerNorm
Parameters: 256
Memory (MB): 0.00

Layer: encoder.embedding_hidden_mapping_in
Type: Linear
Parameters: 99,072
Memory (MB): 0.38

Layer: encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm
Type: LayerNorm
Parameters: 1,536
Memory (MB): 0.01

Layer: encoder.albert_layer_groups.0.albert_layers.0.attention.query
Type: Linear
Parameters: 590,592
Memory (MB): 2.25

Layer: encoder.albert_layer_groups.0.albert_layers.0.attention.key
Type: Linear
Parameters: 590,592
Memory (MB): 2.25

Layer: encoder.albert_layer_groups.0.albert_layers.0.attention.value
T

In [2]:
import torch
from transformers import AlbertModel
from dataclasses import dataclass
from typing import List, Dict, Tuple
import numpy as np

@dataclass
class LayerProfile:
    name: str
    parameters: int
    memory_mb: float
    flops: int
    critical_path: bool
    dependencies: List[str]

class AlbertPartitioner:
    def __init__(self, num_clients=8):
        self.num_clients = num_clients
        self.model = AlbertModel.from_pretrained('albert-base-v2')
        self.config = self.model.config

    def profile_layers(self) -> Dict[str, LayerProfile]:
        layers = {}
        sequence_length = 512  # Standard sequence length

        # Profile embeddings
        embed_params = sum(p.numel() for p in self.model.embeddings.parameters())
        embed_flops = self.config.embedding_size * sequence_length
        layers['embeddings'] = LayerProfile(
            name='embeddings',
            parameters=embed_params,
            memory_mb=embed_params * 4 / (1024**2),
            flops=embed_flops,
            critical_path=True,
            dependencies=[]
        )

        # Profile each transformer layer
        for i in range(self.config.num_hidden_layers):
            # Attention layer
            attention_params = sum(p.numel() for p in
                                self.model.encoder.albert_layer_groups[0].albert_layers[0].attention.parameters())
            attention_flops = (sequence_length ** 2) * self.config.hidden_size * 4

            layers[f'attention_{i}'] = LayerProfile(
                name=f'attention_{i}',
                parameters=attention_params,
                memory_mb=attention_params * 4 / (1024**2),
                flops=attention_flops,
                critical_path=True,
                dependencies=[f'embeddings'] if i == 0 else [f'ffn_{i-1}']
            )

            # FFN layer
            ffn_params = sum(p.numel() for p in
                           self.model.encoder.albert_layer_groups[0].albert_layers[0].ffn.parameters())
            ffn_flops = sequence_length * self.config.hidden_size * self.config.intermediate_size * 2

            layers[f'ffn_{i}'] = LayerProfile(
                name=f'ffn_{i}',
                parameters=ffn_params,
                memory_mb=ffn_params * 4 / (1024**2),
                flops=ffn_flops,
                critical_path=True,
                dependencies=[f'attention_{i}']
            )

        # Profile pooler
        pooler_params = sum(p.numel() for p in self.model.pooler.parameters())
        pooler_flops = self.config.hidden_size ** 2
        layers['pooler'] = LayerProfile(
            name='pooler',
            parameters=pooler_params,
            memory_mb=pooler_params * 4 / (1024**2),
            flops=pooler_flops,
            critical_path=True,
            dependencies=[f'ffn_{self.config.num_hidden_layers-1}']
        )

        return layers

    def create_optimal_partitions(self) -> List[Dict]:
        layers = self.profile_layers()

        # Strategy for 8 clients:
        # 1. Client 0: Embeddings (heavy memory, low compute)
        # 2-6. Clients 1-5: 2-3 transformer layers each (balanced compute)
        # 7. Client 6: Remaining transformer layers
        # 8. Client 7: Pooler and final operations

        partitions = [{
            'client_id': i,
            'layers': [],
            'total_params': 0,
            'total_flops': 0,
            'memory_mb': 0.0,
            'dependencies': set()
        } for i in range(self.num_clients)]

        # Assign embeddings to first client
        partitions[0]['layers'].append(layers['embeddings'])
        partitions[0]['total_params'] += layers['embeddings'].parameters
        partitions[0]['total_flops'] += layers['embeddings'].flops
        partitions[0]['memory_mb'] += layers['embeddings'].memory_mb

        # Distribute transformer layers
        transformer_layers = [(k, v) for k, v in layers.items()
                            if 'attention' in k or 'ffn' in k]
        layers_per_client = len(transformer_layers) // (self.num_clients - 2)

        for i, (name, layer) in enumerate(transformer_layers):
            client_id = 1 + (i // layers_per_client)
            if client_id >= self.num_clients - 1:
                client_id = self.num_clients - 2

            partitions[client_id]['layers'].append(layer)
            partitions[client_id]['total_params'] += layer.parameters
            partitions[client_id]['total_flops'] += layer.flops
            partitions[client_id]['memory_mb'] += layer.memory_mb
            partitions[client_id]['dependencies'].update(layer.dependencies)

        # Assign pooler to last client
        partitions[-1]['layers'].append(layers['pooler'])
        partitions[-1]['total_params'] += layers['pooler'].parameters
        partitions[-1]['total_flops'] += layers['pooler'].flops
        partitions[-1]['memory_mb'] += layers['pooler'].memory_mb
        partitions[-1]['dependencies'].update(layers['pooler'].dependencies)

        return partitions

def print_partition_analysis(partitions):
    print("\n=== 8-Client Partition Analysis ===")
    total_params = sum(p['total_params'] for p in partitions)
    total_flops = sum(p['total_flops'] for p in partitions)
    total_memory = sum(p['memory_mb'] for p in partitions)

    print(f"\nTotal Model Statistics:")
    print(f"Total Parameters: {total_params:,}")
    print(f"Total Estimated FLOPs: {total_flops:,}")
    print(f"Total Memory Usage: {total_memory:.2f} MB")

    for partition in partitions:
        print(f"\nClient {partition['client_id']}:")
        print(f"Parameters: {partition['total_params']:,} ({partition['total_params']/total_params*100:.1f}%)")
        print(f"FLOPs: {partition['total_flops']:,} ({partition['total_flops']/total_flops*100:.1f}%)")
        print(f"Memory: {partition['memory_mb']:.2f} MB ({partition['memory_mb']/total_memory*100:.1f}%)")
        print("Layers:")
        for layer in partition['layers']:
            print(f"  - {layer.name}")
            if partition['dependencies']:
                print(f"    Dependencies: {partition['dependencies']}")

if __name__ == "__main__":
    partitioner = AlbertPartitioner(num_clients=8)
    partitions = partitioner.create_optimal_partitions()
    print_partition_analysis(partitions)


=== 8-Client Partition Analysis ===

Total Model Statistics:
Total Parameters: 61,211,904
Total Estimated FLOPs: 38,655,361,024
Total Memory Usage: 233.50 MB

Client 0:
Parameters: 3,906,048 (6.4%)
FLOPs: 65,536 (0.0%)
Memory: 14.90 MB (6.4%)
Layers:
  - embeddings

Client 1:
Parameters: 9,452,544 (15.4%)
FLOPs: 6,442,450,944 (16.7%)
Memory: 36.06 MB (15.4%)
Layers:
  - attention_0
    Dependencies: {'ffn_0', 'attention_1', 'embeddings', 'attention_0'}
  - ffn_0
    Dependencies: {'ffn_0', 'attention_1', 'embeddings', 'attention_0'}
  - attention_1
    Dependencies: {'ffn_0', 'attention_1', 'embeddings', 'attention_0'}
  - ffn_1
    Dependencies: {'ffn_0', 'attention_1', 'embeddings', 'attention_0'}

Client 2:
Parameters: 9,452,544 (15.4%)
FLOPs: 6,442,450,944 (16.7%)
Memory: 36.06 MB (15.4%)
Layers:
  - attention_2
    Dependencies: {'ffn_2', 'attention_2', 'ffn_1', 'attention_3'}
  - ffn_2
    Dependencies: {'ffn_2', 'attention_2', 'ffn_1', 'attention_3'}
  - attention_3
    Depende

In [3]:
pip install datasets

Collecting datasets
  Downloading datasets-3.0.2-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.0.2-py3-none-any.whl (472 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m472.7/472.7 kB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py310-none-any.whl (134 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m14.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading

In [7]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import (
    AlbertModel,
    AlbertTokenizer,
    AdamW
)
from datasets import load_dataset
from dataclasses import dataclass

@dataclass
class Client0Config:
    batch_size: int = 32
    max_length: int = 128
    learning_rate: float = 2e-5

class MRPCDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {
            key: val[idx].clone().detach()
            for key, val in self.encodings.items()
        }
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

class Client0Trainer:
    def __init__(self, config: Client0Config):
        self.tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
        self.model = AlbertModel.from_pretrained('albert-base-v2')
        self.embeddings = self.model.embeddings
        self.config = config

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.embeddings.to(self.device)

    def prepare_data(self):
        # Load MRPC dataset
        dataset = load_dataset('glue', 'mrpc')
        train_texts = list(zip(dataset['train']['sentence1'], dataset['train']['sentence2']))
        train_labels = dataset['train']['label']

        # Tokenize data
        train_encodings = self.tokenizer(
            train_texts,
            truncation=True,
            padding='max_length',
            max_length=self.config.max_length,
            return_tensors='pt',
            return_token_type_ids=True
        )

        # Create custom dataset
        self.train_dataset = MRPCDataset(train_encodings, train_labels)

        # Create dataloader
        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True
        )

    def train_step(self, batch):
        # Forward pass through embeddings only
        input_ids = batch['input_ids'].to(self.device)
        token_type_ids = batch['token_type_ids'].to(self.device)

        # Get embeddings output - ALBERT embeddings take input_ids and token_type_ids
        outputs = self.embeddings(
            input_ids=input_ids,
            token_type_ids=token_type_ids
        )

        return outputs

    def train_epoch(self):
        self.embeddings.train()

        for batch_idx, batch in enumerate(self.train_loader):
            # Get embeddings output
            OutputClient1 = self.train_step(batch)

            # Save outputs periodically
            if batch_idx % 100 == 0:
                print(f"Processed batch {batch_idx}")
                # Save structure of the output for debugging
                print(f"Output shape: {OutputClient1.shape}")
                self.save_outputs(OutputClient1, batch_idx)

    def save_outputs(self, outputs, batch_idx):
        # Save outputs with batch index
        output_path = f'client0_outputs_batch_{batch_idx}.pt'
        torch.save(outputs, output_path)
        print(f"Saved outputs to {output_path}")

if __name__ == "__main__":
    config = Client0Config()
    trainer = Client0Trainer(config)
    trainer.prepare_data()
    trainer.train_epoch()

Be aware, overflowing tokens are not returned for the setting you have chosen, i.e. sequence pairs with the 'longest_first' truncation strategy. So the returned list will always be empty even if some tokens have been removed.


Processed batch 0
Output shape: torch.Size([32, 128, 128])
Saved outputs to client0_outputs_batch_0.pt
Processed batch 100
Output shape: torch.Size([32, 128, 128])
Saved outputs to client0_outputs_batch_100.pt
