<a href="https://colab.research.google.com/github/358Xin/DL/blob/main/Network_Compression.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!gdown --id '1awF7pZ9Dz7X1jn1_QAiKN-_v56veCEKy' --output food-11.zip

Downloading...
From: https://drive.google.com/uc?id=1awF7pZ9Dz7X1jn1_QAiKN-_v56veCEKy
To: /content/food-11.zip
100% 963M/963M [00:06<00:00, 151MB/s]


In [2]:
!unzip -q food-11.zip

In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
from torch.utils.data import ConcatDataset, DataLoader, Subset
from torchvision.datasets import DatasetFolder
from tqdm.auto import tqdm

In [4]:
train_tfm = transforms.Compose([transforms.Resize((142, 142)), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.RandomCrop(128), transforms.ToTensor()])
test_tfm = transforms.Compose([transforms.Resize((142, 142)), transforms.CenterCrop(128), transforms.ToTensor()])

In [5]:
batch_size = 64
train_set = DatasetFolder("food-11/training/labeled", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm)
valid_set = DatasetFolder("food-11/validation", loader=lambda x: Image.open(x), extensions="jpg", transform=test_tfm)
unlabeled_set = DatasetFolder("food-11/training/unlabeled", loader=lambda x: Image.open(x), extensions="jpg", transform=train_tfm)
test_set = DatasetFolder("food-11/testing", loader=lambda x: Image.open(x), extensions="jpg", transform=test_tfm)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

In [6]:
class StudentNet(nn.Module):
    def __init__(self):
      super(StudentNet, self).__init__()

      self.cnn = nn.Sequential(
        nn.Conv2d(3, 32, 3), 
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.Conv2d(32, 32, 3),  
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.MaxPool2d(2, 2, 0),     

        nn.Conv2d(32, 64, 3), 
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.MaxPool2d(2, 2, 0),     

        nn.Conv2d(64, 100, 3), 
        nn.BatchNorm2d(100),
        nn.ReLU(),
        nn.MaxPool2d(2, 2, 0),
        
        # Here we adopt Global Average Pooling for various input size.
        nn.AdaptiveAvgPool2d((1, 1)))
      self.fc = nn.Sequential(nn.Linear(100, 11))
      
    def forward(self, x):
      out = self.cnn(x)
      out = out.view(out.size()[0], -1)
      return self.fc(out)

In [7]:
from torchsummary import summary

student_net = StudentNet()
summary(student_net, (3, 128, 128), device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 126, 126]             896
       BatchNorm2d-2         [-1, 32, 126, 126]              64
              ReLU-3         [-1, 32, 126, 126]               0
            Conv2d-4         [-1, 32, 124, 124]           9,248
       BatchNorm2d-5         [-1, 32, 124, 124]              64
              ReLU-6         [-1, 32, 124, 124]               0
         MaxPool2d-7           [-1, 32, 62, 62]               0
            Conv2d-8           [-1, 64, 60, 60]          18,496
       BatchNorm2d-9           [-1, 64, 60, 60]             128
             ReLU-10           [-1, 64, 60, 60]               0
        MaxPool2d-11           [-1, 64, 30, 30]               0
           Conv2d-12          [-1, 100, 28, 28]          57,700
      BatchNorm2d-13          [-1, 100, 28, 28]             200
             ReLU-14          [-1, 100,

In [8]:
def loss_fn_kd(outputs, labels, teacher_outputs, alpha=0.5):
  hard_loss = F.cross_entropy(outputs, labels) * (1 - alpha)
  soft_loss = 0
  return hard_loss + soft_loss

In [23]:
!gdown --id '1B8ljdrxYXJsZv2vmTequdPOofp3VF3NN' --output teacher_resnet18.bin
teacher_net = models.resnet18(pretrained=False, num_classes=11).cuda()
teacher_net.load_state_dict(torch.load(f'./teacher_resnet18.bin'))
teacher_net.eval()

Downloading...
From: https://drive.google.com/uc?id=1B8ljdrxYXJsZv2vmTequdPOofp3VF3NN
To: /content/teacher_resnet18.bin
100% 44.8M/44.8M [00:00<00:00, 239MB/s] 


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [26]:
import enum
device = 'cuda' if torch.cuda.is_available() else 'cpu'
student_net = student_net.to(device)
teacher_net = teacher_net.to(device)

do_semi = True

def get_pseudo_labels(dataset, model):
  loader = DataLoader(dataset, batch_size=batch_size*3, shuffle=False, pin_memory=True)
  pseudo_labels = []
  for batch in tqdm(loader):
    img, _ = batch
    with torch.no_grad():
      logits = model(img.to(device))
      pseudo_labels.append(logits.argmax(dim=-1).detach().cpu())
  pseudo_labels = torch.cat(pseudo_labels)
  for index, ((img,_), pseudo_label) in enumerate(zip(dataset.samples, pseudo_labels)):
    dataset.samples[index] = (img, pseudo_label.item())
  return dataset

if do_semi:
  unlabeled_set = get_pseudo_labels(unlabeled_set, teacher_net)
  concat_dataset = ConcatDataset([train_set, unlabeled_set])
  train_loader = DataLoader(concat_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, drop_last=True)

  0%|          | 0/36 [00:00<?, ?it/s]

In [29]:
from torchvision.transforms.autoaugment import TrivialAugmentWide
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(student_net.parameters(), lr=0.0003, weight_decay=1e-5)
n_epochs = 80

for epoch in range(n_epochs):
  student_net.train()
  
  train_loss = []
  train_acc = []

  for batch in tqdm(train_loader):
    imgs, labels = batch
    logits = student_net(imgs.to(device))
    with torch.no_grad():
      soft_labels = teacher_net(imgs.to(device))
    loss = loss_fn_kd(logits, labels.to(device), soft_labels)
    optimizer.zero_grad()
    loss.backward()
    grad_norm = nn.utils.clip_grad_norm_(student_net.parameters(), max_norm=10)
    optimizer.step()
    acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()
    
    train_loss.append(loss.item())
    train_acc.append(acc)
  
  train_loss = sum(train_loss) / len(train_loss)
  train_acc = sum(train_acc) / len(train_acc)
  print(f"[Train | {epoch+1:03d} / {n_epochs:03d}] loss={train_loss:.5f}, acc={train_acc:.5f}")

  student_net.eval()

  valid_loss = []
  valid_acc = []
  for batch in tqdm(valid_loader):
    imgs, labels = batch

    with torch.no_grad():
      logits = student_net(imgs.to(device))
      soft_labels = teacher_net(imgs.to(device))
    loss = loss_fn_kd(logits, labels.to(device), soft_labels)
    acc = (logits.argmax(dim=-1) == labels.to(device)).float().detach().cpu().view(-1).numpy()

    valid_loss.append(loss.item())
    valid_acc += list(acc)

  valid_loss = sum(valid_loss) / len(valid_loss)
  valid_acc = sum(valid_acc) / len(valid_acc)
  print(f"[Valid | {epoch+1:03d} / {n_epochs:03d}] loss={valid_loss:.5f}, acc={valid_acc:.5f}")

  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 001 / 080] loss=0.97769, acc=0.32315


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 001 / 080] loss=0.98278, acc=0.30455


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 002 / 080] loss=0.94477, acc=0.35542


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 002 / 080] loss=0.94180, acc=0.35606


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 003 / 080] loss=0.92999, acc=0.36465


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 003 / 080] loss=0.92250, acc=0.35152


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 004 / 080] loss=0.91177, acc=0.37855


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 004 / 080] loss=0.89456, acc=0.37576


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 005 / 080] loss=0.89496, acc=0.39265


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 005 / 080] loss=0.87644, acc=0.37576


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 006 / 080] loss=0.88638, acc=0.40625


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 006 / 080] loss=0.90200, acc=0.37424


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 007 / 080] loss=0.87765, acc=0.40858


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 007 / 080] loss=0.84892, acc=0.40455


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 008 / 080] loss=0.86629, acc=0.41406


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdd18b8de60>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdd18b8de60>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/pytho

  0%|          | 0/11 [00:01<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdd18b8de60>
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
  File "/usr/lib/python3.7/multiprocessi

[Valid | 008 / 080] loss=0.86787, acc=0.40606


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 009 / 080] loss=0.85728, acc=0.42248


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 009 / 080] loss=0.84883, acc=0.40303


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 010 / 080] loss=0.85264, acc=0.42857


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 010 / 080] loss=0.82342, acc=0.40909


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 011 / 080] loss=0.84050, acc=0.43506


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 011 / 080] loss=0.83018, acc=0.44545


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 012 / 080] loss=0.83606, acc=0.43872


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 012 / 080] loss=0.84694, acc=0.40758


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 013 / 080] loss=0.82975, acc=0.44521


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 013 / 080] loss=0.81176, acc=0.44242


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 014 / 080] loss=0.82426, acc=0.44866


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 014 / 080] loss=0.80092, acc=0.42576


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 015 / 080] loss=0.81876, acc=0.45302


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 015 / 080] loss=0.91110, acc=0.34242


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 016 / 080] loss=0.81223, acc=0.46185


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 016 / 080] loss=0.86449, acc=0.40909


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 017 / 080] loss=0.80881, acc=0.46266


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 017 / 080] loss=0.81755, acc=0.45303


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 018 / 080] loss=0.80097, acc=0.47220


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdd18b8de60>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process


  0%|          | 0/11 [00:01<?, ?it/s]

Traceback (most recent call last):
  File "/usr/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdd18b8de60>
Traceback (most recent call last):
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
    self._shutdown_workers()
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
BrokenPipeError: [Errno 32] Broken pipe
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert 

[Valid | 018 / 080] loss=0.85244, acc=0.41970


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 019 / 080] loss=0.79879, acc=0.47047


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 019 / 080] loss=0.79762, acc=0.43788


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 020 / 080] loss=0.79138, acc=0.47727


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 020 / 080] loss=0.77085, acc=0.48485


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 021 / 080] loss=0.78941, acc=0.47900


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 021 / 080] loss=0.74989, acc=0.47879


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 022 / 080] loss=0.78503, acc=0.48255


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 022 / 080] loss=0.77945, acc=0.43030


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 023 / 080] loss=0.78055, acc=0.48580


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 023 / 080] loss=0.76719, acc=0.45000


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 024 / 080] loss=0.77361, acc=0.49209


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 024 / 080] loss=0.73414, acc=0.50000


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 025 / 080] loss=0.76966, acc=0.48935


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 025 / 080] loss=0.76146, acc=0.47576


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 026 / 080] loss=0.77171, acc=0.49036


  0%|          | 0/11 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdd18b8de60>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdd18b8de60>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/pytho

[Valid | 026 / 080] loss=0.83779, acc=0.43939


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 027 / 080] loss=0.75941, acc=0.49888


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 027 / 080] loss=0.74520, acc=0.48182


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 028 / 080] loss=0.76249, acc=0.49746


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 028 / 080] loss=0.69109, acc=0.54394


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 029 / 080] loss=0.75988, acc=0.49604


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 029 / 080] loss=0.73391, acc=0.47879


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 030 / 080] loss=0.75510, acc=0.49878


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 030 / 080] loss=0.76052, acc=0.48485


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 031 / 080] loss=0.75130, acc=0.50558


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 031 / 080] loss=0.84954, acc=0.40303


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 032 / 080] loss=0.75401, acc=0.51096


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 032 / 080] loss=0.75705, acc=0.46970


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 033 / 080] loss=0.74463, acc=0.51299


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 033 / 080] loss=0.77759, acc=0.46364


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 034 / 080] loss=0.74199, acc=0.50913


  0%|          | 0/11 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdd18b8de60>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdd18b8de60>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/pytho

[Valid | 034 / 080] loss=0.74881, acc=0.48939


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 035 / 080] loss=0.74239, acc=0.50923


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 035 / 080] loss=0.86098, acc=0.40758


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 036 / 080] loss=0.73909, acc=0.51136


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 036 / 080] loss=0.83864, acc=0.42121


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 037 / 080] loss=0.73680, acc=0.51613


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 037 / 080] loss=0.72231, acc=0.48788


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 038 / 080] loss=0.73371, acc=0.51552


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 038 / 080] loss=0.81873, acc=0.43030


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 039 / 080] loss=0.73099, acc=0.51390


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 039 / 080] loss=0.68724, acc=0.54848


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 040 / 080] loss=0.72967, acc=0.52212


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 040 / 080] loss=0.72693, acc=0.49242


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 041 / 080] loss=0.72872, acc=0.52283


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 041 / 080] loss=0.76550, acc=0.45152


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 042 / 080] loss=0.72753, acc=0.52161


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 042 / 080] loss=0.75145, acc=0.48485


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 043 / 080] loss=0.72597, acc=0.52252


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdd18b8de60>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdd18b8de60>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdd18b8de60>
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/d

  0%|          | 0/11 [00:01<?, ?it/s]

[Valid | 043 / 080] loss=0.69067, acc=0.53636


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 044 / 080] loss=0.72436, acc=0.52252


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 044 / 080] loss=0.69658, acc=0.53182


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 045 / 080] loss=0.71733, acc=0.53206


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 045 / 080] loss=0.72182, acc=0.51515


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 046 / 080] loss=0.71342, acc=0.53044


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 046 / 080] loss=0.78160, acc=0.45303


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 047 / 080] loss=0.71417, acc=0.53287


  0%|          | 0/11 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdd18b8de60>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Traceback (most recent call last):
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdd18b8de60>
  File "/usr/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
Traceback (most recent call last):
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/local/lib/pytho

[Valid | 047 / 080] loss=0.75273, acc=0.46364


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 048 / 080] loss=0.71275, acc=0.53155


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 048 / 080] loss=0.71460, acc=0.53030


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 049 / 080] loss=0.70790, acc=0.53409


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 049 / 080] loss=0.71764, acc=0.51515


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 050 / 080] loss=0.70780, acc=0.53429


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 050 / 080] loss=0.67496, acc=0.52273


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 051 / 080] loss=0.70780, acc=0.53724


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 051 / 080] loss=0.73048, acc=0.48788


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 052 / 080] loss=0.70431, acc=0.53744


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdd18b8de60>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdd18b8de60>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/pytho

  0%|          | 0/11 [00:01<?, ?it/s]

Traceback (most recent call last):
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdd18b8de60>
  File "/usr/lib/python3.7/multiprocessing/queues.py", line 242, in _feed
    send_bytes(obj)
Traceback (most recent call last):
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/lib/python3.7/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
    self._shutdown_workers()
BrokenPipeError: [Errno 32] Broken pipe
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert 

[Valid | 052 / 080] loss=0.66894, acc=0.54545


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 053 / 080] loss=0.70592, acc=0.53977


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 053 / 080] loss=0.75424, acc=0.47879


  0%|          | 0/154 [00:10<?, ?it/s]

[Train | 054 / 080] loss=0.70346, acc=0.53500


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 054 / 080] loss=0.67117, acc=0.54091


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 055 / 080] loss=0.69962, acc=0.54200


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 055 / 080] loss=0.82557, acc=0.45758


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 056 / 080] loss=0.69831, acc=0.54363


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 056 / 080] loss=0.67159, acc=0.55000


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 057 / 080] loss=0.69603, acc=0.54495


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 057 / 080] loss=0.70461, acc=0.53030


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 058 / 080] loss=0.69280, acc=0.54616


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 058 / 080] loss=0.68011, acc=0.49848


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 059 / 080] loss=0.69241, acc=0.54464


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 059 / 080] loss=0.68045, acc=0.55000


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 060 / 080] loss=0.69054, acc=0.54180


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 060 / 080] loss=0.65574, acc=0.54545


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 061 / 080] loss=0.68639, acc=0.55286


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 061 / 080] loss=0.66751, acc=0.55758


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 062 / 080] loss=0.68853, acc=0.54728


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 062 / 080] loss=0.73909, acc=0.50000


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 063 / 080] loss=0.68482, acc=0.55022


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 063 / 080] loss=0.64723, acc=0.55909


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 064 / 080] loss=0.68522, acc=0.54576


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 064 / 080] loss=0.68276, acc=0.53939


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 065 / 080] loss=0.67644, acc=0.55591


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 065 / 080] loss=0.73581, acc=0.50303


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 066 / 080] loss=0.67881, acc=0.55560


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 066 / 080] loss=0.71413, acc=0.49545


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 067 / 080] loss=0.67617, acc=0.55682


  0%|          | 0/11 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdd18b8de60>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdd18b8de60>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/pytho

[Valid | 067 / 080] loss=0.78615, acc=0.44545


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 068 / 080] loss=0.67739, acc=0.55398


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 068 / 080] loss=0.69535, acc=0.52121


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 069 / 080] loss=0.67835, acc=0.55509


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 069 / 080] loss=0.81132, acc=0.47121


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 070 / 080] loss=0.67558, acc=0.55651


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 070 / 080] loss=0.63501, acc=0.56061


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 071 / 080] loss=0.67347, acc=0.55824


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 071 / 080] loss=0.73195, acc=0.49848


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 072 / 080] loss=0.67208, acc=0.56199


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 072 / 080] loss=0.72430, acc=0.51061


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 073 / 080] loss=0.67292, acc=0.56067


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 073 / 080] loss=0.66611, acc=0.51667


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 074 / 080] loss=0.67077, acc=0.56138


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 074 / 080] loss=0.68321, acc=0.51212


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 075 / 080] loss=0.66877, acc=0.55763


  0%|          | 0/11 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdd18b8de60>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fdd18b8de60>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1358, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1341, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/pytho

[Valid | 075 / 080] loss=0.70078, acc=0.51667


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 076 / 080] loss=0.66529, acc=0.56159


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 076 / 080] loss=0.71511, acc=0.50909


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 077 / 080] loss=0.66584, acc=0.56372


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 077 / 080] loss=0.70457, acc=0.52121


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 078 / 080] loss=0.66296, acc=0.56392


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 078 / 080] loss=0.63561, acc=0.58333


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 079 / 080] loss=0.66392, acc=0.56534


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 079 / 080] loss=0.81653, acc=0.44394


  0%|          | 0/154 [00:00<?, ?it/s]

[Train | 080 / 080] loss=0.66087, acc=0.56037


  0%|          | 0/11 [00:00<?, ?it/s]

[Valid | 080 / 080] loss=0.72973, acc=0.54091


In [30]:
student_net.eval()
predictions = []

for batch in tqdm(test_loader):
  imgs, labels = batch
  with torch.no_grad():
    logits = student_net(imgs.to(device))
  predictions.extend(logits.argmax(dim=-1).cpu().numpy().tolist())

  0%|          | 0/53 [00:00<?, ?it/s]

In [31]:
with open('predict.csv', 'w') as f:
  f.write('Id, Category\n')
  for i, pred in enumerate(predictions):
    f.write(f"{i}, {pred}\n")