<a href="https://colab.research.google.com/github/jiruneko/3Dpeg/blob/master/EfficientNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms 
import tqdm
import matplotlib.pyplot as plt

In [23]:
class Swish(nn.Module):
  """x*sigmoid(x)"""
  def __init__(self):
    super().__init__()
  def forward(self, x):
    return x*torch.sigmoid(x)

In [24]:
class DWBlock(nn.Module):
  # 出力が同じサイズになるようにパディングはk//2 (ただしkは奇数)
  def __init__(self, in_c, k, s, p, bias=False):
    if p != k//2:
      print("output may not be the same spatial size as input")
    super().__init__()
    self.dw = nn.Conv2d(in_c, in_c, kernel_size=k, stride=s, padding=p, groups=in_c, bias=bias)
    self.bn = nn.BatchNorm2d(in_c)
    self.act = Swish()

  def forward(self, x):
    out = self.act(self.bn(self.dw(x)))
    return out

In [25]:
class PWBlock(nn.Module):
  def __init__(self, in_c, out_c, bias=False, act="swish"):
    super().__init__()
    self.dw = nn.Conv2d(in_c, out_c, kernel_size=1, bias=bias)
    self.bn = nn.BatchNorm2d(out_c)
    if act=='swish':
      self.act = Swish()
    elif act is None:
      self.act = nn.Identity()

  def forward(self, x):
    out = self.act(self.bn(self.dw(x)))
    return out

In [26]:
class SEBlock(nn.Module):
  def __init__(self, in_c, h=8):
    super().__init__()
    # Squeeze
    self.gap = nn.AdaptiveAvgPool2d(1) 
    # Excitation
    self.fc1 = nn.Linear(in_c, in_c//h, bias=False) 
    self.act1 = Swish()
    self.fc2 = nn.Linear(in_c//h, in_c, bias=False)
    self.act2 = nn.Sigmoid()

  def forward(self, x):
    out = self.gap(x).squeeze(-1).squeeze(-1)
    out = self.act1(self.fc1(out))
    out = self.act2(self.fc2(out)).unsqueeze(-1).unsqueeze(-1)
    return out*x

In [27]:
class MBConv(nn.Module):
  def __init__(self, in_c, out_c, k=5, s=1, expansion=1):
    super().__init__()
    self.s = s
    self.in_c = in_c
    self.out_c = out_c
    self.pw1 = PWBlock(in_c, in_c*expansion, bias=False)
    self.dw = DWBlock(in_c*expansion, k=k, s=s, p=k//2, bias=False)
    self.se = SEBlock(in_c*expansion)
    self.pw2 = PWBlock(in_c*expansion, out_c, bias=False, act=None)

  def forward(self, x):
    out = self.pw2(self.se(self.dw(self.pw1(x))))
    if self.s == 1 and self.in_c==self.out_c:
      out = out+x
    return out

In [28]:
class EfficientNetB0(nn.Module):
  def __init__(self, n_c=3, n_classes=10):
    super().__init__()
    self.first = nn.Sequential(
        nn.Conv2d(n_c,32,3,1,1,bias=False),
        nn.BatchNorm2d(32),
        Swish()
    )
    self.mb1 = MBConv(32, 16, 3, expansion=1)
    self.mb6_1 = nn.Sequential(
        MBConv(16, 24, k=3, s=1, expansion=6),
        MBConv(24, 24, k=3, s=1, expansion=6)
    )
    self.mb6_2 = nn.Sequential(
        MBConv(24, 40, k=5, s=1, expansion=6),
        MBConv(40, 40, k=5, s=1, expansion=6)
    )
    #画像サイズ:32->16
    self.mb6_3 = nn.Sequential(
        MBConv(40, 80, k=3, s=2, expansion=6),
        MBConv(80, 80, k=3, s=1, expansion=6),
        MBConv(80, 80, k=3, s=1, expansion=6),
    )
    self.mb6_4 = nn.Sequential(
        MBConv(80, 112, k=5, s=1, expansion=6),
        MBConv(112, 112, k=5, s=1, expansion=6),
        MBConv(112, 112, k=5, s=1, expansion=6),
    )
    #画像サイズ: 16->8
    self.mb6_5 = nn.Sequential(
        MBConv(112, 192, k=5, s=2, expansion=6),
        MBConv(192, 192, k=5, s=1, expansion=6),
        MBConv(192, 192, k=5, s=1, expansion=6),
        MBConv(192, 192, k=5, s=1, expansion=6),
    )

    self.mb6_6 = nn.Sequential(
        MBConv(192, 320, k=3, s=1, expansion=6),
    )
    self.pw = PWBlock(320, 1280)
    self.gap = nn.AdaptiveAvgPool2d(1)
    self.dropout = nn.Dropout(0.2)
    self.fc = nn.Linear(1280, n_classes)

  def forward(self, x):
    out = self.first(x)
    out = self.mb1(out)
    out = self.mb6_1(out)
    out = self.mb6_2(out)
    out = self.mb6_3(out)
    out = self.mb6_4(out)
    out = self.mb6_5(out)
    out = self.mb6_6(out)
    out = self.pw(out)
    out = self.gap(out).view(x.size(0), -1)
    out = self.dropout(out)
    out = self.fc(out)
    return out

In [29]:
train_transform = transforms.Compose([
                                    transforms.RandomCrop(size=32, padding=4),
                                    transforms.RandomHorizontalFlip(p=0.5),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
])
train_ds = torchvision.datasets.CIFAR10('.', train=True, transform=train_transform, download=True)
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=128, num_workers=4, pin_memory=True)

Files already downloaded and verified


In [30]:
test_transform = transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
])
test_ds = torchvision.datasets.CIFAR10('.', train=False, transform=test_transform, download=True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=1024, num_workers=4, pin_memory=True)

Files already downloaded and verified


In [31]:
net = EfficientNetB0(n_c=3, n_classes=10)

In [32]:
criterion = nn.CrossEntropyLoss()

In [33]:
optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25, 40], gamma=0.1)

In [34]:
history = {'loss':[], 'acc':[], 'val_loss':[], 'val_acc':[]}
epochs = 50
device ='cuda:0' if torch.cuda.is_available() else 'cpu'
net = net.to(device)
for epoch in range(epochs):
    net.train()
    epoch_acc, epoch_loss = 0., 0.
    num_imgs = 0.
    for img, label in tqdm.auto.tqdm(train_dl):
        img, label = img.to(device), label.to(device)
        optimizer.zero_grad()
        output = net(img)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.data
        epoch_acc += sum(label == output.argmax(0))
        num_imgs += img.size(0)*1.0 # floatにするため1.0を掛けている。

    scheduler.step()
    epoch_loss /= num_imgs
    epoch_acc /= num_imgs
    history['loss'].append(epoch_loss)
    history['acc'].append(epoch_acc)    
    print(f'Epoch: {epoch}, loss:{epoch_loss}, acc:{epoch_acc}')

  0%|          | 0/391 [00:00<?, ?it/s]

RuntimeError: ignored