In [1]:
import torch
import torch.nn as nn
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from torchvision.models import resnet50
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import copy
import warnings
warnings.filterwarnings(action='ignore')

# Hyperparameter Configuration
CFG = {
    'IMG_SIZE': 224,
    'EPOCHS': 10000,
    'LEARNING_RATE': 1e-6,
    'BATCH_SIZE': 256,
    'SEED': 41
}

torch.manual_seed(CFG['SEED'])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 데이터 증폭을 위한 변환
train_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.RandomResizedCrop(CFG['IMG_SIZE'], scale=(0.8, 1.0)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

val_transforms = transforms.Compose([
    transforms.Resize((CFG['IMG_SIZE'], CFG['IMG_SIZE'])),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

dataset = ImageFolder(root='./INSECT_CLASSIFICATION/FINAL/', transform=train_transforms)
class_names = dataset.classes

# Stratified Split
train_idx, val_idx = train_test_split(
    range(len(dataset)), test_size=0.2, stratify=dataset.targets, random_state=CFG['SEED']
)

train_dataset = Subset(dataset, train_idx)
val_dataset = Subset(dataset, val_idx)
val_dataset.dataset.transform = val_transforms

train_loader = DataLoader(train_dataset, batch_size=CFG['BATCH_SIZE'], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=CFG['BATCH_SIZE'])

class_counts = [0] * 29
for _, label in dataset:
    class_counts[label] += 1

max_count = max(class_counts)
class_weights = [max_count/count for count in class_counts]
class_weights = torch.tensor(class_weights).to(device)

In [2]:
model = resnet50(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 29)
model = model.to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.Adam(model.parameters(), lr=CFG['LEARNING_RATE'])

# Training Function
def train(model, optimizer, train_loader, scheduler, device): 
    model.to(device)
    best_acc = 0
    
    for epoch in tqdm(range(1, CFG["EPOCHS"]+1)): 
        model.train() 
        running_loss = 0.0

        # for img, label in tqdm(train_loader, desc=f"Epoch {epoch}/{CFG['EPOCHS']} - Training"):
        for img, label in train_loader: # , desc=f"Epoch {epoch}/{CFG['EPOCHS']} - Training"):
            img, label = img.to(device), label.to(device)
            optimizer.zero_grad() 

            logit = model(img) 
            loss = criterion(logit, label)
            
            loss.backward() 
            optimizer.step() 
            running_loss += loss.item()

        
        
        if scheduler is not None:
            scheduler.step()
            
        model.eval() 
        vali_loss = 0.0
        correct = 0

        with torch.no_grad():  # Here we disable the autograd for validation phase
            # for img, label in tqdm(val_loader, desc=f"Epoch {epoch}/{CFG['EPOCHS']} - Validation"):
            for img, label in val_loader: # , desc=f"Epoch {epoch}/{CFG['EPOCHS']} - Validation"):
                img, label = img.to(device), label.to(device)

                logit = model(img)
                vali_loss += criterion(logit, label)
                pred = logit.argmax(dim=1, keepdim=True)
                correct += pred.eq(label.view_as(pred)).sum().item()

        vali_acc = 100 * correct / len(val_loader.dataset)
        
        
        if best_acc < vali_acc:
            best_acc = vali_acc
            print('[%d] Train loss: %.10f' %(epoch, running_loss / len(train_loader)))
            print('Vail set: Loss: {:.4f}, Accuracy: {}/{} ( {:.0f}%)\n'.format(vali_loss / len(val_loader), correct, len(val_loader.dataset), 100 * correct / len(val_loader.dataset)))
            torch.save(model.state_dict(), 'data/save_data/saved/best_model3.pth') 
            print('Model Saved.')

# Initiate Training
train(model, optimizer, train_loader, None, device)

  0%|          | 0/10000 [00:00<?, ?it/s]