In [2]:
import torch
import torchvision
from torchvision import transforms
from pascal_voc_multi_label_dataset import PascalVOCMultiLabelDataset



In [3]:
# 데이터 전처리 정의
data_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [4]:
# PASCAL VOC 데이터셋 다운로드 및 로드
voc_train = torchvision.datasets.VOCDetection(root='./data', year='2012', image_set='train', download=True, transform=data_transform)
voc_val = torchvision.datasets.VOCDetection(root='./data', year='2012', image_set='val', download=True, transform=data_transform)

Downloading http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar to ./data/VOCtrainval_11-May-2012.tar


100.0%


Extracting ./data/VOCtrainval_11-May-2012.tar to ./data
Using downloaded and verified file: ./data/VOCtrainval_11-May-2012.tar
Extracting ./data/VOCtrainval_11-May-2012.tar to ./data


In [6]:
train_dataset = PascalVOCMultiLabelDataset(voc_train)
val_dataset = PascalVOCMultiLabelDataset(voc_val)

In [7]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2)


In [9]:
# !pip install timm
import timm
import torch.nn as nn
import torchvision.models as models

Collecting timm
  Downloading timm-0.6.12-py3-none-any.whl (549 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m549.1/549.1 kB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hCollecting huggingface-hub
  Downloading huggingface_hub-0.13.3-py3-none-any.whl (199 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.8/199.8 kB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
Collecting tqdm>=4.42.1
  Downloading tqdm-4.65.0-py3-none-any.whl (77 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m77.1/77.1 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tqdm, huggingface-hub, timm
Successfully installed huggingface-hub-0.13.3 timm-0.6.12 tqdm-4.65.0


  from .autonotebook import tqdm as notebook_tqdm


In [10]:
# 사전 훈련된 ResNet-50 모델 불러오기
model = timm.create_model('resnet50', pretrained=True)

# 출력 계층 수정
num_classes = 20  # PASCAL VOC의 클래스 개수
model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, num_classes),
    nn.Sigmoid()  # 멀티 레이블 분류를 위한 Sigmoid 활성화 함수 사용
)

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth" to /home/jovyan/.cache/torch/hub/checkpoints/resnet50_a1_0-14fe96d1.pth


In [11]:
import torch.optim as optim
# 손실 함수 및 옵티마이저 정의
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [12]:
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act2): ReLU(inplace=True)
      (aa): Identity()
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     

In [13]:
# 훈련 및 검증 루프
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
best_val_loss = float('inf')  # 최소 검증 손실 추적을 위한 변수

for epoch in range(num_epochs):
    # 훈련
    model.train()
    train_loss = 0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)

    # 검증
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item()
    val_loss /= len(val_loader)

    # 모델 저장
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pth')

    print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")


Epoch 1/10, Train Loss: 0.2870, Validation Loss: 0.1754
Epoch 2/10, Train Loss: 0.1420, Validation Loss: 0.1003
Epoch 3/10, Train Loss: 0.1010, Validation Loss: 0.0811
Epoch 4/10, Train Loss: 0.0822, Validation Loss: 0.0722
Epoch 5/10, Train Loss: 0.0688, Validation Loss: 0.0698
Epoch 6/10, Train Loss: 0.0580, Validation Loss: 0.0682
Epoch 7/10, Train Loss: 0.0506, Validation Loss: 0.0676
Epoch 8/10, Train Loss: 0.0433, Validation Loss: 0.0688
Epoch 9/10, Train Loss: 0.0369, Validation Loss: 0.0689
Epoch 10/10, Train Loss: 0.0302, Validation Loss: 0.0713


In [4]:
import timm
timm.list_models('*convnext*')

['convnext_atto',
 'convnext_atto_ols',
 'convnext_base',
 'convnext_base_384_in22ft1k',
 'convnext_base_in22ft1k',
 'convnext_base_in22k',
 'convnext_femto',
 'convnext_femto_ols',
 'convnext_large',
 'convnext_large_384_in22ft1k',
 'convnext_large_in22ft1k',
 'convnext_large_in22k',
 'convnext_nano',
 'convnext_nano_ols',
 'convnext_pico',
 'convnext_pico_ols',
 'convnext_small',
 'convnext_small_384_in22ft1k',
 'convnext_small_in22ft1k',
 'convnext_small_in22k',
 'convnext_tiny',
 'convnext_tiny_384_in22ft1k',
 'convnext_tiny_hnf',
 'convnext_tiny_in22ft1k',
 'convnext_tiny_in22k',
 'convnext_xlarge_384_in22ft1k',
 'convnext_xlarge_in22ft1k',
 'convnext_xlarge_in22k']