# Profile

In [2]:
import os
import multiprocessing
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

datadir = '/root/input/rsna-2022-cervical-spine-fracture-detection'
libdir = '/root/workspace/RSNA2022RAWE'
outputdir = '/root/workspace/RSNA2022RAWE'
otherdir = '/root/workspace/RSNA2022RAWE'
train_bs_ = 4
valid_bs_ = 8
num_workers_ = multiprocessing.cpu_count()

# CFG

In [3]:
class CFG:
    seed=42
    device='GPU'
    nprocs=1 # [1, 8]
    num_workers=num_workers_
    train_bs=train_bs_
    valid_bs=valid_bs_
    fold_num=5

    target_cols=["C1", "C2", "C3", "C4", "C5", "C6", "C7"]
    num_classes=7

    accum_iter=1
    max_grad_norm=1000
    print_freq=100
    normalize_mean=[0.4824, 0.4824, 0.4824] # [0.485, 0.456, 0.406] [0.4824, 0.4824, 0.4824]
    normalize_std=[0.22, 0.22, 0.22] # [0.229, 0.224, 0.225] [0.22, 0.22, 0.22]
    
    suffix="401" 
    fold_list=[0] 
    epochs=25
    model_arch="resnet50d" # tf_efficientnetv2_s, resnest50d
    img_size=320
    optimizer="AdamW"
    scheduler="CosineAnnealingLR"
    loss_fn="BCEWithLogitsLoss"
    scheduler_warmup="GradualWarmupSchedulerV3" 

    warmup_epo=1
    warmup_factor = 10
    T_max= epochs-warmup_epo-2 if scheduler_warmup=="GradualWarmupSchedulerV2" else \
           epochs-warmup_epo-1 if scheduler_warmup=="GradualWarmupSchedulerV3" else epochs-1 # CosineAnnealingLR
    
    seq_len = 24
    # lr=5e-4
    lr = 0.001
    min_lr=1e-6 
    weight_decay=0
    dropout=0.1

    gpu_parallel=False
    n_early_stopping=4
    debug=False
    multihead=False

# Import

In [4]:
!pip install -U scikit-image
!pip install timm
!pip install nibabel
! pip install python-gdcm
! pip install pylibjpeg pylibjpeg-libjpeg pydicom
!pip install -U albumentations
!pip install segmentation_models_pytorch
!pip install -q git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git

Collecting scikit-image
  Using cached scikit_image-0.19.3-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (13.5 MB)
Collecting tifffile>=2019.7.26
  Using cached tifffile-2021.11.2-py3-none-any.whl (178 kB)
Installing collected packages: tifffile, scikit-image
  Attempting uninstall: scikit-image
    Found existing installation: scikit-image 0.16.2
    Uninstalling scikit-image-0.16.2:
      Successfully uninstalled scikit-image-0.16.2
Successfully installed scikit-image-0.19.3 tifffile-2021.11.2
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.2.2[0m[39;49m -> [0m[32;49m22.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Collecting timm
  Using cached timm-0.6.11-py3-none-any.whl (548 kB)
Collecting torch>=1.7
  Using cached torch-1.13.0-cp37-cp37m-manylinux1_x86_64.whl (890.2 MB)
Collecting torchvision
  Using cached torchvision-0.14.0-cp37-cp37m-ma

In [5]:
import sys; 

package_paths = [f'{libdir}pytorch-image-models-master']
for pth in package_paths:
    sys.path.append(pth)
    
import ast
from glob import glob
import cv2
from skimage import io
import os
from datetime import datetime
import time
import random
from tqdm import tqdm
from contextlib import contextmanager
import math

import numpy as np
import pandas as pd
import sklearn
from sklearn.metrics import roc_auc_score, log_loss
from sklearn import metrics
from sklearn.model_selection import GroupKFold, StratifiedKFold
import torch
import torchvision
from torchvision import transforms
from torch import nn
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.sampler import SequentialSampler, RandomSampler
from torch.nn.modules.loss import _WeightedLoss
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch.optim import Adam, SGD, AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
from warmup_scheduler import GradualWarmupScheduler
import timm
import warnings
import joblib
from scipy.ndimage.interpolation import zoom
import nibabel as nib
import pydicom as dicom
import gc 
from torch.nn import DataParallel



if CFG.device == 'TPU':
    !pip install -q pytorch-ignite
    import ignite.distributed as idist
elif CFG.device == 'GPU':
    from torch.cuda.amp import autocast, GradScaler

# helper

In [6]:
train_df = pd.read_pickle(f'{datadir}/vertebrae_df.pkl')
submission_df = pd.read_csv(f'{datadir}/sample_submission.csv')

train_df = train_df[train_df["StudyInstanceUID"] != "1.2.826.0.1.3680043.20574"].reset_index(drop=True)
train_df = train_df[train_df["StudyInstanceUID"] != "1.2.826.0.1.3680043.29952"].reset_index(drop=True)
train_df

gkf = GroupKFold(n_splits=CFG.fold_num)
folds = gkf.split(X=train_df, y=None, groups=train_df['StudyInstanceUID'])

In [7]:
if CFG.device == 'TPU':
    import os
    VERSION = "1.7"
    CP_V = "36" if ENV == "colab" else "37"
    wheel = f"torch_xla-{VERSION}-cp{CP_V}-cp{CP_V}m-linux_x86_64.whl"
    url = f"https://storage.googleapis.com/tpu-pytorch/wheels/{wheel}"
    !pip3 -q install cloud-tpu-client==0.10 $url
    os.system('export XLA_USE_BF16=1')
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    import torch_xla.distributed.xla_multiprocessing as xmp
    CFG.lr = CFG.lr * CFG.nprocs
    CFG.train_bs = CFG.train_bs // CFG.nprocs
    device = xm.xla_device()
    
elif CFG.device == "GPU":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [8]:
print(device)

cuda


In [9]:
print(torch.version.cuda)

11.7


In [10]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True 

seed_everything(CFG.seed)


def get_score(y_true, y_pred):
    scores = []
    for i in range(y_true.shape[1]):
        score = roc_auc_score(y_true[:,i], y_pred[:,i])
        scores.append(score)
    avg_score = np.mean(scores)
    return avg_score, scores


@contextmanager
def timer(name):
    t0 = time.time()
    LOGGER.info(f'[{name}] start')
    yield
    LOGGER.info(f'[{name}] done in {time.time() - t0:.0f} s.')


def init_logger(log_file=outputdir+'stage2_train.log'):
    from logging import getLogger, INFO, FileHandler,  Formatter,  StreamHandler
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = init_logger(outputdir+f'/stage2_train{CFG.suffix}.log')

if CFG.device=='TPU' and CFG.nprocs==8:
    loginfo = xm.master_print
    cusprint = xm.master_print
else:
    loginfo = LOGGER.info
    cusprint = print



def get_timediff(time1,time2):
    minute_,second_ = divmod(time2-time1,60)
    return f"{int(minute_):02d}:{int(second_):02d}"  


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))


def get_img(path):
    im_bgr = cv2.imread(path)
    im_rgb = im_bgr[:, :, ::-1]
    return im_rgb

def load_dicom(path):
    """
    This supports loading both regular and compressed JPEG images. 
    See the first sell with `pip install` commands for the necessary dependencies
    """
    img = dicom.dcmread(path)
    img.PhotometricInterpretation = 'YBR_FULL'
    data = img.pixel_array
    data = data - np.min(data)
    if np.max(data) != 0:
        data = data / np.max(data)
    # data = (data * 255).astype(np.uint8)
    return data

# DataSet

In [11]:
class TrainDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        study_id = row["StudyInstanceUID"]
        slice_num_list = row['slice_num_list']
        slice_list = []
        for s_num in slice_num_list:
            path = f"{datadir}/train_images/{study_id}/{s_num}.dcm"
            img = load_dicom(path)
            if len(slice_list) == 0:
                imgh = img.shape[0]
                imgw = img.shape[1]
            elif img.shape != (imgh, imgw):
                img = cv2.resize(img,(imgh,imgw))

            slice_list.append(img)
        for _ in range(CFG.seq_len - len(slice_list)):
            slice_list.append(np.zeros((imgh,imgw)))

        image = np.stack(slice_list, axis=2) # 512*512*seq_len; 0-1

        assert image.shape == (imgh, imgw, CFG.seq_len)

        # transform
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']

        image = np.transpose(image, (2, 0, 1)) # seq_len*img_size*img_size; 0-1
        return torch.from_numpy(image), torch.tensor(row['label']).float()

In [12]:
from albumentations import (
    HorizontalFlip, VerticalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
    Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
    IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine, RandomResizedCrop,
    IAASharpen, IAAEmboss, RandomBrightnessContrast, Flip, OneOf, Compose, Normalize, Cutout, CoarseDropout, ShiftScaleRotate, 
    CenterCrop, Resize, RandomCrop, GaussianBlur, JpegCompression, Downscale, ElasticTransform
)
import albumentations
from albumentations.pytorch import ToTensorV2

def get_transforms(*, data):
    if data == 'train':
        return Compose([
            RandomResizedCrop(CFG.img_size, CFG.img_size, scale=(0.9, 1), p=1), 
            HorizontalFlip(p=0.5),
            ShiftScaleRotate(p=0.5),
            HueSaturationValue(hue_shift_limit=10, sat_shift_limit=10, val_shift_limit=10, p=0.7),
            RandomBrightnessContrast(brightness_limit=(-0.2,0.2), contrast_limit=(-0.2, 0.2), p=0.7),
            CLAHE(clip_limit=(1,4), p=0.5),
            OneOf([
                OpticalDistortion(distort_limit=1.0),
                GridDistortion(num_steps=5, distort_limit=1.),
                ElasticTransform(alpha=3),
            ], p=0.2),
            OneOf([
                GaussNoise(var_limit=[10, 50]),
                GaussianBlur(),
                MotionBlur(),
                MedianBlur(),
            ], p=0.2),
            Resize(CFG.img_size, CFG.img_size),
            OneOf([
                JpegCompression(),
                Downscale(scale_min=0.1, scale_max=0.15),
            ], p=0.2),
            IAAPiecewiseAffine(p=0.2),
            IAASharpen(p=0.2),
            Cutout(max_h_size=int(CFG.img_size * 0.1), max_w_size=int(CFG.img_size * 0.1), num_holes=5, p=0.5),
            ])
    elif data == 'light_train':
        return Compose([
            Resize(CFG.img_size, CFG.img_size, interpolation=cv2.INTER_NEAREST),
            HorizontalFlip(p=0.5),
            VerticalFlip(p=0.5),
            ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=20, p=0.5),
            OneOf([
                # GaussNoise(),
                GaussianBlur(),
                MotionBlur(),
                # MedianBlur(),
            ], p=0.3),
            OneOf([
                GridDistortion(num_steps=5, distort_limit=0.05, p=1.0),
                OpticalDistortion(distort_limit=0.05, shift_limit=0.05, p=1.0),
                ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0)
            ], p=0.3),
            # CoarseDropout(max_holes=8, max_height=CFG.img_size[0]//20, max_width=CFG.img_size[1]//20,
            #              min_holes=5, fill_value=0, mask_fill_value=0, p=0.5),
            ], p=1.0)    
    elif data == 'valid':
        return Compose([
            Resize(CFG.img_size, CFG.img_size),
        ])

In [13]:
from pylab import rcParams
dataset_show = TrainDataset(
    train_df,
    transform=get_transforms(data='light_train') # None, get_transforms(data='check')
    )
rcParams['figure.figsize'] = 30,20
for i in range(2):
    f, axarr = plt.subplots(1,5)
    idx = np.random.randint(0, len(dataset_show))
    img, label= dataset_show[idx]
    # axarr[p].imshow(img) # transform=None
    axarr[0].imshow(img[0]); plt.axis('OFF');
    axarr[1].imshow(img[1]); plt.axis('OFF');
    axarr[2].imshow(img[2]); plt.axis('OFF');
    axarr[3].imshow(img[3]); plt.axis('OFF');
    axarr[4].imshow(img[4]); plt.axis('OFF');

# Model

In [14]:
import torch.nn as nn
from itertools import repeat

class SpatialDropout(nn.Module):
    def __init__(self, drop=0.5):
        super(SpatialDropout, self).__init__()
        self.drop = drop
        
    def forward(self, inputs, noise_shape=None):
        """
        @param: inputs, tensor
        @param: noise_shape, tuple
        """
        outputs = inputs.clone()
        if noise_shape is None:
            noise_shape = (inputs.shape[0], *repeat(1, inputs.dim()-2), inputs.shape[-1]) 
        
        self.noise_shape = noise_shape
        if not self.training or self.drop == 0:
            return inputs
        else:
            noises = self._make_noises(inputs)
            if self.drop == 1:
                noises.fill_(0.0)
            else:
                noises.bernoulli_(1 - self.drop).div_(1 - self.drop)
            noises = noises.expand_as(inputs)    
            outputs.mul_(noises)
            return outputs
            
    def _make_noises(self, inputs):
        return inputs.new().resize_(self.noise_shape)


import torch
from torch import nn
import torch.nn.functional as F

from typing import Dict, Optional
 
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor


    
class MLPAttentionNetwork(nn.Module):
 
    def __init__(self, hidden_dim, attention_dim=None):
        super(MLPAttentionNetwork, self).__init__()
 
        self.hidden_dim = hidden_dim
        self.attention_dim = attention_dim
        if self.attention_dim is None:
            self.attention_dim = self.hidden_dim
        # W * x + b
        self.proj_w = nn.Linear(self.hidden_dim, self.attention_dim, bias=True)
        # v.T
        self.proj_v = nn.Linear(self.attention_dim, 1, bias=False)
 
    def forward(self, x):
        """
        :param x: seq_len, batch_size, hidden_dim
        :return: batch_size * seq_len, batch_size * hidden_dim
        """
        batch_size, seq_len, _ = x.size()

        H = torch.tanh(self.proj_w(x)) # (batch_size, seq_len, hidden_dim)
        
        att_scores = torch.softmax(self.proj_v(H),axis=1) # (batch_size, seq_len)
        
        attn_x = (x * att_scores).sum(1) # (batch_size, hidden_dim)
        return attn_x

In [15]:
class RSNAClassifier(nn.Module):
    def __init__(self, model_arch, hidden_dim=256, seq_len=24, pretrained=False):
        super().__init__()
        self.seq_len = seq_len
        self.model = timm.create_model(model_arch, in_chans=1, pretrained=pretrained)

        if 'efficientnet' in CFG.model_arch:
            cnn_feature = self.model.classifier.in_features
            self.model.classifier = nn.Identity()
        elif "res" in CFG.model_arch:
            cnn_feature = self.model.fc.in_features
            self.model.global_pool = nn.Identity()
            self.model.fc = nn.Identity()
            self.pooling = nn.AdaptiveAvgPool2d(1)
        
        self.spatialdropout = SpatialDropout(CFG.dropout)
        self.gru = nn.GRU(cnn_feature, hidden_dim, 2, batch_first=True, bidirectional=True)
        self.mlp_attention_layer = MLPAttentionNetwork(2 * hidden_dim)
        self.logits = nn.Sequential(
            nn.Linear(hidden_dim*2, 128),
            nn.ReLU(),
            nn.Dropout(CFG.dropout),
            nn.Linear(128, 1)
        )

        for n, m in self.named_modules():
            if isinstance(m, nn.GRU):
                print(f"init {m}")
                for param in m.parameters():
                    if len(param.shape) >= 2:
                        nn.init.orthogonal_(param.data)
                    else:
                        nn.init.normal_(param.data)

    def forward(self, x): # (B, seq_len, H, W)
        bs = x.size(0) 
        x = x.reshape(bs*self.seq_len, 1, x.size(2), x.size(3)) # (B*seq_len, 1, H, W)
        features = self.model(x)   
        if "res" in CFG.model_arch:                             
            features = self.pooling(features).view(bs*self.seq_len, -1) # (B*seq_len, cnn_feature)
        features = self.spatialdropout(features)                # (B*seq_len, cnn_feature)
        # print(features.shape)
        features = features.reshape(bs, self.seq_len, -1)       # (B, seq_len, cnn_feature)
        features, _ = self.gru(features)                        # (B, seq_len, hidden_dim*2)
        atten_out = self.mlp_attention_layer(features)          # (B, hidden_dim*2)
        pred = self.logits(atten_out)                           # (B, 1)
        pred = pred.view(bs, -1)                                # (B, 1)
        return pred

In [16]:
model = RSNAClassifier(CFG.model_arch, hidden_dim=256, seq_len=24, pretrained=True)

init GRU(2048, 256, num_layers=2, batch_first=True, bidirectional=True)


In [17]:
def get_activation(activ_name: str="relu"):
    """"""
    act_dict = {
        "relu": nn.ReLU(inplace=True),
        "tanh": nn.Tanh(),
        "sigmoid": nn.Sigmoid(),
        "identity": nn.Identity()}
    if activ_name in act_dict:
        return act_dict[activ_name]
    else:
        raise NotImplementedError
        

class Conv2dBNActiv(nn.Module):
    """Conv2d -> (BN ->) -> Activation"""
    
    def __init__(
        self, in_channels, out_channels,
        kernel_size, stride, padding,
        bias=False, use_bn=True, activ="relu"
    ):
        """"""
        super(Conv2dBNActiv, self).__init__()
        layers = []
        layers.append(nn.Conv2d(
            in_channels, out_channels,
            kernel_size, stride, padding, bias=bias))
        if use_bn:
            layers.append(nn.BatchNorm2d(out_channels))
            
        layers.append(get_activation(activ))
        self.layers = nn.Sequential(*layers)
        
    def forward(self, x):
        """Forward"""
        return self.layers(x)
        
    
class SpatialAttentionBlock(nn.Module):
    """Spatial Attention for (C, H, W) feature maps"""
    
    def __init__(
        self, in_channels,
        out_channels_list,
    ):
        """Initialize"""
        super(SpatialAttentionBlock, self).__init__()
        self.n_layers = len(out_channels_list)
        channels_list = [in_channels] + out_channels_list
        assert self.n_layers > 0
        assert channels_list[-1] == 1
        
        for i in range(self.n_layers - 1):
            in_chs, out_chs = channels_list[i: i + 2]
            layer = Conv2dBNActiv(in_chs, out_chs, 3, 1, 1, activ="relu")
            setattr(self, f"conv{i + 1}", layer)
            
        in_chs, out_chs = channels_list[-2:]
        layer = Conv2dBNActiv(in_chs, out_chs, 3, 1, 1, activ="sigmoid")
        setattr(self, f"conv{self.n_layers}", layer)
    
    def forward(self, x):
        """Forward"""
        h = x
        for i in range(self.n_layers):
            h = getattr(self, f"conv{i + 1}")(h)
            
        h = h * x
        return h



class MultiHeadResNet200D(nn.Module):
    def __init__(self, out_dims_head=[3, 4, 3, 1],  pretrained=False):
        self.base_name = "resnet200d_320"
        self.n_heads = len(out_dims_head)
        super(MultiHeadResNet200D, self).__init__()
        
        # # load base model
        base_model = timm.create_model(self.base_name, num_classes=sum(out_dims_head), pretrained=False)
        in_features = base_model.num_features
        
        if pretrained:
            pretrained_model_path = CFG.student
            state_dict = dict()
            for k, v in torch.load(pretrained_model_path, map_location='cpu')["model"].items():
                if k[:6] == "model.":
                    k = k.replace("model.", "")
                state_dict[k] = v
            base_model.load_state_dict(state_dict)
        
        # # remove global pooling and head classifier
        base_model.reset_classifier(0, '')
        
        # # Shared CNN Bacbone
        self.backbone = base_model
        
        # # Multi Heads.
        for i, out_dim in enumerate(out_dims_head):
            layer_name = f"head_{i}"
            layer = nn.Sequential(
                SpatialAttentionBlock(in_features, [64, 32, 16, 1]),
                nn.AdaptiveAvgPool2d(output_size=1),
                nn.Flatten(start_dim=1),
                nn.Linear(in_features, in_features),
                nn.ReLU(inplace=True),
                nn.Dropout(0.5),
                nn.Linear(in_features, out_dim))
            setattr(self, layer_name, layer)

    def forward(self, x):
        h = self.backbone(x)
        hs = [getattr(self, f"head_{i}")(h) for i in range(self.n_heads)]
        y = torch.cat(hs, axis=1)
        return None, None, y

In [18]:
def train_one_epoch(train_loader, model, criterion, optimizer, epoch, scheduler, device):
    if CFG.device == 'GPU':
        scaler = GradScaler()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    # switch to train mode
    model.train()
    start = end = time.time()
    for step, (images, labels) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device, dtype=torch.float)
        labels = labels.to(device, dtype=torch.float)
        batch_size = labels.size(0)

        if CFG.device == 'GPU':
            with autocast():
                y_preds = model(images)
                y_preds = y_preds.squeeze(1)
                loss = criterion(y_preds, labels)
            # record loss
            losses.update(loss.item(), batch_size)
            if CFG.accum_iter > 1:
                loss = loss / CFG.accum_iter
            scaler.scale(loss).backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
            if (step + 1) % CFG.accum_iter == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
        elif CFG.device == 'TPU':
            y_preds = model(images)
            loss = criterion(y_preds, labels)
            # record loss
            losses.update(loss.item(), batch_size)
            if CFG.accum_iter > 1:
                loss = loss / CFG.accum_iter
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
            if (step + 1) % CFG.accum_iter == 0:
                xm.optimizer_step(optimizer, barrier=True)
                optimizer.zero_grad()
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
            cusprint('Epoch: [{0}][{1}/{2}] '
                'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                'Elapsed {remain:s} '
                'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                'Grad: {grad_norm:.4f}  '
                'LR: {lr:.7f}  '
                .format(
                epoch, step, len(train_loader), batch_time=batch_time,
                data_time=data_time, loss=losses,
                remain=timeSince(start, float(step+1)/len(train_loader)),
                grad_norm=grad_norm,
                lr=optimizer.param_groups[0]["lr"],
                ))

    return losses.avg, optimizer.param_groups[0]["lr"]

In [19]:
def valid_one_epoch(valid_loader, model, criterion, device):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    # switch to evaluation mode
    model.eval()
    trues = []
    preds = []
    start = end = time.time()
    for step, (images, labels) in enumerate(valid_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device, dtype=torch.float)
        labels = labels.to(device, dtype=torch.float)
        batch_size = labels.size(0)
        # compute loss
        with torch.no_grad():
            y_preds = model(images)
            y_preds = y_preds.squeeze(1)
        loss = criterion(y_preds, labels)
        losses.update(loss.item(), batch_size)
        # record accuracy
        trues.append(labels.to('cpu').numpy())
        preds.append(y_preds.sigmoid().to('cpu').numpy())
        if CFG.accum_iter > 1:
            loss = loss / CFG.accum_iter
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
            cusprint('EVAL: [{0}/{1}] '
                'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                'Elapsed {remain:s} '
                'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                .format(
                step, len(valid_loader), batch_time=batch_time,
                data_time=data_time, loss=losses,
                remain=timeSince(start, float(step+1)/len(valid_loader)),
                ))

    trues = np.concatenate(trues)
    predictions = np.concatenate(preds)
    return losses.avg, predictions, trues

# loss & optimizer & scheduler

In [20]:
class GradualWarmupSchedulerV3(GradualWarmupScheduler):
    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        super(GradualWarmupSchedulerV3, self).__init__(optimizer, multiplier, total_epoch, after_scheduler)
    def get_lr(self):
        if self.last_epoch >= self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]
        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]

# Training

In [21]:
def train_loop(df, fold, trn_idx, val_idx):
    loginfo(f"========== fold: {fold} training ==========")

    # ====================================================
    # loader
    # ====================================================
    train_folds = train_df.loc[trn_idx].reset_index(drop=True)
    valid_folds = train_df.loc[val_idx].reset_index(drop=True)

    train_dataset = TrainDataset(train_folds, transform=get_transforms(data='light_train'))
    valid_dataset = TrainDataset(valid_folds, transform=get_transforms(data='valid'))
    if CFG.device == 'GPU':
        train_loader = DataLoader(train_dataset, batch_size=CFG.train_bs, shuffle=True, num_workers=CFG.num_workers, pin_memory=True, drop_last=True)
        valid_loader = DataLoader(valid_dataset, batch_size=CFG.valid_bs, shuffle=False, num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
    elif CFG.device == 'TPU':
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=CFG.train_bs, sampler=train_sampler, drop_last=True, num_workers=CFG.num_workers)
        valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=False)
        valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=CFG.valid_bs, sampler=valid_sampler, drop_last=False, num_workers=CFG.num_workers)

    # ====================================================
    # model & optimizer & scheduler & loss
    # ====================================================
    # not checkpoint

    if CFG.multihead:
        model = MultiHeadResNet200D([3, 4, 3, 1], True)
    else:
        model = RSNAClassifier(CFG.model_arch, hidden_dim=256, seq_len=24, pretrained=True)

        
    if CFG.gpu_parallel:    
        num_gpu = torch.cuda.device_count()
        model = DataParallel(model, device_ids=range(num_gpu))
    model.to(device)
    
    
    # optimizer
    if CFG.optimizer == "AdamW":
        if CFG.scheduler_warmup in ["GradualWarmupSchedulerV2","GradualWarmupSchedulerV3"]:
            optimizer = AdamW(model.parameters(), lr=CFG.lr/CFG.warmup_factor, weight_decay=CFG.weight_decay) 
        else:
            optimizer = AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)
    # scheduler
    if CFG.scheduler=='ReduceLROnPlateau':
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=CFG.factor, patience=CFG.patience, verbose=True, eps=CFG.eps)
    elif CFG.scheduler=='CosineAnnealingLR':
        scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1)
    elif CFG.scheduler=='CosineAnnealingWarmRestarts':
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.min_lr, last_epoch=-1)

    scheduler_warmup = GradualWarmupSchedulerV3(optimizer, multiplier=10, total_epoch=CFG.warmup_epo, after_scheduler=scheduler)

    # loss
    if CFG.loss_fn == "BCEWithLogitsLoss":
        criterion = nn.BCEWithLogitsLoss()

    # ====================================================
    # loop
    # ====================================================
    valid_acc_max=0; valid_loss_min=float("inf")
    valid_acc_max_cnt=0; valid_loss_min_cnt=0;
    best_acc_epoch=0;


    for epoch in range(CFG.epochs):
        loginfo(f"***** Epoch {epoch} *****")

        if CFG.scheduler_warmup in ["GradualWarmupSchedulerV2","GradualWarmupSchedulerV3"]:
            loginfo(f"schwarmup_last_epoch:{scheduler_warmup.last_epoch}, schwarmup_lr:{scheduler_warmup.get_last_lr()[0]}")
        if CFG.scheduler=='CosineAnnealingLR':
            loginfo(f"scheduler_last_epoch:{scheduler.last_epoch}, scheduler_lr:{scheduler.get_last_lr()[0]}")
        loginfo(f"optimizer_lr:{optimizer.param_groups[0]['lr']}")

                
        start_time = time.time()
        
        avg_loss, cur_lr = train_one_epoch(train_loader, model, criterion, optimizer, epoch, scheduler, device) # train
        avg_val_loss, preds, _ = valid_one_epoch(valid_loader, model, criterion, device) # valid

        # scoring
        elapsed = time.time() - start_time 

        loginfo(f'Epoch {epoch} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')

        if CFG.scheduler_warmup in ["GradualWarmupSchedulerV2","GradualWarmupSchedulerV3"]:
            scheduler_warmup.step()
        elif CFG.scheduler == "ReduceLROnPlateau":
            scheduler.step(avg_val_loss)
        elif CFG.scheduler in ["CosineAnnealingLR", "CosineAnnealingWarmRestarts"]:
            scheduler.step()

        # early stopping
        if avg_val_loss < valid_loss_min:
            valid_loss_min = avg_val_loss
            valid_loss_min_cnt=0
            best_acc_epoch = epoch
        else:
            valid_loss_min_cnt+=1

        if valid_loss_min_cnt >= CFG.n_early_stopping:
            if CFG.device == 'GPU':
                torch.save({'model': model.state_dict()}, outputdir+f'/{CFG.model_arch}_{CFG.suffix}_fold{fold}_epoch{epoch}.pth')
            elif CFG.device == 'TPU':
                xm.save({'model': model.state_dict()}, outputdir+f'/{CFG.model_arch}_{CFG.suffix}_fold{fold}_epoch{epoch}.pth')
            print("early_stopping")
            break

        if CFG.device == 'GPU':
            torch.save({'model': model.state_dict()}, outputdir+f'/{CFG.model_arch}_{CFG.suffix}_fold{fold}_epoch{epoch}.pth')
        elif CFG.device == 'TPU':
            xm.save({'model': model.state_dict()}, outputdir+f'/{CFG.model_arch}_{CFG.suffix}_fold{fold}_epoch{epoch}.pth')
    
    return valid_folds

In [22]:
def main():
    oof_df = pd.DataFrame()
    for fold, (trn_idx, val_idx) in enumerate(folds):
        if fold in CFG.fold_list:
            train_loop(train_df, fold, trn_idx, val_idx)

In [23]:
train_df

Unnamed: 0,study_cid,StudyInstanceUID,cid,slice_num_list,before_image_size,x0,x1,y0,y1,z0,z1,label
0,1.2.826.0.1.3680043.10001_1,1.2.826.0.1.3680043.10001,1,"[50, 51, 52, 53, 55, 56, 57, 59, 60, 61, 63, 6...",320,0,267,2,273,48,277,0
1,1.2.826.0.1.3680043.10001_2,1.2.826.0.1.3680043.10001,2,"[1, 81, 82, 83, 84, 85, 87, 88, 89, 90, 91, 92...",320,0,267,2,273,48,277,0
2,1.2.826.0.1.3680043.10001_3,1.2.826.0.1.3680043.10001,3,"[108, 109, 110, 111, 112, 113, 114, 115, 116, ...",320,0,267,2,273,48,277,0
3,1.2.826.0.1.3680043.10001_4,1.2.826.0.1.3680043.10001,4,"[129, 130, 131, 132, 133, 134, 135, 136, 137, ...",320,0,267,2,273,48,277,0
4,1.2.826.0.1.3680043.10001_5,1.2.826.0.1.3680043.10001,5,"[154, 155, 156, 157, 158, 159, 160, 161, 162, ...",320,0,267,2,273,48,277,0
...,...,...,...,...,...,...,...,...,...,...,...,...
13555,1.2.826.0.1.3680043.9997_3,1.2.826.0.1.3680043.9997,3,"[93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 10...",320,0,254,2,262,3,252,0
13556,1.2.826.0.1.3680043.9997_4,1.2.826.0.1.3680043.9997,4,"[114, 115, 116, 117, 118, 119, 120, 121, 122, ...",320,0,254,2,262,3,252,0
13557,1.2.826.0.1.3680043.9997_5,1.2.826.0.1.3680043.9997,5,"[136, 137, 138, 139, 140, 141, 142, 144, 145, ...",320,0,254,2,262,3,252,0
13558,1.2.826.0.1.3680043.9997_6,1.2.826.0.1.3680043.9997,6,"[158, 159, 160, 161, 162, 163, 164, 165, 166, ...",320,0,254,2,262,3,252,0


# Main

In [24]:
if __name__ == '__main__':
    print(CFG.suffix)
    if CFG.device == 'TPU':
        def _mp_fn(rank, flags):
            torch.set_default_tensor_type('torch.FloatTensor')
            a = main()
        FLAGS = {}
        xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=CFG.nprocs, start_method='fork')
    elif CFG.device == 'GPU':
        main()



401
init GRU(2048, 256, num_layers=2, batch_first=True, bidirectional=True)


***** Epoch 0 *****
schwarmup_last_epoch:0, schwarmup_lr:0.0001
scheduler_last_epoch:0, scheduler_lr:0.0001
optimizer_lr:0.0001


Epoch: [0][0/2712] Data 5.687 (5.687) Elapsed 0m 18s (remain 830m 54s) Loss: 0.6663(0.6663) Grad: 99198.4688  LR: 0.0001000  
Epoch: [0][100/2712] Data 0.000 (0.057) Elapsed 1m 42s (remain 44m 8s) Loss: 0.0023(0.4971) Grad: 708.3766  LR: 0.0001000  
Epoch: [0][200/2712] Data 0.000 (0.029) Elapsed 3m 5s (remain 38m 41s) Loss: 0.0007(0.5808) Grad: 253.1853  LR: 0.0001000  
Epoch: [0][300/2712] Data 0.000 (0.019) Elapsed 4m 29s (remain 36m 0s) Loss: 0.0020(0.6324) Grad: 633.3929  LR: 0.0001000  
Epoch: [0][400/2712] Data 0.000 (0.014) Elapsed 5m 54s (remain 34m 2s) Loss: 1.7771(0.6241) Grad: 91104.1719  LR: 0.0001000  
Epoch: [0][500/2712] Data 0.000 (0.012) Elapsed 7m 20s (remain 32m 26s) Loss: 0.0008(0.6162) Grad: 274.3682  LR: 0.0001000  
Epoch: [0][600/2712] Data 0.000 (0.010) Elapsed 8m 48s (remain 30m 57s) Loss: 1.7343(0.6199) Grad: 91561.3203  LR: 0.0001000  
Epoch: [0][700/2712] Data 0.000 (0.008) Elapsed 10m 17s (remain 29m 31s) Loss: 3.0664(0.6273) Grad: 160607.9062  LR: 0.00010

Epoch 0 - avg_train_loss: 0.6559  avg_val_loss: 0.6317  time: 2797s


EVAL: [338/339] Data 0.000 (0.020) Elapsed 7m 1s (remain 0m 0s) Loss: 0.0022(0.6317) 


***** Epoch 1 *****
schwarmup_last_epoch:1, schwarmup_lr:0.0001
scheduler_last_epoch:0, scheduler_lr:0.0001
optimizer_lr:0.0001


Epoch: [1][0/2712] Data 6.756 (6.756) Elapsed 0m 7s (remain 345m 20s) Loss: 1.4095(1.4095) Grad: inf  LR: 0.0001000  
Epoch: [1][100/2712] Data 0.000 (0.067) Elapsed 1m 31s (remain 39m 14s) Loss: 3.1451(0.6975) Grad: 175137.8125  LR: 0.0001000  
Epoch: [1][200/2712] Data 0.001 (0.034) Elapsed 2m 55s (remain 36m 33s) Loss: 0.0018(0.6920) Grad: 614.0773  LR: 0.0001000  
Epoch: [1][300/2712] Data 0.000 (0.023) Elapsed 4m 25s (remain 35m 25s) Loss: 0.0007(0.6758) Grad: 278.8288  LR: 0.0001000  
Epoch: [1][400/2712] Data 0.000 (0.017) Elapsed 5m 53s (remain 33m 55s) Loss: 1.5173(0.6611) Grad: 85351.6875  LR: 0.0001000  
Epoch: [1][500/2712] Data 0.000 (0.014) Elapsed 7m 24s (remain 32m 43s) Loss: 1.5618(0.6733) Grad: 89555.5547  LR: 0.0001000  
Epoch: [1][600/2712] Data 0.000 (0.012) Elapsed 8m 53s (remain 31m 13s) Loss: 0.0012(0.6694) Grad: 424.2238  LR: 0.0001000  
Epoch: [1][700/2712] Data 0.001 (0.010) Elapsed 10m 22s (remain 29m 45s) Loss: 1.6067(0.6753) Grad: 93016.6484  LR: 0.0001000

Epoch 1 - avg_train_loss: 0.6381  avg_val_loss: 0.6660  time: 2829s


EVAL: [338/339] Data 0.000 (0.024) Elapsed 6m 59s (remain 0m 0s) Loss: 0.0022(0.6660) 


***** Epoch 2 *****
schwarmup_last_epoch:1, schwarmup_lr:9.953895432879837e-05
scheduler_last_epoch:1, scheduler_lr:9.953895432879837e-05
optimizer_lr:9.953895432879837e-05


Epoch: [2][0/2712] Data 4.790 (4.790) Elapsed 0m 5s (remain 258m 19s) Loss: 3.2175(3.2175) Grad: inf  LR: 0.0000995  
Epoch: [2][100/2712] Data 0.000 (0.048) Elapsed 1m 33s (remain 40m 12s) Loss: 0.0029(0.6697) Grad: 1003.5286  LR: 0.0000995  
Epoch: [2][200/2712] Data 0.000 (0.024) Elapsed 3m 1s (remain 37m 52s) Loss: 0.0007(0.6273) Grad: 292.2397  LR: 0.0000995  
Epoch: [2][300/2712] Data 0.000 (0.016) Elapsed 4m 30s (remain 36m 9s) Loss: 0.0010(0.6075) Grad: 389.9406  LR: 0.0000995  
Epoch: [2][400/2712] Data 0.000 (0.012) Elapsed 5m 59s (remain 34m 33s) Loss: 1.5041(0.6274) Grad: 95035.3203  LR: 0.0000995  
Epoch: [2][500/2712] Data 0.000 (0.010) Elapsed 7m 30s (remain 33m 9s) Loss: 1.5984(0.6375) Grad: 99547.6562  LR: 0.0000995  
Epoch: [2][600/2712] Data 0.000 (0.008) Elapsed 9m 5s (remain 31m 54s) Loss: 0.0013(0.6390) Grad: 492.4354  LR: 0.0000995  
Epoch: [2][700/2712] Data 0.000 (0.007) Elapsed 10m 33s (remain 30m 18s) Loss: 0.0013(0.6512) Grad: 513.3612  LR: 0.0000995  
Epoch

Epoch 2 - avg_train_loss: 0.6162  avg_val_loss: 0.6561  time: 2834s


EVAL: [338/339] Data 0.000 (0.035) Elapsed 6m 30s (remain 0m 0s) Loss: 0.0026(0.6561) 


***** Epoch 3 *****
schwarmup_last_epoch:1, schwarmup_lr:9.816440572371606e-05
scheduler_last_epoch:2, scheduler_lr:9.816440572371606e-05
optimizer_lr:9.816440572371606e-05


Epoch: [3][0/2712] Data 5.334 (5.334) Elapsed 0m 6s (remain 281m 48s) Loss: 0.0017(0.0017) Grad: 2695.4114  LR: 0.0000982  
Epoch: [3][100/2712] Data 0.000 (0.053) Elapsed 1m 29s (remain 38m 40s) Loss: 0.0014(0.7823) Grad: 546.2412  LR: 0.0000982  
Epoch: [3][200/2712] Data 0.000 (0.027) Elapsed 2m 58s (remain 37m 14s) Loss: 0.0013(0.7194) Grad: 516.4739  LR: 0.0000982  
Epoch: [3][300/2712] Data 0.000 (0.018) Elapsed 4m 23s (remain 35m 8s) Loss: 0.0021(0.7114) Grad: 783.5538  LR: 0.0000982  
Epoch: [3][400/2712] Data 0.003 (0.014) Elapsed 5m 50s (remain 33m 40s) Loss: 0.0007(0.6757) Grad: 304.9323  LR: 0.0000982  
Epoch: [3][500/2712] Data 0.000 (0.011) Elapsed 7m 22s (remain 32m 34s) Loss: 0.0013(0.6752) Grad: 496.9733  LR: 0.0000982  
Epoch: [3][600/2712] Data 0.000 (0.009) Elapsed 8m 46s (remain 30m 48s) Loss: 0.0023(0.6554) Grad: 408.0723  LR: 0.0000982  
Epoch: [3][700/2712] Data 0.000 (0.008) Elapsed 10m 10s (remain 29m 10s) Loss: 1.5860(0.6480) Grad: 49812.3438  LR: 0.0000982  

Epoch 3 - avg_train_loss: 0.6093  avg_val_loss: 0.6482  time: 2810s


EVAL: [338/339] Data 0.000 (0.030) Elapsed 6m 58s (remain 0m 0s) Loss: 0.0032(0.6482) 


***** Epoch 4 *****
schwarmup_last_epoch:1, schwarmup_lr:9.590195942451992e-05
scheduler_last_epoch:3, scheduler_lr:9.590195942451992e-05
optimizer_lr:9.590195942451992e-05


Epoch: [4][0/2712] Data 6.348 (6.348) Elapsed 0m 7s (remain 328m 57s) Loss: 0.0027(0.0027) Grad: 3991.7380  LR: 0.0000959  
Epoch: [4][100/2712] Data 0.000 (0.063) Elapsed 1m 30s (remain 39m 2s) Loss: 1.7613(0.7640) Grad: 112354.5547  LR: 0.0000959  
Epoch: [4][200/2712] Data 0.000 (0.032) Elapsed 2m 54s (remain 36m 19s) Loss: 1.2983(0.7483) Grad: 44874.1133  LR: 0.0000959  
Epoch: [4][300/2712] Data 0.001 (0.021) Elapsed 4m 21s (remain 34m 52s) Loss: 0.0039(0.7072) Grad: 706.7645  LR: 0.0000959  
Epoch: [4][400/2712] Data 0.000 (0.016) Elapsed 5m 53s (remain 33m 58s) Loss: 0.0028(0.7180) Grad: 530.8735  LR: 0.0000959  
Epoch: [4][500/2712] Data 0.000 (0.013) Elapsed 7m 28s (remain 33m 0s) Loss: 1.4613(0.6988) Grad: 47154.8594  LR: 0.0000959  
Epoch: [4][600/2712] Data 0.000 (0.011) Elapsed 8m 57s (remain 31m 27s) Loss: 0.0035(0.6901) Grad: 635.4487  LR: 0.0000959  
Epoch: [4][700/2712] Data 0.000 (0.009) Elapsed 10m 26s (remain 29m 56s) Loss: 2.7936(0.6741) Grad: 92029.1250  LR: 0.000

Epoch 4 - avg_train_loss: 0.6091  avg_val_loss: 0.6138  time: 2880s


EVAL: [338/339] Data 0.000 (0.029) Elapsed 7m 8s (remain 0m 0s) Loss: 0.0032(0.6138) 


***** Epoch 5 *****
schwarmup_last_epoch:1, schwarmup_lr:9.279376052505117e-05
scheduler_last_epoch:4, scheduler_lr:9.279376052505117e-05
optimizer_lr:9.279376052505117e-05


Epoch: [5][0/2712] Data 5.234 (5.234) Elapsed 0m 6s (remain 280m 29s) Loss: 0.0035(0.0035) Grad: 5107.8882  LR: 0.0000928  
Epoch: [5][100/2712] Data 0.000 (0.052) Elapsed 1m 34s (remain 40m 38s) Loss: 0.0012(0.6301) Grad: 527.2274  LR: 0.0000928  
Epoch: [5][200/2712] Data 0.000 (0.026) Elapsed 3m 3s (remain 38m 8s) Loss: 0.0036(0.6670) Grad: 686.1704  LR: 0.0000928  
Epoch: [5][300/2712] Data 0.000 (0.018) Elapsed 4m 32s (remain 36m 19s) Loss: 1.4865(0.6259) Grad: 54357.7930  LR: 0.0000928  
Epoch: [5][400/2712] Data 0.000 (0.013) Elapsed 6m 4s (remain 34m 58s) Loss: 0.0042(0.6449) Grad: 759.5908  LR: 0.0000928  
Epoch: [5][500/2712] Data 0.000 (0.011) Elapsed 7m 33s (remain 33m 20s) Loss: 0.0014(0.6024) Grad: 296.0289  LR: 0.0000928  
Epoch: [5][600/2712] Data 0.000 (0.009) Elapsed 9m 10s (remain 32m 14s) Loss: 0.0011(0.5950) Grad: 246.0741  LR: 0.0000928  
Epoch: [5][700/2712] Data 0.000 (0.008) Elapsed 10m 39s (remain 30m 33s) Loss: 0.0019(0.6043) Grad: 407.3957  LR: 0.0000928  
E

Epoch 5 - avg_train_loss: 0.6122  avg_val_loss: 0.6854  time: 2856s


EVAL: [338/339] Data 0.000 (0.028) Elapsed 6m 34s (remain 0m 0s) Loss: 0.0023(0.6854) 


***** Epoch 6 *****
schwarmup_last_epoch:1, schwarmup_lr:8.889770888986878e-05
scheduler_last_epoch:5, scheduler_lr:8.889770888986878e-05
optimizer_lr:8.889770888986878e-05


Epoch: [6][0/2712] Data 6.106 (6.106) Elapsed 0m 7s (remain 316m 56s) Loss: 1.9379(1.9379) Grad: inf  LR: 0.0000889  
Epoch: [6][100/2712] Data 0.000 (0.061) Elapsed 1m 30s (remain 38m 53s) Loss: 0.0013(0.6771) Grad: 568.0807  LR: 0.0000889  
Epoch: [6][200/2712] Data 0.001 (0.031) Elapsed 2m 54s (remain 36m 14s) Loss: 0.0021(0.7136) Grad: 419.6077  LR: 0.0000889  
Epoch: [6][300/2712] Data 0.000 (0.021) Elapsed 4m 23s (remain 35m 11s) Loss: 0.0043(0.6574) Grad: 819.7164  LR: 0.0000889  
Epoch: [6][400/2712] Data 0.000 (0.016) Elapsed 5m 50s (remain 33m 39s) Loss: 3.3744(0.6332) Grad: 114526.8281  LR: 0.0000889  
Epoch: [6][500/2712] Data 0.000 (0.013) Elapsed 7m 19s (remain 32m 18s) Loss: 0.0017(0.6224) Grad: 365.0926  LR: 0.0000889  
Epoch: [6][600/2712] Data 0.000 (0.011) Elapsed 8m 53s (remain 31m 14s) Loss: 0.0018(0.6227) Grad: 386.4292  LR: 0.0000889  
Epoch: [6][700/2712] Data 0.000 (0.009) Elapsed 10m 21s (remain 29m 43s) Loss: 0.0021(0.6208) Grad: 441.9611  LR: 0.0000889  
Epo

Epoch 6 - avg_train_loss: 0.6209  avg_val_loss: 0.6332  time: 2851s


EVAL: [338/339] Data 0.000 (0.050) Elapsed 7m 12s (remain 0m 0s) Loss: 0.0034(0.6332) 


***** Epoch 7 *****
schwarmup_last_epoch:1, schwarmup_lr:8.428638058932337e-05
scheduler_last_epoch:6, scheduler_lr:8.428638058932337e-05
optimizer_lr:8.428638058932337e-05


Epoch: [7][0/2712] Data 4.671 (4.671) Elapsed 0m 5s (remain 253m 37s) Loss: 0.0018(0.0018) Grad: 3064.9773  LR: 0.0000843  
Epoch: [7][100/2712] Data 0.000 (0.047) Elapsed 1m 34s (remain 40m 37s) Loss: 0.0009(0.6666) Grad: 436.1357  LR: 0.0000843  
Epoch: [7][200/2712] Data 0.000 (0.024) Elapsed 3m 2s (remain 38m 0s) Loss: 1.8154(0.6833) Grad: 127537.1641  LR: 0.0000843  
Epoch: [7][300/2712] Data 0.000 (0.016) Elapsed 4m 39s (remain 37m 17s) Loss: 0.0009(0.6377) Grad: 409.5872  LR: 0.0000843  
Epoch: [7][400/2712] Data 0.000 (0.012) Elapsed 6m 7s (remain 35m 17s) Loss: 0.0010(0.6478) Grad: 236.0156  LR: 0.0000843  
Epoch: [7][500/2712] Data 0.000 (0.010) Elapsed 7m 41s (remain 33m 56s) Loss: 0.0020(0.6413) Grad: 427.3482  LR: 0.0000843  
Epoch: [7][600/2712] Data 0.000 (0.008) Elapsed 9m 9s (remain 32m 11s) Loss: 0.0031(0.6362) Grad: 592.2186  LR: 0.0000843  
Epoch: [7][700/2712] Data 0.000 (0.007) Elapsed 10m 38s (remain 30m 32s) Loss: 0.0039(0.6397) Grad: 767.4644  LR: 0.0000843  
E

Epoch 7 - avg_train_loss: 0.5575  avg_val_loss: 0.5184  time: 2827s


EVAL: [338/339] Data 0.000 (0.034) Elapsed 6m 36s (remain 0m 0s) Loss: 0.0122(0.5184) 


***** Epoch 8 *****
schwarmup_last_epoch:1, schwarmup_lr:7.904567594468593e-05
scheduler_last_epoch:7, scheduler_lr:7.904567594468593e-05
optimizer_lr:7.904567594468593e-05


Epoch: [8][0/2712] Data 5.662 (5.662) Elapsed 0m 6s (remain 296m 19s) Loss: 3.0041(3.0041) Grad: inf  LR: 0.0000790  
Epoch: [8][100/2712] Data 0.000 (0.056) Elapsed 1m 31s (remain 39m 31s) Loss: 1.6212(0.6081) Grad: 114909.3047  LR: 0.0000790  
Epoch: [8][200/2712] Data 0.000 (0.028) Elapsed 2m 58s (remain 37m 15s) Loss: 0.0011(0.6411) Grad: 511.6925  LR: 0.0000790  
Epoch: [8][300/2712] Data 0.000 (0.019) Elapsed 4m 27s (remain 35m 45s) Loss: 1.3211(0.6553) Grad: 97393.2891  LR: 0.0000790  
Epoch: [8][400/2712] Data 0.000 (0.014) Elapsed 6m 0s (remain 34m 35s) Loss: 0.0008(0.6075) Grad: 394.3598  LR: 0.0000790  
Epoch: [8][500/2712] Data 0.001 (0.012) Elapsed 7m 32s (remain 33m 15s) Loss: 0.0011(0.6102) Grad: 247.7598  LR: 0.0000790  
Epoch: [8][600/2712] Data 0.000 (0.010) Elapsed 8m 58s (remain 31m 32s) Loss: 0.0021(0.6047) Grad: 456.4678  LR: 0.0000790  
Epoch: [8][700/2712] Data 0.000 (0.008) Elapsed 10m 23s (remain 29m 47s) Loss: 1.4335(0.6090) Grad: 52839.7773  LR: 0.0000790  


Epoch 8 - avg_train_loss: 0.6025  avg_val_loss: 0.6427  time: 2833s


EVAL: [338/339] Data 0.000 (0.031) Elapsed 6m 24s (remain 0m 0s) Loss: 0.0028(0.6427) 


***** Epoch 9 *****
schwarmup_last_epoch:1, schwarmup_lr:7.327321936769202e-05
scheduler_last_epoch:8, scheduler_lr:7.327321936769202e-05
optimizer_lr:7.327321936769202e-05


Epoch: [9][0/2712] Data 6.890 (6.890) Elapsed 0m 7s (remain 354m 14s) Loss: 0.0025(0.0025) Grad: 4339.9404  LR: 0.0000733  
Epoch: [9][100/2712] Data 0.000 (0.068) Elapsed 1m 35s (remain 41m 6s) Loss: 0.0006(0.6393) Grad: 317.6314  LR: 0.0000733  
Epoch: [9][200/2712] Data 0.002 (0.035) Elapsed 3m 5s (remain 38m 39s) Loss: 1.8923(0.6567) Grad: 134585.1562  LR: 0.0000733  
Epoch: [9][300/2712] Data 0.000 (0.023) Elapsed 4m 37s (remain 37m 2s) Loss: 0.0010(0.6341) Grad: 445.4470  LR: 0.0000733  
Epoch: [9][400/2712] Data 0.001 (0.018) Elapsed 6m 6s (remain 35m 14s) Loss: 1.6327(0.6444) Grad: 117080.4062  LR: 0.0000733  
Epoch: [9][500/2712] Data 0.000 (0.014) Elapsed 7m 40s (remain 33m 52s) Loss: 0.0015(0.6387) Grad: 348.5658  LR: 0.0000733  
Epoch: [9][600/2712] Data 0.000 (0.012) Elapsed 9m 9s (remain 32m 9s) Loss: 0.0014(0.6493) Grad: 324.0967  LR: 0.0000733  
Epoch: [9][700/2712] Data 0.001 (0.010) Elapsed 10m 38s (remain 30m 31s) Loss: 1.4078(0.6580) Grad: 51645.5625  LR: 0.0000733 

Epoch 9 - avg_train_loss: 0.6074  avg_val_loss: 0.6418  time: 2844s


EVAL: [338/339] Data 0.000 (0.032) Elapsed 6m 17s (remain 0m 0s) Loss: 0.0028(0.6418) 


***** Epoch 10 *****
schwarmup_last_epoch:1, schwarmup_lr:6.707654080246381e-05
scheduler_last_epoch:9, scheduler_lr:6.707654080246381e-05
optimizer_lr:6.707654080246381e-05


Epoch: [10][0/2712] Data 3.676 (3.676) Elapsed 0m 4s (remain 207m 33s) Loss: 0.0014(0.0014) Grad: 2642.0720  LR: 0.0000671  
Epoch: [10][100/2712] Data 0.000 (0.037) Elapsed 1m 28s (remain 38m 3s) Loss: 3.5081(0.5861) Grad: 121843.9688  LR: 0.0000671  
Epoch: [10][200/2712] Data 0.000 (0.019) Elapsed 2m 53s (remain 36m 7s) Loss: 0.0021(0.6047) Grad: 453.8343  LR: 0.0000671  
Epoch: [10][300/2712] Data 0.000 (0.013) Elapsed 4m 25s (remain 35m 27s) Loss: 0.0042(0.6014) Grad: 809.5601  LR: 0.0000671  
Epoch: [10][400/2712] Data 0.000 (0.010) Elapsed 5m 53s (remain 33m 57s) Loss: 0.0027(0.5841) Grad: 570.4831  LR: 0.0000671  
Epoch: [10][500/2712] Data 0.000 (0.008) Elapsed 7m 27s (remain 32m 54s) Loss: 2.9846(0.5787) Grad: 110222.7344  LR: 0.0000671  
Epoch: [10][600/2712] Data 0.001 (0.007) Elapsed 8m 58s (remain 31m 32s) Loss: 3.4908(0.5711) Grad: 127294.1172  LR: 0.0000671  
Epoch: [10][700/2712] Data 0.000 (0.006) Elapsed 10m 29s (remain 30m 6s) Loss: 0.0032(0.5752) Grad: 650.3676  LR

Epoch 10 - avg_train_loss: 0.6032  avg_val_loss: 0.5984  time: 2778s


EVAL: [338/339] Data 0.000 (0.039) Elapsed 6m 39s (remain 0m 0s) Loss: 0.0047(0.5984) 


***** Epoch 11 *****
schwarmup_last_epoch:1, schwarmup_lr:6.057107264610536e-05
scheduler_last_epoch:10, scheduler_lr:6.057107264610536e-05
optimizer_lr:6.057107264610536e-05


Epoch: [11][0/2712] Data 6.976 (6.976) Elapsed 0m 7s (remain 358m 21s) Loss: 0.0032(0.0032) Grad: 5384.3267  LR: 0.0000606  
Epoch: [11][100/2712] Data 0.000 (0.069) Elapsed 1m 36s (remain 41m 38s) Loss: 0.0018(0.5479) Grad: 744.2637  LR: 0.0000606  
Epoch: [11][200/2712] Data 0.000 (0.035) Elapsed 3m 5s (remain 38m 34s) Loss: 0.0007(0.6083) Grad: 365.3005  LR: 0.0000606  
Epoch: [11][300/2712] Data 0.000 (0.023) Elapsed 4m 33s (remain 36m 31s) Loss: 0.0008(0.6124) Grad: 399.8741  LR: 0.0000606  
Epoch: [11][400/2712] Data 0.000 (0.018) Elapsed 6m 2s (remain 34m 47s) Loss: 0.0031(0.5824) Grad: 1180.0817  LR: 0.0000606  
Epoch: [11][500/2712] Data 0.000 (0.014) Elapsed 7m 30s (remain 33m 9s) Loss: 0.0010(0.5772) Grad: 477.1198  LR: 0.0000606  
Epoch: [11][600/2712] Data 0.002 (0.012) Elapsed 9m 0s (remain 31m 39s) Loss: 1.1021(0.6157) Grad: 44732.3594  LR: 0.0000606  
Epoch: [11][700/2712] Data 0.000 (0.010) Elapsed 10m 33s (remain 30m 16s) Loss: 0.0011(0.5997) Grad: 268.9457  LR: 0.000

Epoch 11 - avg_train_loss: 0.6039  avg_val_loss: 0.5875  time: 2796s


EVAL: [338/339] Data 0.000 (0.032) Elapsed 6m 30s (remain 0m 0s) Loss: 0.0044(0.5875) 
early_stopping


In [25]:
# save as cpu
if CFG.device == 'TPU': 
    for fold in range(CFG.fold_num):
        if fold in CFG.fold_list:
            # best score
            state = torch.load(outputdir+f'{CFG.model_arch}_{CFG.suffix}_fold{fold}_epoch{cur_best_list[4]}.pth')
            torch.save({'model': state['model'].to('cpu').state_dict(), 'preds': state['preds'], 'cur_best_list': state['cur_best_list']}, 
                    outputdir+f'{CFG.model_arch}_{CFG.suffix}_fold{fold}_epoch{cur_best_list[4]}_cpu.pth')