# Clone the repo 
(commented out because it run locally)

In [15]:
# !rm -rf early_exit
# !git clone https://github.com/Ilias-Paralikas/early_exit.git

# Import lirbraries

In [16]:
from early_exit import EarlyExitNetwork,train_model,test_all_exits_accuracy
import torch.nn as nn
from copy import deepcopy

In [17]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

# Get your pretrained model

In [18]:
pretrained_model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_vgg19_bn", pretrained=True)


Using cache found in C:\Users\paral/.cache\torch\hub\chenyaofo_pytorch-cifar-models_master


# Examine model structure
you will need see your model layers and determine the splits

In [19]:
pretrained_model

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256

# Create the core network nn.Sequential
the way that you do this depends on your model

In [20]:
splits = [7,17]
device_side_net = deepcopy(pretrained_model.features[:splits[0]])
server_side_net = deepcopy(pretrained_model.features[splits[0]:splits[1]])
cloud_side_net = nn.Sequential(
    deepcopy(pretrained_model.features[splits[1]:]),
    nn.Flatten(),
    deepcopy(pretrained_model.classifier)
)

core_net = nn.Sequential(device_side_net,server_side_net,cloud_side_net)

# Define the exits
here the exits are just linear layers, it does not have to be that way, you can define the exits however you want, but the dimensions have to match

In [21]:
def get_exit_layers(neurons_in_layer,number_of_classes=10):
    layers = [nn.Flatten()]
    for i in range(len(neurons_in_layer)-1):
        layers.append(nn.Linear(neurons_in_layer[i],neurons_in_layer[i+1]))
        
    layers.append(nn.Linear(neurons_in_layer[-1],number_of_classes))
    return nn.Sequential(*layers)


In [22]:
import torch
# this is the input shape. change it for each use case
x = torch.randn(1,3,32,32)



exit_0 =device_side_net(x)
exit_0_shape = device_side_net(x).flatten().shape
print(exit_0_shape)

exit_1 = server_side_net(exit_0)
exit_1_shape = server_side_net(exit_0).flatten().shape
print(exit_1_shape)

torch.Size([16384])
torch.Size([16384])


# Create the exits

In [23]:
neurons_in_exit_0 = [exit_0_shape[0],512,512]
neurons_in_exit_1 = [exit_1_shape[0],512,512]

exits =nn.ModuleList([get_exit_layers(neurons_in_exit_0),get_exit_layers(neurons_in_exit_1)])

# Define dataset

In [24]:
import torchvision
from torchvision import  transforms
import torch
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 32


trainset = torchvision.datasets.CIFAR10(root='data', train=True,
                                       download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


# Get the early exit network
the class EarlyExitNetwork, pieces together the core network and the exits

In [26]:
eenet=  EarlyExitNetwork(core_net,exits).to(device)
criterion = nn.BCEWithLogitsLoss()
lr = 1e-4

# Usage
The EarlyExitNetwork class allows the user to either get the results of all the exits, or a specific exti chosen.

In [12]:
x = torch.randn(1,3,32,32).to(device)
print('Output the results for all the extis:\t\t',eenet(x).shape)
print('Output the results for the exit specified:\t',eenet(x,exit_chosen=0).shape)


Output the results for all the extis:		 torch.Size([1, 3, 10])
Output the results for the exit specified:	 torch.Size([1, 10])


# Train the whole network

the test function tests all the exits at the same time (not how the deployed early network works, where only one exit is taken).
We can see the correct samples of the first, the second and the third exit, while on the right the number of the total samples.
You can infer the acc of each exit by dividing each number.

In [13]:
optimizer = torch.optim.Adam(eenet.parameters(),lr=lr)
print("Before :",test_all_exits_accuracy(eenet,testloader))
for e in range(2):
    train_model(eenet,trainloader,criterion,optimizer) 
    print(f'Trained Exit {e} :{test_all_exits_accuracy(eenet,testloader)}')   

Before : ([965, 1123, 3975], 10000)


100%|██████████| 1563/1563 [00:24<00:00, 63.66it/s, loss=0.594]


Trained Exit 0 :([6420, 8234, 8648], 10000)


100%|██████████| 1563/1563 [00:23<00:00, 65.21it/s, loss=0.366]


Trained Exit 1 :([6697, 8556, 9042], 10000)


# Train only one exit
when tou want to train only one exit, you first need to specify in the optimizer only the weights of that exit.
also you need to pass the exit chosen parameter on the train function, just like we did before

In [14]:
exit_taken=0
optimizer = torch.optim.Adam(eenet.exits[exit_taken].parameters(),lr=lr)
print("Before :",test_all_exits_accuracy(eenet,testloader))
train_model(eenet,trainloader,criterion,optimizer,exit_chosen=exit_taken)    
print(f'Trained Exit {e} :{test_all_exits_accuracy(eenet,testloader)}')  

Before : ([6697, 8556, 9042], 10000)


100%|██████████| 1563/1563 [00:08<00:00, 193.32it/s, loss=0.218] 


Trained Exit 1 :([6810, 8556, 9042], 10000)


Note that the result on the first exit changed, the other stayed the same

In [15]:
exit_taken=1
optimizer = torch.optim.Adam(eenet.exits[exit_taken].parameters(),lr=lr)

print("Before :",test_all_exits_accuracy(eenet,testloader))
train_model(eenet,trainloader,criterion,optimizer,exit_chosen=exit_taken)    
print(f'Trained Exit {e} :{test_all_exits_accuracy(eenet,testloader)}')  

Before : ([6810, 8556, 9042], 10000)


100%|██████████| 1563/1563 [00:10<00:00, 144.16it/s, loss=0.0746] 


Trained Exit 1 :([6810, 8780, 9042], 10000)


In [16]:
exit_to_train=2
optimizer = torch.optim.Adam(eenet.exits[exit_to_train].parameters(), lr=lr)
print("Before :",test_all_exits_accuracy(eenet,testloader))
for e in range(1):
    train_model(eenet,trainloader,criterion,optimizer,exit_chosen=exit_to_train)    
    print(f'Trained Exit {e} :{test_all_exits_accuracy(eenet,testloader)}')   

Before : ([6810, 8780, 9042], 10000)


100%|██████████| 1563/1563 [00:18<00:00, 83.66it/s, loss=0.0882]  


Trained Exit 0 :([6810, 8780, 9023], 10000)


In [12]:
from early_exit import seperate_networks, SegmentedEarlyExitNetwork,test_networks

In [27]:
networks = seperate_networks(eenet,thresholds=[0.9,0.9]).to(device)  

In [29]:
networks[0]

SegmentedEarlyExitNetwork(
  (network): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (exit): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=16384, out_features=512, bias=True)
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): Linear(in_features=512, out_features=10, bias=True)
  )
  (confidence_function): Softmax(dim=0)
)

In [19]:
test_networks(networks,testset)

(0.8897,
 {0: {'correct': 3483, 'total': 3802},
  1: {'correct': 4360, 'total': 4678},
  2: {'correct': 1054, 'total': 1520}})

In [30]:
import os
model_folder = 'edge_heavy'
os.makedirs(model_folder, exist_ok=True)
for i, net in enumerate(networks):
    net = net.to('cpu')
    model_file = os.path.join(model_folder, f"model_{i}.pth")
    torch.save(net, model_file)