In [1]:
import autoreload

%load_ext autoreload
%autoreload 2

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
import time
import os
from collections import OrderedDict

In [29]:
class Node():
    """Node(Choco_Gossip): x_i(t+1) = x_i(t) + gamma*Sum(w_ij*[xhat_j(t+1) - xhat_i(t+1)])"""
    
    def __init__(self, gamma, loader, model, criterion):
        
        self.neighbors = []
        self.neighbor_wts = {}
        
        self.step_size = gamma
                
        self.dataloader = loader
        
        self.model = model
        
        self.x_i = OrderedDict()
        
        self.model_params = []
        for (k,v) in self.model.state_dict().items():
            
            self.model_params.append(k)
            self.x_i[k] = v.clone().detach()
            
        #for a in self.model.parameters():
        #    self.x_i.append(a)
        
        self.criterion = criterion
        
        self.dataiter = iter(self.dataloader)
        
    
    def compute_gradient(self, quantizer=None, ):
        """Computes nabla(x_i, samples) and returns estimate after quantization"""
        
        # Sample batch from loader #
        
        inputs, targets = self.dataiter.next()
        
        outputs = self.model(inputs)
                
        loss = self.criterion(outputs, targets)
        
        #Equivalent to optimizer.zero_grad()
        for v in self.model.parameters():
            v.grad = None
        
        loss.backward()
        
        gt = OrderedDict()
        
        for k,v in enumerate(self.model.parameters()):
            if quantizer is not None:
                gt[k] = quantizer(v.grad.clone().detach())
            else:
                gt[k] = quantizer(v.grad.clone().detach())
        #optimizer.step()
    
        self.curr_gt = gt
        
        return
    
    def assign_params(self, W):
        """Assign dict W to model"""
        
        with torch.no_grad():
            self.model.load_state_dict(W, strict=False)
        
        return
    
    def update_model(self):
        
        ### Implement Algorithm ###
        
        ## Assign Parameters after obtaining Consensus##
        
        
        self.assign_params(self.x_i)
        
        return        

In [30]:
class Network():
    """Define graph"""
    
    def __init__(self, W, models, learning_rates, loaders, criterion):
        
        self.adjacency = W
        self.num_nodes = W.shape[0]
        
        self.nodes = OrderedDict()
        
        for i in range(self.num_nodes):
            self.nodes[str(i)] = Node(learning_rates[i], loaders[i],models[i], criterion)
            for j in range(self.num_nodes):
                if(j != i and W[i, j] > 0):
                    self.nodes[str(i)].neighbors.append(j)
                    self.nodes[str(i)].neighbor_wts[str(j)] = W[i, j]
                    
            
    def simulate(self, iterations, epochs):
        
        for i in range(epochs):
            for j in range(iterations):
                lr = 1e-3
                for k in range(self.num_nodes):
                    self.nodes[str(k)].compute_gradient()
                
                
                for l in range(self.num_nodes):
                    for m,param in enumerate(self.nodes[str(l)].model.parameters()):
                        gt_update = 0
                        wt_sum = 0
                        for n in self.nodes[str(l)].neighbors:
                            gt_update+= self.nodes[str(l)].neighbor_wts[str(n)] *self.nodes[str(n)].curr_gt[m]
                            wt_sum += self.nodes[str(l)].neighbor_wts[str(n)]
                        gt_update /= wt_sum
                        param.data -= lr*gt_update
                    #self.nodes[str(l)].update_model()    
                        
                        

In [None]:
models = [torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True) for i in range(3)]
criterion = nn.CrossEntropyLoss()
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloaders = [torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2) for i in range(3)]
net = Network(torch.ones([3,3]),models, [1e-3,1e-3,1e-3],trainloaders,nn.CrossEntropyLoss())

Using cache found in /home/harshv834/.cache/torch/hub/pytorch_vision_v0.6.0
Using cache found in /home/harshv834/.cache/torch/hub/pytorch_vision_v0.6.0


In [None]:
net.simulate(1,1)

In [19]:
for i in net.nodes['0'].model.parameters():
    print(i.grad.shape)
    print(i.data.shape)
    print("yes")

torch.Size([64, 3, 7, 7])
torch.Size([64, 3, 7, 7])
yes
torch.Size([64])
torch.Size([64])
yes
torch.Size([64])
torch.Size([64])
yes
torch.Size([64, 64, 3, 3])
torch.Size([64, 64, 3, 3])
yes
torch.Size([64])
torch.Size([64])
yes
torch.Size([64])
torch.Size([64])
yes
torch.Size([64, 64, 3, 3])
torch.Size([64, 64, 3, 3])
yes
torch.Size([64])
torch.Size([64])
yes
torch.Size([64])
torch.Size([64])
yes
torch.Size([64, 64, 3, 3])
torch.Size([64, 64, 3, 3])
yes
torch.Size([64])
torch.Size([64])
yes
torch.Size([64])
torch.Size([64])
yes
torch.Size([64, 64, 3, 3])
torch.Size([64, 64, 3, 3])
yes
torch.Size([64])
torch.Size([64])
yes
torch.Size([64])
torch.Size([64])
yes
torch.Size([128, 64, 3, 3])
torch.Size([128, 64, 3, 3])
yes
torch.Size([128])
torch.Size([128])
yes
torch.Size([128])
torch.Size([128])
yes
torch.Size([128, 128, 3, 3])
torch.Size([128, 128, 3, 3])
yes
torch.Size([128])
torch.Size([128])
yes
torch.Size([128])
torch.Size([128])
yes
torch.Size([128, 64, 1, 1])
torch.Size([128, 64, 1