In [2]:
import torch
import numpy as np
from mlp_mixerECG import MLPMixerForECG
import os
import pickle
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from PIL import Image
from tqdm import tqdm

In [None]:
img = torch.ones([1, 3, 224, 224])

model = MLPMixerForECG(in_channels=3, image_size=224, patch_size=16, num_classes=1000,
                 dim=512, depth=8, token_dim=256, channel_dim=2048)

parameters = filter(lambda p: p.requires_grad, model.parameters())
parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
print('Trainable Parameters: %.3fM' % parameters)

out_img = model(img)

print("Shape of out :", out_img.shape)  # [B, in_channels, image_size, image_size]

In [None]:
class CIFAR10Dataset(Dataset):
    def __init__(self, data_dir, transform=None, train=True):
        self.data = []
        self.labels = []
        self.transform = transform

        # 훈련 데이터 파일 또는 테스트 데이터 파일 선택
        if train:
            for i in range(1, 3):
                file_path = os.path.join(data_dir, f'data_batch_{i}')
                with open(file_path, 'rb') as f:
                    batch = pickle.load(f, encoding='bytes')
                self.data.append(batch[b'data'])
                self.labels += batch[b'labels']
        else:
            file_path = os.path.join(data_dir, 'test_batch')
            with open(file_path, 'rb') as f:
                batch = pickle.load(f, encoding='bytes')
            self.data.append(batch[b'data'])
            self.labels += batch[b'labels']
        
        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img, label = self.data[idx], self.labels[idx]
        img = Image.fromarray(img)
        if self.transform:
            img = self.transform(img)
        return img, label

# 데이터 전처리 및 데이터셋 정의
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

data_dir = 'C:/Users/hongi/Desktop/MLP-Mixer-pytorch-master/cifar-10-batches-py'
train_dataset = CIFAR10Dataset(data_dir=data_dir, transform=transform, train=True)
test_dataset = CIFAR10Dataset(data_dir=data_dir, transform=transform, train=False)

# 훈련 데이터셋을 훈련 및 검증 데이터셋으로 분할
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

In [None]:
print(len(train_dataset))  # 훈련 데이터셋의 크기
print(len(val_dataset))    # 검증 데이터셋의 크기
print(len(test_dataset))   # 테스트 데이터셋의 크기

first_data_point, first_label = train_dataset[0]
print(first_data_point.shape)

In [None]:
batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
import torch.nn as nn
model = MLPMixer(in_channels=3, image_size=224, patch_size=16, num_classes=10,  # CIFAR-10 클래스 수
                 dim=512, depth=8, token_dim=256, channel_dim=2048)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
num_epochs = 1000
best_val_accuracy = 0  # 가장 좋은 검증 정확도 추적

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    train_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Training]')

    for inputs, labels in train_bar:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}, Training Loss: {epoch_loss:.10f}")

    # 검증 단계
    model.eval()
    val_running_loss = 0.0  # 검증 손실 누적
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Validation]'):
            outputs = model(images)
            val_loss = criterion(outputs, labels)  # 검증 손실 계산
            val_running_loss += val_loss.item()
            
            _, predicted = torch.max(outputs.data, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

    val_epoch_loss = val_running_loss / len(val_loader)  # 에폭당 평균 검증 손실
    val_accuracy = 100 * val_correct / val_total
    print(f'Epoch {epoch+1}, Validation Loss: {val_epoch_loss:.10f}, Validation Accuracy: {val_accuracy:.4f}%')

    # 가장 좋은 검증 정확도를 기준으로 모델의 가중치 저장
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), 'best_model_weights.pth')  # 모델 가중치 저장
        print(f'New best model saved with accuracy: {best_val_accuracy:.4f}%')

print('Finished Training')
