# 숫자 멀티 라벨 분류기 실습

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import random
from PIL import Image
from tqdm.notebook import tqdm

import torch
import torchvision.transforms as T
import torchvision.models as models

from sklearn.metrics import f1_score

In [None]:
# 시드를 고정합니다.
SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True

In [None]:
device = 'mps'

# CSV 파일 확인

In [None]:
data_df = pd.read_csv('/Users/kimhongseok/cv_79_projects/part1/chapter3/2/data/annotations.csv')
data_df

In [None]:
data_df['filepath'].tolist()

In [None]:
tmp = data_df.iloc[4,2]
tmp = tmp.replace('[', '')
tmp = tmp.replace(']', '')
tmp = tmp.split(',')
tmp = [int(x) for x in tmp]
tmp

# CustomDataset

In [None]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, classes, data_df, transforms):
        super().__init__()
        self.data = []
        self.transforms = transforms
        self.classes = classes

        data_num = data_df.shape[0]
        img_list = data_df['filepath'].tolist()
        cls_list = data_df['classes'].tolist()
        
        for i in range(data_num):
            img = os.path.join(root_dir, img_list[i])
            cls = cls_list[i]
            cls = cls.replace('[', '')
            cls = cls.replace(']', '')
            cls = cls.split(',')
            cls = [int(x) for x in cls]

            self.data.append((img, cls))

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img = Image.open(self.data[idx][0]).convert('RGB')
        img = self.transforms(img)
        cls = self.data[idx][1]
        cls = torch.nn.functional.one_hot(torch.tensor(cls), len(self.classes)).sum(dim=0).to(torch.float) # sum을 안 하면 각각에 대해서 one-hot이 되어있다. sum을 통해 하나의 리스트로 만들어준다.
        
        return img, cls

In [None]:
transforms = T.Compose([
    T.Resize((112, 224)),
    T.ToTensor()
])

classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

total_dataset = CustomDataset('/Users/kimhongseok/cv_79_projects/part1/chapter3/2/data', classes, data_df, transforms)

In [None]:
plt.figure(figsize=(3, 3))
plt.imshow(total_dataset[0][0].permute(1, 2, 0))
plt.title(total_dataset[0][1])
plt.show()

In [None]:
total_num = len(total_dataset)
train_num, valid_num, test_num = int(total_num*0.8), int(total_num*0.1), int(total_num*0.1)

train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(total_dataset, [train_num, valid_num, test_num])

In [None]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=20, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=20, shuffle=False)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=20, shuffle=False)

# train

In [None]:
def training(model, train_dataloader, train_dataset, criterion, optimizer, threshold, epoch, num_epochs):
    model.train()
    train_loss = 0.0
    total_preds = []
    total_labels = []

    tbar = tqdm(train_dataloader)
    for images, labels in tbar:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        preds = (torch.sigmoid(outputs) > threshold).float()
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        total_labels.extend(labels.cpu().numpy())
        total_preds.extend(preds.cpu().numpy())

        tbar.set_description(f'Epoch/Epochs [{epoch+1}/{num_epochs}]')

    train_loss = train_loss / len(train_dataloader)
    train_f1 = f1_score(total_labels, total_preds, average='micro')

    return model, train_loss, train_f1

def evalutation(model, valid_dataloader, valid_dataset, criterion, threshold, epoch, num_epochs):
    model.eval()
    valid_loss = 0.0
    total_labels = []
    total_preds = []

    with torch.no_grad():
        tbar = tqdm(valid_dataloader)
        for images, labels in tbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = (torch.sigmoid(outputs) > threshold).float()
            loss = criterion(outputs, labels)

            valid_loss += loss.item()
            total_labels.extend(labels.cpu().numpy())
            total_preds.extend(preds.cpu().numpy())

            tbar.set_description(f'Epoch/Epochs [{epoch+1}/{num_epochs}]')

    valid_loss = valid_loss / len(valid_dataloader)
    valid_f1 = f1_score(total_labels, total_preds, average='micro')

    return model, valid_loss, valid_f1

def training_loop(model, train_dataloader, train_dataset, valid_dataloader, valid_dataset, criterion, optimizer, threshold, num_epochs):
    model.to(device)
    train_loss_list = []
    train_f1_list = []
    valid_loss_list = []
    valid_f1_list = []

    for epoch in range(num_epochs):
        model, train_loss, train_f1 = training(model, train_dataloader, train_dataset, criterion, optimizer, threshold, epoch, num_epochs)
        model, valid_loss, valid_f1 = evalutation(model, valid_dataloader, valid_dataset, criterion, threshold, epoch, num_epochs)

        train_loss_list.append(train_loss)
        train_f1_list.append(train_f1)
        valid_loss_list.append(valid_loss)
        valid_f1_list.append(valid_f1)

        print(f'Train Loss: {train_loss}, Train F1: {train_f1}, Valid Loss: {valid_loss}, Valid F1: {valid_f1}')

    return model, train_loss_list, train_f1_list, valid_loss_list, valid_f1_list

# model

In [None]:
model = models.resnet50(pretrained=True)
model

In [None]:
for param in model.parameters():
    param.requires_grad = False

In [None]:
model.fc = torch.nn.Linear(2048, 10)

for param in model.fc.parameters():
    param.requires_grad = True

# 학습

In [None]:
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

model, train_loss_list, train_f1_list, valid_loss_list, valid_f1_list = training_loop(model, train_dataloader, train_dataset, valid_dataloader, valid_dataset, criterion, optimizer, 0.5, 10)

# test

In [None]:
model.eval()
total_preds = []
total_labels = []

with torch.no_grad():
    tbar = tqdm(test_dataloader)
    for images, labels in tbar:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        preds = (torch.sigmoid(outputs) > 0.5).int()

        total_labels.extend(labels.cpu().numpy())
        total_preds.extend(preds.cpu().numpy())

In [None]:
f1_score(total_preds, total_labels, average='micro')

In [None]:
plt.figure(figsize=(20, 20))

for i in range(20):
    ax = plt.subplot(4, 5, i+1)
    img = plt.imshow(test_dataset[i][0].permute(1, 2, 0))
    real = []
    preds = []

    for j in range(10):
        if total_labels[i][j] == 1:
            real.append(j)

        if total_preds[i][j] == 1:
            preds.append(j)

    plt.title(f'True: {real}\nPred: {preds}')