In [1]:
# from common.experiment import KubemlExperiment, History, TrainOptions, TrainRequest
import pandas as pd
import glob
import numpy as np
import pickle
import matplotlib.pyplot as plt
import os

import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torchvision import models
import torch.utils.data as tdata
from torch.nn.functional import nll_loss, cross_entropy
import redisai as rai
import redis

from torch import optim
from datetime import datetime

torch.manual_seed(42) 

<torch._C.Generator at 0x7f73ccb3f930>

In [2]:
torch.manual_seed(42) 

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
valset = datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(trainset, batch_size=256)
val_loader= torch.utils.data.DataLoader(valset, batch_size=256)

# model = models.resnet18(pretrained= False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(device)


Files already downloaded and verified
Files already downloaded and verified


In [3]:
def validate(model, device, val_loader: tdata.DataLoader) -> (float, float):
    """Loop used to validate the network"""
    model.to(device)

    criterion =nn.CrossEntropyLoss()
    model.eval()

    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            test_loss += cross_entropy(output, target).item()  # sum up batch loss
            correct += predicted.eq(target).sum().item()

    test_loss /= len(val_loader)

    accuracy = 100. * correct / len(val_loader.dataset)
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(val_loader.dataset),
        100. * correct / len(val_loader.dataset)))
    return accuracy, test_loss

In [4]:
model = models.resnet18(pretrained= False)
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)

checkpoint = torch.load(f"checkpoint_resnet18.pt")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])


In [5]:
validate(model, device, val_loader)
print("*"*50)


Test set: Average loss: 1.0934, Accuracy: 6516/10000 (65%)
**************************************************


In [6]:
con = rai.Client(host="localhost", port=6379)
for key in model.state_dict().keys():

    con.tensorset(f'{key}',model.state_dict()[key].cpu().numpy(), dtype='float')

loaded_model = models.resnet18(pretrained= False)
validate(loaded_model, device, val_loader)

Test set: Average loss: 6.8922, Accuracy: 0/10000 (0%)


(0.0, 6.8922225594520565)

In [7]:

for key in model.state_dict().keys():
        #print(name)
    layer_weight = con.tensorget(f'{key}')
    layer_weight_copied = np.copy(layer_weight)
    #print(type(classes))
    #counter += 1
    #print(classes.shape)
    
    loaded_model.state_dict()[key].copy_(torch.from_numpy(layer_weight_copied))
#print("redisAI counter:", counter)
#print(loaded_model.state_dict()['fc.bias'])
print("After loading" +"*"*100)
validate(loaded_model, device, val_loader)


After loading****************************************************************************************************
Test set: Average loss: 1.0934, Accuracy: 6516/10000 (65%)


(65.16, 1.093371932208538)

In [8]:
def compare_models(model_1, model_2):
    models_differ = 0
    for key_item_1, key_item_2 in zip(model_1.state_dict().items(), model_2.state_dict().items()):
        if torch.equal(key_item_1[1], key_item_2[1]):
            pass
        else:
            models_differ += 1
            if (key_item_1[0] == key_item_2[0]):
                print('Mismtach found at', key_item_1[0])
            else:
                raise Exception
    if models_differ == 0:
        print('Models match perfectly! :)')


compare_models(model, loaded_model)       

Models match perfectly! :)


In [9]:
loaded_optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)

print("*"*50)

for key in optimizer.state_dict().keys():

    print(optimizer.state_dict()[key])
    print(type(optimizer.state_dict()[key]))

print("*"*50)

for key in loaded_optimizer.state_dict().keys():
    print(key)
    print(loaded_optimizer.state_dict()[key])

print("*"*50 + "before loading" + "*"*50)


print("*"*50 + "print each state" + "*"*50)

print(loaded_optimizer.state_dict()['param_groups'])
print(optimizer.state_dict()['state'][0])

**************************************************
{0: {'momentum_buffer': tensor([[[[-2.3149e-05,  1.8056e-04,  1.8430e-04,  ...,  1.2511e-04,
            1.9244e-04,  3.9053e-04],
          [-1.5863e-04, -1.0674e-04, -7.4275e-05,  ..., -1.6859e-04,
           -1.1081e-04,  2.7883e-04],
          [ 1.0104e-04,  1.3414e-04,  1.8439e-04,  ...,  3.5686e-05,
            6.7209e-05,  1.9917e-04],
          ...,
          [ 4.2883e-04,  2.5965e-04,  1.4181e-04,  ...,  1.5144e-04,
           -5.4976e-05, -1.0588e-05],
          [ 1.7816e-04,  3.3678e-06,  3.9705e-05,  ..., -7.7816e-05,
           -1.8988e-04,  3.8139e-05],
          [-1.7374e-04, -2.2337e-04, -6.5139e-05,  ..., -1.2895e-04,
           -1.6209e-04, -1.7993e-04]],

         [[ 6.7055e-05,  2.9954e-04,  1.9042e-04,  ...,  1.6498e-04,
            2.1844e-04,  3.7933e-04],
          [-1.5860e-06,  1.0806e-04,  7.9105e-05,  ...,  1.5599e-05,
           -7.4225e-05,  3.1515e-04],
          [ 2.7486e-04,  3.6979e-04,  4.6258e-04,  .

In [10]:
for key in loaded_optimizer.state_dict()['param_groups'][0].keys():
    print(key)
print(loaded_optimizer.state_dict()['param_groups'][0]['params'])


lr
momentum
dampening
weight_decay
nesterov
params
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]


In [11]:
#print(optimizer.state_dict()['state'][0])

for key in optimizer.state_dict()['state'][0].keys():
    print(key)
    print(type(key))



momentum_buffer
<class 'str'>


In [12]:


print(optimizer.state_dict()['param_groups'])
print(len(optimizer.state_dict()['param_groups']))


print(type(optimizer.state_dict()['state']))

print(optimizer.state_dict()['state'])

[{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0.0001, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]}]
1
<class 'dict'>
{0: {'momentum_buffer': tensor([[[[-2.3149e-05,  1.8056e-04,  1.8430e-04,  ...,  1.2511e-04,
            1.9244e-04,  3.9053e-04],
          [-1.5863e-04, -1.0674e-04, -7.4275e-05,  ..., -1.6859e-04,
           -1.1081e-04,  2.7883e-04],
          [ 1.0104e-04,  1.3414e-04,  1.8439e-04,  ...,  3.5686e-05,
            6.7209e-05,  1.9917e-04],
          ...,
          [ 4.2883e-04,  2.5965e-04,  1.4181e-04,  ...,  1.5144e-04,
           -5.4976e-05, -1.0588e-05],
          [ 1.7816e-04,  3.3678e-06,  3.9705e-05,  ..., -7.7816e-05,
           -1.8988e-04,  3.8139e-05],
          [-1.7374e-04, -2.2337e-04, -6.5139e-05,  ..., -1.2895e-04,


In [13]:

for i, param_group in enumerate(optimizer.state_dict()['param_groups']):
    print(param_group)
    for key in param_group.keys():
        print(key)
        # save optimizer state
        if key == 'params':
            # parameter number in state list
            for param_num in param_group[key]:
                print('param_num: ', param_num)
                # real parameter
                for param_key in optimizer.state_dict()['state'][param_num].keys():
                    print('param_key: ', param_key)
                    #print(optimizer.state_dict()['state'][i][param_key])
                    con.tensorset(f'{param_num},{param_key}',optimizer.state_dict()['state'][param_num][param_key].cpu().numpy(), dtype='float')

        # TODO: not sure how to save those training hypter-parameters which are not tensor type
        # save them to mongoDB?
        # else:
        #     print('others: ', key)
        #     con.tensorset(f'{key}',param_group[key], dtype='float')


{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0.0001, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]}
lr
momentum
dampening
weight_decay
nesterov
params
param_num:  0
param_key:  momentum_buffer
param_num:  1
param_key:  momentum_buffer
param_num:  2
param_key:  momentum_buffer
param_num:  3
param_key:  momentum_buffer
param_num:  4
param_key:  momentum_buffer
param_num:  5
param_key:  momentum_buffer
param_num:  6
param_key:  momentum_buffer
param_num:  7
param_key:  momentum_buffer
param_num:  8
param_key:  momentum_buffer
param_num:  9
param_key:  momentum_buffer
param_num:  10
param_key:  momentum_buffer
param_num:  11
param_key:  momentum_buffer
param_num:  12
param_key:  momentum_buffer
param_num:  13
param_key:  momentum_buffer
param_num:  14
para

In [14]:
# testcase 1
for i in range(62):
    tensor1 = optimizer.state_dict()['state'][i]['momentum_buffer']
    tensor2 = torch.from_numpy(con.tensorget(f'{i},momentum_buffer'))

    print(torch.sum(tensor1 - tensor2))

tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)


  after removing the cwd from sys.path.


In [15]:
print(dir(loaded_optimizer.state_dict()['state']))

['__class__', '__contains__', '__delattr__', '__delitem__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setitem__', '__sizeof__', '__str__', '__subclasshook__', 'clear', 'copy', 'fromkeys', 'get', 'items', 'keys', 'pop', 'popitem', 'setdefault', 'update', 'values']


In [42]:
print(optimizer.state_dict()['state'][1]['momentum_buffer'].shape)

print(optimizer.state_dict()['state'][1]['momentum_buffer'])

print(dir(optimizer.state_dict()['state']))

print(dir(loaded_optimizer.state_dict()['state']))

# you cannot udpate if the key is not in the dict
loaded_optimizer.state_dict()['state'].update({'1':{'momentum_buffer':torch.zeros([64,], dtype=torch.float32)}})
print(loaded_optimizer.state_dict()['state'])

loaded_optimizer.state_dict()['state'].fromkeys({'1':{'momentum_buffer':torch.zeros([64,], dtype=torch.float32)}})



print(loaded_optimizer.state_dict()['state'])

# you can change it if there if the 
optimizer.state_dict()['state'][1]['momentum_buffer'] = torch.zeros([64,], dtype=torch.float32)
print(optimizer.state_dict()['state'][1]['momentum_buffer'])


torch.Size([64])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
['__class__', '__contains__', '__delattr__', '__delitem__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setitem__', '__sizeof__', '__str__', '__subclasshook__', 'clear', 'copy', 'fromkeys', 'get', 'items', 'keys', 'pop', 'popitem', 'setdefault', 'update', 'values']
['__class__', '__contains__', '__delattr__', '__delitem__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__

In [17]:
# copy back to the new optimizer

for i, param_group in enumerate(loaded_optimizer.state_dict()['param_groups']):
    print(param_group)
    for key in param_group.keys():
        print(key)
        # save optimizer state
        if key == 'params':
            # parameter number in state list
            for param_num in param_group[key]:
                print('param_num: ', param_num)
                # TODO: this implementation only works for SDG with momentum
                param_key = 'momentum_buffer'
                #print(optimizer.state_dict()['state'][i][param_key])
                state_weight = con.tensorget(f'{param_num},{param_key}')
                new_state = {param_key:torch.from_numpy(state_weight)}
                #print({param_num:new_state})
                loaded_optimizer.state_dict()['state'][param_num] = new_state
        # TODO: not sure how to save those training hypter-parameters which are not tensor type
        # save them to mongoDB?
        # else:
        #     print('others: ', key)
        #     con.tensorset(f'{key}',param_group[key], dtype='float')




{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0.0001, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]}
lr
momentum
dampening
weight_decay
nesterov
params
param_num:  0
param_num:  1
param_num:  2
param_num:  3
param_num:  4
param_num:  5
param_num:  6
param_num:  7
param_num:  8
param_num:  9
param_num:  10
param_num:  11
param_num:  12
param_num:  13
param_num:  14
param_num:  15
param_num:  16
param_num:  17
param_num:  18
param_num:  19
param_num:  20
param_num:  21
param_num:  22
param_num:  23
param_num:  24
param_num:  25
param_num:  26
param_num:  27
param_num:  28
param_num:  29
param_num:  30
param_num:  31
param_num:  32
param_num:  33
param_num:  34
param_num:  35
param_num:  36
param_num:  37
param_num:  38
param_num:  39
param_num:  40
param_

In [18]:
print(loaded_optimizer.state_dict()['state'])

{}


In [19]:
for group in loaded_optimizer.param_groups:
    print(dir(group['params']))
    print(type(group['params']))
    print(len(group['params']))
    for i, p in enumerate(group['params']):
        state_weight = con.tensorget(f'{i},momentum_buffer')
        state_weight_tensor = torch.from_numpy(state_weight)
        p.cstate_weight_tensor


['__add__', '__class__', '__contains__', '__delattr__', '__delitem__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__iadd__', '__imul__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__mul__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__rmul__', '__setattr__', '__setitem__', '__sizeof__', '__str__', '__subclasshook__', 'append', 'clear', 'copy', 'count', 'extend', 'index', 'insert', 'pop', 'remove', 'reverse', 'sort']
<class 'list'>
62


AttributeError: 'Parameter' object has no attribute 'cstate_weight_tensor'

In [20]:
print(loaded_optimizer.state_dict())

{'state': {}, 'param_groups': [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0.0001, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]}]}


In [23]:
old_state_dict = optimizer.state_dict()
new_state_dict = loaded_optimizer.state_dict()
# the old state dict will have references to the old parameters, in state_dict['param_groups'][xyz]['params'] and in state_dict['state']
# you now need to find the parameter mismatches between the old and new statedicts
# if your optimizer has multiple param groups, you need to loop over them, too (I use xyz as a placeholder here. mostly, you'll only have 1 anyways, so just replace xyz with 0
new_pars = [p for p in new_state_dict['param_groups'][0]['params'] if not p in old_state_dict['param_groups'][0]['params']]
old_pars = [p for p in old_state_dict['param_groups'][0]['params'] if not p in new_state_dict['param_groups'][0]['params']]
# then you remove all the outdated ones from the state dict
for pid in old_pars:
    old_state_dict['state'].pop(pid)
# and add a new state for each new parameter to the state:
for pid in new_pars:
    old_state_dict['param_groups'][0]['params'].append(pid)
    old_state_dict['state'][pid] = { ... }  # your new state def here, depending on your optimizer



In [25]:
print(old_state_dict)

{'state': {0: {'momentum_buffer': tensor([[[[-2.3149e-05,  1.8056e-04,  1.8430e-04,  ...,  1.2511e-04,
            1.9244e-04,  3.9053e-04],
          [-1.5863e-04, -1.0674e-04, -7.4275e-05,  ..., -1.6859e-04,
           -1.1081e-04,  2.7883e-04],
          [ 1.0104e-04,  1.3414e-04,  1.8439e-04,  ...,  3.5686e-05,
            6.7209e-05,  1.9917e-04],
          ...,
          [ 4.2883e-04,  2.5965e-04,  1.4181e-04,  ...,  1.5144e-04,
           -5.4976e-05, -1.0588e-05],
          [ 1.7816e-04,  3.3678e-06,  3.9705e-05,  ..., -7.7816e-05,
           -1.8988e-04,  3.8139e-05],
          [-1.7374e-04, -2.2337e-04, -6.5139e-05,  ..., -1.2895e-04,
           -1.6209e-04, -1.7993e-04]],

         [[ 6.7055e-05,  2.9954e-04,  1.9042e-04,  ...,  1.6498e-04,
            2.1844e-04,  3.7933e-04],
          [-1.5860e-06,  1.0806e-04,  7.9105e-05,  ...,  1.5599e-05,
           -7.4225e-05,  3.1515e-04],
          [ 2.7486e-04,  3.6979e-04,  4.6258e-04,  ...,  1.6331e-04,
            1.5774e-04, 

In [24]:
print(new_state_dict)

{'state': {}, 'param_groups': [{'lr': 0.1, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0.0001, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]}]}
