In [1]:
!pip install --upgrade wandb

Requirement already up-to-date: wandb in c:\users\lloyd\anaconda3\lib\site-packages (0.12.11)


In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import wandb
import os
from typing import Any, Dict, List
import copy
import random
import wandb

In [None]:
class Leader:
    def __init__(self,
                 leader_id: Any,
                 model: torch.nn.Module,
                 loss: torch.nn.modules.loss._Loss,
                 optimizer: torch.optim.Optimizer,
                 optimizer_conf: Dict,
                 batch_size: int,
                 epochs: int,
                 server=None) -> None:
        self.leader_id = leader_id
        self.model = model
        self.loss = loss
        self.optimizer = optimizer(self.model.parameters(), **optimizer_conf)
        self.batch_size = batch_size
        self.epochs = epochs
        self.server = server
        self.accuracy = None
        self.total_loss = None

        self.data = None
        self.data_loader = None
        
    def setData(self, data):
        self.data = data
        self.data_loader = torch.utils.data.DataLoader(self.data,
                                                       batch_size=self.batch_size,
                                                       shuffle=True)
        self.server.total_data += len(self.data)

    def update_weights(self):
        for eps in range(self.epochs):
            total_loss = 0
            total_batches = 0
            total_correct = 0

            for _, (feature, label) in enumerate(self.data_loader):
                feature = feature.to(device)
                label = label.to(device)
                
                y_pred = self.model(feature)
                y_pred_decode = torch.argmax(y_pred, dim=1)
                
                total_correct += y_pred_decode.eq(label).sum().item()
                loss = self.loss(y_pred, label)
                
                self.optimizer.zero_grad(set_to_none=True)
                loss.backward()
                self.optimizer.step()
                
                total_loss += loss.item()
                total_batches += 1
                

            self.total_loss = total_loss / total_batches
            self.accuracy = total_correct / (total_batches * self.batch_size)
