<a href="https://colab.research.google.com/github/cjfghk5697/Classificiation_BMD/blob/main/bmd_sample_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip3 install timm torchmetrics



In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import torch
import torch.optim as optim
import torch.nn as nn
import timm
import datetime
import csv

from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchmetrics import AUROC, ROC
from pathlib import Path

In [None]:
from google.colab import drive
drive.mount('/gdrive')

Drive already mounted at /gdrive; to attempt to forcibly remount, call drive.mount("/gdrive", force_remount=True).


In [55]:
train_dir = '/content/drive/MyDrive/bmd_dataset/train'
model_root = './model/'

# Train

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Model list

In [None]:
# all_densenet_models = timm.list_models('**')
# all_densenet_models

Model

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model('resnext101_32x8d', pretrained=True)
        self.backbone.classifier = nn.Identity()
        self.fc = nn.Linear(1000, 3)

    def forward(self, x):
        return self.fc(self.backbone(x))

Making a dataset

In [56]:
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataset = ImageFolder(train_dir, transform=train_transform)

# normal: 0, osteopenia:1, osteoporosis:2

In [53]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# start_idx = 0
# plt.figure(figsize=(16,10))

# for i in range(start_idx, start_idx+16):
#     img, label = dataset[i]
#     print(label)
#     plt.subplot(4, 4, i-start_idx+1)
#     plt.imshow(np.array(img).transpose((1,2,0)))
#     plt.axis('off')
    
# plt.show()

Train & Validation

In [57]:
net = Model()
net.train()
net = net.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(params=[
    {'params': net.backbone.parameters(), 'lr':0.001},
    {'params': net.fc.parameters(), 'lr':0.01},
], lr=0.001, momentum=0.9, weight_decay=5e-4, nesterov=True)

Downloading: "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth" to /root/.cache/torch/hub/checkpoints/resnext101_32x8d-8ba56ff5.pth


In [58]:
def save_model(model, acc, date_time, name):
    model_name = name + '.pth'
    model_path = Path(model_root + date_time)
    model_path.mkdir(parents=True, exist_ok=True)
    print('Saving model (Accuracy {:.2f}%) to {}'.format(acc*100, str(model_path / model_name)) )
    torch.save({'model_state_dict':model.state_dict(), 'acc':acc}, str(model_path / model_name) )

In [59]:
def train_model(net, dataloaders_dict, criterion, optimizer, num_epochs, auc=False, device=device):
    tz = datetime.timezone(datetime.timedelta(hours=9)) # Timezone infomation
    date_time = datetime.datetime.now(tz).strftime('%Y-%m-%d-%H-%M-%S')
    
    best_val_acc = 0
    best_train_acc = 0
    epoch_train_acc = 0
    
    for epoch in range(num_epochs + 1):
        print('\n------------------------')
        print('EPOCH {}/{}'.format(epoch, num_epochs))
        print('------------------------')
        
        if auc == True:
            auc_roc_metric = AUROC(num_classes=3, average=None)
            roc_metric = ROC(num_classes=3)

        for phase in ['train', 'val']:
            if phase == 'train':
                net.train()
            else:
                net.eval()
                
            epoch_loss = 0.0
            epoch_corrects = 0
            
            total_loss = 0
            total_size = 0
            total_corrects = 0
            
            # 학습 전 성능 확인
            if (epoch == 0) and (phase == 'train'):
                continue
            
            num_iteration = len(dataloaders_dict[phase])
            
            for idx, (inputs, labels) in enumerate(tqdm(dataloaders_dict[phase])):
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase == 'train'):
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = net(inputs)
                    loss = criterion(outputs, labels)
                    _, preds = torch.max(outputs, 1)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                        
                    epoch_loss = loss.item()
                    total_loss += epoch_loss

                    epoch_corrects = torch.sum(preds == labels.data)
                    epoch_acc = epoch_corrects.double() / inputs.size(0)
                    total_corrects += epoch_corrects
                    total_size += inputs.size(0)
                    
                    if phase == 'train':
                        print('{} [{}/{}] LOSS: {:.4f} ACC: {:.4f}'.format(phase, idx+1, num_iteration, epoch_loss, epoch_acc))

                    if (phase == 'val') & (auc == True):
                        auc_roc_metric(outputs, labels)
                        roc_metric(outputs, labels)
            
            epoch_loss_avg = total_loss / num_iteration
            epoch_acc_avg = total_corrects / total_size
            
            if phase == 'train':
                if best_train_acc < epoch_acc_avg:
                    best_train_acc = epoch_acc_avg
                print('{} LOSS: {:.4f} ACC: {:.4f} BEST ACC: {:.4f}'.format(phase, epoch_loss_avg, epoch_acc_avg, best_train_acc))

            if phase == 'val':
                if (epoch_acc_avg == best_val_acc) & (epoch_acc_avg == best_train_acc):
                    save_model(net, epoch_acc_avg, date_time, 'best')
                elif best_val_acc < epoch_acc_avg:
                    best_val_acc = epoch_acc_avg
                    print('Best Validation Accuracy: {:.4f}'.format(epoch_acc_avg))
                    save_model(net, epoch_acc_avg, date_time, 'best')

                print('{} LOSS: {:.4f} ACC: {:.4f} BEST ACC: {:.4f}\n'.format(phase, epoch_loss_avg, epoch_acc_avg, best_val_acc))

                if auc == True:
                    for auroc, fpt, tpr, thresholds in zip(auc_roc_metric.compute(), *roc_metric.compute()):
                        print('auc: {:5.2f}'.format(auroc* 100))
                        size = min(len(fpt), len(tpr))
                        plt.plot(fpt[:size].cpu(), tpr[:size].cpu())
                        plt.show()
                        plt.cla()

In [None]:
num_epochs=50

batch_size = 32
num_workers = 4
view_auc = False

train_size = int(0.7*len(dataset))
val_size = len(dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=num_workers)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=num_workers)

dataloaders_dict = {"train": train_loader, "val": val_loader}

train_model(net, dataloaders_dict, criterion, optimizer, num_epochs, view_auc)


------------------------
EPOCH 0/50
------------------------


  cpuset_checked))


  0%|          | 0/29 [00:00<?, ?it/s]

Best Validation Accuracy: 0.3733
Saving model (Accuracy 37.33%) to model/2021-12-12-17-34-52/best.pth
val LOSS: 1.2119 ACC: 0.3733 BEST ACC: 0.3733


------------------------
EPOCH 1/50
------------------------


  0%|          | 0/66 [00:00<?, ?it/s]

train [1/66] LOSS: 1.2403 ACC: 0.3438
train [2/66] LOSS: 1.3478 ACC: 0.3438
train [3/66] LOSS: 1.5676 ACC: 0.2812
train [4/66] LOSS: 1.2412 ACC: 0.4062
train [5/66] LOSS: 1.5943 ACC: 0.4375
train [6/66] LOSS: 1.4112 ACC: 0.4062
train [7/66] LOSS: 1.3466 ACC: 0.4062
train [8/66] LOSS: 1.4830 ACC: 0.4688
train [9/66] LOSS: 1.1851 ACC: 0.5312
train [10/66] LOSS: 1.7769 ACC: 0.3125
train [11/66] LOSS: 1.2254 ACC: 0.5625
train [12/66] LOSS: 1.6359 ACC: 0.4688
train [13/66] LOSS: 1.5920 ACC: 0.4375
train [14/66] LOSS: 1.4272 ACC: 0.4375
train [15/66] LOSS: 1.1850 ACC: 0.4375
train [16/66] LOSS: 1.3453 ACC: 0.3438
train [17/66] LOSS: 1.3199 ACC: 0.4375
train [18/66] LOSS: 1.0740 ACC: 0.5625
train [19/66] LOSS: 1.1444 ACC: 0.4375
train [20/66] LOSS: 0.8726 ACC: 0.6562
train [21/66] LOSS: 1.1415 ACC: 0.5000
train [22/66] LOSS: 1.1865 ACC: 0.4688
train [23/66] LOSS: 0.9519 ACC: 0.6250
train [24/66] LOSS: 1.1245 ACC: 0.4375
train [25/66] LOSS: 1.1828 ACC: 0.3750
train [26/66] LOSS: 1.0359 ACC: 0.

  0%|          | 0/29 [00:00<?, ?it/s]

Best Validation Accuracy: 0.4711
Saving model (Accuracy 47.11%) to model/2021-12-12-17-34-52/best.pth
val LOSS: 1.0173 ACC: 0.4711 BEST ACC: 0.4711


------------------------
EPOCH 2/50
------------------------


  0%|          | 0/66 [00:00<?, ?it/s]

train [1/66] LOSS: 0.8046 ACC: 0.7188
train [2/66] LOSS: 0.8410 ACC: 0.5625
train [3/66] LOSS: 0.8283 ACC: 0.5938
train [4/66] LOSS: 0.9311 ACC: 0.4688
train [5/66] LOSS: 0.8815 ACC: 0.5938
train [6/66] LOSS: 0.8712 ACC: 0.5312
train [7/66] LOSS: 1.0334 ACC: 0.4688
train [8/66] LOSS: 0.8747 ACC: 0.5938
train [9/66] LOSS: 0.7753 ACC: 0.6562
train [10/66] LOSS: 0.8916 ACC: 0.5000
train [11/66] LOSS: 0.6914 ACC: 0.6562
train [12/66] LOSS: 0.8081 ACC: 0.6250
train [13/66] LOSS: 0.6962 ACC: 0.7500
train [14/66] LOSS: 0.7578 ACC: 0.5000
train [15/66] LOSS: 0.9756 ACC: 0.5000
train [16/66] LOSS: 1.0443 ACC: 0.5938
train [17/66] LOSS: 0.5415 ACC: 0.7500
train [18/66] LOSS: 0.9717 ACC: 0.5312
train [19/66] LOSS: 0.8623 ACC: 0.5938
train [20/66] LOSS: 0.5813 ACC: 0.7188
train [21/66] LOSS: 0.9647 ACC: 0.5312
train [22/66] LOSS: 0.7903 ACC: 0.7500
train [23/66] LOSS: 0.8703 ACC: 0.6875
train [24/66] LOSS: 0.9446 ACC: 0.5625
train [25/66] LOSS: 0.9437 ACC: 0.6875
train [26/66] LOSS: 0.8834 ACC: 0.