## Setup

## Import

In [1]:
from data_utils import * 
from model_utils import *
from utils import *
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from itertools import product
import os
import json
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
import tensorflow as tf
from tensorflow.keras.utils import to_categorical

  from .autonotebook import tqdm as notebook_tqdm


## FedAvg

In [2]:
dataset = 'cifar10'
num_classes = 10 if dataset == 'cifar10' else 100 
pub_num_classes = 100 if num_classes == 10 else 10
datadir = '../data'
partition = 'iid' 
n_parties = 5
beta = 0.5

(X_train, y_train, X_test, y_test, net_dataidx_map) = partition_data('cifar10', datadir=datadir, partition = partition, n_parties = n_parties, beta = beta)
(X_train_public, y_train_public, X_test_public, y_test_public, net_dataidx_map_public) = partition_data('cifar100', datadir=datadir, partition = 'iid', n_parties = 10, beta = 0.5)
# divide y_public by 10 to make it compatible with cifar10
y_train_public = y_train_public // 10
y_test_public = y_test_public // 10

y_train_cat = to_categorical(y_train, num_classes=num_classes)
y_test_cat = to_categorical(y_test, num_classes=num_classes)

# use num_classes instead of pub_num_classes to make it compatible with cifar10
y_train_public_cat = to_categorical(y_train_public, num_classes=num_classes)
y_test_public_cat = to_categorical(y_test_public, num_classes=num_classes)

print(X_train.shape, y_train.shape, X_test.shape, y_test.shape, y_train_cat.shape, y_test_cat.shape)
print(X_train_public.shape, y_train_public.shape, X_test_public.shape, y_test_public.shape, y_train_public_cat.shape, y_test_public_cat.shape)


local_sets = [] 
test_sets = []
public_set = (X_train_public[net_dataidx_map_public[0]], y_train_public_cat[net_dataidx_map_public[0]])
for i in range(n_parties):
    local_sets.append((X_train[net_dataidx_map[i]], y_train_cat[net_dataidx_map[i]]))
    test_sets.append((X_test, y_test_cat))
    
print(public_set[0].shape, public_set[1].shape)
for i in range(n_parties):
    print('client ', i, ' ', local_sets[i][0].shape, local_sets[i][1].shape)
    print(test_sets[i][0].shape, test_sets[i][1].shape)
    print() 


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
(50000, 32, 32, 3) (50000,) (10000, 32, 32, 3) (10000,) (50000, 10) (10000, 10)
(50000, 32, 32, 3) (50000,) (10000, 32, 32, 3) (10000,) (50000, 10) (10000, 10)
(5000, 32, 32, 3) (5000, 10)
client  0   (10000, 32, 32, 3) (10000, 10)
(10000, 32, 32, 3) (10000, 10)

client  1   (10000, 32, 32, 3) (10000, 10)
(10000, 32, 32, 3) (10000, 10)

client  2   (10000, 32, 32, 3) (10000, 10)
(10000, 32, 32, 3) (10000, 10)

client  3   (10000, 32, 32, 3) (10000, 10)
(10000, 32, 32, 3) (10000, 10)

client  4   (10000, 32, 32, 3) (10000, 10)
(10000, 32, 32, 3) (10000, 10)



In [3]:
aggregation_method = 'weights'
aug = False
weighting = 'uniform'
private =   False
hyperparameter_tuning = False
C = 1
initial_pub_alignment_epochs = 4

fl_params = {
    'client_num': 5, #len(local_sets),
    'tot_T': 30, 
    'C': C,
    'local_sets': local_sets,
    'test_sets': test_sets,
    'public_set': public_set,
    'batch_size': 32,
    'epochs': 6, 
    'lr': 0.003,
    'aggregate': aggregation_method, # 'grads', 'compressed_soft_labels', 'soft_labels'
    'hyperparameter_tuning': hyperparameter_tuning, 
    'weighting': weighting, # 'uniform', 'performance_based'
    'default_client_id': 1, 
    'augment': aug, 
    'private': private,
    'max_grad_norm': 1.0,
    'delta': 1e-4,
    'epsilon': 5,
    'local_benchmark_epochs': 180, 
    'initial_pub_alignment_epochs': initial_pub_alignment_epochs, 
    'temperature': 0.6, 
}

N_pub = len(fl_params['public_set'][0])
exp_path = f"../fl_results/{dataset}/DP{private}/N_pub{N_pub}/Agg{fl_params['aggregate']}_C{fl_params['C']}_HT{fl_params['hyperparameter_tuning']}_Aug{fl_params['augment']}_W{fl_params['weighting']}"
fl_params['exp_path'] = exp_path
server = FLServer(fl_params)

# FL_acc = []
# for t in range(fl_params['tot_T']):
#     avg_acc, min_acc, max_acc, avg_train_acc = server.global_update(verbose = True)
#     FL_acc.append(avg_acc)
#     print(f"Round {t} accuracy: avg: {avg_acc} min: {min_acc}  max: {max_acc} avg_train: {avg_train_acc}")
# print("Final accuracy: ", FL_acc[-1])
# print() 

client_accs = []
for c, client in enumerate(server.clients):
    acc = client.local_benchmark(verbose = True)
    client_accs.append(acc)
    print(f"Client {c} local benchmark accuracy: {acc}")
print("Client accuracies: ", client_accs, "  ", np.mean(client_accs))
print()

server.save_assets()


Epoch 0 : test acc = 0.11%, test loss = 2.2961
Epoch 1 : test acc = 0.10%, test loss = 2.2888
Epoch 2 : test acc = 0.13%, test loss = 2.2760
Epoch 3 : test acc = 0.18%, test loss = 2.2506
Epoch 4 : test acc = 0.22%, test loss = 2.1986
Epoch 5 : test acc = 0.23%, test loss = 2.1256
Epoch 6 : test acc = 0.25%, test loss = 2.0704
Epoch 7 : test acc = 0.27%, test loss = 2.0300
Epoch 8 : test acc = 0.28%, test loss = 1.9937
Epoch 9 : test acc = 0.29%, test loss = 1.9672
Epoch 10 : test acc = 0.29%, test loss = 1.9603
Epoch 11 : test acc = 0.29%, test loss = 1.9356
Epoch 12 : test acc = 0.32%, test loss = 1.9055
Epoch 13 : test acc = 0.31%, test loss = 1.8881
Epoch 14 : test acc = 0.33%, test loss = 1.8596
Epoch 15 : test acc = 0.33%, test loss = 1.8422
Epoch 16 : test acc = 0.35%, test loss = 1.8212
Epoch 17 : test acc = 0.34%, test loss = 1.8010
Epoch 18 : test acc = 0.35%, test loss = 1.7835
Epoch 19 : test acc = 0.35%, test loss = 1.7637
Epoch 20 : test acc = 0.36%, test loss = 1.7480
Ep

KeyboardInterrupt: 

## FedMD

In [None]:
dataset = 'cifar10'
num_classes = 10 if dataset == 'cifar10' else 100 
pub_num_classes = 100 if num_classes == 10 else 10
datadir = '../data'
partition = 'iid' 
n_parties = 10
beta = 0.5

(X_train, y_train, X_test, y_test, net_dataidx_map) = partition_data('cifar10', datadir=datadir, partition = partition, n_parties = n_parties, beta = beta)
(X_train_public, y_train_public, X_test_public, y_test_public, net_dataidx_map_public) = partition_data('cifar100', datadir=datadir, partition = 'iid', n_parties = 10, beta = 0.5)
# divide y_public by 10 to make it compatible with cifar10
y_train_public = y_train_public // 10
y_test_public = y_test_public // 10

y_train_cat = to_categorical(y_train, num_classes=num_classes)
y_test_cat = to_categorical(y_test, num_classes=num_classes)

# use num_classes instead of pub_num_classes to make it compatible with cifar10
y_train_public_cat = to_categorical(y_train_public, num_classes=num_classes)
y_test_public_cat = to_categorical(y_test_public, num_classes=num_classes)

print(X_train.shape, y_train.shape, X_test.shape, y_test.shape, y_train_cat.shape, y_test_cat.shape)
print(X_train_public.shape, y_train_public.shape, X_test_public.shape, y_test_public.shape, y_train_public_cat.shape, y_test_public_cat.shape)


local_sets = [] 
test_sets = []
public_set = (X_train_public[net_dataidx_map_public[0]], y_train_public_cat[net_dataidx_map_public[0]])
for i in range(n_parties):
    local_sets.append((X_train[net_dataidx_map[i]], y_train_cat[net_dataidx_map[i]]))
    test_sets.append((X_test, y_test_cat))
    
print(public_set[0].shape, public_set[1].shape)
for i in range(n_parties):
    print('client ', i, ' ', local_sets[i][0].shape, local_sets[i][1].shape)
    print(test_sets[i][0].shape, test_sets[i][1].shape)
    print() 


In [None]:
aggregation_method = 'soft_labels'
aug = False
weighting = 'uniform'
private =   False
hyperparameter_tuning = False
C = 1
initial_pub_alignment_epochs = 4


fl_params = {
    'client_num': 5, #len(local_sets),
    'tot_T': 30, 
    'C': C,
    'local_sets': local_sets,
    'test_sets': test_sets,
    'public_set': public_set,
    'batch_size': 32,
    'epochs': 6, 
    'lr': 0.003,
    'aggregate': aggregation_method, # 'grads', 'compressed_soft_labels', 'soft_labels'
    'hyperparameter_tuning': hyperparameter_tuning, 
    'weighting': weighting, # 'uniform', 'performance_based'
    'default_client_id': 1, 
    'augment': aug, 
    'private': private,
    'max_grad_norm': 1.0,
    'delta': 1e-4,
    'epsilon': 5,
    'local_benchmark_epochs': 180, 
    'initial_pub_alignment_epochs': initial_pub_alignment_epochs, 
    'temperature': 0.6, 
}

N_pub = len(fl_params['public_set'][0])
exp_path = f"../fl_results/{dataset}/DP{private}/N_pub{N_pub}/Agg{fl_params['aggregate']}_C{fl_params['C']}_HT{fl_params['hyperparameter_tuning']}_Aug{fl_params['augment']}_W{fl_params['weighting']}"
fl_params['exp_path'] = exp_path
server = FLServer(fl_params)

FL_acc = []
for t in range(fl_params['tot_T']):
    avg_acc, min_acc, max_acc, avg_train_acc = server.global_update()
    FL_acc.append(avg_acc)
    print(f"Round {t} accuracy: avg: {avg_acc} min: {min_acc}  max: {max_acc} avg_train: {avg_train_acc}")
print("Final accuracy: ", FL_acc[-1])
print() 

# client_accs = []
# for c, client in enumerate(server.clients):
#     acc = client.local_benchmark()
#     client_accs.append(acc)
#     print(f"Client {c} local benchmark accuracy: {acc}")
# print("Client accuracies: ", client_accs, "  ", np.mean(client_accs))
# print()

server.save_assets()


## CFedAKD

In [None]:
dataset = 'cifar10'
num_classes = 10 if dataset == 'cifar10' else 100 
pub_num_classes = 100 if num_classes == 10 else 10
datadir = '../data'
partition = 'iid' 
n_parties = 10
beta = 0.5

(X_train, y_train, X_test, y_test, net_dataidx_map) = partition_data('cifar10', datadir=datadir, partition = partition, n_parties = n_parties, beta = beta)
(X_train_public, y_train_public, X_test_public, y_test_public, net_dataidx_map_public) = partition_data('cifar100', datadir=datadir, partition = 'iid', n_parties = 10, beta = 0.5)
# divide y_public by 10 to make it compatible with cifar10
y_train_public = y_train_public // 10
y_test_public = y_test_public // 10

y_train_cat = to_categorical(y_train, num_classes=num_classes)
y_test_cat = to_categorical(y_test, num_classes=num_classes)

# use num_classes instead of pub_num_classes to make it compatible with cifar10
y_train_public_cat = to_categorical(y_train_public, num_classes=num_classes)
y_test_public_cat = to_categorical(y_test_public, num_classes=num_classes)

print(X_train.shape, y_train.shape, X_test.shape, y_test.shape, y_train_cat.shape, y_test_cat.shape)
print(X_train_public.shape, y_train_public.shape, X_test_public.shape, y_test_public.shape, y_train_public_cat.shape, y_test_public_cat.shape)


local_sets = [] 
test_sets = []
public_set = (X_train_public[net_dataidx_map_public[0]], y_train_public_cat[net_dataidx_map_public[0]])
for i in range(n_parties):
    local_sets.append((X_train[net_dataidx_map[i]], y_train_cat[net_dataidx_map[i]]))
    test_sets.append((X_test, y_test_cat))
    
print(public_set[0].shape, public_set[1].shape)
for i in range(n_parties):
    print('client ', i, ' ', local_sets[i][0].shape, local_sets[i][1].shape)
    print(test_sets[i][0].shape, test_sets[i][1].shape)
    print() 


In [None]:
aggregation_method = 'compressed_soft_labels'
aug = True
weighting = 'uniform'
private =   False
hyperparameter_tuning = False
C = 1
initial_pub_alignment_epochs = 4

fl_params = {
    'client_num': 5, #len(local_sets),
    'tot_T': 30, 
    'C': C,
    'local_sets': local_sets,
    'test_sets': test_sets,
    'public_set': public_set,
    'batch_size': 32,
    'epochs': 6, 
    'lr': 0.003,
    'aggregate': aggregation_method, # 'grads', 'compressed_soft_labels', 'soft_labels'
    'hyperparameter_tuning': hyperparameter_tuning, 
    'weighting': weighting, # 'uniform', 'performance_based'
    'default_client_id': 1, 
    'augment': aug, 
    'private': private,
    'max_grad_norm': 1.0,
    'delta': 1e-4,
    'epsilon': 5,
    'local_benchmark_epochs': 180, 
    'initial_pub_alignment_epochs': initial_pub_alignment_epochs, 
    'temperature': 0.6, 
}

N_pub = len(fl_params['public_set'][0])
exp_path = f"../fl_results/{dataset}/DP{private}/N_pub{N_pub}/Agg{fl_params['aggregate']}_C{fl_params['C']}_HT{fl_params['hyperparameter_tuning']}_Aug{fl_params['augment']}_W{fl_params['weighting']}"
fl_params['exp_path'] = exp_path
server = FLServer(fl_params)

FL_acc = []
for t in range(fl_params['tot_T']):
    avg_acc, min_acc, max_acc, avg_train_acc = server.global_update()
    FL_acc.append(avg_acc)
    print(f"Round {t} accuracy: avg: {avg_acc} min: {min_acc}  max: {max_acc} avg_train: {avg_train_acc}")
print("Final accuracy: ", FL_acc[-1])
print() 

client_accs = []
for c, client in enumerate(server.clients):
    acc = client.local_benchmark()
    client_accs.append(acc)
    print(f"Client {c} local benchmark accuracy: {acc}")
print("Client accuracies: ", client_accs, "  ", np.mean(client_accs))
print()

server.save_assets()


## FedAKD

In [None]:
dataset = 'cifar10'
num_classes = 10 if dataset == 'cifar10' else 100 
pub_num_classes = 100 if num_classes == 10 else 10
datadir = '../data'
partition = 'iid' 
n_parties = 10
beta = 0.5

(X_train, y_train, X_test, y_test, net_dataidx_map) = partition_data('cifar10', datadir=datadir, partition = partition, n_parties = n_parties, beta = beta)
(X_train_public, y_train_public, X_test_public, y_test_public, net_dataidx_map_public) = partition_data('cifar100', datadir=datadir, partition = 'iid', n_parties = 10, beta = 0.5)
# divide y_public by 10 to make it compatible with cifar10
y_train_public = y_train_public // 10
y_test_public = y_test_public // 10

y_train_cat = to_categorical(y_train, num_classes=num_classes)
y_test_cat = to_categorical(y_test, num_classes=num_classes)

# use num_classes instead of pub_num_classes to make it compatible with cifar10
y_train_public_cat = to_categorical(y_train_public, num_classes=num_classes)
y_test_public_cat = to_categorical(y_test_public, num_classes=num_classes)

print(X_train.shape, y_train.shape, X_test.shape, y_test.shape, y_train_cat.shape, y_test_cat.shape)
print(X_train_public.shape, y_train_public.shape, X_test_public.shape, y_test_public.shape, y_train_public_cat.shape, y_test_public_cat.shape)


local_sets = [] 
test_sets = []
public_set = (X_train_public[net_dataidx_map_public[0]], y_train_public_cat[net_dataidx_map_public[0]])
for i in range(n_parties):
    local_sets.append((X_train[net_dataidx_map[i]], y_train_cat[net_dataidx_map[i]]))
    test_sets.append((X_test, y_test_cat))
    
print(public_set[0].shape, public_set[1].shape)
for i in range(n_parties):
    print('client ', i, ' ', local_sets[i][0].shape, local_sets[i][1].shape)
    print(test_sets[i][0].shape, test_sets[i][1].shape)
    print() 


In [None]:
aggregation_method = 'soft_labels'
aug = True
weighting = 'uniform'
private =   False
hyperparameter_tuning = False
C = 1
initial_pub_alignment_epochs = 4

fl_params = {
    'client_num': 5, #len(local_sets),
    'tot_T': 30, 
    'C': C,
    'local_sets': local_sets,
    'test_sets': test_sets,
    'public_set': public_set,
    'batch_size': 32,
    'epochs': 6, 
    'lr': 0.003,
    'aggregate': aggregation_method, # 'grads', 'compressed_soft_labels', 'soft_labels'
    'hyperparameter_tuning': hyperparameter_tuning, 
    'weighting': weighting, # 'uniform', 'performance_based'
    'default_client_id': 1, 
    'augment': aug, 
    'private': private,
    'max_grad_norm': 1.0,
    'delta': 1e-4,
    'epsilon': 5,
    'local_benchmark_epochs': 180, 
    'initial_pub_alignment_epochs': initial_pub_alignment_epochs, 
    'temperature': 0.6, 
}

N_pub = len(fl_params['public_set'][0])
exp_path = f"../fl_results/{dataset}/DP{private}/N_pub{N_pub}/Agg{fl_params['aggregate']}_C{fl_params['C']}_HT{fl_params['hyperparameter_tuning']}_Aug{fl_params['augment']}_W{fl_params['weighting']}"
fl_params['exp_path'] = exp_path
server = FLServer(fl_params)

FL_acc = []
for t in range(fl_params['tot_T']):
    avg_acc, min_acc, max_acc, avg_train_acc = server.global_update()
    FL_acc.append(avg_acc)
    print(f"Round {t} accuracy: avg: {avg_acc} min: {min_acc}  max: {max_acc} avg_train: {avg_train_acc}")
print("Final accuracy: ", FL_acc[-1])
print() 

client_accs = []
for c, client in enumerate(server.clients):
    acc = client.local_benchmark()
    client_accs.append(acc)
    print(f"Client {c} local benchmark accuracy: {acc}")
print("Client accuracies: ", client_accs, "  ", np.mean(client_accs))
print()

server.save_assets()


# central training

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# Define the device to use for training (GPU if available, else CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define the transforms to apply to the data
transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.RandomCrop(32, padding=4),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# Load the CIFAR-10 dataset and apply the transforms
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                          shuffle=True, num_workers=2)

def softmax_with_temperature(logits, temperature=1.0):
    """Applies softmax with temperature scaling."""
    assert temperature > 0, "Temperature must be positive."
    scaled_logits = logits / temperature
    return F.softmax(scaled_logits, dim=-1)

# Define the network architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 10)
        # self.logsoftmax = nn.LogSoftmax(dim = -1)

    def forward(self, x):
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = self.pool(nn.functional.relu(self.conv3(x)))
        x = x.view(-1, 128 * 4 * 4)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        # log_probs = self.logsoftmax(x) 
        return x 

net = Net().to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
cirterion2 = nn.NLLLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)


# Train the network for 10 epochs
for epoch in range(10):
  running_loss = 0.0
  running_accuracy = 0.0
  for i, data in enumerate(trainloader, 0):
      inputs, labels = data[0].to(device), data[1].to(device)
      optimizer.zero_grad()

      logits = net(inputs)
      probs = softmax_with_temperature(logits, temperature = 1.0)
      soft_labels = softmax_with_temperature(logits, temperature = 0.7)

      # CCE = NLLLoss( log( softmax(x)))
      log_probs = torch.log(soft_labels) 
      loss = cirterion2(log_probs, labels)
      loss.backward()
      optimizer.step()
      running_accuracy += labels.eq(probs.argmax(dim = -1)).sum() / len(labels)
      running_loss += loss.item()
      
  print('[%d, %5d] loss: %.3f  acc: %.2f' %
        (epoch + 1, i + 1, running_loss / len(trainloader), running_accuracy / len(trainloader)))
  


print('Finished Training')

In [None]:
from torch.utils.data import DataLoader 
from torch.utils.data import TensorDataset
import torch.nn as nn
import torch.optim as optim 
import torch.nn.functional as F
import torchvision.transforms as transforms

nllloss =nn.NLLLoss()
def my_train(model, train_loader, optimizer, privacy_engine = None, DELTA = None, device = None):

    if device is None : 
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = model.to(device)
    
    accs = []
    losses = []
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()

        logits = model(inputs)

        probs = softmax_with_temperature(logits, temperature = 1.0) 
        soft_labels = softmax_with_temperature(logits, temperature = 0.7) 
        log_probs = torch.log(probs) 
        print("log probs:", log_probs.shape, labels.shape)
        loss = nllloss(log_probs, labels) 
        # loss = criterion(probs, labels)
        loss.backward()
        optimizer.step()

        preds = probs.argmax(-1)
        n_correct = float(labels.eq(probs.argmax(-1)).sum())
        batch_accuracy = n_correct / len(labels)

        accs.append(batch_accuracy)
        losses.append(float(loss))

    if privacy_engine is not None:
        # epsilon, best_alpha = optimizer.privacy_engine.get_privacy_spent()  # Yep, we put a pointer to privacy_engine into your optimizer :)
        epsilon = privacy_engine.get_epsilon(DELTA)
 
    return np.mean(accs), np.mean(losses)


def softmax_with_temperature(logits, temperature=1.0):
    """Applies softmax with temperature scaling."""
    assert temperature > 0, "Temperature must be positive."
    scaled_logits = logits / temperature
    return F.softmax(scaled_logits, dim=-1)
    
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 10)
        # self.softmax = nn.Softmax(dim = 1)

    def forward(self, x):
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = self.pool(nn.functional.relu(self.conv3(x)))
        x = x.view(-1, 128 * 4 * 4)
        x = nn.functional.relu(self.fc1(x))
        logits = self.fc2(x)
        # probs = self.softmax(x) 
        # soft_labels = F.softmax(x/0.7, dim=-1)
        return logits


model = Net() 

transform = transforms.Compose(
    [transforms.RandomHorizontalFlip(),
     transforms.RandomCrop(32, padding=4),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(testset, batch_size=128,
                                          shuffle=True, num_workers=2)


optimizer = optim.Adam(model.parameters(), lr=0.001)