In [1]:
import torch
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from transformers import AutoModelForImageClassification, AutoImageProcessor
from tqdm import tqdm

In [2]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.mps.manual_seed(seed)

seed_everything(42) # Seed 고정

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

NameError: name 'random' is not defined

In [None]:
# 1. 데이터 불러오기
df = pd.read_csv('./train.csv')
train_df, valid_df = train_test_split(df, test_size=0.1, random_state=42)

In [None]:
# 2. 레이블 인코딩
labels = df['label'].unique()
label_to_idx = {label: idx for idx, label in enumerate(labels)}
idx_to_label = {idx: label for label, idx in label_to_idx.items()}

In [None]:
# 3. 데이터셋 클래스 정의
class BirdDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_path = self.dataframe.iloc[idx]['img_path']
        image = Image.open(img_path).convert('RGB')
        label = self.dataframe.iloc[idx]['label']
        label_idx = label_to_idx[label]

        if self.transform:
            image = self.transform(image)

        return image, label_idx

In [None]:
# 4. 데이터 변환 정의
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

valid_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


In [None]:
# 5. 데이터셋 및 데이터로더 생성
train_dataset = BirdDataset(train_df, transform=train_transform)
valid_dataset = BirdDataset(valid_df, transform=valid_transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False)

In [None]:
from transformers import AutoModelForImageClassification, AutoConfig

# 6. 모델 정의 부분을 다음과 같이 수정
model_name = "microsoft/swinv2-large-patch4-window12to16-192to256-22kto1k-ft"

# 설정 로드
config = AutoConfig.from_pretrained(model_name)
config.num_labels = len(labels)
config.id2label = idx_to_label
config.label2id = label_to_idx

# 모델 로드 및 분류기 교체
model = AutoModelForImageClassification.from_pretrained(
    model_name,
    config=config,
    ignore_mismatched_sizes=True
)

# 새로운 분류기 초기화
model.classifier = torch.nn.Linear(model.classifier.in_features, len(labels))

model.to(device)
# 7. 손실 함수와 옵티마이저 정의
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

In [None]:
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    
    # tqdm을 사용하여 진행 상황 표시
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        
        # 진행 상황 업데이트
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    train_loss /= len(train_loader)
    
    # 검증
    model.eval()
    valid_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        pbar = tqdm(valid_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Valid]')
        for images, labels in pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            valid_loss += loss.item()
            
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            # 진행 상황 업데이트
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    valid_loss /= len(valid_loader)
    accuracy = correct / total
    
    print(f'Epoch {epoch+1}/{num_epochs}:')
    print(f'Train Loss: {train_loss:.4f}')
    print(f'Valid Loss: {valid_loss:.4f}')
    print(f'Valid Accuracy: {accuracy:.4f}')

In [None]:
# 9. 모델 저장
torch.save(model.state_dict(), 'bird_classifier.pth')
print("Model saved.")