<a href="https://colab.research.google.com/github/mohammad-rahbari/federated-learning_visual-classification/blob/main/notebooks/Centralized_model_visual_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Importing DINO and installing its dependencies

In [1]:
# @title Clon the DINO ripo
!git clone https://github.com/facebookresearch/dino.git

Cloning into 'dino'...
remote: Enumerating objects: 175, done.[K
remote: Total 175 (delta 0), reused 0 (delta 0), pack-reused 175 (from 1)[K
Receiving objects: 100% (175/175), 24.47 MiB | 44.03 MiB/s, done.
Resolving deltas: 100% (100/100), done.


In [2]:
# @title Installing required dependencies regarding DINO
%cd dino
!pip install -r requirements.txt
!pip install timm

/content/dino
[31mERROR: Could not open requirements file: [Errno 2] No such file or directory: 'requirements.txt'[0m[31m



# preprocessing the CIFAR-100 dataset

feature size in CIFAR is 32x32 but DINO requires 224x224 in the input layer.

In first step we upscale the dataset and then we add randomization to it

In last step of transformation we normalize data usind mean value and standard division of ImageNet



In [3]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import random_split,DataLoader

In [4]:
# transform_train = transforms.Compose([
#     transforms.Resize(256),
#     transforms.RandomCrop(224),
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=(0.485, 0.456, 0.406),
#                          std=(0.229, 0.224, 0.225))
# ])

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225))
])





In [5]:
# train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True,
#                                         download=True, transform=transform)


# train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
# test_loader = DataLoader(test_daataset, batch_size=64, shuffle=False, num_workers=2)

full_train = torchvision.datasets.CIFAR100(
    root='./data', train=True, download=True, transform=transform)

test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False,
                                       download=True, transform=transform)

#Decide the split sizes
train_frac = 0.8   # e.g. 80% train
val_frac   = 0.2   #     20% validation


n_total = len(full_train)                # 50 000
n_train = int(train_frac * n_total)      # 40 000
n_val   = n_total - n_train              # 10 000


train_dataset, val_dataset = random_split(full_train, [n_train, n_val])
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True,  num_workers=2)
val_loader   = DataLoader(val_dataset,   batch_size=64, shuffle=False, num_workers=2)
test_loader  = DataLoader(test_dataset,  batch_size=64, shuffle=False, num_workers=2)



100%|██████████| 169M/169M [00:01<00:00, 91.8MB/s]


# Loading and preparing the pretrained DINO model *(DINO-DeiT_Small)*

In [6]:
# @title loadig the model
model_name = "dino_vits16" #@param["dino_resnet50", "dino_vits16", "dino_xcit_small_12_p16"]
import torch.hub

dino_model = torch.hub.load('facebookresearch/dino:main', model_name)


Downloading: "https://github.com/facebookresearch/dino/zipball/main" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth" to /root/.cache/torch/hub/checkpoints/dino_deitsmall16_pretrain.pth


100%|██████████| 82.7M/82.7M [00:00<00:00, 256MB/s]


In [7]:
# @title Model Configuration

import torch
import torch.nn as nn

class DinoClassifire(nn.Module):
  def __init__(self, dino_model, num_classes:int=100, device=None):
    super(DinoClassifire, self).__init__()
    self.backbone = dino_model

    #We need to freaze thhe parameters of bakbone first so we can train only on the head layer(output layer)
    for param in self.backbone.parameters():
      param.requiers_grad = False

    #determine the Device
    if device is None:
      device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    self.backbone.to(device)

    #To detect the output feature dimontion of backbone we run  Dummy forward pass
    with torch.no_grad():
      dummy_input = torch.randn(1,3,224,224).to(device)
      dummy_out = self.backbone(dummy_input)


      #If the output is 3D (B, T, D), we assume first token is the [CLS] token.
      if dummy_out.dim() == 3:
        dummy_feature = dummy_out[:,0]
      else:
        dummy_feature = dummy_out
      feature_dim = dummy_feature.shape[1]
      print("Detected feature dimontion:", feature_dim)


      #Difineing the classification Head
      self.head = nn.Linear(feature_dim, num_classes)


      #Ensure the head is trainable.

      for param in self.head.parameters():
        param.requires_grad = True

  def forward(self,x):

    #pass the input through the backbone
    features = self.backbone(x)

    if isinstance(features, tuple):
      features = features[0]
    elif isinstance(features, dict):
      features = features.get("x_norm_clstoken", next(iter(features.values())))


    # If featers are retuened as (B, T, D), use the first token
    if features.dim() == 3:
      features = features[:,0]

    logits = self.head(features)

    return logits




  @torch.no_grad()
  def SGDM(self, buffer,grad_mask ,weight_decay=0.0,lr=1e-3,momentum=0.9,damping=0.0, nesterov=False, max_=False):
    for param in self.head.parameters():
      if param.grad is None:
        continue
      grad = param.grad

      if weight_decay != 0:
        grad = grad.add(param, alpha=weight_decay)

      pid = id(param)
      buf = buffer.get(pid)
      if buf is None:
        buf = torch.zeros_like(param)
        buffer[pid] = buf


      if momentum != 0 :
        buf.mul_(momentum).add_(grad, alpha=(1 - damping))
        update = grad.add(buf, alpha= momentum) if nesterov else buf

      else:

        update = grad

      gm = grad_mask
      mask = gm.get(pid)

      if mask is None:
        mask = torch.ones_like(update)
      else:
        mask = mask.to(device=update.device, dtype=update.dtype)
        if mask.numel() == 1:
          mask = mask.expand_as(update)
        elif mask.shape != update.shape:

          mask = mask.expand_as(update)

      update = update * mask

      param.add_(update, alpha=(lr if max_ else -lr))

      buffer[pid] = buf
    return buffer


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DinoClassifire(dino_model=dino_model, num_classes=100, device=device)
model.to(device)

Detected feature dimontion: 384


DinoClassifire(
  (backbone): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0-11): 12 x Block(
        (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=384, out_features=1152, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=384, out_features=384, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=384, out_features=1536, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=1536, out_features=384, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (norm): L

In [8]:
# @title Caluculate Fisher Information Matrix

import math

def claculate_sparsity(S=0.9,n=5, a= 0.4):
  St = []
  Ct = [ S * (  (1-math.exp(-a*t)) / (1-math.exp(-a*n))  ) for t in  range(1,n+1) ]
  for t in  range(n):
    ct0 = 0 if t==0 else Ct[t-1]
    St.append( 1- (1-Ct[t])/(1-ct0)  )
  return St




def calculate_fisher_mask(model,train_set, sparsity=0.9, n=5):
  sparsity_lst = claculate_sparsity(S=sparsity,n=n)
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  criterion = nn.CrossEntropyLoss()
  head_params = [p for p in model.head.parameters()]
  param_ids = [id(p) for p in head_params]


  fisher_scores = {id_p:torch.zeros_like(p, device= device) for id_p,p in  zip(param_ids, head_params)}
  last_mask = {id_p:torch.ones_like(p, device= device) for id_p,p in  zip(param_ids, head_params)}



  model.eval()

  for i in range(n):

    for v in fisher_scores.values():
      v.zero_()

    for images, labels in train_set:
      images, labels = images.to(device), labels.to(device)
      outputs = model(images)
      loss = criterion(outputs, labels)

      grads = torch.autograd.grad(
          loss,
          head_params,
          create_graph=False,
          retain_graph=False
      )

      for p, g in zip(head_params,grads):
        pid = id(p)
        fisher_scores[pid] += g.detach().pow(2) * last_mask[pid]

    all_scores = torch.cat([
          fisher_scores[id(p)].mul(last_mask[id(p)]).reshape(-1)
        for p in  head_params
        ])

    non_zero  = all_scores[all_scores != 0]

    if non_zero.numel() == 0:
      new_mask = {id(p):torch.zeros_like(p,device=device) for p in head_params }
      last_mask = new_mask
      continue

    total_nz = non_zero.numel()
    keep = int( (1 - sparsity_lst[i])* total_nz)
    keep = min(keep, total_nz)

    if keep == 0:
      threshold = non_zero.max() + 1
    elif keep == total_nz:
      threshold = non_zero.min() - 1

    else:
      kth_smallest = total_nz - keep + 1
      threshold, _ = torch.kthvalue(non_zero, k= kth_smallest)

    new_mask = {}

    for p in head_params:
      pid = id(p)
      masked_scores = fisher_scores[pid] * last_mask[pid]
      current_mask = (masked_scores >= threshold).float() * last_mask[pid]
      new_mask[pid]  = current_mask
      last_mask[pid] = current_mask

  return new_mask







# Config the loss, optimizer and training loop

In [9]:
is_it_validation = False #@param {type:"boolean"}
num_epochs = 10 #@param {type:"integer"}


In [10]:
# @title SGDM train
criterion = nn.CrossEntropyLoss()

for p in model.backbone.parameters():
  p.requires_grad = False
model.backbone.eval()

for p in model.head.parameters():
    p.requires_grad = True

model.head.train()

buffer = {}
loss_hist = []

for epoch_num in range(num_epochs):
  running_loss = 0.0
  print(f"Epoch: {epoch_num}" )
  num_batches  = 0


  grad_mask = calculate_fisher_mask(model=model,train_set=train_loader  )


  for images, labels in train_loader:
    images = images.to(device)
    labels =  labels.to(device).long()

    outputs = model(images)
    loss = criterion(outputs,labels)

    model.zero_grad(set_to_none=True)
    loss.backward()
    buffer = model.SGDM(buffer=buffer,grad_mask= grad_mask, momentum=0.9)

    running_loss += float(loss.item())
    num_batches += 1


    avg_loss = running_loss / max(1, num_batches)
  loss_hist.append(avg_loss)
  print(f"Epoch: {epoch_num} | Loss: {avg_loss:.4f}")

Epoch: 0 Batch number: 0
Epoch: 0 | Loss: 3.8544
Epoch: 1 Batch number: 0


KeyboardInterrupt: 

In [None]:
# @title Defult training setting

import torch.optim as optim



criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.head.parameters(),lr=1e-3)
optimizer = optim.SGD(model.head.parameters(),lr=1e-3, momentum=0.9 )


losses = []
message = ""
for epoch in range(num_epochs):
  model.train()
  running_loss = 0.0

  for index, (images, lables) in enumerate(train_loader):
    images = images.to(device).requires_grad_(True)
    lables =  lables.to(device)

    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs,lables)



    loss.backward()
    optimizer.step()
    running_loss += loss.item() +images.size(0)
  losses.append(loss.item())
  epoch_loss = running_loss / len(train_loader.dataset)
  lrimprovement = 0
  if len(losses) > 2:
    lrimprovement = (losses[-2]- losses[-1])/losses[-2] * 100
    wimprovement = (losses[0]- losses[-1])/losses[0] * 100

  i1  = f"last round improvement:{lrimprovement:.3f}%"if len(losses)>=2 else ""
  i2  = f"Improvement from begining:{wimprovement:.3f}%"if len(losses)>2 else ""
  message = f"Epoch: {epoch}, Loss {epoch_loss:.4f}, \n{i1}\n{i2}"
  print(message)

# Evaluation

In [None]:

def evaluation(model, data_loader):
  criterion = nn.CrossEntropyLoss()
  model.eval()
  correct = 0
  total = 0
  test_loss = 0

  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  with torch.no_grad():
    for images, labels in  data_loader:
      images, labels = images.to(device), labels.to(device)
      outputs = model(images)

      _, prediction = torch.max(outputs.data,1)
      loss = criterion(outputs, labels)
      test_loss += loss.item() * labels.size(0)

      total += labels.size(0)
      correct += (prediction == labels).sum().item()
    accuracy = 100 * correct / total
    loss = test_loss / total
    return accuracy, loss

acc, loss_v = evaluation(model, test_loader)
print(f"Evaluation:\n accurace:{acc:.4f}, loss:{loss_v:.4f}")


Evaluation:
 accurace:57.0900, loss:1.6316


In [None]:
from google.colab import drive
drive.mount('/content/drive')

torch.save(model.state_dict(), '/content/drive/MyDrive/MLDL_FederatedLearning/models/centralized/model.pth')

Mounted at /content/drive
