In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
import argparse
import os
import shutil
import time

from fastai.transforms import *
from fastai.dataset import *
from fastai.fp16 import *
from fastai.conv_learner import *
from pathlib import *
from fastai import io
import tarfile

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import models
import models.cifar10 as cifar10models
from distributed import DistributedDataParallel as DDP

from datetime import datetime

# print(models.cifar10.__dict__)
model_names = sorted(name for name in models.__dict__
                     if name.islower() and not name.startswith("__")
                     and callable(models.__dict__[name]))

cifar10_names = sorted(name for name in cifar10models.__dict__
                     if name.islower() and not name.startswith("__")
                     and callable(cifar10models.__dict__[name]))

model_names = cifar10_names + model_names

In [3]:
arch = 'wrn_22'

In [7]:
model = cifar10models.__dict__[arch] if arch in cifar10_names else models.__dict__[arch]
model = model(); model

WideResNet(
  (features): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): NetworkBlock(
      (layer): Sequential(
        (0): BasicBlock(
          (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True)
          (relu1): ReLU(inplace)
          (conv1): Conv2d(16, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True)
          (relu2): ReLU(inplace)
          (conv2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (convShortcut): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (1): BasicBlock(
          (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True)
          (relu1): ReLU(inplace)
          (conv1): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True)
    

In [9]:
arch2 = 'wrn_28'
model2 = cifar10models.__dict__[arch2] if arch2 in cifar10_names else models.__dict__[arch2]
model2 = model2(); model2

WideResNet(
  (features): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): NetworkBlock(
      (layer): Sequential(
        (0): BasicBlock(
          (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True)
          (relu1): ReLU(inplace)
          (conv1): Conv2d(16, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True)
          (relu2): ReLU(inplace)
          (conv2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (convShortcut): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (1): BasicBlock(
          (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True)
          (relu1): ReLU(inplace)
          (conv1): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True)
    

In [11]:
model.state_dict()

OrderedDict([('features.0.weight', 
              (0 ,0 ,.,.) = 
                0.0749  0.1061  0.0658
               -0.1342  0.0571 -0.1334
                0.0447  0.2141 -0.0325
              
              (0 ,1 ,.,.) = 
               -0.0634  0.0537 -0.2533
               -0.1042  0.0671 -0.0361
               -0.1395  0.0290 -0.2526
              
              (0 ,2 ,.,.) = 
                0.0848 -0.1804 -0.1632
               -0.0194  0.0951  0.1279
               -0.0144 -0.0961 -0.0811
              
              (1 ,0 ,.,.) = 
                0.0160  0.0030 -0.1306
               -0.2165 -0.1735 -0.0424
                0.1700  0.2723 -0.0645
              
              (1 ,1 ,.,.) = 
                0.2822  0.1359 -0.0898
                0.0735 -0.0578  0.0594
                0.0510 -0.0787 -0.1883
              
              (1 ,2 ,.,.) = 
                0.2040 -0.0097 -0.0221
                0.0342 -0.0403  0.1804
               -0.1329  0.0462 -0.0925
             

In [10]:
model.parameters()

<generator object Module.parameters at 0x7f76cefc9200>

In [15]:
state_dict = model.state_dict()

In [29]:
state_dict2 = model2.state_dict()

In [30]:
state_dict2.update(state_dict)

In [31]:
model2.load_state_dict(state_dict2)

In [23]:
d = {'banana': 8, 'apple': 8, 'pear': 8, 'orange': 8}

In [24]:
od = OrderedDict(sorted(d.items(), key=lambda t: t[0]))

In [25]:
d2 = {'banana': 3, 'apple': 4, 'cucumber': 5, 'pear': 1, 'orange': 2}

In [26]:
od2 = OrderedDict(sorted(d2.items(), key=lambda t: t[0]))

In [27]:
od2.update(od)

In [28]:
od2

OrderedDict([('apple', 8),
             ('banana', 8),
             ('cucumber', 5),
             ('orange', 8),
             ('pear', 8)])