In [1]:
import torch
import torchvision
from torchvision import transforms
from pascal_voc_multi_label_dataset import PascalVOCMultiLabelDataset



In [2]:
# 데이터 전처리 정의
data_transform = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [3]:
# PASCAL VOC 데이터셋 다운로드 및 로드
voc_train = torchvision.datasets.VOCDetection(root='./data', year='2012', image_set='train', download=True, transform=data_transform)
voc_val = torchvision.datasets.VOCDetection(root='./data', year='2012', image_set='val', download=True, transform=data_transform)

Using downloaded and verified file: ./data/VOCtrainval_11-May-2012.tar
Extracting ./data/VOCtrainval_11-May-2012.tar to ./data
Using downloaded and verified file: ./data/VOCtrainval_11-May-2012.tar
Extracting ./data/VOCtrainval_11-May-2012.tar to ./data


In [4]:
train_dataset = PascalVOCMultiLabelDataset(voc_train)
val_dataset = PascalVOCMultiLabelDataset(voc_val)

In [5]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2)


In [6]:
# !pip install timm
import timm
import torch.nn as nn
import torchvision.models as models

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
# 사전 훈련된 ResNet-50 모델 불러오기
# Swin Transformer 모델 불러오기
model = timm.create_model('swin_base_patch4_window12_384', pretrained=True)
num_classes = 20
# 출력 레이어 수정하기
in_features = model.head.in_features # 기존 출력 레이어의 input feature 개수
out_features = num_classes # 새로운 출력 레이어의 output feature 개수
model.head = torch.nn.Linear(in_features, out_features)

# model = timm.create_model('vit_tiny_r_s16_p8_224', pretrained=True)

# # 출력 계층 수정
# num_classes = 20  # PASCAL VOC의 클래스 개수
# model.fc = nn.Sequential(
#     nn.Linear(model.fc.in_features, num_classes),
#     nn.Sigmoid()  # 멀티 레이블 분류를 위한 Sigmoid 활성화 함수 사용
# )

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [8]:
import torch.optim as optim
# 손실 함수 및 옵티마이저 정의
#criterion = nn.BCELoss()
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [9]:
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

SwinTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
    (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (layers): Sequential(
    (0): BasicLayer(
      (blocks): Sequential(
        (0): SwinTransformerBlock(
          (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (attn): WindowAttention(
            (qkv): Linear(in_features=128, out_features=384, bias=True)
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=128, out_features=128, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
            (softmax): Softmax(dim=-1)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=128, out_features=512, bias=True)
            (act): GELU(approximate='none')
           

In [10]:
# 훈련 및 검증 루프
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
best_val_loss = float('inf')  # 최소 검증 손실 추적을 위한 변수

for epoch in range(num_epochs):
    # 훈련
    model.train()
    train_loss = 0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)

    # 검증
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item()
    val_loss /= len(val_loader)

    # 모델 저장
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model_swin.pth')

    print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")


Epoch 1/10, Train Loss: 0.0872, Validation Loss: 0.0537
Epoch 2/10, Train Loss: 0.0399, Validation Loss: 0.0527
Epoch 3/10, Train Loss: 0.0274, Validation Loss: 0.0546
Epoch 4/10, Train Loss: 0.0161, Validation Loss: 0.0593
Epoch 5/10, Train Loss: 0.0134, Validation Loss: 0.0773
Epoch 6/10, Train Loss: 0.0140, Validation Loss: 0.0628
Epoch 7/10, Train Loss: 0.0119, Validation Loss: 0.0752
Epoch 8/10, Train Loss: 0.0108, Validation Loss: 0.0747
Epoch 9/10, Train Loss: 0.0077, Validation Loss: 0.0781
Epoch 10/10, Train Loss: 0.0082, Validation Loss: 0.0794


In [11]:
import timm
timm.list_models('*vit*')

['convit_base',
 'convit_small',
 'convit_tiny',
 'crossvit_9_240',
 'crossvit_9_dagger_240',
 'crossvit_15_240',
 'crossvit_15_dagger_240',
 'crossvit_15_dagger_408',
 'crossvit_18_240',
 'crossvit_18_dagger_240',
 'crossvit_18_dagger_408',
 'crossvit_base_240',
 'crossvit_small_240',
 'crossvit_tiny_240',
 'gcvit_base',
 'gcvit_small',
 'gcvit_tiny',
 'gcvit_xtiny',
 'gcvit_xxtiny',
 'levit_128',
 'levit_128s',
 'levit_192',
 'levit_256',
 'levit_256d',
 'levit_384',
 'maxvit_base_224',
 'maxvit_large_224',
 'maxvit_nano_rw_256',
 'maxvit_pico_rw_256',
 'maxvit_rmlp_nano_rw_256',
 'maxvit_rmlp_pico_rw_256',
 'maxvit_rmlp_small_rw_224',
 'maxvit_rmlp_small_rw_256',
 'maxvit_rmlp_tiny_rw_256',
 'maxvit_small_224',
 'maxvit_tiny_224',
 'maxvit_tiny_pm_256',
 'maxvit_tiny_rw_224',
 'maxvit_tiny_rw_256',
 'maxvit_xlarge_224',
 'maxxvit_rmlp_nano_rw_256',
 'maxxvit_rmlp_small_rw_256',
 'maxxvit_rmlp_tiny_rw_256',
 'mobilevit_s',
 'mobilevit_xs',
 'mobilevit_xxs',
 'mobilevitv2_050',
 'mobi