In [35]:
import autoreload

%load_ext autoreload
%autoreload 2

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
import time
import os
from collections import OrderedDict
from google.colab import drive
drive.mount('/content/gdrive/')
torch.autograd.set_detect_anomaly(True)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Drive already mounted at /content/gdrive/; to attempt to forcibly remount, call drive.mount("/content/gdrive/", force_remount=True).


<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f1203a6ea90>

In [0]:
class Node():
    """Node(Choco_Gossip): x_i(t+1) = x_i(t) + gamma*Sum(w_ij*[xhat_j(t+1) - xhat_i(t+1)])"""
    
    def __init__(self, gamma, loader, model, criterion):
        
        self.neighbors = []
        self.neighbor_wts = {}
        
        self.step_size = gamma
                
        self.dataloader = loader
        
        self.model = model
        
        self.x_i = OrderedDict()
        
        self.model_params = []
        for (k,v) in self.model.state_dict().items():
            
            self.model_params.append(k)
            self.x_i[k] = v.clone().detach()
            
        #for a in self.model.parameters():
        #    self.x_i.append(a)
        
        self.criterion = criterion
        
        self.dataiter = iter(self.dataloader)
        
    
    def compute_gradient(self, quantizer=None, ):
        """Computes nabla(x_i, samples) and returns estimate after quantization"""
        
        # Sample batch from loader #
        optimizer  = optim.SGD(self.model.parameters(), lr=1e-3)
        #for v in self.model.parameters():
        #  if v.grad is not None:
        #    v.detach_()
        #    v.zero_()

        optimizer.zero_grad()    
        inputs, targets = self.dataiter.next()
        
        outputs = self.model(inputs)


        loss = self.criterion(outputs, targets)
        
        #Equivalent to optimizer.zero_grad()
        
        
        loss.backward()
        
        gt = OrderedDict()
        
        
        for k,v in enumerate(self.model.parameters()):
            if v.grad is not None:
                if quantizer is not None:
                    gt[k] = quantizer(v.grad.clone().detach_())
                else:
                    gt[k] = v.grad.clone().detach()
        #optimizer.step()
    
        self.curr_gt = gt
        
        return
    
    def assign_params(self, W):
        """Assign dict W to model"""
        
        with torch.no_grad():
            self.model.load_state_dict(W, strict=False)
        
        return
    
    def update_model(self):
        
        ### Implement Algorithm ###
        
        ## Assign Parameters after obtaining Consensus##
        
        
        self.assign_params(self.x_i)
        
        return        

In [0]:
class Network():
    """Define graph"""
    
    def __init__(self, W, models, learning_rates, loaders, criterion):
        
        self.adjacency = W
        self.num_nodes = W.shape[0]
        
        self.nodes = OrderedDict()
        
        for i in range(self.num_nodes):
            self.nodes[str(i)] = Node(learning_rates[i], loaders[i],models[i], criterion)
            for j in range(self.num_nodes):
                if(j != i and W[i, j] > 0):
                    self.nodes[str(i)].neighbors.append(j)
                    self.nodes[str(i)].neighbor_wts[str(j)] = W[i, j]
                    
            
    def simulate(self, iterations, epochs):
        
        for i in range(epochs):
            for j in range(iterations):
                lr = 1e-3
                print(j)
                for k in range(self.num_nodes):
                    self.nodes[str(k)].compute_gradient()
                
                
                for l in range(self.num_nodes):
                    for m,param in enumerate(self.nodes[str(l)].model.parameters()):
                        if param.grad is None:
                            continue
                      
                        gt_update = self.nodes[str(l)].curr_gt[m]
                        wt_sum = 1
                        for n in self.nodes[str(l)].neighbors:
                            gt_update= gt_update + self.nodes[str(l)].neighbor_wts[str(n)] *self.nodes[str(n)].curr_gt[m]
                            wt_sum = wt_sum + self.nodes[str(l)].neighbor_wts[str(n)]
                        gt_update = gt_update/wt_sum
                        param.data = param.data - lr*gt_update
                    #self.nodes[str(l)].update_model()    
                        
                        

In [54]:
models = [torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True) for i in range(3)]
criterion = nn.CrossEntropyLoss()
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloaders = [torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2) for i in range(3)]
net = Network(torch.ones([3,3]),models, [1e-3,1e-3,1e-3],trainloaders,nn.CrossEntropyLoss())

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.6.0
Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.6.0
Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.6.0


Files already downloaded and verified


In [0]:
net = Network(torch.ones([3,3]),models, [1e-3,1e-3,1e-3],trainloaders,nn.CrossEntropyLoss())

In [56]:
net.simulate(10,1)

0
1
2
3
4
5
6
7
8
9


In [95]:
net.nodes['0'].model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  