In [1]:
from datasets.utils.logging import disable_progress_bar
from torch.utils.data import DataLoader, Subset, random_split
import torchvision
import torch
import torch.nn as nn
import torch.optim as optim

DATA_PATH='/tmp/data/cifar10'
DUMP_FILE_NAME = '/tmp/data/fed-data-NonIDD.pkl'

BATCH_SIZE = 64
NUM_CLIENTS = 10
NUM_CLASS = 10


transform = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                (0.4914, 0.4822, 0.4465), 
                (0.2023, 0.1994, 0.2010)),
            ])

cifar10_train = torchvision.datasets.CIFAR10(
    root=DATA_PATH,
    train=True,
    transform=transform,
    download=True
)

cifar10_test = torchvision.datasets.CIFAR10(
    root=DATA_PATH,
    train=False,
    transform=transform,
    download=True
)

train_loader = DataLoader(cifar10_train, batch_size=64, shuffle=True)
test_loader = DataLoader(cifar10_test, batch_size=64, shuffle=False)

  from .autonotebook import tqdm as notebook_tqdm


Files already downloaded and verified
Files already downloaded and verified


In [3]:
# Define VGG block
def vgg_block(num_convs, in_channels, out_channels):
    layers = []
    for _ in range(num_convs):
        layers += [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)]
        in_channels = out_channels
    layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
    return nn.Sequential(*layers)

# Define the VGG11bn model
class VGG11bn(nn.Module):
    def __init__(self, num_classes=10):
        super(VGG11bn, self).__init__()
        self.features = nn.Sequential(
            vgg_block(1, 3, 64),
            vgg_block(1, 64, 128),
            vgg_block(2, 128, 256),
            vgg_block(2, 256, 512),
            vgg_block(2, 512, 512),
        )
        self.classifier = nn.Sequential(
            nn.Linear(512, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [4]:
# Training settings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VGG11bn(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)


Epoch 1, Loss: 1.8966399873309123
Epoch 2, Loss: 1.4988476328380274
Epoch 3, Loss: 1.1953544190625096
Epoch 4, Loss: 0.9570493989283472
Epoch 5, Loss: 0.7900260783674772
Epoch 6, Loss: 0.6771278785698859
Epoch 7, Loss: 0.5832006204920961
Epoch 8, Loss: 0.49418837670475013
Epoch 9, Loss: 0.41230202577722347
Epoch 10, Loss: 0.3505897023584074


In [6]:

# Training loop
num_epochs = 10  # Adjust as needed
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader)}")

Epoch 1, Loss: 0.2936335784880935
Epoch 2, Loss: 0.24922560699893842
Epoch 3, Loss: 0.21101587172121267
Epoch 4, Loss: 0.17824730002666678
Epoch 5, Loss: 0.14595967465463808
Epoch 6, Loss: 0.13911952501486824
Epoch 7, Loss: 0.1590074803692091
Epoch 8, Loss: 0.10324907992237611
Epoch 9, Loss: 0.09589396336215937
Epoch 10, Loss: 0.08508980711576675


In [5]:
# Testing for accuracy
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %')

Accuracy of the network on the 10000 test images: 80.88 %
