<a href="https://colab.research.google.com/github/nooblette/DeepLearning/blob/main/ShuffleNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torchvision import datasets, transforms, utils
import pandas as pd
from torchsummary import summary
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
import time
import tqdm

%matplotlib inline

torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)

transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalize
                                transforms.Resize((32, 32))
                                ])
train_dataset = datasets.CIFAR10(root="./data", train = True, transform=transform, download=True)
test_dataset = datasets.CIFAR10(root="./data", train = False, transform=transform, download=True)

batch_size = 256

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

device = "cuda" if torch.cuda.is_available() else "cpu"

class ShuffleBlock(nn.Module):
    def __init__(self, groups):
        super(ShuffleBlock, self).__init__()
        self.groups = groups

    def forward(self, x):
        '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
        N,C,H,W = x.size()
        g = self.groups
        return x.view(N,g,C//g,H,W).permute(0,2,1,3,4).reshape(N,C,H,W)


class Bottleneck(nn.Module):
    def __init__(self, in_planes, out_planes, stride, groups):
        super(Bottleneck, self).__init__()
        self.stride = stride

        mid_planes = out_planes//4
        g = 1 if in_planes==24 else groups
        self.conv1 = nn.Conv2d(in_planes, mid_planes, kernel_size=1, groups=g, bias=False)
        self.bn1 = nn.BatchNorm2d(mid_planes)
        self.shuffle1 = ShuffleBlock(groups=g)
        self.conv2 = nn.Conv2d(mid_planes, mid_planes, kernel_size=3, 
                               stride=stride, padding=1, groups=mid_planes, bias=False)
        self.bn2 = nn.BatchNorm2d(mid_planes)
        self.conv3 = nn.Conv2d(mid_planes, out_planes, kernel_size=1, groups=groups, bias=False)
        self.bn3 = nn.BatchNorm2d(out_planes)

        self.shortcut = nn.Sequential()
        if stride == 2:
            self.shortcut = nn.Sequential(nn.AvgPool2d(3, stride=2, padding=1))

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.shuffle1(out)
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        res = self.shortcut(x)
        out = F.relu(torch.cat([out,res], 1)) if self.stride==2 else F.relu(out+res)
        return out

class ShuffleNet(nn.Module):
    cfg = {'out_planes': [200, 400, 800], 'num_blocks': [4,8,4], 'groups': 2}
    def __init__(self):
        super(ShuffleNet, self).__init__()
        out_planes = ShuffleNet.cfg['out_planes']
        num_blocks = ShuffleNet.cfg['num_blocks']
        groups = ShuffleNet.cfg['groups']

        self.conv1 = nn.Conv2d(3, 24, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(24)
        self.in_planes = 24
        self.layer1 = self._make_layer(out_planes[0], num_blocks[0], groups)
        self.layer2 = self._make_layer(out_planes[1], num_blocks[1], groups)
        self.layer3 = self._make_layer(out_planes[2], num_blocks[2], groups)
        self.linear = nn.Linear(out_planes[2], 10)

    def _make_layer(self, out_planes, num_blocks, groups):
        layers = []
        for i in range(num_blocks):
            stride = 2 if i == 0 else 1
            cat_planes = self.in_planes if i == 0 else 0
            layers.append(Bottleneck(self.in_planes, out_planes-cat_planes, stride=stride, groups=groups))
            self.in_planes = out_planes
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def accuracy(y, label, printable=False):
      with torch.no_grad():
        pred = torch.argmax(y.data, 1)
        correct = (pred==label).sum().item()
        temp_acc = (100*correct / y.shape[0])
        if printable:
          print(f'Accuracy of the network on the test images (batch_size : {y.shape[0]}): {temp_acc}%')
        return temp_acc
        
def train():
    # summary(shuffleNet, input_size=(3, 32, 32), batch_size=batch_size, device=device)

    max_epoch = 10
    for epoch in range(max_epoch):
      total_acc = []
      total_loss = []
      start_time = time.time()
      print(f"Epoch {epoch} starts.")

      for data, label in tqdm.tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        optimizer.zero_grad()
        y = shuffleNet(data)

        temp_acc = accuracy(y, label, printable = False)
        total_acc.append(temp_acc)

        loss = CEloss(y, label)
        total_loss.append(loss)

        loss.backward()

        optimizer.step()

      print("\n")
      print(f"{epoch} epoch loss : {np.array(total_loss).sum() / len(total_loss)}")
      print(f"{epoch} epoch accuracy : {np.array(total_acc).sum() / len(total_acc)}")
      print(f"{epoch} epoch time : {time.time() - start_time } (s)")
      print("\n")

def test():
    start_time = time.time()

    total_acc = []
    total_loss = []

    for data, label in tqdm.tqdm(test_loader):
      with torch.no_grad():
        data = data.to(device)
        label = label.to(device)

        y = shuffleNet(data)

        temp_acc = accuracy(y, label)
        total_acc.append(temp_acc)

        loss = CEloss(y, label)
        total_loss.append(loss)

    print("\n")
    print(f"Test loss ; {np.array(total_loss).sum() / len(total_loss)}")
    print(f"Test train accuracy : {np.array(total_acc).sum() / len(total_acc)}")
    print(f"Single epoch Time : {time.time() - start_time} (s)")
    print("\n")

shuffleNet = ShuffleNet().to(device)
CEloss = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(shuffleNet.parameters(), lr=0.01, momentum=0.9)

train()
test()


Files already downloaded and verified
Files already downloaded and verified
Epoch 0 starts.


100%|██████████| 196/196 [22:06<00:00,  6.77s/it]




0 epoch loss : 1.9499588012695312
0 epoch accuracy : 33.182397959183675
0 epoch time : 1326.5453009605408 (s)


Epoch 1 starts.


100%|██████████| 196/196 [22:08<00:00,  6.78s/it]




1 epoch loss : 1.4620800018310547
1 epoch accuracy : 48.704958545918366
1 epoch time : 1328.8966012001038 (s)


Epoch 2 starts.


100%|██████████| 196/196 [21:59<00:00,  6.73s/it]




2 epoch loss : 1.2048498392105103
2 epoch accuracy : 57.11734693877551
2 epoch time : 1319.2097301483154 (s)


Epoch 3 starts.


100%|██████████| 196/196 [22:16<00:00,  6.82s/it]




3 epoch loss : 1.083083152770996
3 epoch accuracy : 62.0344387755102
3 epoch time : 1336.7220633029938 (s)


Epoch 4 starts.


100%|██████████| 196/196 [22:08<00:00,  6.78s/it]




4 epoch loss : 0.9734275341033936
4 epoch accuracy : 65.89405293367346
4 epoch time : 1328.2540402412415 (s)


Epoch 5 starts.


100%|██████████| 196/196 [21:54<00:00,  6.71s/it]




5 epoch loss : 0.8989843130111694
5 epoch accuracy : 68.41119260204081
5 epoch time : 1314.2624323368073 (s)


Epoch 6 starts.


100%|██████████| 196/196 [22:25<00:00,  6.86s/it]




6 epoch loss : 0.8184847831726074
6 epoch accuracy : 71.3719706632653
6 epoch time : 1345.2457571029663 (s)


Epoch 7 starts.


100%|██████████| 196/196 [22:05<00:00,  6.76s/it]




7 epoch loss : 0.7149566411972046
7 epoch accuracy : 75.0015943877551
7 epoch time : 1325.8969197273254 (s)


Epoch 8 starts.


100%|██████████| 196/196 [22:19<00:00,  6.83s/it]




8 epoch loss : 0.6474006772041321
8 epoch accuracy : 77.1727519132653
8 epoch time : 1339.5041732788086 (s)


Epoch 9 starts.


100%|██████████| 196/196 [21:54<00:00,  6.71s/it]




9 epoch loss : 0.6040862202644348
9 epoch accuracy : 78.55907206632654
9 epoch time : 1314.2175059318542 (s)




100%|██████████| 40/40 [01:33<00:00,  2.33s/it]



Test loss ; 0.9601673126220703
Test train accuracy : 69.560546875
Single epoch Time : 93.0175895690918 (s)





