# 目標
1. Solve image classification with CNN
2. Improve the performance with data augmentation
3. Understand how to utilze unlabeled data and how it benefits

In [54]:
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms

from PIL import Image
# "ConcatDataset" and "Subset" are possibly useful when doing semi-supervised learning.
from torch.utils.data import ConcatDataset, DataLoader, Subset, Dataset
from torchvision.datasets import DatasetFolder

# For progress bar
from tqdm.auto import tqdm


# Dataset, Data Loader and Transforms

In [4]:
# 資料預處理
train_transformer = transforms.Compose([
    # Resize the image into a fixed shape (height = width = 128)
    transforms.Resize((128, 128)),
    # You may add some transforms here.
    # ToTensor() should be the last one of the transforms.
    transforms.ToTensor(),
])

test_transformer = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor()
])

In [67]:
batch_size = 128

# Dataset: 用DatasetFolder可以將folder內的檔案存入一個dataset
# loader代表要開啟的方式
# lambda x: Image.open(x) => 每個檔案都使用PIL的Image.open()方法來開啟檔案
# transformer: 做資料預處理
train_set = DatasetFolder('./data/food-11/training/labeled', loader=lambda x:Image.open(x), extensions='jpg', transform=train_transformer)
valid_set = DatasetFolder('./data/food-11/validation', loader=lambda x:Image.open(x), extensions='jpg', transform=test_transformer)
unlabeled_set = DatasetFolder('./data/food-11/training/unlabeled', loader=lambda x:Image.open(x), extensions='jpg', transform=train_transformer)
test_set = DatasetFolder('./data/food-11/testing', loader=lambda x:Image.open(x), extensions='jpg', transform=test_transformer)


# DataLoader
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

# Model

In [73]:
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()

        self.cnn_layers = nn.Sequential(
            # input = [batch_size, 3(通道數rgb), 128, 128(圖片size)]
            # nn.Conv2d(input通道數, output通道數, kernel_size, step, padding)
            #   功能：設定成2D conventional layer
            # nn.BatchNorm2d(input通道數)
            #   功能：在batch做一次normalize
            # 
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0), # 池化(圖片size / 2)

            # input = [batch_size, 64, 64, 64]
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2, 0),

            # input = [batch_size, 128, 32, 32]
            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(4, 4, 0), # 池化(圖片size / 4)

            # output = [batch_size, 256, 8, 8]      
        )

        self.fc_layers = nn.Sequential(
            # input_features = 通道數 * 圖片size
            nn.Linear(256 * 8 * 8, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 11)
        )

    def forward(self, x):
        x = self.cnn_layers(x)
        x = x.flatten(1)
        x = self.fc_layers(x)
        return x

# Training

In [56]:
def get_pseudo_labels(dataset, model, threshold=0.65):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    model.eval()
    softmax = nn.Softmax(dim=-1)

    psuedo_labels = []
    # iterate over the dataset bybatches
    for batch in tqdm(data_loader):
        img, _ = batch
        with torch.no_grad():
            logits = model(img.to(device))

        # Obtain the probability distributions
        probs = softmax(logits)

        # Filter data and construct a new dataset
        # if probs中最有可能的label的機率 > threshold
        # then 將img和其class加入psuedo_labels
        if torch.max(probs) > threshold:
            psuedo_label = torch.argmax(probs).item()
            psuedo_labels.append(img, psuedo_label)

    dataset = []
    for img, psuedo_label in psuedo_labels:
        dataset.append((img.cpu(), psuedo_label))


    model.train()
    return dataset

In [76]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device: ' + device)

model = Classifier().to(device)

criterion = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr = 0.0003, weight_decay=1e-5)

n_epochs = 80

do_semi = False

for epoch in range(n_epochs):
    if do_semi:
        pseudo_set = get_pseudo_labels(unlabeled_set, model=model)
        concat_dataset = ConcatDataset([train_set, pseudo_set])
        train_loader = DataLoader(concat_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)

    model.train()

    train_loss = []
    train_accs = []

    for batch in tqdm(train_loader):
        imgs, labels = batch
        logits = model(imgs.to(device))
        loss = criterion(logits, labels.to(device))
        optimizer.zero_grad()
        loss.backward()
        grad_norm = nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
        optimizer.step()

        # 計算出logits中最大的label與實際label相等的img
        # 回傳一個bool array並將其轉換成float(1. or 0.)
        # 最後取其平均(正確數量/所有數量)
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()
        
        # 紀錄本次batch的loss與accuracy
        train_loss.append(loss.item())
        train_accs.append(acc)

    # avg_loss = 所有batch的loss加總 / batch次數
    # avg_accs = 所有batch的acc加總 / batch次數
    train_loss = sum(train_loss) / len(train_loss)
    train_accs = sum(train_accs) / len(train_accs)

    print('[Train | {0:03d}/{1:03d}] Loss = {2:.5f}, Accurancy = {3:.5f}'.format(epoch, n_epochs, train_loss, train_accs))

    # Validation
    model.eval()

    valid_loss = []
    valid_accs = []

    for batch in tqdm(valid_loader):
        imgs, labels = batch
        with torch.no_grad():
            logits = model(imgs.to(device))

        loss = criterion(logits, labels.to(device))
        acc = (logits.argmax(dim=-1) == labels.to(device)).float().mean()

        valid_loss.append(loss.item())
        valid_accs.append(acc)

    valid_loss = sum(valid_loss) / len(valid_loss)
    valid_accs = sum(valid_accs) / len(valid_accs)

    print('[ Validation | {0:03d}/{1:03d}] Loss = {2:.5f}, Accurancy = {3:.5f}'.format(epoch, n_epochs, valid_loss, valid_accs))



Device: cpu


  0%|          | 0/25 [00:00<?, ?it/s]

[Train | 000/080] Loss = 2.20289, Accurancy = 0.21125


  0%|          | 0/6 [00:00<?, ?it/s]

[ Validation | 000/080] Loss = 2.71937, Accurancy = 0.10156


  0%|          | 0/25 [00:00<?, ?it/s]

[Train | 001/080] Loss = 1.88868, Accurancy = 0.33625


  0%|          | 0/6 [00:00<?, ?it/s]

[ Validation | 001/080] Loss = 2.18790, Accurancy = 0.22708


  0%|          | 0/25 [00:00<?, ?it/s]

[Train | 002/080] Loss = 1.71412, Accurancy = 0.40312


  0%|          | 0/6 [00:00<?, ?it/s]

[ Validation | 002/080] Loss = 1.83537, Accurancy = 0.34375


  0%|          | 0/25 [00:00<?, ?it/s]

[Train | 003/080] Loss = 1.57161, Accurancy = 0.46687


  0%|          | 0/6 [00:00<?, ?it/s]

[ Validation | 003/080] Loss = 1.75476, Accurancy = 0.38125


  0%|          | 0/25 [00:00<?, ?it/s]

[Train | 004/080] Loss = 1.43303, Accurancy = 0.51375


  0%|          | 0/6 [00:00<?, ?it/s]

[ Validation | 004/080] Loss = 1.71889, Accurancy = 0.38932


  0%|          | 0/25 [00:00<?, ?it/s]

[Train | 005/080] Loss = 1.28545, Accurancy = 0.56406


  0%|          | 0/6 [00:00<?, ?it/s]

[ Validation | 005/080] Loss = 1.92107, Accurancy = 0.36771


  0%|          | 0/25 [00:00<?, ?it/s]

[Train | 006/080] Loss = 1.25711, Accurancy = 0.57344


  0%|          | 0/6 [00:00<?, ?it/s]

[ Validation | 006/080] Loss = 1.95168, Accurancy = 0.38568


  0%|          | 0/25 [00:00<?, ?it/s]

[Train | 007/080] Loss = 1.10373, Accurancy = 0.62719


  0%|          | 0/6 [00:00<?, ?it/s]

[ Validation | 007/080] Loss = 1.71260, Accurancy = 0.42005


  0%|          | 0/25 [00:00<?, ?it/s]

[Train | 008/080] Loss = 1.00337, Accurancy = 0.67969


  0%|          | 0/6 [00:00<?, ?it/s]

[ Validation | 008/080] Loss = 1.56278, Accurancy = 0.47969


  0%|          | 0/25 [00:00<?, ?it/s]

[Train | 009/080] Loss = 0.89816, Accurancy = 0.70625


  0%|          | 0/6 [00:00<?, ?it/s]

[ Validation | 009/080] Loss = 1.63912, Accurancy = 0.45990


  0%|          | 0/25 [00:00<?, ?it/s]

[Train | 010/080] Loss = 0.82362, Accurancy = 0.73531


  0%|          | 0/6 [00:00<?, ?it/s]

[ Validation | 010/080] Loss = 1.91758, Accurancy = 0.40833


  0%|          | 0/25 [00:00<?, ?it/s]

[Train | 011/080] Loss = 0.71548, Accurancy = 0.77094


  0%|          | 0/6 [00:00<?, ?it/s]

[ Validation | 011/080] Loss = 1.83045, Accurancy = 0.43021


  0%|          | 0/25 [00:00<?, ?it/s]

[Train | 012/080] Loss = 0.62671, Accurancy = 0.80781


  0%|          | 0/6 [00:00<?, ?it/s]

[ Validation | 012/080] Loss = 1.60910, Accurancy = 0.48307


  0%|          | 0/25 [00:00<?, ?it/s]

[Train | 013/080] Loss = 0.57008, Accurancy = 0.82812


  0%|          | 0/6 [00:00<?, ?it/s]

[ Validation | 013/080] Loss = 1.78182, Accurancy = 0.45729


  0%|          | 0/25 [00:00<?, ?it/s]

[Train | 014/080] Loss = 0.45694, Accurancy = 0.86469


  0%|          | 0/6 [00:00<?, ?it/s]

[ Validation | 014/080] Loss = 1.60204, Accurancy = 0.49062


  0%|          | 0/25 [00:00<?, ?it/s]

[Train | 015/080] Loss = 0.36440, Accurancy = 0.90844


  0%|          | 0/6 [00:00<?, ?it/s]

[ Validation | 015/080] Loss = 1.64954, Accurancy = 0.50104


  0%|          | 0/25 [00:00<?, ?it/s]

[Train | 016/080] Loss = 0.32655, Accurancy = 0.92125


  0%|          | 0/6 [00:00<?, ?it/s]

[ Validation | 016/080] Loss = 1.71846, Accurancy = 0.50104


  0%|          | 0/25 [00:00<?, ?it/s]

[Train | 017/080] Loss = 0.29197, Accurancy = 0.92313


  0%|          | 0/6 [00:00<?, ?it/s]

[ Validation | 017/080] Loss = 2.02611, Accurancy = 0.45417


  0%|          | 0/25 [00:00<?, ?it/s]

[Train | 018/080] Loss = 0.25332, Accurancy = 0.93969


  0%|          | 0/6 [00:00<?, ?it/s]

[ Validation | 018/080] Loss = 2.18821, Accurancy = 0.44141


  0%|          | 0/25 [00:00<?, ?it/s]

[Train | 019/080] Loss = 0.24435, Accurancy = 0.93000


  0%|          | 0/6 [00:00<?, ?it/s]

[ Validation | 019/080] Loss = 1.87168, Accurancy = 0.49167


  0%|          | 0/25 [00:00<?, ?it/s]

[Train | 020/080] Loss = 0.19736, Accurancy = 0.95500


  0%|          | 0/6 [00:00<?, ?it/s]

[ Validation | 020/080] Loss = 2.10321, Accurancy = 0.45130


  0%|          | 0/25 [00:00<?, ?it/s]

[Train | 021/080] Loss = 0.13723, Accurancy = 0.97875


  0%|          | 0/6 [00:00<?, ?it/s]

[ Validation | 021/080] Loss = 2.20400, Accurancy = 0.44922


  0%|          | 0/25 [00:00<?, ?it/s]

KeyboardInterrupt: 

# Test

In [None]:
model.eval()

predictions = []

for batch in tqdm(test_loader):
    imgs, labels = batch

    with torch.no_grad():
        logits = model(img.to(device))

    predictions.append(logits.argmax(dim=-1).cpu().numpy().tolist())

with open('predcit.csv', 'w') as f:
    f.write("Id, Category\n")

    for i, pred in predictions:
        f.write('{}, {}\n'.format(i, pred))