In [2]:
%load_ext autoreload
%autoreload 2
from biotorch.benchmark.run import Benchmark

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [1]:
from biotorch.models.weight_mirroring.resnet import resnet18, resnet34, resnet50, wide_resnet101_2

In [2]:
from collections import deque, OrderedDict, defaultdict
def compute_angles_module(module):
    queue = deque()
    layers_alignment = OrderedDict()
    seen_keys = defaultdict(lambda: 0)
    
    # First pass to store module keys
    for module_keys in module._modules.keys():
        queue.append((module, module_keys))
    
    # Approximate depth first traversal of the model using a deque
    while len(queue) > 0:
        module, module_key = queue.popleft()
        layer = getattr(module, module_key)
        if 'alignment' in layer.__dict__:
            angle = layer.compute_alignment()
            key_name = module_key + '_' + str(seen_keys[module_key])
            seen_keys[module_key] += 1 
            layers_alignment[key_name] = angle
        if len(layer._modules.keys()) > 0:
            # Reverse list as we are appending from the left side of the queue
            for key in list(layer._modules.keys())[::-1]:
                queue.appendleft((layer, key))
                
    return layers_alignment

In [3]:
model = resnet18(pretrained=False, num_classes=10)

In [4]:
import torch
x = torch.randn((10, 3, 224, 224))

In [None]:
for i in range(1000):
    print(i)
    model.mirror_weights(x, mirror_learning_rate=0.1)
    layers_alignment = compute_angles_module(model)
    print(layers_alignment['conv1_0'], layers_alignment['fc_0'])

0
tensor(89.6799) tensor(89.0085)
1
tensor(89.6799) tensor(89.4610)
2
tensor(89.6799) tensor(89.7542)
3
tensor(89.6799) tensor(90.5147)
4
tensor(89.6799) tensor(90.0884)
5
tensor(89.6799) tensor(90.0422)
6
tensor(89.6799) tensor(90.2209)
7
tensor(89.6799) tensor(90.0883)
8
tensor(89.6799) tensor(90.2389)
9
tensor(89.6799) tensor(90.3282)
10
tensor(89.6799) tensor(90.3753)
11
tensor(89.6799) tensor(90.2336)
12
tensor(89.6799) tensor(90.0177)
13
tensor(89.6799) tensor(89.9604)
14
tensor(89.6799) tensor(89.9810)
15
tensor(89.6799) tensor(90.2687)
16
tensor(89.6799) tensor(90.0052)
17
tensor(89.6799) tensor(90.0854)
18
tensor(89.6799) tensor(90.4526)
19
tensor(89.6799) tensor(90.5064)
20
tensor(89.6799) tensor(90.6439)
21
tensor(89.6799) tensor(90.4352)
22
tensor(89.6799) tensor(90.4939)
23
tensor(89.6799) tensor(90.5969)
24
tensor(89.6799) tensor(90.6130)
25
tensor(89.6799) tensor(90.8903)
26
tensor(89.6799) tensor(90.6586)
27
tensor(89.6799) tensor(90.7271)
28
tensor(89.6799) tensor(90.8

tensor(89.6799) tensor(90.0986)
232
tensor(89.6799) tensor(90.0576)
233
tensor(89.6799) tensor(90.1055)
234
tensor(89.6799) tensor(90.1240)
235
tensor(89.6799) tensor(90.1286)
236
tensor(89.6799) tensor(90.1323)
237
tensor(89.6799) tensor(90.0858)
238
tensor(89.6799) tensor(90.0222)
239
tensor(89.6799) tensor(89.9897)
240
tensor(89.6799) tensor(90.0049)
241
tensor(89.6799) tensor(90.0440)
242
tensor(89.6799) tensor(89.9612)
243
tensor(89.6799) tensor(89.9572)
244
tensor(89.6799) tensor(89.9058)
245
tensor(89.6799) tensor(89.9515)
246
tensor(89.6799) tensor(89.9797)
247
tensor(89.6799) tensor(89.9710)
248
tensor(89.6799) tensor(90.0006)
249
tensor(89.6799) tensor(89.9948)
250
tensor(89.6799) tensor(89.9276)
251
tensor(89.6799) tensor(89.9239)
252
tensor(89.6799) tensor(89.9375)
253
tensor(89.6799) tensor(89.9781)
254
tensor(89.6799) tensor(90.0123)
255
tensor(89.6799) tensor(90.0374)
256
tensor(89.6799) tensor(90.1023)
257
tensor(89.6799) tensor(90.1196)
258
tensor(89.6799) tensor(90.05

tensor(89.6799) tensor(90.0395)
460
tensor(89.6799) tensor(90.0413)
461
tensor(89.6799) tensor(90.0481)
462
tensor(89.6799) tensor(90.0855)
463
tensor(89.6799) tensor(90.1082)
464
tensor(89.6799) tensor(90.0972)
465
tensor(89.6799) tensor(90.1648)
466
tensor(89.6799) tensor(90.0702)
467
tensor(89.6799) tensor(90.1187)
468
tensor(89.6799) tensor(90.1705)
469
tensor(89.6799) tensor(90.2234)
470
tensor(89.6799) tensor(90.1710)
471
tensor(89.6799) tensor(90.1244)
472
tensor(89.6799) tensor(90.2288)
473
tensor(89.6799) tensor(90.2594)
474
tensor(89.6799) tensor(90.2033)
475
tensor(89.6799) tensor(90.2281)
476
tensor(89.6799) tensor(90.2070)
477
tensor(89.6799) tensor(90.2311)
478
tensor(89.6799) tensor(90.1936)
479
tensor(89.6799) tensor(90.1742)
480
tensor(89.6799) tensor(90.1090)
481
tensor(89.6799) tensor(90.1123)
482
tensor(89.6799) tensor(90.1243)
483
tensor(89.6799) tensor(90.0381)
484
tensor(89.6799) tensor(90.0354)
485
tensor(89.6799) tensor(90.0226)
486
tensor(89.6799) tensor(89.99

In [24]:
layers_alignment = compute_angles_module(model)
layers_alignment

OrderedDict([('conv1_0', tensor(90.5185)),
             ('conv1_1', tensor(89.8839)),
             ('conv2_0', tensor(89.6150)),
             ('conv1_2', tensor(89.7460)),
             ('conv2_1', tensor(89.9676)),
             ('conv1_3', tensor(90.2857)),
             ('conv2_2', tensor(90.1112)),
             ('0_0', tensor(90.9471)),
             ('conv1_4', tensor(89.9271)),
             ('conv2_3', tensor(90.0976)),
             ('conv1_5', tensor(90.0693)),
             ('conv2_4', tensor(89.8859)),
             ('0_1', tensor(89.7417)),
             ('conv1_6', tensor(89.9281)),
             ('conv2_5', tensor(90.0481)),
             ('conv1_7', tensor(90.0098)),
             ('conv2_6', tensor(89.9928)),
             ('0_2', tensor(89.8712)),
             ('conv1_8', tensor(90.0136)),
             ('conv2_7', tensor(90.0754)),
             ('fc_0', tensor(89.6884))])

In [10]:
layers_alignment = compute_angles_module(model)

In [11]:
print(layers_alignment)

OrderedDict([('conv1_0', tensor(90.6273)),
             ('conv1_1', tensor(90.6992)),
             ('conv2_0', tensor(90.4775)),
             ('conv1_2', tensor(90.2928)),
             ('conv2_1', tensor(89.5960)),
             ('conv1_3', tensor(89.8826)),
             ('conv2_2', tensor(89.9748)),
             ('0_0', tensor(90.7542)),
             ('conv1_4', tensor(89.9956)),
             ('conv2_3', tensor(89.8266)),
             ('conv1_5', tensor(90.0207)),
             ('conv2_4', tensor(90.1097)),
             ('0_1', tensor(90.2923)),
             ('conv1_6', tensor(90.0619)),
             ('conv2_5', tensor(89.8943)),
             ('conv1_7', tensor(90.0901)),
             ('conv2_6', tensor(90.0088)),
             ('0_2', tensor(89.7614)),
             ('conv1_8', tensor(89.9518)),
             ('conv2_7', tensor(90.0207)),
             ('fc_0', tensor(90.2587))])

In [100]:
from collections import defaultdict
seen_keys = defaultdict(lambda: 0)
seen_keys['conv2']

0

In [97]:
seen_keys['conv2']

NameError: name 'seen_keys' is not defined

In [95]:
layers_list

[[(Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
   biotorch.layers.weight_mirroring.conv.Conv2d)],
 [(BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
   torch.nn.modules.batchnorm.BatchNorm2d)],
 [(ReLU(inplace=True), torch.nn.modules.activation.ReLU)],
 [(MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False),
   torch.nn.modules.pooling.MaxPool2d)],
 [(Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
   biotorch.layers.weight_mirroring.conv.Conv2d)],
 [(BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
   torch.nn.modules.batchnorm.BatchNorm2d)],
 [(ReLU(inplace=True), torch.nn.modules.activation.ReLU)],
 [(Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
   biotorch.layers.weight_mirroring.conv.Conv2d)],
 [(BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
   torch.nn.modules.ba

In [None]:
queue.append()

In [29]:
x = [] 
compute_angle_layers(model, x)

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
tensor(90.7209)
Linear(in_features=512, out_features=10, bias=True)
tensor(89.3565)
conv1
bn1
relu
maxpool
layer1
0
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.0204)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.1279)
conv1
bn1
relu
conv2
bn2
relu2
1
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.3268)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.2096)
conv1
bn1
relu
conv2
bn2
relu2
layer2
0
Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
tensor(90.0323)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.0547)
conv1
bn1
relu
conv2
bn2
relu2
downsample
Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(89.5946)
0
1
1
Conv2d(128, 128, kernel_size=(3, 3), str

In [27]:
sum(x)

21

In [7]:
compute_angle_layers(model)

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
tensor(90.7209)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.0204)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.1279)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.3268)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.2096)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.0204)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.1279)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.3268)
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.2096)
Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
tensor(90.0323)
Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=Fals

tensor(89.8488)
Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(89.9422)
Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
tensor(89.8325)
Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(89.7005)
Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(89.7005)
Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(89.7005)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(90.0091)
Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(89.7005)
Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(89.7005)
Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(89.7005)
Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(89.7005)
Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(89.9091)
Conv2d(256, 256, kernel_size=(3, 3), str

tensor(90.2159)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(89.9910)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(89.9887)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(89.9743)
Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
tensor(89.9938)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(256, 512, kernel_size=(1, 1), str

tensor(89.9887)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(89.9743)
Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
tensor(89.9938)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(89.9910)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
tensor(90.2159)
Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
tensor(89.9887)
Conv2d(512, 512, kernel_size=(3, 3), str

In [6]:
import torch
import torch.nn.functional as F
x = torch.randn(1, 3, 50, 50)
# define the W and B kernels
W = torch.randn([64, 3, 7, 7])
B = torch.randn([64, 3, 7, 7])
# run in the forward pass
stride = [3, 3]
padding = 3
y = F.conv2d(x, W, stride=stride, padding=padding)
# compute dB
dW = F.conv2d(torch.transpose(x, 0, 1), torch.transpose(y, 0, 1), dilation=stride, padding=padding)
dB = torch.transpose(F.interpolate(dW, size=[W.shape[2], W.shape[3]]), 0, 1)
print("x  = ", x.shape)
print("y  = ", y.shape)
print("B  = ", B.shape)
print("dW = ", dW.shape)
print("dB = ", dB.shape)

x  =  torch.Size([1, 3, 50, 50])
y  =  torch.Size([1, 64, 17, 17])
B  =  torch.Size([64, 3, 7, 7])
dW =  torch.Size([3, 64, 8, 8])
dB =  torch.Size([64, 3, 7, 7])


In [11]:
import torch
import torch.nn.functional as F
# define the W and B kernels
forward_kernel = torch.randn([16, 16, 3, 3])
backward_kernel = torch.randn([16, 16, 3, 3])
# downsample in the forward pass
noise_input  = torch.randn(1, 16, 12, 12)
noise_output = F.conv2d(noise_input, forward_kernel, stride=[1,1], padding=1)
# upsample in the backward pass
output = F.conv_transpose2d(noise_output, backward_kernel, stride=[1,1], padding=1)
assert(list(output.shape) == list(noise_input.shape))

In [12]:
output.size()

torch.Size([1, 16, 12, 12])

In [13]:
noise_input.size()

torch.Size([1, 16, 12, 12])

In [4]:
benchmark = Benchmark('/home/albert/projects/biotorch/biotorch/configs/mnist.yaml')

In [5]:
benchmark.run()

Preparing CIFAR10 Dataset and storing data in ./data
Files already downloaded and verified
Files already downloaded and verified
=> Creating model 'resnet18'
Converting ResNet-18 to Feedback Alignment mode


TypeError: super(type, obj): obj must be an instance or subtype of type

In [23]:
import torch
import torch.nn.functional as F
x = torch.randn(1, 3, 224, 224)
# define the W and B kernels
W = torch.randn([64, 3, 7, 7])
B = torch.randn([64, 3, 7, 7])
# downsample in the forward pass
y = F.conv2d(x, W, stride=2, padding=3)
# compute the dB update for B
dB = F.conv2d(torch.transpose(x, 0, 1), torch.transpose(y, 0, 1), stride=2, padding=3)

dB = torch.transpose(dB, 0, 1)
print("x t = ", torch.transpose(x, 0, 1).shape)
print("y t = ", torch.transpose(y, 0, 1).shape)
print("x  = ", x.shape)
print("y  = ", y.shape)
print("B  = ", B.shape)
print("dB = ", dB.shape)

x t =  torch.Size([3, 1, 224, 224])
y t =  torch.Size([64, 1, 112, 112])
x  =  torch.Size([1, 3, 224, 224])
y  =  torch.Size([1, 64, 112, 112])
B  =  torch.Size([64, 3, 7, 7])
dB =  torch.Size([64, 3, 60, 60])


In [4]:
m = benchmark.model

In [5]:
def compute_angle_layers(module):
        # Go through all of module nn.module (e.g. network or layer)
        for module_name in module._modules.keys():
            # Get layer
            layer = getattr(module, module_name)
            # Convert layer
            print(layer)
        # Iterate through immediate child modules
        for name, child_module in module.named_children():
            forward_layers(child_module)

In [6]:
forward_layers(m)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [14]:
biotorch.layers.fa

<module 'biotorch.layers.fa' from '/home/albert/projects/biotorch/biotorch/layers/fa/__init__.py'>

In [4]:
benchmark.model

BioModel(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runn

In [6]:
from torchvision import models

In [7]:
model = models.resnet18()

In [8]:
import torch

In [10]:
model = torch.nn.DataParallel(model, [1, 2])

In [14]:
torch.save(model.module, 'model.pth')

In [15]:
torch.load('model.pth')

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  