In [2]:
%load_ext autoreload
%autoreload 2
from biotorch.benchmark.run import Benchmark

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [1]:
from biotorch.models.weight_mirroring.resnet import resnet18, resnet34, resnet50, wide_resnet101_2

In [20]:
from collections import deque, OrderedDict, defaultdict
def compute_angles_module(module):
    queue = deque()
    layers_alignment = OrderedDict()
    seen_keys = defaultdict(lambda: 0)
    
    # First pass to store module keys
    if len(module._modules.keys()) > 0:
        for module_keys in module._modules.keys():
            queue.append((module, module_keys))
            
    # If we receive only a layer
    else:
        layer = module
        if 'alignment' in layer.__dict__:
            angle = layer.compute_alignment()
            key_name = 'layer'
            layers_alignment[key_name] = angle
            
    # Approximate depth first traversal of the model using a deque
    while len(queue) > 0:
        module, module_key = queue.popleft()
        layer = getattr(module, module_key)
        if 'alignment' in layer.__dict__:
            angle = layer.compute_alignment()
            key_name = module_key + '_' + str(seen_keys[module_key])
            seen_keys[module_key] += 1 
            layers_alignment[key_name] = angle
        if len(layer._modules.keys()) > 0:
            # Reverse list as we are appending from the left side of the queue
            for key in list(layer._modules.keys())[::-1]:
                queue.appendleft((layer, key))
                
    return layers_alignment

In [21]:
model = resnet18(pretrained=False, num_classes=10)

In [22]:
import torch
x = torch.randn((10, 3, 224, 224))

In [23]:
linear_layer

Linear(in_features=512, out_features=10, bias=True)

In [31]:
linear_layer = model.fc
noise_amplitude = 0.1
for i in range(10000):
    with torch.no_grad():
        input_noise = noise_amplitude * torch.randn(12, 512)
        output_noise = linear_layer(input_noise)
        linear_layer.update_B(input_noise, 
                            output_noise, 
                            mirror_learning_rate=0.05, 
                            growth_control=True,
                            damping_factor=0.5)
        layers_alignment = compute_angles_module(linear_layer)
        print(layers_alignment)

OrderedDict([('layer', tensor(90.0252))])
OrderedDict([('layer', tensor(90.1967))])
OrderedDict([('layer', tensor(90.6656))])
OrderedDict([('layer', tensor(90.7723))])
OrderedDict([('layer', tensor(90.6465))])
OrderedDict([('layer', tensor(90.5991))])
OrderedDict([('layer', tensor(90.6646))])
OrderedDict([('layer', tensor(90.5998))])
OrderedDict([('layer', tensor(89.3459))])
OrderedDict([('layer', tensor(88.4419))])
OrderedDict([('layer', tensor(88.9015))])
OrderedDict([('layer', tensor(89.3344))])
OrderedDict([('layer', tensor(89.8514))])
OrderedDict([('layer', tensor(89.9274))])
OrderedDict([('layer', tensor(89.6659))])
OrderedDict([('layer', tensor(89.8931))])
OrderedDict([('layer', tensor(90.4793))])
OrderedDict([('layer', tensor(90.4039))])
OrderedDict([('layer', tensor(90.7702))])
OrderedDict([('layer', tensor(90.2368))])
OrderedDict([('layer', tensor(88.4143))])
OrderedDict([('layer', tensor(88.0673))])
OrderedDict([('layer', tensor(88.5788))])
OrderedDict([('layer', tensor(88.6

OrderedDict([('layer', tensor(90.0761))])
OrderedDict([('layer', tensor(89.9459))])
OrderedDict([('layer', tensor(89.7034))])
OrderedDict([('layer', tensor(89.1671))])
OrderedDict([('layer', tensor(88.8362))])
OrderedDict([('layer', tensor(88.9848))])
OrderedDict([('layer', tensor(89.8012))])
OrderedDict([('layer', tensor(90.0098))])
OrderedDict([('layer', tensor(90.4292))])
OrderedDict([('layer', tensor(90.8434))])
OrderedDict([('layer', tensor(90.6540))])
OrderedDict([('layer', tensor(89.9804))])
OrderedDict([('layer', tensor(89.6063))])
OrderedDict([('layer', tensor(89.5407))])
OrderedDict([('layer', tensor(89.2250))])
OrderedDict([('layer', tensor(89.8629))])
OrderedDict([('layer', tensor(89.7885))])
OrderedDict([('layer', tensor(89.9344))])
OrderedDict([('layer', tensor(90.7600))])
OrderedDict([('layer', tensor(91.1535))])
OrderedDict([('layer', tensor(90.0565))])
OrderedDict([('layer', tensor(89.5653))])
OrderedDict([('layer', tensor(89.8453))])
OrderedDict([('layer', tensor(89.4

OrderedDict([('layer', tensor(88.3574))])
OrderedDict([('layer', tensor(89.1921))])
OrderedDict([('layer', tensor(88.4807))])
OrderedDict([('layer', tensor(88.9864))])
OrderedDict([('layer', tensor(89.7577))])
OrderedDict([('layer', tensor(89.6129))])
OrderedDict([('layer', tensor(89.9290))])
OrderedDict([('layer', tensor(90.1770))])
OrderedDict([('layer', tensor(90.8983))])
OrderedDict([('layer', tensor(90.5347))])
OrderedDict([('layer', tensor(90.0085))])
OrderedDict([('layer', tensor(90.5835))])
OrderedDict([('layer', tensor(90.7747))])
OrderedDict([('layer', tensor(91.4132))])
OrderedDict([('layer', tensor(91.1614))])
OrderedDict([('layer', tensor(90.0600))])
OrderedDict([('layer', tensor(90.7419))])
OrderedDict([('layer', tensor(90.6156))])
OrderedDict([('layer', tensor(89.9374))])
OrderedDict([('layer', tensor(89.3798))])
OrderedDict([('layer', tensor(89.9042))])
OrderedDict([('layer', tensor(89.8622))])
OrderedDict([('layer', tensor(89.8682))])
OrderedDict([('layer', tensor(90.1

OrderedDict([('layer', tensor(92.0919))])
OrderedDict([('layer', tensor(90.9947))])
OrderedDict([('layer', tensor(90.9194))])
OrderedDict([('layer', tensor(90.9375))])
OrderedDict([('layer', tensor(90.8883))])
OrderedDict([('layer', tensor(90.2808))])
OrderedDict([('layer', tensor(90.0282))])
OrderedDict([('layer', tensor(90.5517))])
OrderedDict([('layer', tensor(89.3473))])
OrderedDict([('layer', tensor(89.4855))])
OrderedDict([('layer', tensor(90.1172))])
OrderedDict([('layer', tensor(90.4107))])
OrderedDict([('layer', tensor(90.2611))])
OrderedDict([('layer', tensor(89.5221))])
OrderedDict([('layer', tensor(89.6112))])
OrderedDict([('layer', tensor(89.9765))])
OrderedDict([('layer', tensor(90.3678))])
OrderedDict([('layer', tensor(90.5639))])
OrderedDict([('layer', tensor(90.2553))])
OrderedDict([('layer', tensor(90.0946))])
OrderedDict([('layer', tensor(89.6773))])
OrderedDict([('layer', tensor(88.5559))])
OrderedDict([('layer', tensor(89.3148))])
OrderedDict([('layer', tensor(90.0

OrderedDict([('layer', tensor(89.4073))])
OrderedDict([('layer', tensor(89.9771))])
OrderedDict([('layer', tensor(89.5414))])
OrderedDict([('layer', tensor(89.6360))])
OrderedDict([('layer', tensor(88.7384))])
OrderedDict([('layer', tensor(89.8032))])
OrderedDict([('layer', tensor(88.9159))])
OrderedDict([('layer', tensor(88.7402))])
OrderedDict([('layer', tensor(89.9980))])
OrderedDict([('layer', tensor(90.1617))])
OrderedDict([('layer', tensor(89.9295))])
OrderedDict([('layer', tensor(89.8037))])
OrderedDict([('layer', tensor(90.4281))])
OrderedDict([('layer', tensor(89.5175))])
OrderedDict([('layer', tensor(89.1238))])
OrderedDict([('layer', tensor(89.6762))])
OrderedDict([('layer', tensor(88.9873))])
OrderedDict([('layer', tensor(89.0909))])
OrderedDict([('layer', tensor(89.8335))])
OrderedDict([('layer', tensor(90.7508))])
OrderedDict([('layer', tensor(90.3884))])
OrderedDict([('layer', tensor(91.4009))])
OrderedDict([('layer', tensor(91.6641))])
OrderedDict([('layer', tensor(90.9

OrderedDict([('layer', tensor(90.8162))])
OrderedDict([('layer', tensor(91.6302))])
OrderedDict([('layer', tensor(91.5225))])
OrderedDict([('layer', tensor(90.9272))])
OrderedDict([('layer', tensor(90.2304))])
OrderedDict([('layer', tensor(89.8893))])
OrderedDict([('layer', tensor(89.2362))])
OrderedDict([('layer', tensor(89.1193))])
OrderedDict([('layer', tensor(90.0881))])
OrderedDict([('layer', tensor(90.5315))])
OrderedDict([('layer', tensor(91.2071))])
OrderedDict([('layer', tensor(90.9332))])
OrderedDict([('layer', tensor(90.9385))])
OrderedDict([('layer', tensor(91.6424))])
OrderedDict([('layer', tensor(90.9140))])
OrderedDict([('layer', tensor(91.5919))])
OrderedDict([('layer', tensor(91.5248))])
OrderedDict([('layer', tensor(90.7938))])
OrderedDict([('layer', tensor(89.5245))])
OrderedDict([('layer', tensor(90.2209))])
OrderedDict([('layer', tensor(90.6645))])
OrderedDict([('layer', tensor(90.5598))])
OrderedDict([('layer', tensor(90.4764))])
OrderedDict([('layer', tensor(90.0

OrderedDict([('layer', tensor(90.6141))])
OrderedDict([('layer', tensor(89.3200))])
OrderedDict([('layer', tensor(89.5414))])
OrderedDict([('layer', tensor(89.4807))])
OrderedDict([('layer', tensor(90.9275))])
OrderedDict([('layer', tensor(90.8294))])
OrderedDict([('layer', tensor(91.4469))])
OrderedDict([('layer', tensor(92.0056))])
OrderedDict([('layer', tensor(92.0220))])
OrderedDict([('layer', tensor(91.5721))])
OrderedDict([('layer', tensor(90.3044))])
OrderedDict([('layer', tensor(90.0080))])
OrderedDict([('layer', tensor(89.3812))])
OrderedDict([('layer', tensor(88.9646))])
OrderedDict([('layer', tensor(89.2429))])
OrderedDict([('layer', tensor(89.1464))])
OrderedDict([('layer', tensor(88.4207))])
OrderedDict([('layer', tensor(89.5680))])
OrderedDict([('layer', tensor(89.6833))])
OrderedDict([('layer', tensor(90.0247))])
OrderedDict([('layer', tensor(90.1094))])
OrderedDict([('layer', tensor(90.6363))])
OrderedDict([('layer', tensor(90.5810))])
OrderedDict([('layer', tensor(90.4

OrderedDict([('layer', tensor(90.4889))])
OrderedDict([('layer', tensor(90.3436))])
OrderedDict([('layer', tensor(90.0866))])
OrderedDict([('layer', tensor(90.0000))])
OrderedDict([('layer', tensor(89.9579))])
OrderedDict([('layer', tensor(90.1211))])
OrderedDict([('layer', tensor(90.7682))])
OrderedDict([('layer', tensor(89.4381))])
OrderedDict([('layer', tensor(89.5040))])
OrderedDict([('layer', tensor(89.9105))])
OrderedDict([('layer', tensor(90.2607))])
OrderedDict([('layer', tensor(90.2312))])
OrderedDict([('layer', tensor(90.4650))])
OrderedDict([('layer', tensor(90.1432))])
OrderedDict([('layer', tensor(89.8240))])
OrderedDict([('layer', tensor(90.3723))])
OrderedDict([('layer', tensor(89.4553))])
OrderedDict([('layer', tensor(90.3469))])
OrderedDict([('layer', tensor(90.3557))])
OrderedDict([('layer', tensor(90.4353))])
OrderedDict([('layer', tensor(89.9037))])
OrderedDict([('layer', tensor(89.4504))])
OrderedDict([('layer', tensor(90.5992))])
OrderedDict([('layer', tensor(90.0

OrderedDict([('layer', tensor(90.6419))])
OrderedDict([('layer', tensor(89.6769))])
OrderedDict([('layer', tensor(90.1602))])
OrderedDict([('layer', tensor(89.7860))])
OrderedDict([('layer', tensor(89.9751))])
OrderedDict([('layer', tensor(89.1108))])
OrderedDict([('layer', tensor(89.8834))])
OrderedDict([('layer', tensor(90.7995))])
OrderedDict([('layer', tensor(90.2306))])
OrderedDict([('layer', tensor(90.4241))])
OrderedDict([('layer', tensor(90.3097))])
OrderedDict([('layer', tensor(90.8059))])
OrderedDict([('layer', tensor(90.1644))])
OrderedDict([('layer', tensor(89.7008))])
OrderedDict([('layer', tensor(89.9053))])
OrderedDict([('layer', tensor(89.3193))])
OrderedDict([('layer', tensor(89.4614))])
OrderedDict([('layer', tensor(88.7710))])
OrderedDict([('layer', tensor(88.6937))])
OrderedDict([('layer', tensor(88.8133))])
OrderedDict([('layer', tensor(88.2460))])
OrderedDict([('layer', tensor(88.3523))])
OrderedDict([('layer', tensor(88.5824))])
OrderedDict([('layer', tensor(88.8

OrderedDict([('layer', tensor(90.5277))])
OrderedDict([('layer', tensor(90.5961))])
OrderedDict([('layer', tensor(91.0140))])
OrderedDict([('layer', tensor(89.4511))])
OrderedDict([('layer', tensor(89.1916))])
OrderedDict([('layer', tensor(89.3789))])
OrderedDict([('layer', tensor(89.2388))])
OrderedDict([('layer', tensor(89.6153))])
OrderedDict([('layer', tensor(90.4516))])
OrderedDict([('layer', tensor(89.9689))])
OrderedDict([('layer', tensor(90.0702))])
OrderedDict([('layer', tensor(90.2225))])
OrderedDict([('layer', tensor(89.4641))])
OrderedDict([('layer', tensor(90.2642))])
OrderedDict([('layer', tensor(90.7799))])
OrderedDict([('layer', tensor(91.7643))])
OrderedDict([('layer', tensor(91.5999))])
OrderedDict([('layer', tensor(90.1706))])
OrderedDict([('layer', tensor(90.6995))])
OrderedDict([('layer', tensor(91.2131))])
OrderedDict([('layer', tensor(90.5444))])
OrderedDict([('layer', tensor(89.9653))])
OrderedDict([('layer', tensor(90.0741))])
OrderedDict([('layer', tensor(90.2

OrderedDict([('layer', tensor(89.5895))])
OrderedDict([('layer', tensor(90.2064))])
OrderedDict([('layer', tensor(90.2451))])
OrderedDict([('layer', tensor(90.0864))])
OrderedDict([('layer', tensor(89.8254))])
OrderedDict([('layer', tensor(89.9455))])
OrderedDict([('layer', tensor(89.8347))])
OrderedDict([('layer', tensor(89.9979))])
OrderedDict([('layer', tensor(90.3744))])
OrderedDict([('layer', tensor(90.3670))])
OrderedDict([('layer', tensor(90.8507))])
OrderedDict([('layer', tensor(90.6374))])
OrderedDict([('layer', tensor(90.6898))])
OrderedDict([('layer', tensor(90.8649))])
OrderedDict([('layer', tensor(90.9463))])
OrderedDict([('layer', tensor(90.9314))])
OrderedDict([('layer', tensor(90.0492))])
OrderedDict([('layer', tensor(88.9099))])
OrderedDict([('layer', tensor(88.9272))])
OrderedDict([('layer', tensor(89.0232))])
OrderedDict([('layer', tensor(89.5197))])
OrderedDict([('layer', tensor(88.5587))])
OrderedDict([('layer', tensor(88.5726))])
OrderedDict([('layer', tensor(88.8

OrderedDict([('layer', tensor(90.9080))])
OrderedDict([('layer', tensor(90.7176))])
OrderedDict([('layer', tensor(90.5315))])
OrderedDict([('layer', tensor(90.5966))])
OrderedDict([('layer', tensor(90.6876))])
OrderedDict([('layer', tensor(90.8333))])
OrderedDict([('layer', tensor(91.0098))])
OrderedDict([('layer', tensor(90.5745))])
OrderedDict([('layer', tensor(90.8014))])
OrderedDict([('layer', tensor(90.3123))])
OrderedDict([('layer', tensor(89.9579))])
OrderedDict([('layer', tensor(89.4600))])
OrderedDict([('layer', tensor(90.3016))])
OrderedDict([('layer', tensor(89.7575))])
OrderedDict([('layer', tensor(89.2739))])
OrderedDict([('layer', tensor(88.5725))])
OrderedDict([('layer', tensor(89.0300))])
OrderedDict([('layer', tensor(88.9225))])
OrderedDict([('layer', tensor(88.5362))])
OrderedDict([('layer', tensor(89.6188))])
OrderedDict([('layer', tensor(89.7527))])
OrderedDict([('layer', tensor(88.8079))])
OrderedDict([('layer', tensor(88.9076))])
OrderedDict([('layer', tensor(88.6

OrderedDict([('layer', tensor(89.7602))])
OrderedDict([('layer', tensor(89.6391))])
OrderedDict([('layer', tensor(90.1671))])
OrderedDict([('layer', tensor(90.4121))])
OrderedDict([('layer', tensor(89.9101))])
OrderedDict([('layer', tensor(90.0836))])
OrderedDict([('layer', tensor(89.8040))])
OrderedDict([('layer', tensor(89.0289))])
OrderedDict([('layer', tensor(88.8561))])
OrderedDict([('layer', tensor(89.0918))])
OrderedDict([('layer', tensor(88.9674))])
OrderedDict([('layer', tensor(88.9153))])
OrderedDict([('layer', tensor(88.7134))])
OrderedDict([('layer', tensor(88.7704))])
OrderedDict([('layer', tensor(89.1044))])
OrderedDict([('layer', tensor(89.8005))])
OrderedDict([('layer', tensor(89.7944))])
OrderedDict([('layer', tensor(90.2712))])
OrderedDict([('layer', tensor(90.2545))])
OrderedDict([('layer', tensor(89.3881))])
OrderedDict([('layer', tensor(88.8826))])
OrderedDict([('layer', tensor(89.1615))])
OrderedDict([('layer', tensor(88.7452))])
OrderedDict([('layer', tensor(89.1

OrderedDict([('layer', tensor(88.8443))])
OrderedDict([('layer', tensor(88.0233))])
OrderedDict([('layer', tensor(89.0023))])
OrderedDict([('layer', tensor(89.1826))])
OrderedDict([('layer', tensor(89.1468))])
OrderedDict([('layer', tensor(89.4588))])
OrderedDict([('layer', tensor(89.6650))])
OrderedDict([('layer', tensor(90.4559))])
OrderedDict([('layer', tensor(90.0572))])
OrderedDict([('layer', tensor(90.1607))])
OrderedDict([('layer', tensor(90.1971))])
OrderedDict([('layer', tensor(90.5276))])
OrderedDict([('layer', tensor(91.3225))])
OrderedDict([('layer', tensor(90.9369))])
OrderedDict([('layer', tensor(90.4553))])
OrderedDict([('layer', tensor(89.3417))])
OrderedDict([('layer', tensor(88.7996))])
OrderedDict([('layer', tensor(88.6667))])
OrderedDict([('layer', tensor(89.7644))])
OrderedDict([('layer', tensor(89.0934))])
OrderedDict([('layer', tensor(90.0007))])
OrderedDict([('layer', tensor(89.7648))])
OrderedDict([('layer', tensor(89.6726))])
OrderedDict([('layer', tensor(89.2

OrderedDict([('layer', tensor(88.8537))])
OrderedDict([('layer', tensor(88.8725))])
OrderedDict([('layer', tensor(89.0280))])
OrderedDict([('layer', tensor(89.6461))])
OrderedDict([('layer', tensor(89.2701))])
OrderedDict([('layer', tensor(89.7690))])
OrderedDict([('layer', tensor(89.4187))])
OrderedDict([('layer', tensor(90.3298))])
OrderedDict([('layer', tensor(89.9845))])
OrderedDict([('layer', tensor(90.0134))])
OrderedDict([('layer', tensor(89.8092))])
OrderedDict([('layer', tensor(89.2785))])
OrderedDict([('layer', tensor(88.2677))])
OrderedDict([('layer', tensor(88.9300))])
OrderedDict([('layer', tensor(88.8660))])
OrderedDict([('layer', tensor(88.6414))])
OrderedDict([('layer', tensor(88.6009))])
OrderedDict([('layer', tensor(88.1183))])
OrderedDict([('layer', tensor(88.6461))])
OrderedDict([('layer', tensor(89.2394))])
OrderedDict([('layer', tensor(89.3259))])
OrderedDict([('layer', tensor(90.1094))])
OrderedDict([('layer', tensor(90.6431))])
OrderedDict([('layer', tensor(90.5

OrderedDict([('layer', tensor(91.3077))])
OrderedDict([('layer', tensor(91.2482))])
OrderedDict([('layer', tensor(90.2749))])
OrderedDict([('layer', tensor(90.6861))])
OrderedDict([('layer', tensor(90.0480))])
OrderedDict([('layer', tensor(89.8471))])
OrderedDict([('layer', tensor(90.8510))])
OrderedDict([('layer', tensor(90.6794))])
OrderedDict([('layer', tensor(89.9226))])
OrderedDict([('layer', tensor(89.8449))])
OrderedDict([('layer', tensor(89.1927))])
OrderedDict([('layer', tensor(89.0593))])
OrderedDict([('layer', tensor(89.1304))])
OrderedDict([('layer', tensor(89.3210))])
OrderedDict([('layer', tensor(89.9058))])
OrderedDict([('layer', tensor(89.7674))])
OrderedDict([('layer', tensor(89.8120))])
OrderedDict([('layer', tensor(90.1311))])
OrderedDict([('layer', tensor(90.8529))])
OrderedDict([('layer', tensor(90.2694))])
OrderedDict([('layer', tensor(91.3847))])
OrderedDict([('layer', tensor(91.4187))])
OrderedDict([('layer', tensor(91.4494))])
OrderedDict([('layer', tensor(91.9

OrderedDict([('layer', tensor(90.2481))])
OrderedDict([('layer', tensor(90.3863))])
OrderedDict([('layer', tensor(90.0555))])
OrderedDict([('layer', tensor(90.2892))])
OrderedDict([('layer', tensor(90.1260))])
OrderedDict([('layer', tensor(90.8082))])
OrderedDict([('layer', tensor(89.5804))])
OrderedDict([('layer', tensor(89.6578))])
OrderedDict([('layer', tensor(89.9648))])
OrderedDict([('layer', tensor(89.2327))])
OrderedDict([('layer', tensor(89.5550))])
OrderedDict([('layer', tensor(90.0490))])
OrderedDict([('layer', tensor(89.5428))])
OrderedDict([('layer', tensor(89.4224))])
OrderedDict([('layer', tensor(89.7042))])
OrderedDict([('layer', tensor(90.1149))])
OrderedDict([('layer', tensor(89.4058))])
OrderedDict([('layer', tensor(89.9149))])
OrderedDict([('layer', tensor(89.2175))])
OrderedDict([('layer', tensor(89.7086))])
OrderedDict([('layer', tensor(89.0949))])
OrderedDict([('layer', tensor(88.7041))])
OrderedDict([('layer', tensor(88.2074))])
OrderedDict([('layer', tensor(88.9

OrderedDict([('layer', tensor(90.3659))])
OrderedDict([('layer', tensor(90.8304))])
OrderedDict([('layer', tensor(90.4433))])
OrderedDict([('layer', tensor(90.8649))])
OrderedDict([('layer', tensor(90.8454))])
OrderedDict([('layer', tensor(91.0839))])
OrderedDict([('layer', tensor(90.3798))])
OrderedDict([('layer', tensor(90.4423))])
OrderedDict([('layer', tensor(90.1538))])
OrderedDict([('layer', tensor(89.6468))])
OrderedDict([('layer', tensor(89.2493))])
OrderedDict([('layer', tensor(88.9201))])
OrderedDict([('layer', tensor(88.4976))])
OrderedDict([('layer', tensor(90.0082))])
OrderedDict([('layer', tensor(89.2380))])
OrderedDict([('layer', tensor(88.5260))])
OrderedDict([('layer', tensor(89.6337))])
OrderedDict([('layer', tensor(89.6886))])
OrderedDict([('layer', tensor(89.6731))])
OrderedDict([('layer', tensor(89.9042))])
OrderedDict([('layer', tensor(89.6763))])
OrderedDict([('layer', tensor(90.7107))])
OrderedDict([('layer', tensor(89.8016))])
OrderedDict([('layer', tensor(89.9

OrderedDict([('layer', tensor(91.1682))])
OrderedDict([('layer', tensor(91.3425))])
OrderedDict([('layer', tensor(90.4970))])
OrderedDict([('layer', tensor(90.3034))])
OrderedDict([('layer', tensor(90.4676))])
OrderedDict([('layer', tensor(90.4829))])
OrderedDict([('layer', tensor(90.0294))])
OrderedDict([('layer', tensor(89.3857))])
OrderedDict([('layer', tensor(90.2052))])
OrderedDict([('layer', tensor(90.6574))])
OrderedDict([('layer', tensor(91.3421))])
OrderedDict([('layer', tensor(90.9973))])
OrderedDict([('layer', tensor(90.6215))])
OrderedDict([('layer', tensor(90.5393))])
OrderedDict([('layer', tensor(89.8619))])
OrderedDict([('layer', tensor(89.6235))])
OrderedDict([('layer', tensor(89.9984))])
OrderedDict([('layer', tensor(89.6907))])
OrderedDict([('layer', tensor(89.7612))])
OrderedDict([('layer', tensor(89.3657))])
OrderedDict([('layer', tensor(89.0610))])
OrderedDict([('layer', tensor(89.7812))])
OrderedDict([('layer', tensor(89.9468))])
OrderedDict([('layer', tensor(90.1

OrderedDict([('layer', tensor(89.8778))])
OrderedDict([('layer', tensor(90.1695))])
OrderedDict([('layer', tensor(89.6944))])
OrderedDict([('layer', tensor(90.0131))])
OrderedDict([('layer', tensor(90.6723))])
OrderedDict([('layer', tensor(91.0195))])
OrderedDict([('layer', tensor(91.1463))])
OrderedDict([('layer', tensor(91.2354))])
OrderedDict([('layer', tensor(91.0348))])
OrderedDict([('layer', tensor(91.2621))])
OrderedDict([('layer', tensor(90.3627))])
OrderedDict([('layer', tensor(90.2030))])
OrderedDict([('layer', tensor(90.9849))])
OrderedDict([('layer', tensor(90.8571))])
OrderedDict([('layer', tensor(90.9734))])
OrderedDict([('layer', tensor(91.1780))])
OrderedDict([('layer', tensor(91.2509))])
OrderedDict([('layer', tensor(91.0774))])
OrderedDict([('layer', tensor(90.9340))])
OrderedDict([('layer', tensor(91.0628))])
OrderedDict([('layer', tensor(90.7206))])
OrderedDict([('layer', tensor(90.3990))])
OrderedDict([('layer', tensor(89.6248))])
OrderedDict([('layer', tensor(90.0

OrderedDict([('layer', tensor(88.9580))])
OrderedDict([('layer', tensor(88.7780))])
OrderedDict([('layer', tensor(89.1179))])
OrderedDict([('layer', tensor(89.1942))])
OrderedDict([('layer', tensor(89.8220))])
OrderedDict([('layer', tensor(91.1619))])
OrderedDict([('layer', tensor(90.6332))])
OrderedDict([('layer', tensor(90.5647))])
OrderedDict([('layer', tensor(89.8670))])
OrderedDict([('layer', tensor(90.8426))])
OrderedDict([('layer', tensor(90.4363))])
OrderedDict([('layer', tensor(90.5319))])
OrderedDict([('layer', tensor(89.9081))])
OrderedDict([('layer', tensor(89.2817))])
OrderedDict([('layer', tensor(90.1665))])
OrderedDict([('layer', tensor(89.7891))])
OrderedDict([('layer', tensor(88.9113))])
OrderedDict([('layer', tensor(90.3579))])
OrderedDict([('layer', tensor(90.2302))])
OrderedDict([('layer', tensor(89.3578))])
OrderedDict([('layer', tensor(89.3526))])
OrderedDict([('layer', tensor(89.7221))])
OrderedDict([('layer', tensor(89.9331))])
OrderedDict([('layer', tensor(90.3

OrderedDict([('layer', tensor(89.3766))])
OrderedDict([('layer', tensor(89.3532))])
OrderedDict([('layer', tensor(89.4455))])
OrderedDict([('layer', tensor(90.1025))])
OrderedDict([('layer', tensor(90.2206))])
OrderedDict([('layer', tensor(90.1581))])
OrderedDict([('layer', tensor(89.9766))])
OrderedDict([('layer', tensor(89.4454))])
OrderedDict([('layer', tensor(89.3326))])
OrderedDict([('layer', tensor(89.4917))])
OrderedDict([('layer', tensor(89.1647))])
OrderedDict([('layer', tensor(89.0483))])
OrderedDict([('layer', tensor(88.6127))])
OrderedDict([('layer', tensor(88.4222))])
OrderedDict([('layer', tensor(88.8177))])
OrderedDict([('layer', tensor(89.1498))])
OrderedDict([('layer', tensor(88.7982))])
OrderedDict([('layer', tensor(88.5495))])
OrderedDict([('layer', tensor(89.2087))])
OrderedDict([('layer', tensor(89.6964))])
OrderedDict([('layer', tensor(89.4701))])
OrderedDict([('layer', tensor(89.0422))])
OrderedDict([('layer', tensor(89.7063))])
OrderedDict([('layer', tensor(88.9

OrderedDict([('layer', tensor(90.4362))])
OrderedDict([('layer', tensor(90.6645))])
OrderedDict([('layer', tensor(90.4292))])
OrderedDict([('layer', tensor(90.5241))])
OrderedDict([('layer', tensor(90.7858))])
OrderedDict([('layer', tensor(90.9051))])
OrderedDict([('layer', tensor(90.1824))])
OrderedDict([('layer', tensor(90.6310))])
OrderedDict([('layer', tensor(90.8386))])
OrderedDict([('layer', tensor(90.7227))])
OrderedDict([('layer', tensor(90.5362))])
OrderedDict([('layer', tensor(89.1210))])
OrderedDict([('layer', tensor(89.0211))])
OrderedDict([('layer', tensor(88.5651))])
OrderedDict([('layer', tensor(88.8822))])
OrderedDict([('layer', tensor(89.1850))])
OrderedDict([('layer', tensor(88.7478))])
OrderedDict([('layer', tensor(88.7025))])
OrderedDict([('layer', tensor(88.2674))])
OrderedDict([('layer', tensor(88.5232))])
OrderedDict([('layer', tensor(88.6714))])
OrderedDict([('layer', tensor(88.8621))])
OrderedDict([('layer', tensor(88.8737))])
OrderedDict([('layer', tensor(89.4

OrderedDict([('layer', tensor(89.8985))])
OrderedDict([('layer', tensor(90.3024))])
OrderedDict([('layer', tensor(90.5141))])
OrderedDict([('layer', tensor(89.8703))])
OrderedDict([('layer', tensor(89.2620))])
OrderedDict([('layer', tensor(88.6631))])
OrderedDict([('layer', tensor(88.8277))])
OrderedDict([('layer', tensor(89.4058))])
OrderedDict([('layer', tensor(88.5652))])
OrderedDict([('layer', tensor(88.3800))])
OrderedDict([('layer', tensor(89.2794))])
OrderedDict([('layer', tensor(88.5305))])
OrderedDict([('layer', tensor(89.0221))])
OrderedDict([('layer', tensor(89.2581))])
OrderedDict([('layer', tensor(89.5344))])
OrderedDict([('layer', tensor(89.9764))])
OrderedDict([('layer', tensor(90.3269))])
OrderedDict([('layer', tensor(90.1885))])
OrderedDict([('layer', tensor(89.9742))])
OrderedDict([('layer', tensor(89.9620))])
OrderedDict([('layer', tensor(89.9725))])
OrderedDict([('layer', tensor(90.3045))])
OrderedDict([('layer', tensor(90.1863))])
OrderedDict([('layer', tensor(89.9

OrderedDict([('layer', tensor(90.4601))])
OrderedDict([('layer', tensor(90.4789))])
OrderedDict([('layer', tensor(90.3346))])
OrderedDict([('layer', tensor(90.7889))])
OrderedDict([('layer', tensor(90.6704))])
OrderedDict([('layer', tensor(89.7368))])
OrderedDict([('layer', tensor(90.6116))])
OrderedDict([('layer', tensor(89.7907))])
OrderedDict([('layer', tensor(89.9067))])
OrderedDict([('layer', tensor(90.3326))])
OrderedDict([('layer', tensor(89.5087))])
OrderedDict([('layer', tensor(89.2760))])
OrderedDict([('layer', tensor(89.1010))])
OrderedDict([('layer', tensor(89.2145))])
OrderedDict([('layer', tensor(88.4000))])
OrderedDict([('layer', tensor(89.7090))])
OrderedDict([('layer', tensor(89.3157))])
OrderedDict([('layer', tensor(89.4686))])
OrderedDict([('layer', tensor(89.1869))])
OrderedDict([('layer', tensor(88.9205))])
OrderedDict([('layer', tensor(89.5811))])
OrderedDict([('layer', tensor(89.4904))])
OrderedDict([('layer', tensor(89.1996))])
OrderedDict([('layer', tensor(89.3

In [25]:
layers_alignment = compute_angles_module(model)
layers_alignment

OrderedDict([('conv1_0', tensor(90.0027)),
             ('conv1_1', tensor(90.0435)),
             ('conv2_0', tensor(90.5038)),
             ('conv1_2', tensor(89.9072)),
             ('conv2_1', tensor(90.2888)),
             ('conv1_3', tensor(90.1281)),
             ('conv2_2', tensor(90.0208)),
             ('0_0', tensor(89.1358)),
             ('conv1_4', tensor(89.8470)),
             ('conv2_3', tensor(90.2987)),
             ('conv1_5', tensor(90.0346)),
             ('conv2_4', tensor(90.0171)),
             ('0_1', tensor(90.4614)),
             ('conv1_6', tensor(90.0779)),
             ('conv2_5', tensor(89.9771)),
             ('conv1_7', tensor(90.0542)),
             ('conv2_6', tensor(89.9995)),
             ('0_2', tensor(89.9728)),
             ('conv1_8', tensor(90.0238)),
             ('conv2_7', tensor(90.0196)),
             ('fc_0', tensor(90.5137))])

In [10]:
layers_alignment = compute_angles_module(model)

In [11]:
print(layers_alignment)

OrderedDict([('conv1_0', tensor(90.6273)),
             ('conv1_1', tensor(90.6992)),
             ('conv2_0', tensor(90.4775)),
             ('conv1_2', tensor(90.2928)),
             ('conv2_1', tensor(89.5960)),
             ('conv1_3', tensor(89.8826)),
             ('conv2_2', tensor(89.9748)),
             ('0_0', tensor(90.7542)),
             ('conv1_4', tensor(89.9956)),
             ('conv2_3', tensor(89.8266)),
             ('conv1_5', tensor(90.0207)),
             ('conv2_4', tensor(90.1097)),
             ('0_1', tensor(90.2923)),
             ('conv1_6', tensor(90.0619)),
             ('conv2_5', tensor(89.8943)),
             ('conv1_7', tensor(90.0901)),
             ('conv2_6', tensor(90.0088)),
             ('0_2', tensor(89.7614)),
             ('conv1_8', tensor(89.9518)),
             ('conv2_7', tensor(90.0207)),
             ('fc_0', tensor(90.2587))])

In [100]:
from collections import defaultdict
seen_keys = defaultdict(lambda: 0)
seen_keys['conv2']

0

In [97]:
seen_keys['conv2']

NameError: name 'seen_keys' is not defined

In [95]:
layers_list

[[(Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
   biotorch.layers.weight_mirroring.conv.Conv2d)],
 [(BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
   torch.nn.modules.batchnorm.BatchNorm2d)],
 [(ReLU(inplace=True), torch.nn.modules.activation.ReLU)],
 [(MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
   torch.nn.modules.pooling.MaxPool2d)],
 [(Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
   biotorch.layers.weight_mirroring.conv.Conv2d)],
 [(BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
   torch.nn.modules.batchnorm.BatchNorm2d)],
 [(ReLU(inplace=True), torch.nn.modules.activation.ReLU)],
 [(Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
   biotorch.layers.weight_mirroring.conv.Conv2d)],
 [(BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
   torch.nn.modules.ba

In [None]:
queue.append()

In [29]:
x = [] 
compute_angle_layers(model, x)

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
tensor(90.7209)
Linear(in_features=512, out_features=10, bias=True)
tensor(89.3565)
conv1
bn1
relu
maxpool
layer1
0
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.0204)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.1279)
conv1
bn1
relu
conv2
bn2
relu2
1
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.3268)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.2096)
conv1
bn1
relu
conv2
bn2
relu2
layer2
0
Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
tensor(90.0323)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.0547)
conv1
bn1
relu
conv2
bn2
relu2
downsample
Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(89.5946)
0
1
1
Conv2d(128, 128, kernel_size=(3, 3), str

In [27]:
sum(x)

21

In [7]:
compute_angle_layers(model)

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
tensor(90.7209)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.0204)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.1279)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.3268)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.2096)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.0204)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.1279)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.3268)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.2096)
Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
tensor(90.0323)
Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=Fals

tensor(89.8488)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(89.9422)
Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
tensor(89.8325)
Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(89.7005)
Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(89.7005)
Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(89.7005)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.0091)
Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(89.7005)
Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(89.7005)
Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(89.7005)
Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(89.7005)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(89.9091)
Conv2d(256, 256, kernel_size=(3, 3), str

tensor(90.2159)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(89.9910)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(89.9887)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(89.9743)
Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
tensor(89.9938)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(256, 512, kernel_size=(1, 1), str

tensor(89.9887)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(89.9743)
Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
tensor(89.9938)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(89.9910)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(89.9887)
Conv2d(512, 512, kernel_size=(3, 3), str

In [6]:
import torch
import torch.nn.functional as F
x = torch.randn(1, 3, 50, 50)
# define the W and B kernels
W = torch.randn([64, 3, 7, 7])
B = torch.randn([64, 3, 7, 7])
# run in the forward pass
stride = [3, 3]
padding = 3
y = F.conv2d(x, W, stride=stride, padding=padding)
# compute dB
dW = F.conv2d(torch.transpose(x, 0, 1), torch.transpose(y, 0, 1), dilation=stride, padding=padding)
dB = torch.transpose(F.interpolate(dW, size=[W.shape[2], W.shape[3]]), 0, 1)
print("x  = ", x.shape)
print("y  = ", y.shape)
print("B  = ", B.shape)
print("dW = ", dW.shape)
print("dB = ", dB.shape)

x  =  torch.Size([1, 3, 50, 50])
y  =  torch.Size([1, 64, 17, 17])
B  =  torch.Size([64, 3, 7, 7])
dW =  torch.Size([3, 64, 8, 8])
dB =  torch.Size([64, 3, 7, 7])


In [11]:
import torch
import torch.nn.functional as F
# define the W and B kernels
forward_kernel = torch.randn([16, 16, 3, 3])
backward_kernel = torch.randn([16, 16, 3, 3])
# downsample in the forward pass
noise_input  = torch.randn(1, 16, 12, 12)
noise_output = F.conv2d(noise_input, forward_kernel, stride=[1,1], padding=1)
# upsample in the backward pass
output = F.conv_transpose2d(noise_output, backward_kernel, stride=[1,1], padding=1)
assert(list(output.shape) == list(noise_input.shape))

In [12]:
output.size()

torch.Size([1, 16, 12, 12])

In [13]:
noise_input.size()

torch.Size([1, 16, 12, 12])

In [4]:
benchmark = Benchmark('/home/albert/projects/biotorch/biotorch/configs/mnist.yaml')

In [5]:
benchmark.run()

Preparing CIFAR10 Dataset and storing data in ./data
Files already downloaded and verified
Files already downloaded and verified
=> Creating model 'resnet18'
Converting ResNet-18 to Feedback Alignment mode


TypeError: super(type, obj): obj must be an instance or subtype of type

In [23]:
import torch
import torch.nn.functional as F
x = torch.randn(1, 3, 224, 224)
# define the W and B kernels
W = torch.randn([64, 3, 7, 7])
B = torch.randn([64, 3, 7, 7])
# downsample in the forward pass
y = F.conv2d(x, W, stride=2, padding=3)
# compute the dB update for B
dB = F.conv2d(torch.transpose(x, 0, 1), torch.transpose(y, 0, 1), stride=2, padding=3)

dB = torch.transpose(dB, 0, 1)
print("x t = ", torch.transpose(x, 0, 1).shape)
print("y t = ", torch.transpose(y, 0, 1).shape)
print("x  = ", x.shape)
print("y  = ", y.shape)
print("B  = ", B.shape)
print("dB = ", dB.shape)

x t =  torch.Size([3, 1, 224, 224])
y t =  torch.Size([64, 1, 112, 112])
x  =  torch.Size([1, 3, 224, 224])
y  =  torch.Size([1, 64, 112, 112])
B  =  torch.Size([64, 3, 7, 7])
dB =  torch.Size([64, 3, 60, 60])


In [4]:
m = benchmark.model

In [5]:
def compute_angle_layers(module):
        # Go through all of module nn.module (e.g. network or layer)
        for module_name in module._modules.keys():
            # Get layer
            layer = getattr(module, module_name)
            # Convert layer
            print(layer)
        # Iterate through immediate child modules
        for name, child_module in module.named_children():
            forward_layers(child_module)

In [6]:
forward_layers(m)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [14]:
biotorch.layers.fa

<module 'biotorch.layers.fa' from '/home/albert/projects/biotorch/biotorch/layers/fa/__init__.py'>

In [4]:
benchmark.model

BioModel(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runn

In [6]:
from torchvision import models

In [7]:
model = models.resnet18()

In [8]:
import torch

In [10]:
model = torch.nn.DataParallel(model, [1, 2])

In [14]:
torch.save(model.module, 'model.pth')

In [15]:
torch.load('model.pth')

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  