In [1]:
import os
import argparse
import torch
from torch.utils.tensorboard import SummaryWriter
from torchvision import models, datasets
import numpy as np
from collections import defaultdict

# from modules import BYOL
# from modules.transformations import TransformsSimCLR
import copy
import random
from functools import wraps

import torch
from torch import nn
import torch.nn.functional as F

# helper functions

# distributed training
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP


from datetime import datetime

In [2]:
def cleanup():
    dist.destroy_process_group()

def mkdir(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)

In [3]:
import torchvision 

class TransformsSimCLR:
    """
    A stochastic data augmentation module that transforms any given data example randomly 
    resulting in two correlated views of the same example,
    denoted x ̃i and x ̃j, which we consider as a positive pair.
    """

    def __init__(self, size):
        s = 1
        color_jitter = torchvision.transforms.ColorJitter(
            0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
        )
        self.train_transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.RandomResizedCrop(size=size),
                torchvision.transforms.RandomHorizontalFlip(),  # with 0.5 probability
                torchvision.transforms.RandomApply([color_jitter], p=0.8),
                torchvision.transforms.RandomGrayscale(p=0.2),
                torchvision.transforms.ToTensor(),
            ]
        )

        self.test_transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.Resize(size=size),
                torchvision.transforms.ToTensor(),
            ]
        )

    def __call__(self, x):
        return self.train_transform(x), self.train_transform(x)

In [4]:

input1 = torch.randn(4, 12)
input2 = torch.randn(4, 12)
print('Input1:', input1)
print('Input2:', input2)
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
print('Cosine:', cos)
output = cos(input1, input2)
print('Output:', output)

Input1: tensor([[-0.1128, -2.1092,  0.2635,  0.6669, -0.6763, -0.0313,  0.6924,  0.4403,
          1.7536, -0.7036,  0.1907,  0.5174],
        [-2.7859, -1.3126, -0.0597, -0.4849,  1.1314, -0.3457, -0.4029, -0.4896,
         -1.6431, -0.4440, -0.7539,  2.0112],
        [ 0.1758,  0.2646,  1.6288,  1.5116, -1.1957,  0.8327,  2.1143, -1.4118,
         -0.7418, -1.4383, -1.0075,  1.4252],
        [ 0.8099,  0.0328, -0.4675,  1.3027, -0.4971,  1.3294,  0.6015, -0.6023,
         -0.4650, -0.8172,  0.6054, -0.5021]])
Input2: tensor([[-0.5113, -0.8285,  1.0173, -1.1441,  0.8226,  0.9513,  0.7124,  2.0193,
          0.6536, -0.1931,  0.1525, -0.1074],
        [ 1.1626, -0.6248,  0.7782,  1.8841,  0.3065, -1.4666,  1.8618,  0.1381,
         -0.0877,  1.0392, -0.8166, -0.0641],
        [ 0.1579,  0.8694,  0.5434,  0.7253,  0.2851, -1.7096,  0.1727, -0.6824,
          0.2591,  2.0393,  1.8917,  1.0290],
        [-1.1561,  0.1702,  0.2000, -0.3092,  1.0136, -0.5957, -0.2376, -0.7837,
          0.3

In [5]:
m = nn.Softplus()
input = torch.randn(2)
print('Input: ', input)
output = m(input)
print('Output: ', output)

Input:  tensor([-1.4904,  0.3065])
Output:  tensor([0.2032, 0.8581])


In [6]:
x = torch.randn(4, 12)
print('X: ', x)
dim = 0
rotate = [x.shape[dim]-1] + list(range(0, x.shape[dim]-1))
# print(tuple(rotate))
# torch.permute(x, (2, 0, 1)).size()
# y = torch.permute(x, rotate)


y = torch.index_select(x, dim, torch.LongTensor(rotate))
print('Y: ', y)

X:  tensor([[ 0.7117, -0.9927, -0.2839, -1.3336,  0.1674,  0.0136, -0.4612, -0.0393,
          2.0642,  0.3666,  1.1870,  1.0393],
        [-0.6244, -1.0016,  0.7983,  0.9031,  0.3688, -1.1939, -0.0909, -0.1809,
          0.6790,  0.0811,  1.3697, -0.3409],
        [ 0.5057,  0.9305,  1.0444,  0.6630,  0.2778, -0.7578, -0.3167,  1.6050,
          1.6239, -1.8701, -0.3561,  0.0958],
        [-0.4307, -0.3881,  1.6332,  1.8315, -0.5780,  0.4772, -0.3420,  0.3544,
         -0.6579, -0.2821,  0.2395,  0.4580]])
Y:  tensor([[-0.4307, -0.3881,  1.6332,  1.8315, -0.5780,  0.4772, -0.3420,  0.3544,
         -0.6579, -0.2821,  0.2395,  0.4580],
        [ 0.7117, -0.9927, -0.2839, -1.3336,  0.1674,  0.0136, -0.4612, -0.0393,
          2.0642,  0.3666,  1.1870,  1.0393],
        [-0.6244, -1.0016,  0.7983,  0.9031,  0.3688, -1.1939, -0.0909, -0.1809,
          0.6790,  0.0811,  1.3697, -0.3409],
        [ 0.5057,  0.9305,  1.0444,  0.6630,  0.2778, -0.7578, -0.3167,  1.6050,
          1.6239, -1.

In [7]:
import math
import torchvision


def default(val, def_val):
    return def_val if val is None else val


def flatten(t):
    return t.reshape(t.shape[0], -1)


'''
Write code for singleton instance attribute 
Call initialize function only once - iff not exist
Then use the initialized instance 
'''
def singleton(cache_key):
    def inner_fn(fn):
        @wraps(fn)
        def wrapper(self, *args, **kwargs):
            instance = getattr(self, cache_key)
            if instance is not None:
                return instance

            instance = fn(self, *args, **kwargs)
            setattr(self, cache_key, instance)
            return instance

        return wrapper

    return inner_fn


# loss fn
def soft_cos(x,y, temperature=0.1):
    cos = nn.CosineSimilarity(dim=-1, eps=1e-6)
    soft = nn.Softplus()
    result = soft(-1/temperature * cos(x,y))
    return result

def loss_fn(x, y, gpu, temperature=0.1):
    '''
    x - from online 
    y - from target 
    '''
    cos = nn.CosineSimilarity(dim=-1, eps=1e-6)
    soft = nn.Softplus()
    positive = soft_cos(x,y, temperature)
    
    
    dim = 0
    rotate = [y.shape[dim]-1] + list(range(0, y.shape[dim]-1))
    rotate = torch.LongTensor(rotate)
    
    if x.is_cuda:
        rotate = rotate.cuda(gpu)
    
    print('X:', x.is_cuda, x.get_device())
    print('Y:', y.is_cuda, y.get_device())
    print('Rotate:', rotate.is_cuda, rotate.get_device())
    
    y = torch.index_select(y, dim, rotate)
    negative = soft_cos(x,y, temperature)
    
    result = positive - negative
    
    return result
    
    
    
def byol_loss(x,y):
    print('Output shape: ', x.shape, y.shape)
    x = F.normalize(x, dim=-1, p=2)
    y = F.normalize(y, dim=-1, p=2)
    
    result = 2 - 2 * (x * y).sum(dim=-1)
    print('Loss: ', result)
    return result


# augmentation utils


class RandomApply(nn.Module):
    def __init__(self, fn, p):
        super().__init__()
        self.fn = fn
        self.p = p

    def forward(self, x):
        if random.random() > self.p:
            return x
        return self.fn(x)


# exponential moving average


class EMA:
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_average(self, old, new):
        if old is None:
            return new
        # EMA simple updating function 
        return old * self.beta + (1 - self.beta) * new


def update_moving_average(ema_updater, ma_model, current_model):
    for current_params, ma_params in zip(
        current_model.parameters(), ma_model.parameters()
    ):
        old_weight, up_weight = ma_params.data, current_params.data
        ma_params.data = ema_updater.update_average(old_weight, up_weight)


# MLP class for projector and predictor


class MLP(nn.Module):
    def __init__(self, dim, projection_size, hidden_size=4096):
        super().__init__()
        # Projector still go through several step of input - hidden - output 
        # Not simply a direct linear mapping 
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size, projection_size),
        )

    def forward(self, x):
        return self.net(x)


# a wrapper class for the base neural network
# will manage the interception of the hidden layer output
# and pipe it into the projecter and predictor nets


class NetWrapper(nn.Module):
    def __init__(self, net, projection_size, projection_hidden_size, layer=-2):
        super().__init__()

        # 1. Backbone network 
        self.net = net
        self.layer = layer

        # 2. Projector - map from backbone output to projection output 
        self.projector = None
        self.projection_size = projection_size
        self.projection_hidden_size = projection_hidden_size

        self.hidden = None
        self.hook_registered = False

    def _find_layer(self):
        if type(self.layer) == str:
            modules = dict([*self.net.named_modules()])
            return modules.get(self.layer, None)
        elif type(self.layer) == int:
            children = [*self.net.children()]
            return children[self.layer]
        return None

    def _hook(self, _, __, output):
        self.hidden = flatten(output)

    def _register_hook(self):
        layer = self._find_layer()
        assert layer is not None, f"hidden layer ({self.layer}) not found"
        handle = layer.register_forward_hook(self._hook)
        self.hook_registered = True

    @singleton("projector")
    def _get_projector(self, hidden):
        _, dim = hidden.shape
        # Projector simply is a MLP to map from backbone output to projection output 
        projector = MLP(dim, self.projection_size, self.projection_hidden_size)
        return projector.to(hidden)

    def get_representation(self, x):
        
        # 1. Not really understand why need to do this to get representation output 
        if not self.hook_registered:
            self._register_hook()

        if self.layer == -1:
            return self.net(x)

        _ = self.net(x)

        # 2. Using hook and a lot of thing to get hidden - instead of directory assign from backbone net
        hidden = self.hidden
        self.hidden = None
        assert hidden is not None, f"hidden layer {self.layer} never emitted an output"
        return hidden

    def forward(self, x):
        # 1. Get representation from backbone net 
        representation = self.get_representation(x)

        # 2. Get projection from projector 
        # 2.1. Get projector - why need to intialize new projector for every forward step ???
        # And why need to use singleton here ??? 
        projector = self._get_projector(representation)
        # 2.2. Get projection from projector 
        projection = projector(representation)
        return projection


# main class


class BYOL(nn.Module):
    def __init__(
        self,
        net,
        image_size,
        device,
        hidden_layer=-2,
        projection_size=256,
        projection_hidden_size=4096,
        augment_fn=None,
        moving_average_decay=0.99,
    ):
        super().__init__()
        
        self.device = device

        # 1. Online encoder 
        # Backbone network model 
        self.online_encoder = NetWrapper(
            net, projection_size, projection_hidden_size, layer=hidden_layer
        )

        # 2. Target encoder 
        # Backbone network model - Current None - will be copy from online encoder 
        self.target_encoder = None
        # 2.1. Target parameter updater - Exponential moving average 
        self.target_ema_updater = EMA(moving_average_decay)

        # 3. Online preditor 
        # Using to predict online output to target output 
        self.online_predictor = MLP(
            projection_size, projection_size, projection_hidden_size
        )

        # send a mock image tensor to instantiate singleton parameters
        self.forward(torch.randn(2, 3, image_size, image_size), torch.randn(2, 3, image_size, image_size))


    @singleton("target_encoder")
    def _get_target_encoder(self):
        target_encoder = copy.deepcopy(self.online_encoder)
        return target_encoder

    def reset_moving_average(self):
        del self.target_encoder
        self.target_encoder = None

    def update_moving_average(self):
        assert (
            self.target_encoder is not None
        ), "target encoder has not been created yet"

        # EMA update for target parameters using current target parameters and online parameters 
        update_moving_average(
            self.target_ema_updater, self.target_encoder, self.online_encoder
        )

    def forward(self, image_one, image_two):

        # Do forward with 2 augmented version of 1 image 
        # Forward with both online and target and take average of loss 
        # A little bit different from original formula - but still the same mearning 

        # 1. Online output - do normally through encoder and predictor 
        online_proj_one = self.online_encoder(image_one)
        online_proj_two = self.online_encoder(image_two)

        online_pred_one = self.online_predictor(online_proj_one)
        online_pred_two = self.online_predictor(online_proj_two)

        # 2. Target output - do with no grad - encoder no update by backward 
        with torch.no_grad():
            # Why at each forward step - target encoder is updated by same as online encoder ???
            # And why need to use singleton here ??? 
            target_encoder = self._get_target_encoder()
            target_proj_one = target_encoder(image_one)
            target_proj_two = target_encoder(image_two)

        # 3. Calculate loss - target output detach from backward 
        loss_one = loss_fn(online_pred_one, target_proj_two.detach(), self.device)
        loss_two = loss_fn(online_pred_two, target_proj_one.detach(), self.device)

        loss = loss_one + loss_two
        return -math.log(4) + loss.mean()

In [8]:




def main(gpu, args):

    # 0. Initialize distributed GPU training 
    rank = args.nr * args.gpus + gpu
    dist.init_process_group("nccl", rank=rank, world_size=args.world_size)

    torch.manual_seed(0)
    torch.cuda.set_device(gpu)

    # 1. dataset

    # Using pytorch datasets with custom-transform to generate augmentation 
    train_dataset = datasets.CIFAR10(
        args.dataset_dir,
        download=True,
        transform=TransformsSimCLR(size=args.image_size), # paper 224
    )

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=args.world_size, rank=rank
    )

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        drop_last=True,
        num_workers=args.num_workers,
        pin_memory=True,
        sampler=train_sampler,
    )

    # 2. model
    if args.resnet_version == "resnet18":
        resnet = models.resnet18(pretrained=False)
    elif args.resnet_version == "resnet50":
        resnet = models.resnet50(pretrained=False)
    else:
        raise NotImplementedError("ResNet not implemented")

    # 2.1. BYOL model 
    model = BYOL(resnet, image_size=args.image_size, device=gpu, hidden_layer="avgpool")
    model = model.cuda(gpu)

    mkdir(args.train_dir)

    # 2.2. Distributed data parallel
    model = DDP(model, device_ids=[gpu], find_unused_parameters=True)

    # 3. optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    # 4. TensorBoard writer

    if gpu == 0:
        writer = SummaryWriter()

    # 5. Solver
    global_step = 0

    # Training loop 
    for epoch in range(args.num_epochs):
        metrics = defaultdict(list)
        for step, ((x_i, x_j), _) in enumerate(train_loader):
            print('Step: ', step, ', batch-shape: ', x_i.shape, x_j.shape)

            # Get 2 augmented samples from same samples 
            # Logic in dataset 
            x_i = x_i.cuda(non_blocking=True)
            x_j = x_j.cuda(non_blocking=True)
            
            print('X_i: ', x_i.is_cuda, x_i.get_device())
            print('X_j: ', x_j.is_cuda, x_j.get_device())

            # Calculate loss 
            # Logic in detail BYOL model 
            loss = model(x_i, x_j)

            # Optimize and backward 
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update by exponential average moving 
            # Logic in detail BYOL model 
            model.module.update_moving_average()  # update moving average of target encoder

            if step % 10 == 0 and gpu == 0:
                print(f"Step [{step}/{len(train_loader)}]:\tLoss: {loss.item()}")

            if gpu == 0:
                writer.add_scalar("Loss/train_step", loss, global_step)
                metrics["Loss/train"].append(loss.item())
                global_step += 1
                
            break
        break

        if gpu == 0:
            # write metrics to TensorBoard
            for k, v in metrics.items():
                writer.add_scalar(k, np.array(v).mean(), epoch)

            if epoch % args.checkpoint_epochs == 0:
                if gpu == 0:
                    now = datetime.now()
                    print(now)
                    print(f"Saving model at epoch {epoch}")
                    torch.save(resnet.state_dict(), f"{args.train_dir}/model-{epoch}.pt")

                # let other workers wait until model is finished
                # dist.barrier()

    # save your improved network
    if gpu == 0:
        torch.save(resnet.state_dict(), f"{args.train_dir}/model-final.pt")

    cleanup()


In [9]:
from dotmap import DotMap

args = DotMap()

args.batch_size=192
args.checkpoint_epochs=5
args.dataset_dir='./datasets'
args.gpus=4
args.image_size=224
args.learning_rate=0.0003
args.nodes=1
args.nr=0
args.num_epochs=100
args.num_workers=8
args.resnet_version='resnet18'
args.train_dir='./train_dir_mi_202112-17'
args.world_size=1

In [10]:
# Master address for distributed data parallel
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "8011"

# Initialize the process and join up with the other processes.
# This is “blocking,” meaning that no process will continue until all processes have joined.
# mp.spawn(main, args=(args,), nprocs=args.gpus, join=True)
main(0, args)

Files already downloaded and verified
X: False -1
Y: False -1
Rotate: False -1
X: False -1
Y: False -1
Rotate: False -1
Step:  0 , batch-shape:  torch.Size([192, 3, 224, 224]) torch.Size([192, 3, 224, 224])
X_i:  True 0
X_j:  True 0
X: True 0
Y: True 0
Rotate: True 0
X: True 0
Y: True 0
Rotate: True 0
Step [0/260]:	Loss: -1.432499885559082


In [11]:
# Files already downloaded and verified
# Step:  0 , batch-shape:  torch.Size([192, 3, 224, 224]) torch.Size([192, 3, 224, 224])
# Step [0/260]:	Loss: 3.972233295440674

# Fix resnet18 - cifar10 accuracy 
# https://hd10.dev/posts/experiments_cifar10_part1/

In [13]:
gpu = 0

# 0. Initialize distributed GPU training 
rank = args.nr * args.gpus + gpu
dist.init_process_group("nccl", rank=rank, world_size=args.world_size)

torch.manual_seed(0)
torch.cuda.set_device(gpu)

# 1. dataset

# Using pytorch datasets with custom-transform to generate augmentation 
train_dataset = datasets.CIFAR10(
    args.dataset_dir,
    download=True,
    transform=TransformsSimCLR(size=args.image_size), # paper 224
)

train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_dataset, num_replicas=args.world_size, rank=rank
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    drop_last=True,
    num_workers=args.num_workers,
    pin_memory=True,
    sampler=train_sampler,
)

# 2. model
if args.resnet_version == "resnet18":
    resnet = models.resnet18(pretrained=False)
elif args.resnet_version == "resnet50":
    resnet = models.resnet50(pretrained=False)
else:
    raise NotImplementedError("ResNet not implemented")

# 2.1. BYOL model 
model = BYOL(resnet, image_size=args.image_size, device=gpu, hidden_layer="avgpool")
model = model.cuda(gpu)

mkdir(args.train_dir)

# 2.2. Distributed data parallel
model = DDP(model, device_ids=[gpu], find_unused_parameters=True)

# 3. optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

# 4. TensorBoard writer

if gpu == 0:
    writer = SummaryWriter()

# 5. Solver
global_step = 0

# Training loop 
for epoch in range(args.num_epochs):
    metrics = defaultdict(list)
    for step, ((x_i, x_j), _) in enumerate(train_loader):
        print('Step: ', step, ', batch-shape: ', x_i.shape, x_j.shape)

        # Get 2 augmented samples from same samples 
        # Logic in dataset 
        x_i = x_i.cuda(non_blocking=True)
        x_j = x_j.cuda(non_blocking=True)

        print('X_i: ', x_i.is_cuda, x_i.get_device())
        print('X_j: ', x_j.is_cuda, x_j.get_device())

        # Calculate loss 
        # Logic in detail BYOL model 
        loss = model(x_i, x_j)

        # Optimize and backward 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update by exponential average moving 
        # Logic in detail BYOL model 
        model.module.update_moving_average()  # update moving average of target encoder

        if step % 10 == 0 and gpu == 0:
            print(f"Step [{step}/{len(train_loader)}]:\tLoss: {loss.item()}")

        if gpu == 0:
            writer.add_scalar("Loss/train_step", loss, global_step)
            metrics["Loss/train"].append(loss.item())
            global_step += 1

        break
    break

    if gpu == 0:
        # write metrics to TensorBoard
        for k, v in metrics.items():
            writer.add_scalar(k, np.array(v).mean(), epoch)

        if epoch % args.checkpoint_epochs == 0:
            if gpu == 0:
                now = datetime.now()
                print(now)
                print(f"Saving model at epoch {epoch}")
                torch.save(resnet.state_dict(), f"{args.train_dir}/model-{epoch}.pt")

            # let other workers wait until model is finished
            # dist.barrier()

# save your improved network
if gpu == 0:
    torch.save(resnet.state_dict(), f"{args.train_dir}/model-final.pt")

cleanup()

Files already downloaded and verified
X: False -1
Y: False -1
Rotate: False -1
X: False -1
Y: False -1
Rotate: False -1
Step:  0 , batch-shape:  torch.Size([192, 3, 224, 224]) torch.Size([192, 3, 224, 224])
X_i:  True 0
X_j:  True 0
X: True 0
Y: True 0
Rotate: True 0
X: True 0
Y: True 0
Rotate: True 0
Step [0/260]:	Loss: -1.432499885559082


In [14]:
print(model)

DistributedDataParallel(
  (module): BYOL(
    (online_encoder): NetWrapper(
      (net): ResNet(
        (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
        (layer1): Sequential(
          (0): BasicBlock(
            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (relu): ReLU(inplace=True)
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
          (1): BasicBlock(
            (conv1): Conv2d(64, 64, kernel_size=(3, 3

In [15]:
print(model.module)

BYOL(
  (online_encoder): NetWrapper(
    (net): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchN

In [None]:
# Use this instead of nn.Conv2d at all places
class WeightStdConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
                 padding, dilation, groups, bias)

    def forward(self, x):
        weight = self.weight
        weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
                                  keepdim=True).mean(dim=3, keepdim=True)
        weight = weight - weight_mean
        std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
        weight = weight / std.expand_as(weight)
        return F.conv2d(x, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

In [17]:
print(resnet.conv1)

def cifar10_resnet18(resnet):
    resnet.conv1 = Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(3, 3), bias=False)
    
    

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)


In [18]:
print(resnet)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  