In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, SubsetRandomSampler
import torchvision
from torchvision import datasets, transforms
import numpy as np
from sklearn.mixture import GaussianMixture

# 데이터 준비
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

data_size = len(train_data)
indices = np.arange(data_size)
np.random.shuffle(indices)
split = int(data_size * 0.9)

train_sampler = SubsetRandomSampler(indices[:split])
valid_sampler = SubsetRandomSampler(indices[split:])

batch_size = 128
train_loader = DataLoader(train_data, batch_size=batch_size, sampler=train_sampler)
valid_loader = DataLoader(train_data, batch_size=batch_size, sampler=valid_sampler)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 29431049.86it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [15]:
class GMMLoss(nn.Module):
    def __init__(self, num_components, num_classes):
        super(GMMLoss, self).__init__()
        self.num_components = num_components
        self.num_classes = num_classes

    def forward(self, features, labels):
        loss = 0
        for c in range(self.num_classes):
            class_features = features[labels == c]
            if len(class_features) < self.num_components:  # 이 부분 추가
                continue
            gmm = GaussianMixture(n_components=self.num_components).fit(class_features.detach().cpu().numpy())
            log_probabilities = gmm.score_samples(class_features.detach().cpu().numpy())
            proba = torch.tensor(log_probabilities, dtype=torch.float32).to(features.device)

            intra_gmm_distances = torch.mean(proba)

            all_features_log_probs = torch.tensor(gmm.score_samples(features.detach().cpu().numpy()), dtype=torch.float32).to(features.device)
            all_features_log_probs = torch.tensor(gmm.score_samples(features.detach().cpu().numpy()), dtype=torch.float32).to(features.device)
            proba_all_features = all_features_log_probs
            inter_gmm_distances = torch.mean(proba_all_features - proba.view(-1, 1))

            loss += intra_gmm_distances - inter_gmm_distances

        loss /= self.num_classes
        return loss

In [16]:
# 모델 정의
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()

        self.conv_block1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.25)
        )
        
        self.conv_block2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.25)
        )

        self.conv_block3 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.25)
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x = self.classifier(x)
        return x

In [19]:
# 학습 및 검증
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CNN().to(device)

# 특징 추출 함수 정의
def extract_intermediate_features(model, x):
    x = model.conv_block1(x)
    x = model.conv_block2(x)
    x_return = model.conv_block3(x)  # 여기서 특징 추출
    x = model.classifier(x_return.view(x_return.size(0), -1))
    return x, x_return

# 기본 손실 함수 및 GMM 손실 함수를 동시에 사용
criterion = nn.CrossEntropyLoss()
gmm_loss = GMMLoss(num_components=5, num_classes=10)
optimizer = optim.Adam(model.parameters(), lr=0.001)
epochs = 100

best_loss = float('inf')  # 이 부분 추가
for epoch in range(epochs):
    # 학습 및 검증 코드
    for phase in ['train', 'val']:
        if phase == 'train':
            data_loader = train_loader
            model.train()
        else:
            data_loader = valid_loader
            model.eval()

        running_loss = 0.0
        corrects = 0

        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            with torch.set_grad_enabled(phase == 'train'):
                outputs, features = extract_intermediate_features(model, inputs)  # 바뀐 부분
                loss = criterion(outputs, labels) + gmm_loss(features.mean([2, 3]), labels)  # GMM 손실 함수 추가
                _, preds = torch.max(outputs, 1)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(data_loader.dataset)
        epoch_acc = corrects.double() / len(data_loader.dataset)

        if phase == 'train':
            print(f"Epoch {epoch+1}/{epochs} - Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
            if epoch > 10:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = 0.0005
            if epoch > 20:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = 0.0001
            if epoch > 30:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = 0.00005
        else:
            print(f"Epoch {epoch+1}/{epochs} - Val Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
            if epoch_loss < best_loss:
                torch.save(model.state_dict(), 'best_cnn_model.pt')
                best_loss = epoch_loss
        print()

Epoch 1/5 - Train Loss: 8991269.8149 Acc: 0.3857

Epoch 1/5 - Val Loss: 887965.1637 Acc: 0.0537

Epoch 2/5 - Train Loss: 9984119.3982 Acc: 0.5270

Epoch 2/5 - Val Loss: 868092.8924 Acc: 0.0591

Epoch 3/5 - Train Loss: 9701433.2014 Acc: 0.5871

Epoch 3/5 - Val Loss: 848104.1409 Acc: 0.0698



KeyboardInterrupt: ignored