In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install split-folders

Collecting split-folders
  Downloading split_folders-0.5.1-py3-none-any.whl.metadata (6.2 kB)
Downloading split_folders-0.5.1-py3-none-any.whl (8.4 kB)
Installing collected packages: split-folders
Successfully installed split-folders-0.5.1


In [None]:
import torch
import torch.nn as nn
from torchvision.models import DenseNet121_Weights, densenet121
from transformers import ViTModel, ViTConfig
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import time
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

import kagglehub
import shutil
import os
import timm

In [None]:
liewyousheng_minc2500_path = kagglehub.dataset_download('liewyousheng/minc2500')

Downloading from https://www.kaggle.com/api/v1/datasets/download/liewyousheng/minc2500?dataset_version_number=1...


100%|██████████| 2.10G/2.10G [00:12<00:00, 184MB/s]


Extracting files...


In [None]:
base_path = liewyousheng_minc2500_path

In [None]:
source_dir = os.path.join(base_path, "minc-2500", "images")

selected_classes = ['brick', 'carpet', 'ceramic', 'fabric', 'foliage', 'food', 'glass', 'hair', 'leather',
                    'metal', 'mirror', 'other', 'painted', 'paper', 'plastic', 'polishedstone', 'skin',
                    'sky', 'stone', 'tile', 'wallpaper', 'water', 'wood']

new_dir = "/kaggle/working/selected_images"

os.makedirs(new_dir, exist_ok=True)

for class_name in selected_classes:
    class_dir = os.path.join(source_dir, class_name)
    if os.path.isdir(class_dir):
        shutil.copytree(class_dir, os.path.join(new_dir, class_name))

In [None]:
import splitfolders
splitfolders.ratio(new_dir, output="/kaggle/working/Splitted", seed=1337, ratio=(0.85, 0.05, 0.10))

print("Selected classes have been split into train, validation, and test sets.")

Copying files: 57500 files [00:07, 7471.20 files/s]

Selected classes have been split into train, validation, and test sets.





In [None]:
# path 바꿀 필요 없음

train_path = '/kaggle/working/Splitted/train'
val_path = '/kaggle/working/Splitted/val'
test_path = '/kaggle/working/Splitted/test'

In [None]:
IMAGE_SIZE = 224
TRAIN_BATCH_SIZE = 512
VAL_BATCH_SIZE = 512
TEST_BATCH_SIZE = 512

In [None]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

In [None]:
train_transforms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomRotation(30),
    transforms.RandomAffine(degrees=0, scale=(0.8, 1.2)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])

In [None]:
val_test_transforms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
])


train_dataset = datasets.ImageFolder(
    root=train_path,
    transform=train_transforms
)

val_dataset = datasets.ImageFolder(
    root=val_path,
    transform=val_test_transforms
)

test_dataset = datasets.ImageFolder(
    root=test_path,
    transform=val_test_transforms
)

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=TRAIN_BATCH_SIZE,
    shuffle=True,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=VAL_BATCH_SIZE,
    shuffle=False,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=TEST_BATCH_SIZE,
    shuffle=False,
    pin_memory=True
)

훈련 데이터 로더 준비 완료. 배치 크기: 512
검증 데이터 로더 준비 완료. 배치 크기: 512
테스트 데이터 로더 준비 완료. 배치 크기: 512


In [None]:
class FrequencyFeatureExtractor(nn.Module):
    def __init__(self, output_dim, image_size=224):
        super(FrequencyFeatureExtractor, self).__init__()
        self.image_size = image_size

        self.spectrum_processor = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.projection = nn.Linear(64, output_dim)

    def forward(self, x):
        x_gray = torch.mean(x, dim=1, keepdim=True)
        f = torch.fft.fft2(x_gray, dim=(-2, -1))
        f_shifted = torch.fft.fftshift(f, dim=(-2, -1))
        magnitude_spectrum = torch.log(1 + torch.abs(f_shifted))

        processed_features = self.spectrum_processor(magnitude_spectrum)
        processed_features = torch.flatten(processed_features, 1)

        frequency_features = self.projection(processed_features)

        return frequency_features

In [None]:
import torch
import torch.nn as nn
import timm
from torchvision.models import DenseNet121_Weights, densenet121
import os

class FineTunedFusionClassifier(nn.Module):
    def __init__(self, num_classes=23, vit_model_name='vit_base_patch16_224', vit_weight_path=None):
        super(FineTunedFusionClassifier, self).__init__()

        # ViT model
        self.vit_model = timm.create_model(
            vit_model_name,
            pretrained=False,
            num_classes=0,
        )
        self.vit_embed_dim = self.vit_model.num_features # 768

        # ViT 가중치 로드
        if vit_weight_path and os.path.exists(vit_weight_path):
            print(f"INFO: Custom ViT weights loading from {vit_weight_path}...")
            custom_weights = torch.load(vit_weight_path)
            self.vit_model.load_state_dict(custom_weights, strict=False)

        # DenseNet
        self.cnn_model = densenet121(weights=DenseNet121_Weights.IMAGENET1K_V1)
        self.cnn_features = self.cnn_model.features
        self.cnn_feature_dim = self.cnn_model.classifier.in_features # 1024

        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))

        self.cnn_projection = nn.Sequential(
            nn.Linear(self.cnn_feature_dim, self.vit_embed_dim),
            nn.BatchNorm1d(self.vit_embed_dim),
            nn.GELU()
        )

        # FFT feature extractor
        self.fft_extractor = FrequencyFeatureExtractor(output_dim=self.vit_embed_dim)

        total_fused_dim = 3 * self.vit_embed_dim

        self.classifier = nn.Sequential(
            nn.Linear(total_fused_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Linear(512, num_classes)
        )

        self._set_trainable_layers()

    def _set_trainable_layers(self):
        for param in self.vit_model.parameters():
            param.requires_grad = False
        for param in self.cnn_features.parameters():
            param.requires_grad = False
        for name, module in self.fft_extractor.named_children():
            for param in module.parameters():
                param.requires_grad = False

        # ViT의 마지막 1개 블록만 학습
        layers_to_train = 1
        for block in self.vit_model.blocks[-layers_to_train:]:
            for param in block.parameters():
                param.requires_grad = True

        # ViT의 Layer Norm 학습
        for param in self.vit_model.norm.parameters():
            param.requires_grad = True

        # DenseNet의 마지막 DenseBlock만 학습
        for param in self.cnn_features.denseblock4.parameters():
            param.requires_grad = True
        for param in self.cnn_features.norm5.parameters(): # 마지막 Norm
            param.requires_grad = True

        # projection, classifier 학습
        for param in self.cnn_projection.parameters():
            param.requires_grad = True

        # FFT Extractor 내부의 projection 레이어 학습
        if hasattr(self.fft_extractor, 'projection'):
            for param in self.fft_extractor.projection.parameters():
                param.requires_grad = True

        for param in self.classifier.parameters():
            param.requires_grad = True

    def forward(self, x):
        # 1. ViT (CLS Token)
        vit_features = self.vit_model.forward_features(x)
        cls_token = vit_features[:, 0, :] # (B, 768)

        # 2. DenseNet
        cnn_features_map = self.cnn_features(x)
        cnn_features_vec = torch.flatten(self.global_avg_pool(cnn_features_map), 1)
        cnn_frequency_features = self.cnn_projection(cnn_features_vec) # (B, 768)

        # 3. FFT
        fft_frequency_features = self.fft_extractor(x) # (B, 768)

        # 4. Fusion
        fused_features = torch.cat((cls_token, cnn_frequency_features, fft_frequency_features), dim=1)

        # 5. Classifier
        logits = self.classifier(fused_features)

        return logits

In [None]:
NUM_CLASSES = 23
NUM_EPOCHS = 60
LEARNING_RATE = 1e-4

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

사용 장치: cuda


In [None]:
# MINC dataset에 pretrained된 ViT의 pt path로 바꾸기

VIT_WEIGHT_PATH = "YOUR_PATH"

In [None]:
model = FineTunedFusionClassifier(
    num_classes=NUM_CLASSES,
    vit_weight_path=VIT_WEIGHT_PATH
).to(device)

In [None]:
learnable_params = filter(lambda p: p.requires_grad, model.parameters())

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(learnable_params, lr=LEARNING_RATE)

In [None]:
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.1,
    patience=5,
)

In [None]:
# learnable_params = []
# for name, param in model.named_parameters():
#     if param.requires_grad:
#         learnable_params.append(param)
#         print(f"✅ 학습 가능: {name} ({param.numel()}개)")

In [None]:
trainable_params_count = sum(p.numel() for p in learnable_params)
total_params_count = sum(p.numel() for p in model.parameters())
print(f"총 파라미터 수: {total_params_count / 1e6:.2f}M")
print(f"최종 학습 가능한 파라미터 수: {trainable_params_count / 1e6:.2f}M")
print(f"최종 학습 가능한 파라미터 비율: {trainable_params_count / total_params_count * 100:.2f}%")

총 파라미터 수: 97.53M
최종 학습 가능한 파라미터 수: 14.01M
최종 학습 가능한 파라미터 비율: 14.37%


In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    for inputs, labels in tqdm(loader, desc="Training"):
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct_predictions += torch.sum(preds == labels.data)
        total_samples += inputs.size(0)

    epoch_loss = total_loss / total_samples
    epoch_acc = correct_predictions.double() / total_samples
    return epoch_loss, epoch_acc.item()

In [None]:
def validate_model(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    with torch.no_grad():
        for inputs, labels in tqdm(loader, desc="Validation"):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct_predictions += torch.sum(preds == labels.data)
            total_samples += inputs.size(0)

    epoch_loss = total_loss / total_samples
    epoch_acc = correct_predictions.double() / total_samples
    return epoch_loss, epoch_acc.item()

In [None]:
best_val_acc = 0.0
print("\n--- 학습 시작 ---")

for epoch in range(NUM_EPOCHS + 1):
    start_time = time.time()

    # train
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)

    # validation
    val_loss, val_acc = validate_model(model, val_loader, criterion, device)

    scheduler.step(val_loss)

    epoch_duration = time.time() - start_time

    current_lr = optimizer.param_groups[0]['lr']
    print(f"\n[Epoch {epoch+1}/{NUM_EPOCHS}] Time: {epoch_duration:.2f}s LR: {current_lr:.1e}")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        model_save_path = f'best_material_classifier_epoch_{epoch+1}.pth'
        torch.save(model.state_dict(), model_save_path)
        print(f"  >>> 최적 모델 저장: {model_save_path} (Acc: {best_val_acc:.4f})")

print("\n--- 학습 완료 ---")
print(f"최고 검증 정확도: {best_val_acc:.4f}")