<a href="https://colab.research.google.com/github/hassanSattariNia/FederatedLearning/blob/main/workingSplitModelLLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
pip install torch transformers



In [2]:
from transformers import AlbertModel, AlbertTokenizer
import torch

In [3]:
# load original Albert-v2
model_name = "albert-base-v2"
model = AlbertModel.from_pretrained(model_name)
tokenizer = AlbertTokenizer.from_pretrained(model_name)
print("Model Architecture:")
print(model)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/684 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/47.4M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/760k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.31M [00:00<?, ?B/s]

Model Architecture:
AlbertModel(
  (embeddings): AlbertEmbeddings(
    (word_embeddings): Embedding(30000, 128, padding_idx=0)
    (position_embeddings): Embedding(512, 128)
    (token_type_embeddings): Embedding(2, 128)
    (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0, inplace=False)
  )
  (encoder): AlbertTransformer(
    (embedding_hidden_mapping_in): Linear(in_features=128, out_features=768, bias=True)
    (albert_layer_groups): ModuleList(
      (0): AlbertLayerGroup(
        (albert_layers): ModuleList(
          (0): AlbertLayer(
            (full_layer_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (attention): AlbertAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (attention_dropout): Dropout(p=0, i



In [4]:
from transformers import AlbertModel
import torch.nn as nn

# return trainable parameter of model
def count_parameters(module):
    return sum(p.numel() for p in module.parameters() if p.requires_grad)

def split_model_comprehensive(model, num_clients=4):
    # Comprehensive list of all modules in the ALBERT-v2 model
    modules = [
        ('embeddings.word_embeddings', model.embeddings.word_embeddings),
        ('embeddings.position_embeddings', model.embeddings.position_embeddings),
        ('embeddings.token_type_embeddings', model.embeddings.token_type_embeddings),
        ('embeddings.LayerNorm', model.embeddings.LayerNorm),
        ('encoder.embedding_hidden_mapping_in', model.encoder.embedding_hidden_mapping_in),
        ('encoder.albert_layer_groups.0.albert_layers.0.full_layer_layer_norm',
         model.encoder.albert_layer_groups[0].albert_layers[0].full_layer_layer_norm),
        ('encoder.albert_layer_groups.0.albert_layers.0.attention.query',
         model.encoder.albert_layer_groups[0].albert_layers[0].attention.query),
        ('encoder.albert_layer_groups.0.albert_layers.0.attention.key',
         model.encoder.albert_layer_groups[0].albert_layers[0].attention.key),
        ('encoder.albert_layer_groups.0.albert_layers.0.attention.value',
         model.encoder.albert_layer_groups[0].albert_layers[0].attention.value),
        ('encoder.albert_layer_groups.0.albert_layers.0.attention.dense',
         model.encoder.albert_layer_groups[0].albert_layers[0].attention.dense),
        ('encoder.albert_layer_groups.0.albert_layers.0.attention.LayerNorm',
         model.encoder.albert_layer_groups[0].albert_layers[0].attention.LayerNorm),
        ('encoder.albert_layer_groups.0.albert_layers.0.attention.dropout',
         model.encoder.albert_layer_groups[0].albert_layers[0].attention.attention_dropout),
        ('encoder.albert_layer_groups.0.albert_layers.0.attention.output_dropout',
         model.encoder.albert_layer_groups[0].albert_layers[0].attention.output_dropout),
        ('encoder.albert_layer_groups.0.albert_layers.0.ffn',
         model.encoder.albert_layer_groups[0].albert_layers[0].ffn),
        ('encoder.albert_layer_groups.0.albert_layers.0.ffn_output',
         model.encoder.albert_layer_groups[0].albert_layers[0].ffn_output),
        ('encoder.albert_layer_groups.0.albert_layers.0.activation',
         model.encoder.albert_layer_groups[0].albert_layers[0].activation),
        ('encoder.albert_layer_groups.0.albert_layers.0.dropout',
         model.encoder.albert_layer_groups[0].albert_layers[0].dropout),
        ('pooler', model.pooler),
        ('pooler_activation', model.pooler_activation)
    ]

    # Calculate total parameters
    total_params = sum(count_parameters(module) for _, module in modules)
    print(f'total params of list modules is ${total_params}')
    target_params_per_client = total_params // num_clients
    print(f'expected params of one client ${target_params_per_client}')

    client_modules = [[] for _ in range(num_clients)]
    current_client = 0
    current_client_params = 0

    for name, module in modules:
        module_params = count_parameters(module)
        # Check if adding this module exceeds the target per client and we haven't reached the last client
        if current_client_params + module_params > target_params_per_client and current_client < num_clients - 1:
            current_client += 1
            current_client_params = 0

        # Assign the module to the current client
        client_modules[current_client].append((name, module))
        current_client_params += module_params

    return client_modules

# Load ALBERT model
model = AlbertModel.from_pretrained("albert-base-v2")

# Split the model between 4 clients
client_models = split_model_comprehensive(model, num_clients=4)




total params of list modules is $11683584
expected params of one client $2920896


In [8]:
len(client_models)

4

In [None]:
# Display splitting information
for i, parts in enumerate(client_models):
    print(f"Client {i+1}:")
    client_total_params = 0
    for name, module in parts:
        num_params = count_parameters(module)
        client_total_params += num_params
        print(f"  - {name}: {num_params:,} parameters")
    print(f"  Total client parameters: {client_total_params:,}")
    print()

# Calculate total parameters
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total model parameters: {total_params:,}")

In [4]:
print(client_models[1])

[('embeddings.word_embeddings', Embedding(30000, 128, padding_idx=0))]


In [5]:
from datasets import load_dataset
from transformers import AlbertTokenizer, AlbertForSequenceClassification
import torch

# Load the GLUE MRPC dataset
dataset = load_dataset("glue", "mrpc")

# Load the tokenizer and model
model_name = "albert-base-v2"
tokenizer = AlbertTokenizer.from_pretrained(model_name)
model = AlbertForSequenceClassification.from_pretrained(model_name, num_labels=2)

# Tokenize the dataset with padding and truncation
def preprocess_function(examples):
    return tokenizer(examples['sentence1'], examples['sentence2'], truncation=True, padding='max_length', max_length=128)

tokenized_datasets = dataset.map(preprocess_function, batched=True)

# Use the tokenized datasets for training and evaluation
train_dataset = tokenized_datasets['train']
eval_dataset = tokenized_datasets['validation']

Some weights of AlbertForSequenceClassification were not initialized from the model checkpoint at albert-base-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
import torch
from torch.utils.data import DataLoader
from transformers import AlbertTokenizer, AlbertForSequenceClassification
from datasets import load_dataset

# Load the GLUE MRPC dataset
dataset = load_dataset("glue", "mrpc")

# Load the tokenizer and model
model_name = "albert-base-v2"
tokenizer = AlbertTokenizer.from_pretrained(model_name)
model = AlbertForSequenceClassification.from_pretrained(model_name, num_labels=2)

# Tokenize the dataset with padding and truncation
def preprocess_function(examples):
    return tokenizer(examples['sentence1'], examples['sentence2'], truncation=True, padding='max_length', max_length=128)

tokenized_datasets = dataset.map(preprocess_function, batched=True)

# Remove unnecessary columns to simplify the dataset for training
tokenized_datasets = tokenized_datasets.remove_columns(['sentence1', 'sentence2', 'idx'])

# Create DataLoader with batch size of 16
train_dataset = tokenized_datasets['train']
eval_dataset = tokenized_datasets['validation']

# Custom collate function to ensure correct batching
def collate_fn(batch):
    # Convert lists of dicts to dict of lists, then to tensor
    input_ids = torch.tensor([item['input_ids'] for item in batch])
    attention_mask = torch.tensor([item['attention_mask'] for item in batch])
    token_type_ids = torch.tensor([item['token_type_ids'] for item in batch])
    labels = torch.tensor([item['label'] for item in batch])

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'token_type_ids': token_type_ids,
        'labels': labels
    }

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
eval_dataloader = DataLoader(eval_dataset, batch_size=16, collate_fn=collate_fn)

# Define a device (GPU if available)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

# Function to process data through client 1
def forward_client_1(batch):
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    token_type_ids = batch['token_type_ids'].to(device)

    # Get the embedding output
    embedding_output = model.albert.embeddings(
        input_ids=input_ids,
        token_type_ids=token_type_ids
    )

    return embedding_output

# Function to process data through client 2
def forward_client_2(embedding_output, attention_mask):
    # Reshape attention_mask to match the expected size: (batch_size, 1, 1, sequence_length)
    attention_mask = attention_mask[:, None, None, :].to(device)

    # Continue processing from the encoder layers of client 2
    encoder_output = model.albert.encoder(
        embedding_output,
        attention_mask=attention_mask
    )

    return encoder_output

# Loop over the dataloader and process batches
for batch in train_dataloader:
    # Batch is already a dictionary of tensors
    batch = {k: v.to(device) for k, v in batch.items()}

    # Client 1 processes the input
    output_client_1 = forward_client_1(batch)
    print("Output from Client 1 (to be used as input for Client 2):")
    print(output_client_1)

    # Client 2 processes the output from Client 1
    output_client_2 = forward_client_2(output_client_1, batch['attention_mask'])
    print("Output from Client 2:")
    print(output_client_2)


Some weights of AlbertForSequenceClassification were not initialized from the model checkpoint at albert-base-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
         [-1.3136,  0.6259,  0.9811,  ..., -2.2829, -1.1925, -1.0033],
         [-1.8989,  0.7085,  1.4909,  ..., -1.4112,  4.9418, -3.1911],
         ...,
         [ 4.2331,  1.7357, -4.6993,  ..., -1.6946,  0.4600,  2.9973],
         [ 4.8969,  2.0867, -5.0083,  ..., -1.6072,  0.6653,  3.0128],
         [ 5.7153,  2.8192, -5.3392,  ..., -1.6013,  0.5094,  2.8236]],

        [[-1.4039,  0.4480,  2.3831,  ...,  0.7478, -1.2142, -1.1283],
         [-1.8554,  0.8159,  1.8363,  ..., -2.6651, -1.9490,  0.4298],
         [-1.4058, -0.2801, -0.7050,  ...,  1.4865, -0.7262, -1.0955],
         ...,
         [ 4.2331,  1.7357, -4.6993,  ..., -1.6946,  0.4600,  2.9973],
         [ 4.8969,  2.0867, -5.0083,  ..., -1.6072,  0.6653,  3.0128],
         [ 5.7153,  2.8192, -5.3392,  ..., -1.6013,  0.5094,  2.8236]],

        ...,

        [[-1.4039,  0.4480,  2.3831,  ...,  0.7478, -1.2142, -1.1283],
         [ 0.5826,  0.4550, -0.1953, 