In [None]:
from lib.skin_disease.data_module import DataModule
from lib.skin_disease.model_factory import get_model_and_preprocess
from lib.skin_disease.trainer import Trainer


def train(model_name, epochs=10):
    model, preprocess = get_model_and_preprocess(model_name)

    dm = DataModule(
        data_dir='../../dataset/skin_diseases/train',
        img_size=preprocess.resize_size[0],
        mean=preprocess.mean,
        std=preprocess.std,
    )
    dm.setup()

    train_files = {
        dm.train_ds.subset.dataset.samples[i][0] for i in dm.train_ds.subset.indices
    }
    val_files = {
        dm.val_ds.subset.dataset.samples[i][0] for i in dm.val_ds.subset.indices
    }

    overlap = train_files.intersection(val_files)
    print(f'중복된 파일 개수: {len(overlap)}')  # 0이 나와야 정상입니다.
    overlap = train_files.intersection(val_files)
    print(f'중복된 파일 개수: {len(overlap)}')  # 0이 나와야 정상입니다.
    train_loader, val_loader = dm.get_loaders()

    trainer = Trainer(model, train_loader, val_loader, model_name, lr=1e-4)
    trainer.run(epochs=epochs)

    return model, dm, trainer

In [None]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

from lib.skin_disease.test_model import test_model


def report(model, dm, trainer):
    test_transform = transforms.Compose(
        [
            transforms.Resize((dm.img_size, dm.img_size)),
            transforms.ToTensor(),
            transforms.Normalize(dm.mean, dm.std),
        ]
    )

    test_dataset = datasets.ImageFolder(
        root='../../dataset/skin_diseases/test', transform=test_transform
    )
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2)

    test_model(model, test_loader, test_dataset.classes, trainer.device)

In [None]:
model, dm, trainer = train('resnet34')
report(model, dm, trainer)

In [None]:
model, dm, trainer = train('convnext_tiny', epochs=1)
report(model, dm, trainer)

In [None]:
model, dm, trainer = train('efficientnet_v2_s')
report(model, dm, trainer)

In [None]:
# import numpy as np
# from matplotlib import pyplot as plt


# def visualize_results(model, test_loader, classes, device, n=5):
#     model.eval()
#     plt.figure(figsize=(15, 5))

#     # 로더에서 한 배포 가져오기
#     imgs, labels = next(iter(test_loader))
#     imgs, labels = imgs.to(device), labels.to(device)
#     outputs = model(imgs)
#     preds = outputs.argmax(dim=1)

#     for i in range(n):
#         img = imgs[i].cpu().numpy().transpose((1, 2, 0))
#         # 정규화 되돌리기 (시각화용)
#         mean = np.array([0.485, 0.456, 0.406])
#         std = np.array([0.229, 0.224, 0.225])
#         img = std * img + mean
#         img = np.clip(img, 0, 1)

#         plt.subplot(1, n, i + 1)
#         plt.imshow(img)
#         plt.title(f'True: {classes[labels[i]]}\nPred: {classes[preds[i]]}')
#         plt.axis('off')
#     plt.show()


# test_loader = []

# visualize_results(model, test_loader, dm.classes, trainer.device)