In [1]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"]="1,2,3"

In [2]:
from datasets import load_dataset

dataset = load_dataset("food101", split="train[:5000]")

In [3]:
from collections import Counter
Counter(dataset['label'])

Counter({6: 750, 79: 750, 81: 750, 53: 750, 10: 750, 20: 750, 77: 500})

In [4]:
splits = dataset.train_test_split(test_size=0.1)
train_ds = splits["train"]
val_ds = splits["test"]

In [5]:
from collections import Counter
sub1 = train_ds.filter(lambda example: example['label'] <= 34)
sub2 = train_ds.filter(lambda example: 35 <= example['label'] <= 78)
sub3 = train_ds.filter(lambda example: example['label'] >= 79)
Counter(sub1['label']), Counter(sub2['label']), Counter(sub3['label'])

Filter:   0%|          | 0/4500 [00:00<?, ? examples/s]

Filter:   0%|          | 0/4500 [00:00<?, ? examples/s]

KeyboardInterrupt: 

In [None]:
len(sub1), len(sub2), len(sub3)

(2020, 1125, 1355)

In [None]:
import transformers
import accelerate
import peft

In [None]:
model_checkpoint = "google/vit-base-patch16-224-in21k"  # pre-trained model from which to fine-tune

In [None]:
from transformers import AutoImageProcessor
access_token = "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"

image_processor = AutoImageProcessor.from_pretrained(model_checkpoint, token=access_token)
image_processor

ViTImageProcessor {
  "_valid_processor_keys": [
    "images",
    "do_resize",
    "size",
    "resample",
    "do_rescale",
    "rescale_factor",
    "do_normalize",
    "image_mean",
    "image_std",
    "return_tensors",
    "data_format",
    "input_data_format"
  ],
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViTImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}

In [None]:
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
train_transforms = Compose(
    [
        RandomResizedCrop(image_processor.size["height"]),
        RandomHorizontalFlip(),
        ToTensor(),
        normalize,
    ]
)

val_transforms = Compose(
    [
        Resize(image_processor.size["height"]),
        CenterCrop(image_processor.size["height"]),
        ToTensor(),
        normalize,
    ]
)


def preprocess_train(example_batch):
    """Apply train_transforms across a batch."""
    example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch


def preprocess_val(example_batch):
    """Apply val_transforms across a batch."""
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

In [None]:
sub1.set_transform(preprocess_train)
sub2.set_transform(preprocess_train)
sub3.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)

In [None]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
    )

In [None]:
labels = dataset.features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label

id2label[2]

'baklava'

In [None]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(
    model_checkpoint,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True,  # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
)
print_trainable_parameters(model)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 85876325 || all params: 85876325 || trainable%: 100.00


In [None]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=64,
    lora_alpha=16,
    target_modules=["query", "value"],
    lora_dropout=0.1,
    bias="none",
    modules_to_save=["classifier"],
    use_dovera=True
)
lora_model = get_peft_model(model, config)
print_trainable_parameters(lora_model)

CAI QUAN QUE GI VAY TROI
trainable params: 116069 || all params: 88351690 || trainable%: 0.13


In [None]:
from transformers import TrainingArguments, Trainer


model_name = model_checkpoint.split("/")[-1]
batch_size = 128

args = TrainingArguments(
    f"{model_name}-finetuned-lora-food101",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-3,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    fp16=True,
    num_train_epochs=5,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=True,
    label_names=["labels"],
)

In [None]:
import numpy as np
import evaluate

metric = evaluate.load("accuracy")


# the compute_metrics function takes a Named Tuple as input:
# predictions, which are the logits of the model as Numpy arrays,
# and label_ids, which are the ground-truth labels as Numpy arrays.
def compute_metrics(eval_pred):
    """Computes accuracy on a batch of predictions"""
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

In [None]:
import torch


def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [None]:
import copy

server_model = copy.deepcopy(lora_model)

In [None]:
from serverbase import *
from userbase import *
from serverDoVeRA import *
from userDoVeRA import *

In [None]:
import random
from tqdm import tqdm

sub_train_ds = [sub1, sub2, sub3]

torch.manual_seed(42)
torch.cuda.manual_seed(42)

server = ServerMLP(model=server_model, test_loader=val_ds)

user_list = []

# Create users
for i in range(3):
    user_i = UserMLP(train_loader=sub_train_ds[i], model=server_model, user_id=i, local_epochs=5)
    user_list.append(user_i)

for t in tqdm(range(20), desc=f"Progress"):
    # Distribute initial model to users
    distributed_layers = None
    if t%5==0:
        distributed_layers=['lora_d', 'lora_b', 'lora_magnitude_vector']
    else:
        distributed_layers=['lora_d', 'lora_b']
    server.distribute_model(user_list, distributed_layers=distributed_layers)
    
    # Sub-sample users
    sub_user_list = random.sample(user_list, int(1 * 3))

    # Check the sub-sampled user and train model
    users_loss = 0.0
    for user in sub_user_list:
        trainer = Trainer(
                    user.model,
                    args,
                    train_dataset=user.train_loader,
                    eval_dataset=val_ds,
                    tokenizer=image_processor,
                    compute_metrics=compute_metrics,
                    data_collator=collate_fn,
                )
        train_results = trainer.train()
        # user_loss = user.user_train()
        users_loss += train_results.training_loss
    # Aggregate weights on server
    server.aggregate_weights(sub_user_list, layer_require_updated=None)

    # Calulate avg loss on selected users
    train_loss =  users_loss / len(sub_user_list)    
    # val_loss = server.model_eval()
    # val_results = server.

    # wandb.log({"train_loss": train_loss, "val_loss": val_loss})

    # print(f'Test accuracy of server: {server.compute_accuracy()}%')
    trainer = Trainer(
                server.model,
                args,
                train_dataset=server.test_loader,
                eval_dataset=val_ds,
                tokenizer=image_processor,
                compute_metrics=compute_metrics,
                data_collator=collate_fn,
            )
    val_results = trainer.evaluate(val_ds)
    print('Test accuracy of server:', val_results['eval_loss'], val_results['eval_accuracy'])

Progress:   0%|          | 0/20 [00:00<?, ?it/s]Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkenlvq[0m ([33mquanla[0m). Use [1m`wandb login --relogin`[0m to force relogin




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,4.322387,0.446
2,No log,4.108954,0.45
3,No log,3.951221,0.448
4,No log,3.847821,0.448
5,No log,3.796746,0.45




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,4.405945,0.274
2,No log,4.154013,0.286
3,No log,4.051695,0.286




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,4.440296,0.24
2,No log,4.233181,0.238
3,No log,4.149625,0.238




Progress:   5%|▌         | 1/20 [01:58<37:30, 118.45s/it]

Test accuracy of server: 4.228321552276611 0.562




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,3.985749,0.286
2,No log,3.788388,0.286
3,No log,3.725399,0.286




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,3.847843,0.448
2,No log,3.646564,0.45
3,No log,3.503483,0.45
4,No log,3.413286,0.45
5,No log,3.370059,0.45




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,4.289484,0.238
2,No log,4.093505,0.24
3,No log,4.019078,0.24




Progress:  10%|█         | 2/20 [03:51<34:36, 115.37s/it]

Test accuracy of server: 3.877647638320923 0.832




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,3.837726,0.286
2,No log,3.691926,0.286
3,No log,3.657856,0.284




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,3.416033,0.45
2,No log,3.251698,0.45
3,No log,3.146103,0.45
4,No log,3.085384,0.45
5,No log,3.058176,0.45




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,3.975045,0.24
2,No log,3.859359,0.242
3,No log,3.835681,0.242




Progress:  15%|█▌        | 3/20 [05:59<34:19, 121.15s/it]

Test accuracy of server: 3.600407600402832 0.77




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,3.724376,0.286
2,No log,3.645191,0.284
3,No log,3.642065,0.284




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,3.216419,0.45
2,No log,3.087229,0.45
3,No log,3.012102,0.45
4,No log,2.972937,0.45
5,No log,2.956697,0.452




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,3.828246,0.242
2,No log,3.869496,0.244
3,No log,3.915938,0.244




Progress:  20%|██        | 4/20 [07:51<31:17, 117.37s/it]

Test accuracy of server: 3.147721529006958 0.796




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,3.657004,0.284
2,No log,3.650265,0.284
3,No log,3.675957,0.284




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,2.907468,0.452
2,No log,2.898367,0.452
3,No log,2.911034,0.452
4,No log,2.927593,0.452
5,No log,2.938037,0.452




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,3.930786,0.244
2,No log,4.102675,0.244
3,No log,4.193608,0.244




Progress:  25%|██▌       | 5/20 [09:48<29:22, 117.48s/it]

Test accuracy of server: 2.9490928649902344 0.8




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,3.643429,0.284
2,No log,3.702088,0.284
3,No log,3.750992,0.284




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,2.899513,0.452
2,No log,2.925422,0.452
3,No log,2.9613,0.452
4,No log,2.991807,0.454
5,No log,3.008717,0.454




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,4.042904,0.244
2,No log,4.254306,0.244
3,No log,4.357758,0.244




Progress:  30%|███       | 6/20 [11:52<27:55, 119.70s/it]

Test accuracy of server: 2.6074185371398926 0.824




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,3.675577,0.284
2,No log,3.789173,0.284
3,No log,3.856541,0.284




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,3.074028,0.454
2,No log,3.16445,0.454
3,No log,3.240704,0.454
4,No log,3.293996,0.454
5,No log,3.320901,0.454




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,4.184128,0.244
2,No log,4.422729,0.244
3,No log,4.534659,0.244




Progress:  35%|███▌      | 7/20 [14:02<26:36, 122.83s/it]

Test accuracy of server: 2.4242331981658936 0.838




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,3.749372,0.284
2,No log,3.905318,0.284
3,No log,3.985589,0.286




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,3.186675,0.454
2,No log,3.28988,0.454
3,No log,3.373604,0.454
4,No log,3.430799,0.454
5,No log,3.459415,0.454




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,4.346608,0.244
2,No log,4.603447,0.244
3,No log,4.720017,0.244




Progress:  40%|████      | 8/20 [15:47<23:25, 117.10s/it]

Test accuracy of server: 2.1412336826324463 0.928




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,4.117825,0.286
2,No log,4.335837,0.286
3,No log,4.431668,0.286




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,3.318336,0.454
2,No log,3.43078,0.454
3,No log,3.51969,0.454
4,No log,3.579273,0.454
5,No log,3.608798,0.454




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,4.530387,0.244
2,No log,4.795323,0.244
3,No log,4.912351,0.244




Progress:  45%|████▌     | 9/20 [17:44<21:30, 117.31s/it]

Test accuracy of server: 1.9849661588668823 0.932




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,4.219655,0.286
2,No log,4.438811,0.286
3,No log,4.533627,0.286




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,3.41177,0.454
2,No log,3.528339,0.454
3,No log,3.618924,0.454
4,No log,3.679017,0.454
5,No log,3.708714,0.454




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,4.699993,0.244
2,No log,4.954964,0.244
3,No log,5.068688,0.244




Progress:  50%|█████     | 10/20 [19:34<19:08, 114.82s/it]

Test accuracy of server: 1.84334135055542 0.932




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,4.378231,0.286
2,No log,4.60167,0.286
3,No log,4.694962,0.286




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,3.554253,0.454
2,No log,3.672608,0.454
3,No log,3.763432,0.454
4,No log,3.822829,0.454
5,No log,3.852233,0.454




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,4.886064,0.244
2,No log,5.136817,0.244
3,No log,5.245406,0.244




Progress:  55%|█████▌    | 11/20 [21:22<16:56, 112.91s/it]

Test accuracy of server: 1.7128007411956787 0.94




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,4.542483,0.286
2,No log,4.766006,0.286
3,No log,4.856469,0.286




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,3.702384,0.454
2,No log,3.821187,0.454
3,No log,3.910113,0.454
4,No log,3.967793,0.454
5,No log,3.996372,0.454




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,5.072602,0.244
2,No log,5.303948,0.244
3,No log,5.404283,0.244




Progress:  60%|██████    | 12/20 [23:44<16:13, 121.65s/it]

Test accuracy of server: 1.5941435098648071 0.94




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,4.70779,0.286
2,No log,4.928408,0.286
3,No log,5.015733,0.286




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,3.84978,0.454
2,No log,3.965217,0.454
3,No log,4.049588,0.454
4,No log,4.10419,0.454
5,No log,4.131332,0.454




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,5.249622,0.244
2,No log,5.459412,0.244
3,No log,5.552148,0.244




Progress:  65%|██████▌   | 13/20 [26:36<15:59, 137.00s/it]

Test accuracy of server: 1.4877160787582397 0.94




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,4.872155,0.286
2,No log,5.089622,0.286
3,No log,5.17363,0.286




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,3.99755,0.454
2,No log,4.107982,0.454
3,No log,4.189948,0.454
4,No log,4.24199,0.454
5,No log,4.267822,0.454




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,5.417564,0.244
2,No log,5.602376,0.244
3,No log,5.685908,0.244




Progress:  70%|███████   | 14/20 [28:48<13:33, 135.58s/it]

Test accuracy of server: 1.3934117555618286 0.942




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,4.97088,0.286
2,No log,5.180496,0.286
3,No log,5.261568,0.286




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,4.09777,0.454
2,No log,4.201382,0.454
3,No log,4.279629,0.454
4,No log,4.329475,0.454
5,No log,4.354292,0.454




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,5.496678,0.244
2,No log,5.666804,0.244
3,No log,5.75307,0.244




Progress:  75%|███████▌  | 15/20 [31:09<11:25, 137.10s/it]

Test accuracy of server: 1.314041256904602 0.936




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,5.135483,0.286
2,No log,5.33903,0.286
3,No log,5.41124,0.286




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,4.232525,0.454
2,No log,4.332161,0.454
3,No log,4.405799,0.454
4,No log,4.451321,0.454
5,No log,4.474176,0.454




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,5.666445,0.244
2,No log,5.813049,0.244
3,No log,5.87753,0.244




Progress:  80%|████████  | 16/20 [33:21<09:02, 135.53s/it]

Test accuracy of server: 1.2423850297927856 0.932




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,5.271558,0.286
2,No log,5.471929,0.288
3,No log,5.546984,0.288




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,4.359395,0.454
2,No log,4.452202,0.454
3,No log,4.522038,0.454
4,No log,4.565124,0.456
5,No log,4.586819,0.456




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,5.776156,0.244
2,No log,5.9048,0.244
3,No log,5.977019,0.244




Progress:  85%|████████▌ | 17/20 [36:28<07:32, 150.95s/it]

Test accuracy of server: 1.1130298376083374 0.902




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,5.598305,0.288
2,No log,5.79208,0.288
3,No log,5.858023,0.288




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,4.662455,0.454
2,No log,4.741493,0.456
3,No log,4.801684,0.456
4,No log,4.8379,0.456
5,No log,4.856132,0.458




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,5.930638,0.244
2,No log,6.046054,0.244
3,No log,6.100858,0.244




Progress:  90%|█████████ | 18/20 [39:03<05:04, 152.17s/it]

Test accuracy of server: 1.031442642211914 0.91




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,5.742377,0.288
2,No log,5.932167,0.288
3,No log,5.992795,0.288




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,4.948723,0.456
2,No log,5.018692,0.456
3,No log,5.070617,0.456
4,No log,5.101059,0.456
5,No log,5.116775,0.458




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,6.032182,0.244
2,No log,6.137784,0.244
3,No log,6.200125,0.244




Progress:  95%|█████████▌| 19/20 [41:10<02:24, 144.81s/it]

Test accuracy of server: 0.9657821655273438 0.906




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,5.780768,0.288
2,No log,5.962944,0.288
3,No log,6.027776,0.288




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,5.161922,0.456
2,No log,5.223159,0.456
3,No log,5.268758,0.456
4,No log,5.295624,0.458
5,No log,5.309768,0.458




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,6.118042,0.244
2,No log,6.214193,0.244
3,No log,6.274756,0.244




Progress: 100%|██████████| 20/20 [43:44<00:00, 131.22s/it]

Test accuracy of server: 0.9161735773086548 0.902





In [None]:
for name, param in server.model.named_parameters():
    if 'lora_A' in name:
        print(param)

Parameter containing:
tensor([[ 0.0482,  0.0367, -0.0963,  ..., -0.0196, -0.0073,  0.0835],
        [ 0.0753,  0.0296,  0.3870,  ...,  0.0915, -0.1257,  0.2687],
        [ 0.1344, -0.0804,  0.0930,  ..., -0.0355,  0.0894, -0.1117],
        ...,
        [ 0.0441, -0.1476, -0.0203,  ...,  0.0372,  0.0316,  0.0318],
        [ 0.1445, -0.0555,  0.1027,  ...,  0.1428, -0.2191,  0.0384],
        [ 0.0161, -0.1385, -0.0650,  ..., -0.1662, -0.0134,  0.2207]],
       device='cuda:0')
Parameter containing:
tensor([[ 0.2360,  0.0753,  0.0278,  ..., -0.0995, -0.2957,  0.0006],
        [ 0.0500, -0.1484, -0.0281,  ..., -0.2330,  0.2297,  0.0018],
        [-0.1127,  0.1647,  0.0546,  ..., -0.0336,  0.0650,  0.0381],
        ...,
        [-0.1072, -0.0417,  0.1510,  ...,  0.0572,  0.2524,  0.0731],
        [ 0.0997,  0.0985, -0.0479,  ...,  0.0011, -0.0696, -0.0756],
        [ 0.0856,  0.0298,  0.0103,  ..., -0.0602,  0.0542,  0.1415]],
       device='cuda:0')
Parameter containing:
tensor([[-0.1002, 

In [None]:
for name, param in user_list[1].model.named_parameters():
    if 'lora_magnitude_vector' in name:
        print(param)

Parameter containing:
tensor([ 4.1328,  4.8465,  1.9314,  4.6300,  2.5196,  1.8456,  1.6916,  1.5190,
         1.9023,  3.7239,  3.8517,  4.5108,  4.1425,  3.8542,  4.2502,  4.4578,
         4.2153,  1.9018,  3.6703,  1.5419,  4.4364,  4.8452,  2.4419,  4.6928,
         4.6546,  1.8184,  2.6153,  4.6135,  3.8796,  4.5364,  3.0146,  4.5480,
         1.6393,  4.8610,  4.5819,  4.5055,  4.4623,  4.8105,  4.6904,  4.5075,
         2.7510,  4.8122,  3.3314,  1.9345,  4.3869,  2.8901,  4.4952,  4.5352,
         4.4980,  2.6335,  4.9921,  1.5320,  4.6076,  4.5066,  4.1976,  1.4844,
         4.6008,  2.5931,  4.5314,  2.6907,  2.2086,  4.0997,  4.3878,  4.6990,
         3.4612,  2.7654,  3.2220,  2.4769,  3.2182,  3.3780,  3.5620,  3.9553,
         3.3946,  3.3289,  2.2029,  3.4923,  2.1527,  2.3129,  3.0007,  2.6384,
         2.1320,  3.3802,  2.1779,  3.3383,  2.8287,  2.8484,  2.3815,  3.2646,
         2.3721,  3.3938,  3.3008,  3.0403,  2.4546,  3.4806,  3.3650,  2.5224,
         3.1667,  

Parameter containing:
tensor([1.3469, 1.4202, 1.3631, 1.3907, 1.7067, 1.5173, 1.8584, 1.6430, 1.5597,
        1.2731, 1.6389, 1.2397, 1.1889, 1.5640, 1.8289, 1.2930, 1.4520, 1.3819,
        1.3828, 1.3416, 1.5730, 1.4889, 1.4466, 1.4621, 1.5338, 1.5366, 1.4582,
        1.4571, 1.8639, 1.3100, 1.5222, 1.4744, 1.7201, 1.7365, 1.5724, 1.5046,
        1.9973, 1.8605, 1.3138, 1.8786, 1.4033, 1.5081, 1.4328, 1.6287, 1.8883,
        1.6416, 1.5492, 1.3418, 1.6273, 1.7898, 1.2332, 1.5412, 1.5340, 1.4721,
        1.5361, 1.7177, 1.5701, 1.3649, 1.3292, 1.5137, 1.5084, 1.5541, 1.4502,
        1.7148, 1.4496, 1.2554, 1.4589, 1.3348, 1.3227, 1.4200, 1.4675, 1.2957,
        1.4672, 1.3805, 1.5847, 1.7317, 1.2600, 1.5859, 1.4225, 1.2606, 1.2538,
        1.3560, 1.2916, 1.2162, 1.2998, 1.2795, 1.2444, 1.4016, 1.3997, 1.4847,
        1.3869, 1.3848, 1.3301, 1.5362, 1.3143, 1.4807, 1.3256, 1.3761, 1.7320,
        1.3143, 1.6668, 1.4772, 1.4922, 1.2774, 1.3062, 1.6310, 1.3575, 1.4061,
        1.5308, 1.

In [None]:
trainer = Trainer(
                lora_model,
                args,
                train_dataset=sub_train_ds[0],
                eval_dataset=val_ds,
                tokenizer=image_processor,
                compute_metrics=compute_metrics,
                data_collator=collate_fn,
            )
train_results = trainer.train()
val_results = trainer.evaluate(val_ds)
print('Test accuracy of server:', val_results['eval_loss'], val_results['eval_accuracy'])



Epoch,Training Loss,Validation Loss,Accuracy
1,No log,4.322387,0.446
2,No log,4.108954,0.45
3,No log,3.951221,0.448
4,No log,3.847821,0.448
5,No log,3.796746,0.45




Test accuracy of server: 4.108953952789307 0.45


In [None]:
for name, param in server.model.named_parameters():
    if 'lora_d' in name:
        param.data = torch.zeros_like(param)

for name, param in server.model.named_parameters():
    if 'lora_b' in name:
        print(param)

Parameter containing:
tensor([[ 0.0188],
        [-0.0416],
        [-0.0221],
        [ 0.0112],
        [ 0.0509],
        [ 0.0688],
        [ 0.0300],
        [ 0.0580],
        [-0.0425],
        [ 0.0019],
        [ 0.0391],
        [ 0.0482],
        [ 0.0342],
        [-0.0066],
        [ 0.0326],
        [-0.0406],
        [-0.0217],
        [-0.0125],
        [ 0.0647],
        [-0.0551],
        [ 0.0170],
        [-0.0632],
        [ 0.0355],
        [ 0.0102],
        [ 0.0284],
        [ 0.0105],
        [ 0.0887],
        [-0.0551],
        [-0.0370],
        [-0.0255],
        [-0.0438],
        [-0.0269],
        [ 0.0052],
        [-0.0426],
        [ 0.0729],
        [-0.0366],
        [ 0.0609],
        [ 0.0079],
        [-0.0251],
        [-0.0307],
        [-0.0618],
        [ 0.0460],
        [-0.0132],
        [ 0.0270],
        [ 0.0036],
        [-0.0444],
        [ 0.0184],
        [ 0.0163],
        [-0.0191],
        [-0.0582],
        [-0.0803],
        [

In [None]:
for param in user_list[0].model.parameters():
    # if 'lora_b' in name:
        print(param)

Parameter containing:
tensor([[[-1.5620e-03,  6.5624e-03, -2.4414e-01,  2.3681e-04,  3.7727e-01,
           4.5454e-02, -1.2689e-02,  1.1674e-02,  3.1769e-02, -2.0883e-01,
          -6.4701e-03, -7.8645e-03, -8.6769e-03, -7.4107e-03, -1.2021e-02,
          -4.5716e-03, -7.1202e-04,  6.3879e-02,  3.0223e-02, -4.8832e-03,
          -5.3233e-02, -4.3242e-03, -4.8654e-03, -1.9290e-02, -5.5730e-04,
           4.9587e-02,  2.4419e-03, -9.7914e-03,  3.8656e-02, -3.0388e-02,
          -5.2411e-04,  1.7129e-02, -1.5828e-02,  7.6080e-03,  2.1919e-02,
           1.7543e-03,  4.4304e-02,  3.4365e-03, -1.7939e-03,  7.7395e-03,
           5.8823e-03,  7.7494e-03,  1.8474e-02,  3.7538e-03,  5.7981e-02,
           4.0637e-01, -3.4466e-03,  1.0692e-02,  2.8214e-02, -1.4306e-03,
          -8.9247e-03, -4.5078e-02,  6.4855e-03,  5.8085e-04,  1.0369e-03,
          -5.9847e-03, -4.3196e-03, -6.6326e-03,  1.0140e-02, -6.7442e-02,
          -1.4103e-02,  7.0461e-04, -2.7727e-03, -1.2431e-03,  3.2974e-02,
   