### Image classification 
- [venv] anaconda/py310
- [date] 2024/12/23
- [posi] 

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.transforms.functional import crop
from torch.utils.data import TensorDataset
from torch.optim import Adam
from torchvision import transforms, datasets
from torchvision.models import efficientnet_b2, EfficientNet_B2_Weights


from skorch import NeuralNetClassifier
from skorch.callbacks import Checkpoint
from skorch.helper import predefined_split
from sklearn.model_selection import train_test_split
import numpy as np
from torch.utils.data import DataLoader, Subset

In [None]:
# 1. 데이터 로드 및 Augmentation 설정
data_dir = './glasses'  # 데이터셋 루트 경로
batch_size = 32

In [None]:
# 2. 사용자 정의 Transform: image crop
class CustomCrop:
    def __init__(self, top, left, height, width):
        self.top = top
        self.left = left
        self.height = height
        self.width = width

    def __call__(self, image):
        return crop(image, self.top, self.left, self.height, self.width)


crop_transform = CustomCrop(top=0, left=90, height=1050, width=1500)

In [None]:
# 데이터 증강 설정
train_transforms = transforms.Compose([
    crop_transform,
    transforms.Resize((300, 300)),
    transforms.RandomHorizontalFlip(),
    #transforms.RandomRotation(10),
    transforms.RandomResizedCrop(260, scale=(0.87, 0.87)),
    #transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
valid_transforms = transforms.Compose([
    crop_transform,
    transforms.Resize((260, 260)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
# 전체 데이터 로드
dataset = datasets.ImageFolder(root=data_dir, transform=None)
dataset, dataset.classes

In [None]:
# Train/Validation 분리
train_idx, valid_idx = train_test_split(
    np.arange(len(dataset)),
    test_size=0.1,  # 10%를 Validation 데이터로 사용
    stratify=dataset.targets,  # 클래스 비율을 유지
    shuffle=True,            
    random_state=42        
)

In [None]:
# Subset으로 Train/Valid 데이터 생성
train_subset = Subset(dataset, train_idx)
valid_subset = Subset(dataset, valid_idx)

# Subset에 각각 Transform 적용
train_subset.dataset.transform = train_transforms
valid_subset.dataset.transform = valid_transforms

In [None]:
image, label = train_subset[0]
len(train_subset), image.shape, label

In [None]:
X_train = np.array([image for image,_ in train_subset]) 
y_train = np.array([label for _,label in train_subset])
X_valid = np.array([image for image,_ in valid_subset]) 
y_valid = np.array([label for _,label in valid_subset])

In [None]:
# PyTorch TensorDataset으로 변환
train_dataset = TensorDataset(torch.tensor(X_train), torch.tensor(y_train, dtype=torch.long))
valid_dataset = TensorDataset(torch.tensor(X_valid), torch.tensor(y_valid, dtype=torch.long))

In [None]:
# 2. EfficientNet-B2 모델 정의 및 수정
class EfficientNetB2Model(nn.Module):
    def __init__(self, pretrained=True, num_classes=2):
        super().__init__()
        #self.base_model = models.efficientnet_b2(pretrained=pretrained)
        self.base_model = efficientnet_b2(weights=EfficientNet_B2_Weights.DEFAULT)
        self.base_model.classifier[1] = nn.Linear(self.base_model.classifier[1].in_features, num_classes)  # 1408->2

    def forward(self, x):
        return self.base_model(x)
        #return F.softmax(o2, dim=1)


# 모델 인스턴스화
model = EfficientNetB2Model(pretrained=True, num_classes=2)

In [None]:
# 3. Skorch NeuralNetClassifier 정의
net = NeuralNetClassifier(
    model,
    criterion=nn.CrossEntropyLoss,
    optimizer=Adam,
    optimizer__lr=1e-3,
    max_epochs=1000,
    batch_size=batch_size,
    iterator_train__shuffle=True,
    iterator_valid__shuffle=False,
    train_split=predefined_split(valid_dataset),   
    callbacks=[Checkpoint(f_params='best_params.pt')],
    device='cuda' if torch.cuda.is_available() else 'cpu'   
)

In [None]:
net.fit(train_dataset, y=None)

In [None]:
#torch.save(model.state_dict(), "best_params_0001.pt")
model.load_state_dict(torch.load("best_params_0001.pt", weights_only=True))

In [None]:
res = net.predict(valid_dataset)
res, y_valid

In [None]:
# 5. 모델 평가
accuracy = net.score(X_valid, y=y_valid)
print(f'Validation Accuracy: {accuracy:.2f}')

In [None]:
res = net.predict(train_dataset)
res, y_train