In [1]:
import os
import json
import torch
import torch.nn as nn
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader
from pathlib import Path
from sklearn.model_selection import train_test_split
from helper import *

In [2]:
LOG_PATH = Path('LOG')
SHARE_PATH = Path('share')
CHECKPOINT_PATH = Path('checkpoint')

In [3]:
if not os.path.exists(LOG_PATH):
    os.mkdir(LOG_PATH)

if not os.path.exists(CHECKPOINT_PATH):
    os.mkdir(CHECKPOINT_PATH)

In [4]:
logger = get_logger(LOG_PATH / 'conv_train.log', 'conv_train')

In [5]:
random_state = 777

root_path = SHARE_PATH / '1_train+val_210220 upload'
ann_path = root_path / 'Annotation_v2_Train+Val_210208.json'
save_path = CHECKPOINT_PATH / 'resnet50.pth'

image_size = 224
batch_size = 1024
lr = 0.01
epoch = 30
device = 'cuda'
num_classes = 3
n_splits = 5

In [6]:
train_transforms = transforms.Compose([transforms.Grayscale(),
                                       transforms.Resize((image_size, image_size)),
                                       transforms.ToTensor(),
                                       transforms.Normalize(0.5, 0.5)])

test_transforms = transforms.Compose([transforms.Grayscale(),
                                      transforms.Resize((image_size, image_size)),
                                      transforms.ToTensor(),
                                      transforms.Normalize(0.5, 0.5)])

In [7]:
with open(ann_path, 'r') as f:
    json_data = json.load(f)
    
patients = json_data['Patient']

In [8]:
train_patients, valid_patients = train_test_split(patients, test_size=0.2, random_state=random_state)

In [9]:
print(f"TRAIN Patients : {len(train_patients)}")
print(f"VALID Patients : {len(valid_patients)}")

TRAIN Patients : 3944
VALID Patients : 986


In [10]:
train_dataset = SleepConvDataset(train_patients, root_path, train_transforms)

train_loader = DataLoader(train_dataset,
                           batch_size=batch_size,
                           num_workers=8,
                           pin_memory=True,
                           shuffle=True)

train_labels = np.array(train_dataset.labels)

print(f"Wake Ratio : {sum(train_labels == 0) / len(train_labels)}")
print(f"NREM Ratio : {sum(train_labels == 1) / len(train_labels)}")
print(f"REM Ratio : {sum(train_labels == 2) / len(train_labels)}")

Wake Ratio : 0.23357731121557443
NREM Ratio : 0.6199952320914652
REM Ratio : 0.1464274566929604


In [11]:
valid_dataset = SleepConvDataset(valid_patients, root_path, test_transforms)

valid_loader = DataLoader(valid_dataset,
                           batch_size=batch_size,
                           num_workers=8,
                           pin_memory=True,
                           shuffle=True)

valid_labels = np.array(valid_dataset.labels)

print(f"Wake Ratio : {sum(valid_labels == 0) / len(valid_labels)}")
print(f"NREM Ratio : {sum(valid_labels == 1) / len(valid_labels)}")
print(f"REM Ratio : {sum(valid_labels == 2) / len(valid_labels)}")

Wake Ratio : 0.23315098759330227
NREM Ratio : 0.6206709221784599
REM Ratio : 0.14617809022823788


In [12]:
early_stopping = EarlyStopping(verbose=True, path=save_path)

train_total = len(train_dataset)
valid_total = len(valid_dataset)

model = get_resnet50(num_classes, pretrained=True)
model = nn.DataParallel(model)
model = model.to(device)

criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=1e-5, momentum=0.9)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10, 20], gamma=0.1)

for e in range(0, epoch):
    train_correct, train_loss = train(model, train_loader, optimizer, criterion, device=device)
    train_acc = train_correct / train_total
    train_loss = train_loss / train_total

    valid_correct, valid_loss = valid(model, valid_loader, criterion, device=device)
    valid_acc = valid_correct / valid_total
    valid_loss = valid_loss / valid_total

    scheduler.step()

    logger.info("===============================================================")
    logger.info("===============================================================")
    logger.info(f"||    EPOCH : {epoch} / {e}]   ||")
    logger.info(f"|| [TRAIN ACC : {train_acc}] || [TRAIN LOSS : {train_loss}] ||")
    logger.info(f"|| [VALID ACC : {valid_acc}] || [VALID LOSS : {valid_loss}] ||")
    logger.info("===============================================================")
    logger.info("===============================================================")

    early_stopping(valid_loss, model)

    if early_stopping.early_stop:
        logger.info("Early stopping")
        break

    model.load_state_dict(torch.load(save_path))

100%|██████████| 2782/2782 [3:53:13<00:00,  5.03s/it]  
100%|██████████| 698/698 [56:31<00:00,  4.86s/it]  
||    [ FOLD : 5 / <_io.TextIOWrapper name='share/1_train+val_210220 upload/Annotation_v2_Train+Val_210208.json' mode='r' encoding='UTF-8'> || EPOCH : 30 / 0]   ||
|| [TRAIN ACC : 0.911825642008715] || [TRAIN LOSS : 0.00022137594338020848] ||
|| [VALID ACC : 0.914813689511923] || [VALID LOSS : 0.00021617001478766343] ||


Validation loss decreased (inf --> 0.000216).  Saving model ...


100%|██████████| 2782/2782 [3:42:23<00:00,  4.80s/it]  
100%|██████████| 698/698 [56:00<00:00,  4.81s/it]  
||    [ FOLD : 5 / <_io.TextIOWrapper name='share/1_train+val_210220 upload/Annotation_v2_Train+Val_210208.json' mode='r' encoding='UTF-8'> || EPOCH : 30 / 1]   ||
|| [TRAIN ACC : 0.9275425363798794] || [TRAIN LOSS : 0.0001812317547181674] ||
|| [VALID ACC : 0.9074839228183268] || [VALID LOSS : 0.00023316275679613878] ||


EarlyStopping counter: 1 out of 7


100%|██████████| 2782/2782 [3:41:43<00:00,  4.78s/it]  
100%|██████████| 698/698 [56:22<00:00,  4.85s/it]  
||    [ FOLD : 5 / <_io.TextIOWrapper name='share/1_train+val_210220 upload/Annotation_v2_Train+Val_210208.json' mode='r' encoding='UTF-8'> || EPOCH : 30 / 2]   ||
|| [TRAIN ACC : 0.9275523671191265] || [TRAIN LOSS : 0.00018128208463056425] ||
|| [VALID ACC : 0.916954513442106] || [VALID LOSS : 0.00021260532517453738] ||


Validation loss decreased (0.000216 --> 0.000213).  Saving model ...


100%|██████████| 2782/2782 [3:52:18<00:00,  5.01s/it]  
100%|██████████| 698/698 [59:23<00:00,  5.11s/it]  
||    [ FOLD : 5 / <_io.TextIOWrapper name='share/1_train+val_210220 upload/Annotation_v2_Train+Val_210208.json' mode='r' encoding='UTF-8'> || EPOCH : 30 / 3]   ||
|| [TRAIN ACC : 0.9332120641427648] || [TRAIN LOSS : 0.00016604541172869248] ||
|| [VALID ACC : 0.9126532635313653] || [VALID LOSS : 0.00022453152528725795] ||


EarlyStopping counter: 1 out of 7


100%|██████████| 2782/2782 [3:48:08<00:00,  4.92s/it]  
100%|██████████| 698/698 [56:38<00:00,  4.87s/it]  
||    [ FOLD : 5 / <_io.TextIOWrapper name='share/1_train+val_210220 upload/Annotation_v2_Train+Val_210208.json' mode='r' encoding='UTF-8'> || EPOCH : 30 / 4]   ||
|| [TRAIN ACC : 0.9332211926863513] || [TRAIN LOSS : 0.00016596950932911832] ||
|| [VALID ACC : 0.9101554022550759] || [VALID LOSS : 0.00023740606143987192] ||


EarlyStopping counter: 2 out of 7


100%|██████████| 2782/2782 [3:41:57<00:00,  4.79s/it]  
100%|██████████| 698/698 [56:08<00:00,  4.83s/it]  
||    [ FOLD : 5 / <_io.TextIOWrapper name='share/1_train+val_210220 upload/Annotation_v2_Train+Val_210208.json' mode='r' encoding='UTF-8'> || EPOCH : 30 / 5]   ||
|| [TRAIN ACC : 0.9332317256212588] || [TRAIN LOSS : 0.00016598055886465167] ||
|| [VALID ACC : 0.9164700627685656] || [VALID LOSS : 0.00021863206051301793] ||


EarlyStopping counter: 3 out of 7


100%|██████████| 2782/2782 [3:43:10<00:00,  4.81s/it]  
100%|██████████| 698/698 [56:51<00:00,  4.89s/it]  
||    [ FOLD : 5 / <_io.TextIOWrapper name='share/1_train+val_210220 upload/Annotation_v2_Train+Val_210208.json' mode='r' encoding='UTF-8'> || EPOCH : 30 / 6]   ||
|| [TRAIN ACC : 0.9331593994682272] || [TRAIN LOSS : 0.00016597013046261343] ||
|| [VALID ACC : 0.904559016873165] || [VALID LOSS : 0.00024814171341343093] ||
  0%|          | 0/2782 [00:00<?, ?it/s]

EarlyStopping counter: 4 out of 7


100%|██████████| 2782/2782 [3:43:34<00:00,  4.82s/it]  
100%|██████████| 698/698 [56:08<00:00,  4.83s/it]  
||    [ FOLD : 5 / <_io.TextIOWrapper name='share/1_train+val_210220 upload/Annotation_v2_Train+Val_210208.json' mode='r' encoding='UTF-8'> || EPOCH : 30 / 7]   ||
|| [TRAIN ACC : 0.9331053304023686] || [TRAIN LOSS : 0.00016606420061472785] ||
|| [VALID ACC : 0.9145434612460464] || [VALID LOSS : 0.00022404497308462193] ||


EarlyStopping counter: 5 out of 7


 20%|█▉        | 544/2782 [43:37<2:59:30,  4.81s/it]


KeyboardInterrupt: 