In [None]:

import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch.utils
from torchvision.transforms import v2
import torchvision.transforms as T
from torchvision import transforms
import torch
from torchvision.datasets import ImageFolder
import koreanize_matplotlib 
from torchmetrics.classification import F1Score, BinaryF1Score, Accuracy
from torchmetrics.classification import BinaryConfusionMatrix
from torchinfo import summary
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim 
from torchvision.models import alexnet, AlexNet_Weights
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler
from sklearn.preprocessing import *
from sklearn.model_selection import train_test_split
import os

# Set device to GPU if available
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")


In [None]:

# Set Image Data paths
TEST_IMG_DIR = "./data/train/"
VAL_IMG_DIR = "./data/Validation/"

# Compose Transformations
transform = transforms.Compose(
    [ 
        transforms.Resize(size=(256, 256)),
        transforms.ToTensor(),                                      # Image Tensor로 변환
        transforms.ColorJitter(),                                   # Image 색상 무작위 조정
        transforms.RandomResizedCrop(size=(224, 224)),              # Image 무작위 자르고 크기조절 
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))  
    ]
)

# Load Datasets
TRAINDS = ImageFolder(root=TEST_IMG_DIR, transform=transform)
VALDS = ImageFolder(root=VAL_IMG_DIR, transform=transform)


In [None]:

# Image Dataset 확인
print(f"[Image Dataset]\n{TRAINDS}")
print(f'[classes] {TRAINDS.classes}, {TRAINDS.class_to_idx}')
print(f'[Targets] {TRAINDS.targets}')
print('[imgs]')
for item in TRAINDS.imgs : print(item)

# Fake/Real Counts
FAKE, REAL = 0, 0
for _ in TRAINDS.targets :
    if _ == 0 :
        FAKE += 1
    elif _ == 1 :
        REAL += 1
print(f"FAKE : {FAKE}개, REAL : {REAL}개")

# Total count of training targets
print(f"Total Training Samples: {len(TRAINDS.targets)}")


In [None]:

# DataLoader with specified batch size
BATCH_SIZE = 128
TRAINDL = DataLoader(TRAINDS, batch_size=BATCH_SIZE, shuffle=True)
VALDL = DataLoader(VALDS, batch_size=BATCH_SIZE, shuffle=True)


In [None]:

# Load AlexNet model with pre-trained weights
model = alexnet(weights=AlexNet_Weights.DEFAULT)

# Modify the classifier for binary classification
num_features = model.classifier[6].in_features
model.classifier[6] = nn.Linear(num_features, 2)

# Move the model to the selected device (GPU or CPU)
model.to(DEVICE)

# Display model summary
summary(model)


In [None]:

# Set all parameters to require gradient computation
for param in model.parameters():
    param.requires_grad = True

# Training setup
EPOCH = 10
LR = 0.001
patience = 7
reqLoss = nn.BCELoss()

optimizer = optim.Adam(model.parameters(), lr=LR)
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", patience=patience, verbose=True)
LOSS_HISTORY, SCORE_HISTORY = [[], []], [[], []]

# Training Loop
for epoch in range(EPOCH):
    model.train()
    running_loss = 0.0
    score_total = 0.0
    
    for inputs, labels in TRAINDL:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)  # Move to device
        labels = labels.float()

        # Forward pass
        outputs = model(inputs)
        outputs = torch.sigmoid(outputs[:, 1])

        # Compute loss
        loss = reqLoss(outputs, labels)
        running_loss += loss.item()

        # Calculate F1 Score
        predictions = (outputs > 0.5).float()
        score = BinaryF1Score()(predictions, labels)
        score_total += score.item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Average loss and F1 score for training
    avg_train_loss = running_loss / len(TRAINDL)
    avg_train_score = score_total / len(TRAINDL)

    # Validation
    model.eval()
    val_loss = 0.0
    val_score_total = 0.0
    
    with torch.no_grad():
        for val_inputs, val_labels in VALDL:
            val_inputs, val_labels = val_inputs.to(DEVICE), val_labels.to(DEVICE)  # Move to device
            val_labels = val_labels.float()

            val_outputs = model(val_inputs)
            val_outputs = torch.sigmoid(val_outputs[:, 1])

            val_loss += reqLoss(val_outputs, val_labels).item()

            val_predictions = (val_outputs > 0.5).float()
            val_score = BinaryF1Score()(val_predictions, val_labels)
            val_score_total += val_score.item()
    
    # Average validation loss and score
    avg_val_loss = val_loss / len(VALDL)
    avg_val_score = val_score_total / len(VALDL)

    # Log results
    LOSS_HISTORY[0].append(avg_train_loss)
    LOSS_HISTORY[1].append(avg_val_loss)
    SCORE_HISTORY[0].append(avg_train_score)
    SCORE_HISTORY[1].append(avg_val_score)
    
    print(f'[{epoch}/{EPOCH}] [TRAIN] LOSS: {avg_train_loss:.4f}, SCORE: {avg_train_score:.4f}')
    print(f'[VALID] LOSS: {avg_val_loss:.4f}, SCORE: {avg_val_score:.4f}')
    
    # Step scheduler
    scheduler.step(avg_val_loss)

    # Early stopping check
    if scheduler.num_bad_epochs >= scheduler.patience:
        print(f"Stopping early at epoch {epoch} due to no improvement!")
        break

    # Save model if validation score improves
    if len(SCORE_HISTORY[1]) == 1 or SCORE_HISTORY[1][-1] > max(SCORE_HISTORY[1][:-1]):
        SAVE_PATH = './models/project/BCF/'
        SAVE_FILE = SAVE_PATH + 'model_train_wb.pth'
        SAVE_MODEL = SAVE_PATH + 'model_all.pth'

        if not os.path.exists(SAVE_PATH):
            os.makedirs(SAVE_PATH)

        torch.save(model.state_dict(), SAVE_FILE)
        torch.save(model, SAVE_MODEL)


In [None]:

# Visualization of loss and F1 score
def plot_loss(LOSS_HISTORY):
    plt.figure(figsize=(10, 5))
    plt.plot(LOSS_HISTORY[0], label='Train Loss', color='blue', marker='o')
    plt.plot(LOSS_HISTORY[1], label='Validation Loss', color='orange', marker='o')
    plt.title("Train vs Validation Loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True)
    plt.show()

def plot_f1_score(SCORE_HISTORY):
    plt.figure(figsize=(10, 5))
    plt.plot(SCORE_HISTORY[0], label='Train F1 Score', color='blue', marker='o')
    plt.plot(SCORE_HISTORY[1], label='Validation F1 Score', color='orange', marker='o')
    plt.title("Train vs Validation F1 Score")
    plt.xlabel("Epochs")
    plt.ylabel("F1 Score")
    plt.legend()
    plt.grid(True)
    plt.show()

# Plot results
plot_loss(LOSS_HISTORY)
plot_f1_score(SCORE_HISTORY)
