<a href="https://colab.research.google.com/github/holoho/CP/blob/main/cp2_project_car_color.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/MyDrive/AI/CP2/cartrain

In [None]:
#!unzip -qq -O cp949 "/content/drive/MyDrive/AI/CP2/cartrain/red.zip" -d  "/content/drive/MyDrive/AI/CP2/cartrain/red"
!unzip -O cp949 "/content/drive/MyDrive/AI/CP2/cartrain/black.zip" -d  "/content/drive/MyDrive/AI/CP2/cartrain/black"
!unzip -O cp949 "/content/drive/MyDrive/AI/CP2/cartrain/blue.zip" -d  "/content/drive/MyDrive/AI/CP2/cartrain/blue"
!unzip -O cp949 "/content/drive/MyDrive/AI/CP2/cartrain/grey.zip" -d  "/content/drive/MyDrive/AI/CP2/cartrain/grey"
!unzip -O cp949 "/content/drive/MyDrive/AI/CP2/cartrain/white.zip" -d  "/content/drive/MyDrive/AI/CP2/cartrain/white"

In [None]:
%cd /content/drive/MyDrive/AI/CP2/carvalid
!unzip -qq -O cp949 "/content/drive/MyDrive/AI/CP2/carvalid/red.zip" -d  "/content/drive/MyDrive/AI/CP2/carvalid/red"
!unzip -qq -O cp949 "/content/drive/MyDrive/AI/CP2/carvalid/black.zip" -d  "/content/drive/MyDrive/AI/CP2/carvalid/black"
!unzip -qq -O cp949 "/content/drive/MyDrive/AI/CP2/carvalid/blue.zip" -d  "/content/drive/MyDrive/AI/CP2/carvalid/blue"
!unzip -qq -O cp949 "/content/drive/MyDrive/AI/CP2/carvalid/grey.zip" -d  "/content/drive/MyDrive/AI/CP2/carvalid/grey"
!unzip -qq -O cp949 "/content/drive/MyDrive/AI/CP2/carvalid/white.zip" -d  "/content/drive/MyDrive/AI/CP2/carvalid/white"

In [None]:
from tqdm.auto import tqdm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch
import argparse
import torch.nn as nn
import torch.optim as optim
import time
import matplotlib
import matplotlib.pyplot as plt
matplotlib.style.use('ggplot')

In [None]:
TRAIN_DIR = '/content/drive/MyDrive/AI/CP2/cartrain/'
VALID_DIR = '/content/drive/MyDrive/AI/CP2/carvalid/'
IMAGE_SIZE = 224 # 이미지 크기 조정
BATCH_SIZE = 32
NUM_WORKERS = 4 
# computation device
device = ('cuda' if torch.cuda.is_available() else 'gpu')
print(f"Computation device: {device}\n")

In [None]:
# train transform
def get_train_transform(IMAGE_SIZE):
    train_transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(30),
        transforms.RandomAdjustSharpness(sharpness_factor=2, p=0.5),
        #transforms.RandomGrayscale(p=0.5),
        transforms.RandomPerspective(distortion_scale=0.5, p=0.5),
        transforms.RandomPosterize(bits=2, p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
            )
    ])
    return train_transform

# Validation transforms
def get_valid_transform(IMAGE_SIZE):
    valid_transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
            )
    ])
    return valid_transform

In [None]:
def get_datasets():

    dataset_train = datasets.ImageFolder(
        TRAIN_DIR, 
        transform=(get_train_transform(IMAGE_SIZE))
    )
    dataset_valid = datasets.ImageFolder(
        VALID_DIR, 
        transform=(get_valid_transform(IMAGE_SIZE))
    )
    return dataset_train, dataset_valid, dataset_train.classes

def get_data_loaders(dataset_train, dataset_valid):
    """
    Prepares the training and validation data loaders.
    :param dataset_train: The training dataset.
    :param dataset_valid: The validation dataset.
    Returns the training and validation data loaders.
    """
    train_loader = DataLoader(
        dataset_train, batch_size=BATCH_SIZE, 
        shuffle=True, num_workers=NUM_WORKERS
    )
    valid_loader = DataLoader(
        dataset_valid, batch_size=BATCH_SIZE, 
        shuffle=False, num_workers=NUM_WORKERS
    )
    return train_loader, valid_loader 

In [None]:
import torch.nn as nn
from torchvision import models

def build_model(pretrained=True, fine_tune=True, num_classes=5):
    if pretrained:
        print('[INFO]: Loading pre-trained weights')
    else:
        print('[INFO]: Not loading pre-trained weights')
    model = models.efficientnet_b0(pretrained=pretrained)
   

    if fine_tune:
        print('[INFO]: Fine-tuning all layers...')
        for params in model.parameters():
            params.requires_grad = True
    elif not fine_tune:
        print('[INFO]: Freezing hidden layers...')
        for params in model.parameters():
            params.requires_grad = False

    # 마지막 분류 변경
    model.classifier[1] = nn.Linear(in_features=1280, out_features=num_classes)
    return model

In [None]:
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

In [None]:
def train(model, trainloader, optimizer, criterion):
    model.train()
    print('Training')
    train_running_loss = 0.0
    train_running_correct = 0
    counter = 0
    for i, data in tqdm(enumerate(trainloader), total=len(trainloader)):
        counter += 1
        image, labels = data
        image = image.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        # Forward pass.
        outputs = model(image)
        # Calculate the loss.
        loss = criterion(outputs, labels)
        train_running_loss += loss.item()
        # Calculate the accuracy.
        _, preds = torch.max(outputs.data, 1)
        train_running_correct += (preds == labels).sum().item()
        # Backpropagation.
        loss.backward()
        # Update the weights.
        optimizer.step()
    
    # Loss and accuracy for the complete epoch.
    epoch_loss = train_running_loss / counter
    epoch_acc = 100. * (train_running_correct / len(trainloader.dataset))
    return epoch_loss, epoch_acc

In [None]:
# Validation function.
def validate(model, testloader, criterion, class_names):
    model.eval()
    print('Validation')
    valid_running_loss = 0.0
    valid_running_correct = 0
    counter = 0

    with torch.no_grad():
        for i, data in tqdm(enumerate(testloader), total=len(testloader)):
            counter += 1
            
            image, labels = data
            image = image.to(device)
            labels = labels.to(device)
            # Forward pass.
            outputs = model(image)
            # Calculate the loss.
            loss = criterion(outputs, labels)
            valid_running_loss += loss.item()
            # Calculate the accuracy.
            _, preds = torch.max(outputs.data, 1)
            valid_running_correct += (preds == labels).sum().item()
        
    # Loss and accuracy for the complete epoch.
    epoch_loss = valid_running_loss / counter
    epoch_acc = 100. * (valid_running_correct / len(testloader.dataset))
    return epoch_loss, epoch_acc


In [None]:
def save_model(epochs, model, optimizer, criterion):

    torch.save({
                'epoch': epochs,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': criterion,
                }, f"/content/drive/MyDrive/AI/CP2/model.pth")

def save_plots(train_acc, valid_acc, train_loss, valid_loss):
    # Accuracy plots.
    plt.figure(figsize=(10, 7))
    plt.plot(
        train_acc, color='red', linestyle='-', 
        label='train accuracy'
    )
    plt.plot(
        valid_acc, color='blue', linestyle='-', 
        label='validataion accuracy'
    )
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.savefig(f"/content/drive/MyDrive/AI/CP2/accuracy.png")
    
    # Loss plots.
    plt.figure(figsize=(10, 7))
    plt.plot(
        train_loss, color='orange', linestyle='-', 
        label='train loss'
    )
    plt.plot(
        valid_loss, color='green', linestyle='-', 
        label='validataion loss'
    )
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(f"/content/drive/MyDrive/AI/CP2/loss.png")
plt.style.use('ggplot')
class SaveBestModel:

    def __init__(
        self, best_valid_loss=float('inf')
    ):
        self.best_valid_loss = best_valid_loss
        
    def __call__(
        self, current_valid_loss, 
        epoch, model, optimizer, criterion
    ):
        if current_valid_loss < self.best_valid_loss:
            self.best_valid_loss = current_valid_loss
            print(f"\nBest validation loss: {self.best_valid_loss}")
            print(f"\nSaving best model for epoch: {epoch+1}\n")
            torch.save({
                'epoch': epoch+1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': criterion,
                }, '/content/drive/MyDrive/AI/CP2/best_model.pth')

In [None]:
dataset_train, dataset_valid, dataset_classes = get_datasets()
print(f"[INFO]: Number of training images: {len(dataset_train)}")
print(f"[INFO]: Number of validation images: {len(dataset_valid)}")

train_loader, valid_loader = get_data_loaders(dataset_train, dataset_valid)

In [None]:
dataset_classes

In [None]:
import torch, gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
lr = 0.001
epochs = 50 
device = ('cuda' if torch.cuda.is_available() else 'gpu')
print(f"Computation device: {device}")
print(f"Learning rate: {lr}")
print(f"Epochs to train for: {epochs}\n")


model = build_model(
    pretrained=True,
    fine_tune=True, 
    num_classes=len(dataset_classes)
).to(device)


total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")

# Optimizer.
optimizer = optim.Adam(model.parameters(), lr=lr)
# Loss function.
criterion = nn.CrossEntropyLoss()

# 초기화
save_best_model = SaveBestModel()

# list for tracking
train_loss, valid_loss = [], []
train_acc, valid_acc = [], []

for epoch in range(epochs):
    print(f"[INFO]: Epoch {epoch+1} of {epochs}")
    train_epoch_loss, train_epoch_acc = train(model, train_loader, 
                                            optimizer, criterion)
    valid_epoch_loss, valid_epoch_acc = validate(model, valid_loader,  
                                                criterion, dataset_classes)
    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)
    train_acc.append(train_epoch_acc)
    valid_acc.append(valid_epoch_acc)
    print(f"Training loss: {train_epoch_loss:.3f}, training acc: {train_epoch_acc:.3f}")
    print(f"Validation loss: {valid_epoch_loss:.3f}, validation acc: {valid_epoch_acc:.3f}")
    # 현 epoch에서 도달하는 최소의 손실의 최적화 모델 찾기
    save_best_model(
          valid_epoch_loss, epoch, model, optimizer, criterion
          )

    print('-'*50)
    time.sleep(2)


save_model(epochs, model, optimizer, criterion)

save_plots(train_acc, valid_acc, train_loss, valid_loss)
print('TRAINING COMPLETE')

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd
import numpy as np



model = build_model(
    pretrained=True,
    fine_tune=True, 
    num_classes=len(dataset_classes)
).to(device)

checkpoint = torch.load('/content/drive/MyDrive/AI/CP2/best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])

y_pred = []
y_true = []





for inputs, labels in valid_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        output = model(inputs) # Feed Network
        output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy()
        y_pred.extend(output) # Save Prediction
        
        labels = labels.data.cpu().numpy()
        y_true.extend(labels) # Save Truth


classes = ('black', 'blue', 'white', 'grey', 'red')


cf_matrix = confusion_matrix(y_true, y_pred)
df_cm = pd.DataFrame(cf_matrix/np.sum(cf_matrix) *10, index = [i for i in classes],
                     columns = [i for i in classes])
plt.figure(figsize = (12,7))
sn.heatmap(df_cm, annot=True)
plt.savefig('output.png')