In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from torch.amp import autocast, GradScaler  

from model.convnextv2 import ConvNeXtV2
from model.convnextv2_moe import ConvNeXtV2_MoE
from model.convnextv2_moe_grn import ConvNeXtV2_MoE_GRN 

# 파라미터 설정

## model
# input_dim = output_dim = 3072  # cifar10 이미지 크기
# hidden_dim = 784
num_classes = 10     # MNIST 클래스 수
# num_experts = 5
# topk = 2
# noise_std = 0.1

## train
batch_size = 256  
lambda_cov = 0.1  # 공분산 손실의 가중치
epochs = 10       

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
convnext = ConvNeXtV2(num_classes=num_classes)
convnext_moe = ConvNeXtV2_MoE(num_classes=num_classes)
convnext_moe_grn = ConvNeXtV2_MoE_GRN(num_classes=num_classes)

In [3]:
from torchinfo import summary

summary(convnext, input_size=(1, 3, 32, 32), depth=3, col_names=["input_size", "output_size", "num_params"])

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
ConvNeXtV2                               [1, 3, 32, 32]            [1, 10]                   --
├─ModuleList: 1-7                        --                        --                        (recursive)
│    └─Sequential: 2-1                   [1, 3, 32, 32]            [1, 96, 8, 8]             --
│    │    └─Conv2d: 3-1                  [1, 3, 32, 32]            [1, 96, 8, 8]             4,704
│    │    └─LayerNorm: 3-2               [1, 96, 8, 8]             [1, 96, 8, 8]             192
├─ModuleList: 1-8                        --                        --                        (recursive)
│    └─Sequential: 2-2                   [1, 96, 8, 8]             [1, 96, 8, 8]             --
│    │    └─Block: 3-3                   [1, 96, 8, 8]             [1, 96, 8, 8]             79,968
│    │    └─Block: 3-4                   [1, 96, 8, 8]             [1, 96, 8, 8]             79,968
│    

In [4]:
summary(convnext_moe, input_size=(1, 3, 32, 32), col_names=["input_size", "output_size", "num_params"])

Layer (type:depth-idx)                        Input Shape               Output Shape              Param #
ConvNeXtV2_MoE                                [1, 3, 32, 32]            [1, 10]                   --
├─ModuleList: 1-7                             --                        --                        (recursive)
│    └─Sequential: 2-1                        [1, 3, 32, 32]            [1, 96, 8, 8]             --
│    │    └─Conv2d: 3-1                       [1, 3, 32, 32]            [1, 96, 8, 8]             4,704
│    │    └─LayerNorm: 3-2                    [1, 96, 8, 8]             [1, 96, 8, 8]             192
├─ModuleList: 1-8                             --                        --                        (recursive)
│    └─ModuleList: 2-2                        --                        --                        --
│    │    └─Block: 3-3                        [1, 96, 8, 8]             [1, 96, 8, 8]             80,162
│    │    └─Block: 3-4                        [1, 96, 8, 8] 

In [5]:
summary(convnext_moe_grn, input_size=(1, 3, 32, 32), depth=3, col_names=["input_size", "output_size", "num_params"])

Layer (type:depth-idx)                        Input Shape               Output Shape              Param #
ConvNeXtV2_MoE_GRN                            [1, 3, 32, 32]            [1, 10]                   --
├─ModuleList: 1-7                             --                        --                        (recursive)
│    └─Sequential: 2-1                        [1, 3, 32, 32]            [1, 96, 8, 8]             --
│    │    └─Conv2d: 3-1                       [1, 3, 32, 32]            [1, 96, 8, 8]             4,704
│    │    └─LayerNorm: 3-2                    [1, 96, 8, 8]             [1, 96, 8, 8]             192
├─ModuleList: 1-8                             --                        --                        (recursive)
│    └─ModuleList: 2-2                        --                        --                        --
│    │    └─Block: 3-3                        [1, 96, 8, 8]             [1, 96, 8, 8]             80,166
│    │    └─Block: 3-4                        [1, 96, 8, 8] 

In [6]:
# GPU 사용 여부 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train(model, train_loader, optimizer, criterion, epochs=1):
    model.to(device)
    model.train()
    scaler = GradScaler()  # GradScaler 초기화
    for epoch in range(epochs):
        total_loss = 0
        for images, labels in tqdm(train_loader):
            images = images.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            with autocast('cuda'):  # autocast 사용
                if isinstance(model, (ConvNeXtV2_MoE)):
                    outputs, l_aux = model(images)
                    loss = criterion(outputs, labels) + l_aux
                else:
                    outputs = model(images)
                    loss = criterion(outputs, labels)

            scaler.scale(loss).backward()  # 손실 스케일링 후 역전파
            scaler.step(optimizer)         # Optimizer 스텝
            scaler.update()                # 스케일러 업데이트

            total_loss += loss.item()
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss / len(train_loader):.4f}")

# 테스트 함수 (변경 없음)
def test(model, test_loader):
    model.to(device)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(test_loader):
            images = images.to(device)
            labels = labels.to(device)
            if isinstance(model, (ConvNeXtV2_MoE)):
                outputs, _ = model(images)
            else:
                outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f"테스트 정확도: {100 * correct / total:.2f}%")
    return 100 * correct / total


In [7]:
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 데이터셋 로드 및 전처리
transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.6,1), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=1., scale=(0.02, 0.33)),
])

train_dataset = datasets.CIFAR10(root='./cifar10_data/', train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR10(root='./cifar10_data/', train=False, transform=transform, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
optimizer = optim.AdamW(convnext.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
train(convnext, train_loader, optimizer, criterion, epochs=epochs)
test(convnext, test_loader)

100%|██████████| 196/196 [02:16<00:00,  1.44it/s]


Epoch [1/10], Loss: 2.3876


100%|██████████| 196/196 [02:15<00:00,  1.44it/s]


Epoch [2/10], Loss: 1.7853


100%|██████████| 196/196 [02:16<00:00,  1.44it/s]


Epoch [3/10], Loss: 1.5552


100%|██████████| 196/196 [02:15<00:00,  1.45it/s]


Epoch [4/10], Loss: 1.3572


100%|██████████| 196/196 [02:15<00:00,  1.44it/s]


Epoch [5/10], Loss: 1.1718


100%|██████████| 196/196 [02:15<00:00,  1.45it/s]


Epoch [6/10], Loss: 1.0568


100%|██████████| 196/196 [02:15<00:00,  1.44it/s]


Epoch [7/10], Loss: 0.9388


  0%|          | 0/196 [00:00<?, ?it/s]

In [14]:
optimizer = optim.AdamW(convnext_moe.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
train(convnext_moe, train_loader, optimizer, criterion, epochs=epochs)
test(convnext_moe, test_loader)

100%|██████████| 196/196 [02:33<00:00,  1.28it/s]


Epoch [1/30], Loss: 2.7283


  1%|          | 1/196 [00:01<04:18,  1.32s/it]

In [None]:
optimizer = optim.AdamW(convnext_moe_grn.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
train(convnext_moe_grn, train_loader, optimizer, criterion, epochs=epochs)
test(convnext_moe_grn, test_loader)