In [None]:
!pip install wandb



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import wandb
import time

In [None]:
# Hsigmoid Implementation
class Hsigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(Hsigmoid, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return F.relu6(x + 3.0, inplace=self.inplace) / 6.0


# SEModule Implementation
class SEModule(nn.Module):
    def __init__(self, channel, reduction=4):
        super(SEModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            Hsigmoid(),
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)


# H-Swish Implementation
class Hswish(nn.Module):
    def __init__(self, inplace=True):
        super(Hswish, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0


# Fuse Layer
class Fuse(nn.Module):
    def __init__(
        self,
        in_channel,
        exp_channel,
        out_channel,
        non_linearity,
        kernel_size,
        stride,
        is_se=True,
        apply_bn=True,
    ):
        super(Fuse, self).__init__()
        if non_linearity == "relu":
            self.nl = nn.ReLU(inplace=True)
        elif non_linearity == "hswish":
            self.nl = Hswish(inplace=True)
        else:
            raise Exception("Please use proper non-linearity")
        self.is_se = is_se
        self.apply_bn = apply_bn

        # Defining trainable parameters

        self.conv1 = nn.Conv2d(
            in_channels=in_channel,
            out_channels=exp_channel,
            kernel_size=1,
            stride=1,
            bias=False,
        )

        self.batchn1 = nn.BatchNorm2d(num_features=exp_channel)

        self.conv2_a = nn.Conv2d(
            in_channels=exp_channel,
            out_channels=exp_channel,
            kernel_size=(kernel_size, 1),
            stride=stride,
            padding=((kernel_size - 1) // 2, 0),
            groups=exp_channel,
            bias=False,
        )
        self.conv2_b = nn.Conv2d(
            in_channels=exp_channel,
            out_channels=exp_channel,
            kernel_size=(1, kernel_size),
            stride=stride,
            padding=(0, (kernel_size - 1) // 2),
            groups=exp_channel,
            bias=False,
        )

        self.batchn2_a = nn.BatchNorm2d(num_features=exp_channel)
        self.batchn2_b = nn.BatchNorm2d(num_features=exp_channel)

        if self.is_se:
            self.se = SEModule(2 * exp_channel)

        self.conv3 = nn.Conv2d(
            in_channels=2 * exp_channel,
            out_channels=out_channel,
            kernel_size=1,
            stride=1,
            bias=False,
        )
        self.batchn3 = nn.BatchNorm2d(num_features=out_channel)

    def forward(self, x):
        x = self.conv1(x)
        x = self.nl(x)
        x = self.batchn1(x)

        x_1 = self.conv2_a(x)
        x_2 = self.conv2_b(x)

        x_1 = self.batchn2_a(x_1)
        x_2 = self.batchn2_b(x_2)

        x = torch.cat([x_1, x_2], dim=1)

        if self.is_se:
            x = self.se(x)

        x = self.nl(x)

        x = self.conv3(x)

        if self.apply_bn:
            x = self.batchn3(x)

        return x

In [None]:
# Fusenet Model
class FuseNet(nn.Module):
    def __init__(self, checkpoint_list):
        super(FuseNet, self).__init__()

        self.layer_to_be_checkpointed = []

        self.checkpoint_list = checkpoint_list

        # Defining trainable parameters

        self.conv1 = nn.Conv2d(
            in_channels=3,
            out_channels=16,
            kernel_size=3,
            stride=2,
            padding=1,
            bias=False,
        )

        self.batchn1 = nn.BatchNorm2d(16)
        self.hswish = Hswish()

        self.fuse1 = Fuse(
            in_channel=16,
            exp_channel=16,
            out_channel=16,
            non_linearity="relu",
            kernel_size=3,
            stride=2,
            is_se=True,
        )
        self.fuse2 = Fuse(
            in_channel=16,
            exp_channel=72,
            out_channel=24,
            non_linearity="relu",
            kernel_size=3,
            stride=2,
            is_se=False,
        )
        self.fuse3 = Fuse(
            in_channel=24,
            exp_channel=88,
            out_channel=24,
            non_linearity="relu",
            kernel_size=3,
            stride=1,
            is_se=False,
        )
        self.fuse4 = Fuse(
            in_channel=24,
            exp_channel=96,
            out_channel=40,
            non_linearity="hswish",
            kernel_size=5,
            stride=2,
            is_se=True,
        )
        self.fuse5 = Fuse(
            in_channel=40,
            exp_channel=240,
            out_channel=40,
            non_linearity="hswish",
            kernel_size=5,
            stride=1,
            is_se=True,
        )
        self.fuse6 = Fuse(
            in_channel=40,
            exp_channel=240,
            out_channel=40,
            non_linearity="hswish",
            kernel_size=5,
            stride=1,
            is_se=True,
        )
        self.fuse7 = Fuse(
            in_channel=40,
            exp_channel=120,
            out_channel=48,
            non_linearity="hswish",
            kernel_size=5,
            stride=1,
            is_se=True,
        )
        self.fuse8 = Fuse(
            in_channel=48,
            exp_channel=144,
            out_channel=48,
            non_linearity="hswish",
            kernel_size=5,
            stride=1,
            is_se=True,
        )
        self.fuse9 = Fuse(
            in_channel=48,
            exp_channel=288,
            out_channel=96,
            non_linearity="hswish",
            kernel_size=5,
            stride=2,
            is_se=True,
        )
        self.fuse10 = Fuse(
            in_channel=96,
            exp_channel=576,
            out_channel=96,
            non_linearity="hswish",
            kernel_size=5,
            stride=1,
            is_se=True,
        )
        self.fuse11 = Fuse(
            in_channel=96,
            exp_channel=576,
            out_channel=96,
            non_linearity="hswish",
            kernel_size=5,
            stride=1,
            is_se=True,
        )

        self.conv2 = nn.Conv2d(
            in_channels=96, out_channels=576, kernel_size=1, stride=1, bias=False
        )
        self.batchn2 = nn.BatchNorm2d(576)

        self.adap = nn.AdaptiveAvgPool2d(1)

        self.conv3 = nn.Conv2d(
            in_channels=576, out_channels=1024, kernel_size=1, stride=1, bias=False
        )

        self.drop = nn.Dropout(p=0.2)

        self.lin = nn.Linear(in_features=1024, out_features=100, bias=True)

    def forward(self, x):

        # 1 - Conv 1
        if 1 in self.checkpoint_list:
          x = self.conv1(x)
        else:
          x = checkpoint.checkpoint(self.conv1,x)
        
        # 2 - Hswish
        if 2 in self.checkpoint_list:
          x = self.hswish(x)
        else:
          x = checkpoint.checkpoint(self.hswish,x)
        
        # 3 - Bactchnorm 1
        if 3 in self.checkpoint_list:
          x = self.batchn1(x)
        else:
          x = checkpoint.checkpoint(self.batchn1,x)

        # 4 - Fuse 1
        if 4 in self.checkpoint_list:
          x = self.fuse1(x)
        else:
          x = checkpoint.checkpoint(self.fuse1,x)
        # 5 - Fuse 2
        if 5 in self.checkpoint_list:
          x = self.fuse2(x)
        else:
          x = checkpoint.checkpoint(self.fuse2,x)
        # 6 - Fuse 3
        if 6 in self.checkpoint_list:
          x = self.fuse3(x)
        else:
          x = checkpoint.checkpoint(self.fuse3,x)
        # 7 - Fuse 4
        if 7 in self.checkpoint_list:
          x = self.fuse4(x)
        else:
          x = checkpoint.checkpoint(self.fuse4,x)
        # 8 - Fuse 5
        if 8 in self.checkpoint_list:
          x = self.fuse5(x)
        else:
          x = checkpoint.checkpoint(self.fuse5,x)
        # 9 - Fuse 6
        if 9 in self.checkpoint_list:
          x = self.fuse6(x)
        else:
          x = checkpoint.checkpoint(self.fuse6,x)
        # 10 - Fuse 7
        if 10 in self.checkpoint_list:
          x = self.fuse7(x)
        else:
          x = checkpoint.checkpoint(self.fuse7,x)
        # 11 - Fuse 8
        if 11 in self.checkpoint_list:
          x = self.fuse8(x)
        else:
          x = checkpoint.checkpoint(self.fuse8,x)
        # 12 - Fuse 9
        if 12 in self.checkpoint_list:
          x = self.fuse9(x)
        else:
          x = checkpoint.checkpoint(self.fuse9,x)
        # 13 - Fuse 10
        if 13 in self.checkpoint_list:
          x = self.fuse10(x)
        else:
          x = checkpoint.checkpoint(self.fuse10,x)
        # 14 - Fuse 11
        if 14 in self.checkpoint_list:
          x = self.fuse11(x)
        else:
          x = checkpoint.checkpoint(self.fuse11,x)

        # 15 - Conv2
        if 15 in self.checkpoint_list:
          x = self.conv2(x)
        else:
          x = checkpoint.checkpoint(self.conv2,x)
        # 16 - Hswish
        if 16 in self.checkpoint_list:
          x = self.hswish(x)
        else:
          x = checkpoint.checkpoint(self.hswish,x)
        # 17 - Batchnorm2
        if 17 in self.checkpoint_list:
          x = self.batchn2(x)
        else:
          x = checkpoint.checkpoint(self.batchn2,x)
        # 18 - Adaptive
        if 18 in self.checkpoint_list:
          x = self.adap(x)
        else:
          x = checkpoint.checkpoint(self.adap,x)
        # 19 - Conv3
        if 19 in self.checkpoint_list:
          x = self.conv3(x)
        else:
          x = checkpoint.checkpoint(self.conv3,x)
        # 20 - Hswish
        if 20 in self.checkpoint_list:
          x = self.hswish(x)
        else:
          x = checkpoint.checkpoint(self.hswish,x)

        x = x.flatten(start_dim=1)
        x = self.drop(x)
        x = self.lin(x)

        return x

In [None]:
def _initialize_weights(self):

    # weight initialization
    for m in self.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode="fan_out")
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, 0, 0.01)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

In [None]:
layers_names = {1:"Conv 1",2:"Batchn 1",3:"Hswish 1",4:"Fuse 1",5:"Fuse 2",6:"Fuse 3",7:"Fuse 4",8:"Fuse 5",9:"Fuse 6",10:"Fuse 7",11:"Fuse 8",12:"Fuse 9",13:"Fuse 10",14:"Fuse 11",15:"Conv 2",16:"Hswish 2",17:"Batchn 2",18:"Adaptive",19:"Conv 3",20:"Hswish 3"}

In [None]:
# Set the Layers to checkpoint :
cpt_layers = [2,8,17]
cpt_names = [layers_names[i] for i in cpt_layers]

print(cpt_names)

['Batchn 1', 'Fuse 5', 'Batchn 2']


In [None]:
wandb.init(project="fusenet-checkpoint-runs",name=str(cpt_names),reinit=True)

[34m[1mwandb[0m: Currently logged in as: [33momshri[0m (use `wandb login --relogin` to force relogin)


In [None]:
# Setting GPU Device
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("Running on the GPU")
else:
    print("No GPU Available!")
    exit()

Running on the GPU


In [None]:
fusenet = FuseNet(cpt_layers).to(device)
fusenet.apply(_initialize_weights)

FuseNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (batchn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (hswish): Hswish()
  (fuse1): Fuse(
    (nl): ReLU(inplace=True)
    (conv1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (batchn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2_a): Conv2d(16, 16, kernel_size=(3, 1), stride=(2, 2), padding=(1, 0), groups=16, bias=False)
    (conv2_b): Conv2d(16, 16, kernel_size=(1, 3), stride=(2, 2), padding=(0, 1), groups=16, bias=False)
    (batchn2_a): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (batchn2_b): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (se): SEModule(
      (avg_pool): AdaptiveAvgPool2d(output_size=1)
      (fc): Sequential(
        (0): Linear(in_features=32, out_features=8, bias=False)
        (1):

In [None]:
wandb.watch(fusenet)

[<wandb.wandb_torch.TorchGraph at 0x7f3a4b971b70>]

In [None]:
config = wandb.config        
config.batch_size = 128      
config.test_batch_size = 1000
config.epochs = 20         
config.lr = 0.001              
config.beta0 = 0.9
config.beta1 = 0.999
config.eps = 1e-08
config.weight_decay = 0

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer=optim.Adam(fusenet.parameters(), lr=config.lr, betas=(config.beta0,config.beta1), eps=config.eps, weight_decay=config.weight_decay)

In [None]:
transform_train = transforms.Compose([                                                                                                                                                  
                    transforms.RandomCrop(32, padding=4),                                                                                                                               
                    transforms.RandomHorizontalFlip(),                                                                                                                                  
                    transforms.ToTensor(),                                                                                                                                              
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),                                                                                           
                    ])                                                                                                                                                                  
trainset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                        download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=config.batch_size,
                                          shuffle=True, num_workers=2)

Files already downloaded and verified


In [None]:
# Example inputs to run forward pass

input, labels = next(iter(trainloader))
input = input.to(device)
labels = labels.to(device)

In [None]:
# Reset max-memory allocation

torch.cuda.reset_max_memory_allocated("cuda:0")



In [None]:
outputs = fusenet(input)



In [None]:
loss = criterion(outputs, labels)

In [None]:
mem_max = torch.cuda.max_memory_allocated("cuda:0")

In [None]:
before_bwd_time = time.time()
loss.backward()
after_bwd_time = time.time()

In [None]:
wandb.log({"max_memory_used":mem_max/(1024)**2})

In [None]:
time_list = []
for i in range(20):
  outputs = fusenet(input)
  loss = criterion(outputs, labels)
  before_bwd_time = time.time()
  loss.backward()
  after_bwd_time = time.time()
  time_list.append(after_bwd_time-before_bwd_time)



In [None]:
mean_time = (torch.mean(torch.FloatTensor(time_list)).item())

In [None]:
wandb.log({"time_taken":mean_time})

In [None]:
wandb.finish()

VBox(children=(Label(value=' 0.02MB of 0.02MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
max_memory_used,24.85205
_step,1.0
_runtime,7.0
_timestamp,1608034426.0
time_taken,0.02898


0,1
max_memory_used,▁
_step,▁█
_runtime,▁█
_timestamp,▁█
time_taken,▁
