In [None]:
GPU_NUM = 1
GPU_NUM = str(GPU_NUM)
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"] = GPU_NUM

import shutil
import warnings
import contextlib
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from copy import deepcopy
from IPython.display import clear_output

warnings.filterwarnings(action='ignore')
plt.style.use(plt.style.available[-3])
plt.rcParams['image.cmap'] = 'gray'

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, Dataset

torch.set_default_dtype(torch.float32)

GPU = torch.device('cuda')
CPU = torch.device('cpu')
print(torch.cuda.is_available(), ': ', torch.cuda.get_device_name(0))
torch.cuda.empty_cache()

In [None]:
from utils.utils import *
from networks.networks import *
from options.hyper_parameters import HP

In [None]:
hp = HP()
hp.__dict__

In [None]:
model = ResNet(shape=True)
x = torch.zeros([1, 1, 28, 28])
with torch.no_grad(): res = model(x)

In [None]:
model = ResNet()
train_set = Dataset_Temp(hp, phase='train')
valid_set = Dataset_Temp(hp, phase='valid')
train_loader = DataLoader(dataset=train_set, shuffle=True, batch_size=hp.batch_size, 
                          num_workers=2, pin_memory=True)
valid_loader = DataLoader(dataset=valid_set, shuffle=True, batch_size=hp.batch_size, 
                          num_workers=2, pin_memory=True)
data_loader = [train_loader, valid_loader]
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=hp.scheduler_step, gamma=hp.scheduler_gamma)

In [None]:
if hp.multi_gpu: model = nn.DataParallel(model)
model = model.to(hp.device)

torch.cuda.empty_cache()
print(f'{"Model":<20}: {hp.name}')

BCE_func = BCE()
MAE_func = MAE()
ACC_func = ACC()

loss_keys = ['BCE', 'MAE']
rate_keys = ['ACC']

epoch_s, epoch_e = 1, hp.epochs+1
if hp.epoch_load!=None:
    last_point = torch.load(f'../res/{hp.name}/model/last_point.pt')
    model.load_state_dict(last_point['model'].state_dict())
    optimizer.load_state_dict(last_point['optimizer'].state_dict())
    epoch_s = hp.epoch_load + 1
    epoch_save = last_point['epoch']
    print(f'※ Continual Learning: {epoch_save} Epoch ※')
    
scaler = torch.cuda.amp.GradScaler()
loss_epoch = torch.zeros([epoch_e, 2, len(loss_keys)])
rate_epoch = torch.zeros([epoch_e, 2, len(rate_keys)])
for epoch in range(epoch_s, epoch_e):
    
    for i, phase in enumerate(['Train', 'Valid']):
        if phase=='Train':
            model.train()
            context_manager = contextlib.nullcontext()
        if phase=='Valid':
            model.eval()
            context_manager = torch.no_grad()
        
        loss_batch = torch.zeros([len(data_loader[i]), len(loss_keys)])
        rate_batch = torch.zeros([len(data_loader[i]), len(rate_keys)])
        with context_manager:
            for k, data in enumerate(data_loader[i]):
                X, y = data[0].to(hp.device), data[1].to(hp.device)
                
                with torch.cuda.amp.autocast(enabled=True):
                    # Forward
                    p = model(X)
                    
                    # Loss
                    BCE_res = BCE_func(p, y)
                    MAE_res = MAE_func(p, y)
                    
                    # Final Loss
                    loss_final = BCE_res+MAE_res
                    
                    loss_batch[k, 0] = BCE_res.detach()
                    loss_batch[k, 1] = MAE_res.detach()
                    
                if phase=='Train':
                    optimizer.zero_grad()
                    scaler.scale(loss_final).backward()
                    scaler.step(optimizer)
                    scaler.update()
                
                # Rate
                with torch.no_grad():
                    ACC_res = ACC_func(p, y)
                    rate_batch[k, 0] = ACC_res
        
        loss_epoch[epoch, i] = torch.mean(loss_batch.cpu(), axis=0)
        rate_epoch[epoch, i] = torch.mean(rate_batch.cpu(), axis=0)
        
    # Scheduler
    if scheduler is not None: scheduler.step()
    
    # Monitoring
    es_loss = es(loss_epoch, epoch, inverse=False)
    es_rate = es(rate_epoch, epoch, inverse=True)
    
    if epoch==1:
        print(f'===== Loss Monitoring =====')
        print(f'Loss: {loss_keys}', end=' ')
        print(f'rate: {rate_keys}', end=' ')
        print(f'(Train, Valid)')
    if epoch%hp.monitoring_cycle==0:
        print(f'{epoch:5.0f}/{hp.epochs:5.0f}', end=' ')
        for l in range(len(loss_keys)):
            loss_train = loss_epoch[epoch, 0, l]
            loss_valid = loss_epoch[epoch, 1, l]
            loss_ratio = (loss_train/loss_valid)*100
            print(f'({loss_train:6.4f}, {loss_valid:6.4f})', end=f' {es_loss[l]} ')
        print('*', end=' ')
        for l in range(len(rate_keys)):
            rate_train = rate_epoch[epoch, 0, l]
            rate_valid = rate_epoch[epoch, 1, l]
            rate_ratio = (rate_train/rate_valid)*100
            print(f'({rate_train:6.4f}, {rate_valid:6.4f})', end=f' {es_rate[l]} ')
        print()
        
        # Save
        history = {'loss':loss_epoch, 
                   'rate':rate_epoch, 
                   'loss_keys':loss_keys, 
                   'rate_keys':rate_keys}
        if epoch-1==hp.epoch_load:
            history['loss'][:hp.epoch_load+1] = last_point['history']['loss'][:hp.epoch_load+1]
            history['rate'][:hp.epoch_load+1] = last_point['history']['rate'][:hp.epoch_load+1]
        
        if epoch%hp.save_cycle==0:
            torch.save(history, f'{hp.path_model}/history.pt')
            torch.save(model, f'{hp.path_model}/model_{epoch}.pt')
            last_point = {'epoch':epoch, 
                          'history':history, 
                          'model':model, 
                          'optimizer':optimizer}
            # torch.save(last_point, f'{hp.path_model}/last_point.pt')

torch.cuda.empty_cache()