## Imports

In [None]:
import numpy as np
import librosa as lb
import librosa.display as lbd
import soundfile as sf
from  soundfile import SoundFile
import pandas as pd
from  IPython.display import Audio
from pathlib import Path

import torch
import librosa
from torchvision import transforms
from torch import nn, optim
from  torch.utils.data import Dataset, DataLoader
from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR
from resnest.torch import resnest50, resnest101
from torch.optim import Adam, SGD, AdamW
from common import *
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold
from matplotlib import pyplot as plt
from torchlibrosa.augmentation import SpecAugmentation
import os, random, gc
import re, time, json
from  ast import literal_eval
import timm
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
from IPython.display import Audio
from sklearn.metrics import label_ranking_average_precision_score

from tqdm.notebook import tqdm
import joblib

In [None]:
from efficientnet_pytorch import EfficientNet
import pretrainedmodels
import resnest.torch as resnest_torch

In [None]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything()

## Config

In [None]:
# Define all the hyperparms with their values
class CFG:
    debug=False
    num_classes=397
    model_name='resnest50'
    augs = ['white_noise','pink_noise','bandpass_noise', 'upper']
    num_workers=0
    batch_size=64
    epochs=60
    n_mels=128
    len_check=281
    exp_name='augs_mix3_3_col_cor_60'
    scheduler='CosineAnnealingLR' # ['ReduceLROnPlateau', 'CosineAnnealingLR', 'CosineAnnealingWarmRestarts']
#     factor=0.4 # ReduceLROnPlateau
#     patience=0 # ReduceLROnPlateau
#     eps=1e-6 # ReduceLROnPlateau
    T_max=80 # CosineAnnealingLR
#     T_0=20 #osineAnnealingWarmRestarts
    pretrained=True
    checkpoint=False
    save=True
    train=True
    base_lr=0.001 
    min_lr=1e-6
    n_fold=10
    trn_fold=[0]
    seed=42

## Paths

In [None]:
DATA_ROOT = Path(r"Birdcall")
MODEL_ROOT = Path(r"Birdcall\models")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
OUTPUT_DIR = f"Birdcall/models/{CFG.model_name}_{CFG.exp_name}/fold{CFG.trn_fold}"
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
meta = pd.read_csv(r"Birdcall\meta.csv")
meta = meta.fillna('')

meta['labels_bg'] = meta['labels_bg'].replace({'rocpig1 solsan whtdov':'rocpig solsan whtdov'})
meta['labels_bg'] = meta['labels_bg'].replace({'rocpig1 grtgra':'rocpig grtgra'})
meta['labels_bg'] = meta['labels_bg'].replace({'rewbla rocpig1 cangoo saypho killde amerob':'rewbla rocpig cangoo saypho killde amerob'})

In [None]:
meta[meta['labels_bg'].str.contains('rocpig1')]

In [None]:
meta.head()

In [None]:
meta.shape

In [None]:
df_train = pd.read_csv(r"Birdcall\train_metadata.csv")

LABEL_IDS = {label: label_id for label_id,label in enumerate(sorted(df_train["primary_label"].unique()))}
INV_LABEL_IDS = {val: key for key,val in LABEL_IDS.items()}

## 5 fold-split

In [None]:
folds = meta.copy()
Fold = StratifiedKFold(n_splits=CFG.n_fold, shuffle=True, random_state=CFG.seed)
for n, (train_index, val_index) in enumerate(Fold.split(folds, folds['bird'])):
    folds.loc[val_index, 'fold'] = int(n)
folds['fold'] = folds['fold'].astype(int)
print(folds.groupby(['fold']).size())

In [None]:
def get_model(name, num_classes=CFG.num_classes):
    """
    Loads a pretrained model. 
    Supports ResNest, ResNext-wsl, EfficientNet, ResNext and ResNet.

    Arguments:
        name {str} -- Name of the model to load

    Keyword Arguments:
        num_classes {int} -- Number of classes to use (default: {1})

    Returns:
        torch model -- Pretrained model
    """
    if "resnest" in name:
        model = getattr(resnest_torch, name)(pretrained=True)
    elif "wsl" in name:
        model = torch.hub.load("facebookresearch/WSL-Images", name)
    elif name.startswith("resnext") or  name.startswith("resnet"):
        model = torch.hub.load("pytorch/vision:v0.6.0", name, pretrained=True)
    elif name.startswith("densenet"):
        model = getattr(timm.models.densenet, name)(pretrained=True)
    elif name.startswith("tf_efficientnet_b"):
        model = getattr(timm.models.efficientnet, name)(pretrained=True)
    elif "efficientnet-b" in name:
        model = EfficientNet.from_pretrained(name)
    else:
        model = pretrainedmodels.__dict__[name](pretrained='imagenet')

    if hasattr(model, "fc"):
        nb_ft = model.fc.in_features
        model.fc = nn.Linear(nb_ft, num_classes)
    elif hasattr(model, "_fc"):
        nb_ft = model._fc.in_features
        model._fc = nn.Linear(nb_ft, num_classes)
    elif hasattr(model, "classifier"):
        nb_ft = model.classifier.in_features
        model.classifier = nn.Linear(nb_ft, num_classes)
    elif hasattr(model, "last_linear"):
        nb_ft = model.last_linear.in_features
        model.last_linear = nn.Linear(nb_ft, num_classes)

    return model

## Load data on RAM

In [None]:
def load_data(df):
    def load_row(row):
        # impath = TRAIN_IMAGES_ROOT/f"{row.primary_label}/{row.filename}.npy"
        return row.file, np.load(str(row.impath))
    pool = joblib.Parallel(4)
    mapper = joblib.delayed(load_row)
    tasks = [mapper(row) for row in df.itertuples(False)]
    res = pool(tqdm(tasks))
    res = dict(res)
    return res

In [None]:
audio_image_store = load_data(meta)
len(audio_image_store)

## Augmentations

In [None]:
def random_power(images, power = 1.5, c= 0.7):
    images = images - images.min()
    images = images/(images.max()+0.0000001)
    images = images**(random.random()*power + c)
    return images

def mono_to_color(X: np.ndarray,len_check, mean=0.5, std=0.5, eps=1e-6):
    trans = transforms.Compose([transforms.ToPILImage(),
                                        transforms.Resize([CFG.n_mels, len_check]), transforms.ToTensor(),
                                        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
    X = np.stack([X, X, X], axis=-1)
    V = (255 * X).astype(np.uint8)
    V = (trans(V)+1)/2
    return V

## Dataset definition

In [None]:
class BirdClefDataset(Dataset):
    def __init__(self, audio_image_store,bird_list,LABEL_IDS,INV_LABEL_IDS,is_train):
        # Initialize the list of melspectrograms
        self.audio_image_store = audio_image_store
        self.bird_list = bird_list
        self.n_mels = CFG.n_mels
        self.noise = pd.read_csv(r"Birdcall\no_call.csv")
        self.len_check = CFG.len_check
        self.count_bird = CFG.num_classes
        self.augs = CFG.augs
        self.BIRD_CODE = LABEL_IDS
        self.INV_BIRD_CODE = INV_LABEL_IDS
        self.stop_border = 0.3 # Probability of stopping mixing
        self.level_noise = 0.05 # level noise
        self.div_coef = 100 # signal amplification during mixing
        self.is_train = is_train
        
    def __len__(self):
        return len(self.bird_list)

    def __getitem__(self, idx):
        if self.is_train:
            idx2 = random.randint(0, len(self.bird_list)-1) # Second file
            idx3 = random.randint(0, len(self.bird_list)-1) # Third file not added as of now

            y = torch.zeros(self.count_bird)
            birds, background = [],[]

            # Length of the segment
            self.len_check = random.randint(281-14, 281+21)
            #self.len_chack = self.hp.len_chack[0]
            
            images = np.zeros((self.n_mels, self.len_check)).astype(np.float32)     
            for i,idy in enumerate([idx,idx2,idx3]):
                # Choosing a record with a bird 
                sample = self.bird_list.iloc[idy]
                # Uploading a record with a bird 
                mel = self.audio_image_store[sample.file]
                mel = mel[np.random.choice(len(mel))]
                # Birds in the file
                labels_bird = sample.bird.split()
                for bird in labels_bird:
                    if not bird in birds and bird != CFG.num_classes:
                        birds.append(self.BIRD_CODE[bird])

                # Birds in the background   
                if sample.labels_bg:
                    labels_bg = sample.labels_bg.split()
                    for bg in labels_bg:
                        if not bg in background:
                            background.append(self.BIRD_CODE[bg])

                # Select the piece that contains the sound 
                if mel.shape[1]>self.len_check: 
                    start = random.randint(0, mel.shape[1] - self.len_check - 1)
                    mel = mel[:, start : start + random.randint(self.len_check-14, self.len_check)]
                else:
                    len_zero = random.randint(0, self.len_check-mel.shape[1])
                    mel = np.concatenate((np.zeros((self.n_mels,len_zero)),mel), axis=1)

                mel = np.concatenate((mel,np.zeros((self.n_mels,self.len_check-mel.shape[1]))), axis=1)

                # Change the contrast
                mel = random_power(mel, power = 3, c= 0.5)
                #mel = librosa.power_to_db(mel.astype(np.float32), ref=np.max)
                #mel = (mel+80)/80

                # Mix the signal
                images = images + mel*(random.random() * self.div_coef + 1)

                # Abort accidentally
                if random.random()<self.stop_border:
                    break

            # Add a different sound with second bird... 
            idy = random.randint(0, len(self.noise)-1)
            sample = self.noise.loc[idy, :]
            mel = np.load(sample.impath)
            mel = mel[np.random.choice(len(mel))]
            mel = np.concatenate((np.zeros((self.n_mels,self.len_check)),mel), axis=1)
            mel = np.concatenate((mel,np.zeros((self.n_mels,self.len_check))), axis=1)
            start = random.randint(0, mel.shape[1] - self.len_check - 1)
            mel = mel[:, start : start + self.len_check]

            mel = random_power(mel)
#             mel = lb.power_to_db(mel.astype(np.float32), ref=np.max)
#             mel = (mel+80)/80
            images = images + mel/(mel.max()+0.0000001)*(random.random()*1+0.5)*images.max()
            
            # In db and normalize
#             images = librosa.power_to_db(images.astype(np.float32), ref=np.max)
#             images = (images+80)/80

            # Add noise
            # Add white noise 
            if random.random()<0.7 and 'white_noise' in self.augs:
#                 print('white noise')
                images = images + (np.random.sample((self.n_mels,self.len_check)).astype(np.float32)+9) * images.mean() * self.level_noise * (np.random.sample() + 0.3)

            # Add pink noise
            if random.random()<0.7 and 'pink_noise' in self.augs:
#                 print('Pink Noise')
                r = random.randint(1,self.n_mels)
                pink_noise = np.array([np.concatenate((1 - np.arange(r)/r,np.zeros(self.n_mels-r)))]).T
                images = images + (np.random.sample((self.n_mels,self.len_check)).astype(np.float32)+9) * 2  * images.mean() * self.level_noise * (np.random.sample() + 0.3)

            # Add bandpass noise
            if random.random()<0.7 and 'bandpass_noise' in self.augs:
#                 print('bandpass noise')
                a = random.randint(0, self.n_mels//2)
                b = random.randint(a+20, self.n_mels)
                images[a:b,:] = images[a:b,:] + (np.random.sample((b-a,self.len_check)).astype(np.float32)+9) * 0.1 * images.mean() * self.level_noise  * (np.random.sample() + 0.3)


            # Lower the upper frequencies
            if random.random()<0.5 and 'upper' in self.augs:
#                 print('upper')
                images = images - images.min()
                r = random.randint(self.n_mels//2,self.n_mels)
                x = random.random()/2
                pink_noise = np.array([np.concatenate((1-np.arange(r)*x/r,np.zeros(self.n_mels-r)-x+1))]).T
                images = images*pink_noise
                images = images/images.max()

            # Change the contrast
    #         print('random power')
            images = random_power(images, power = 2, c= 0.7)

            # Expand to 3 channels
            #images = torch.from_numpy(np.stack([images, images, images])).float()
            images = mono_to_color(images,281)

            # Draw pictures
            if random.random()<0.00001:
                img = images.numpy()
                img = img - img.min()
                img = img/img.max()
                img = np.moveaxis(img, 0, 2)
                imgplot = plt.imshow(img)
                plt.savefig("Birdcall/images/"+("_".join(self.INV_BIRD_CODE[x] for x in birds))+'_train_'+sample.file+'.png')    

            # The background is 0.3, and the marked bird is 1
            for bird in background:
                if bird < len(y):
                    y[bird]=0.3
            for bird in birds:
                #if not bird==264:
                y[bird]=1
            return images, y
        
        else:
            y = torch.zeros(self.count_bird)
            birds, background = [],[]
            
            sample = self.bird_list.iloc[idx]
            mel = self.audio_image_store[sample.file]
            mel = mel[np.random.choice(len(mel))]
            images = mel
            # Birds in the file
            labels_bird = sample.bird.split()
            for bird in labels_bird:
                if not bird in birds and bird != CFG.num_classes:
                    birds.append(self.BIRD_CODE[bird])
                    
            # Birds in the background   
            if sample.labels_bg:
                labels_bg = sample.labels_bg.split()
                for bg in labels_bg:
                    if not bg in background:
                        background.append(self.BIRD_CODE[bg])
            
            images = random_power(images, power = 2, c= 0.7)
            images = mono_to_color(images,281)
            # Draw pictures
            if random.random()<0.00001:
                img = images.numpy()
                img = img - img.min()
                img = img/img.max()
                img = np.moveaxis(img, 0, 2)
                imgplot = plt.imshow(img)
                plt.savefig("E:/Birdcall/images/"+("_".join(self.INV_BIRD_CODE[x] for x in birds))+'_valid_'+sample.file+'.png')    
            
            # The background is 0.3, and the marked bird is 1
            for bird in background:
                if bird < len(y):
                    y[bird]=0.3
            for bird in birds:
                #if not bird==264:
                y[bird]=1
            return images, y

## Training Utils

In [None]:
@torch.no_grad()
def evaluate(net, criterion, valid_loader):
    net.eval()

    os, y = [], []
    valid_loader = tqdm(valid_loader, leave = False, total=len(valid_loader))

    for icount, (xb, yb) in  enumerate(valid_loader):

        y.append(yb.to(DEVICE))

        xb = xb.to(DEVICE)
        o = net(xb)

        os.append(o)

    y = torch.cat(y)
    o = torch.cat(os)

    l = criterion(o, y).item()
    
    o = o.sigmoid()
    y = (y > 0.5)*1.0

    lrap = label_ranking_average_precision_score(y.cpu().numpy(), o.cpu().numpy())

    o = (o > 0.5)*1.0

    prec = ((o*y).sum()/(1e-6 + o.sum())).item()
    rec = ((o*y).sum()/(1e-6 + y.sum())).item()
    f1 = 2*prec*rec/(1e-6+prec+rec)

    return l, lrap, f1, rec, prec, 
    

In [None]:
def one_step(x_data, y_data, net, criterion, optimizer, scaler):

  x_data, y_data = x_data.to(DEVICE), y_data.to(DEVICE)
        
  optimizer.zero_grad()
  output = net(x_data)
  with torch.cuda.amp.autocast():
    loss = criterion(output, y_data)
  scaler.scale(loss).backward()
  scaler.step(optimizer)
  scaler.update()
  
  with torch.no_grad():
      l = loss.item()

      output = output.sigmoid()
      y_data = (y_data > 0.5 )*1.0
      lrap = label_ranking_average_precision_score(y_data.cpu().numpy(), output.cpu().numpy())

      output = (output > 0.5)*1.0

      prec = (output*y_data).sum()/(1e-6 + output.sum())
      rec = (output*y_data).sum()/(1e-6 + y_data.sum())
      f1 = 2*prec*rec/(1e-6+prec+rec)

  return l, lrap, f1.item(), rec.item(), prec.item()

In [None]:
def one_epoch(epoch,net, criterion, optimizer, scheduler, scaler, train_loader, valid_loader):
    net.train()
    l, lrap, prec, rec, f1, icount = 0.,0.,0.,0., 0., 0
    train_loader = tqdm(train_loader, leave = False)
    epoch_bar = train_loader
    scaler = scaler
    
    for (xb, yb) in  epoch_bar:
        # epoch_bar.set_description("----|----|----|----|---->")
        _l, _lrap, _f1, _rec, _prec = one_step(xb, yb, net, criterion, optimizer, scaler)
        l += _l
        lrap += _lrap
        f1 += _f1
        rec += _rec
        prec += _prec

        icount += 1
        if hasattr(epoch_bar, "set_postfix") and not icount%10:
            epoch_bar.set_postfix(
                loss="{:.3f}".format(l/icount),
                lrap="{:.3f}".format(lrap/icount),
                prec="{:.3f}".format(prec/icount),
                rec="{:.3f}".format(rec/icount),
                f1="{:.3f}".format(f1/icount),
                lr="{:.6f}".format(optimizer.param_groups[0]['lr'])
              )
            
    scheduler.step()

    l /= icount
    lrap /= icount
    f1 /= icount
    rec /= icount
    prec /= icount
  
    l_val, lrap_val, f1_val, rec_val, prec_val = evaluate(net, criterion, valid_loader)
  
    return (l, l_val), (lrap, lrap_val), (f1, f1_val), (rec, rec_val), (prec, prec_val)

## Training loop

In [None]:
# ====================================================
# Train loop
# ====================================================
def train_loop(folds,fold):
    
    save_root = MODEL_ROOT/f"{CFG.model_name}_{CFG.exp_name}/fold{CFG.trn_fold}"
    save_root.mkdir(exist_ok=True, parents=True)
    
    trn_idx = folds[folds['fold'] != fold].index
    val_idx = folds[folds['fold'] == fold].index
    
    train_folds = folds.loc[trn_idx].reset_index(drop=True)
    valid_folds = folds.loc[val_idx].reset_index(drop=True)
    del train_folds['fold']
    del valid_folds['fold']
    
    log = Logger()
    log.open(OUTPUT_DIR + '/log.train.txt', mode='a')
    log.write('\n--- [START %s] %s\n\n' % (IDENTIFIER, '-' * 64))
    log.write('\t%s\n' % COMMON_STRING)
    log.write('\t__file__ = %s\n' % CFG.exp_name)
    log.write('\tout_dir  = %s\n' % OUTPUT_DIR)
    log.write('\n')
    # ====================================================
    # dataset and loader
    # ====================================================
    
    train_dataset = BirdClefDataset(audio_image_store,train_folds,LABEL_IDS,INV_LABEL_IDS,is_train=True)
    
    valid_dataset = BirdClefDataset(audio_image_store,valid_folds,LABEL_IDS,INV_LABEL_IDS,is_train=False)
    
    train_loader = DataLoader(train_dataset, 
                              batch_size=CFG.batch_size,
                              num_workers=CFG.num_workers,
                              shuffle=True,
                              pin_memory=True)
    
    valid_loader = DataLoader(valid_dataset,
                              batch_size=CFG.batch_size,
                              num_workers=CFG.num_workers,
                              shuffle=False)
    
    # ====================================================
    # scheduler 
    # ====================================================
    def get_scheduler(optimizer):
        if CFG.scheduler=='ReduceLROnPlateau':
            scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=CFG.factor, patience=CFG.patience, verbose=True, eps=CFG.eps)
        elif CFG.scheduler=='CosineAnnealingLR':
            scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max,eta_min=CFG.min_lr,verbose=True)
        elif CFG.scheduler=='CosineAnnealingWarmRestarts':
            scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.min_lr,verbose=True)
        return scheduler
    
    # ====================================================
    # model & optimizer
    # ====================================================
    log.write('** net setting **\n')

    log.write('** start training here! **\n')
    log.write('   is_mixed_precision = %s \n' % str(True))
    log.write('   batch_size = %d\n' % (CFG.batch_size))
    log.write('   experiment = %s\n' % str(CFG.exp_name))
    def message():
        text = \
            'E:%.1f | ' % (epoch) + \
            'L:%.5f  %.4f | ' % (l, l_val) + \
            'P:%.3f  %.3f | ' % (prec,prec_val) + \
            'R:%.3f  %.3f | ' % (rec,rec_val) + \
            'F:%.3f  %.3f | ' % (f1,f1_val) + \
            'Lr:%.3f  %.3f | ' % (lrap,lrap_val)

        return text
              
    if CFG.checkpoint:
        print('Loading checkpoint')
        ckpt = torch.load(r"Birdcall\models\resnest50_augs_mix3_3_col_cor\fold[0]\resnest50_fold0_e59_augs_mix3_3_col_cor.pth") # load your last pth file here and change the last_epoch value in get_scheduler()
        net = get_model(CFG.model_name)
        net.load_state_dict(ckpt['net'])
        net.to(DEVICE)
#     net = Net(CFG.model_name,pretrained=True)
        optimizer = Adam(net.parameters(), lr=CFG.base_lr, amsgrad=False)
        optimizer.load_state_dict(ckpt['optimizer'])
        
        scheduler = get_scheduler(optimizer)
        scheduler.load_state_dict(ckpt['scheduler'])
        
        criterion = nn.BCEWithLogitsLoss()

        # Stochastic Weighted Averaging
    #     swa_model = AveragedModel(net)
    #     swa_scheduler = SWALR(optimizer, swa_lr=CFG.swa_lr)

        epochs_bar = tqdm(list(range(CFG.epochs)), leave=False)
        # FP-16
        scaler = torch.cuda.amp.GradScaler()
        e = ckpt['epoch']
        print(f"Epoch:{e} - Loss:{ckpt['loss']} - F1:{ckpt['f1']} - F1_val:{ckpt['f1_val']} - LR{ckpt['scheduler']}")
        for epoch  in epochs_bar:
            epoch = epoch+ e+1 
            epochs_bar.set_description(f"[EPOCH {epoch:02d}]")
            net.train()

            (l, l_val), (lrap, lrap_val), (f1, f1_val), (rec, rec_val), (prec, prec_val) = one_epoch(
                epoch,
                net=net,
                criterion=criterion,
                optimizer=optimizer,
                scheduler=scheduler,
                scaler=scaler,
                train_loader=train_loader,
                valid_loader=valid_loader,
              )

            epochs_bar.set_postfix(
            loss="({:.6f}, {:.6f})".format(l, l_val),
            prec="({:.3f}, {:.3f})".format(prec, prec_val),
            rec="({:.3f}, {:.3f})".format(rec, rec_val),
            f1="({:.3f}, {:.3f})".format(f1, f1_val),
            lrap="({:.3f}, {:.3f})".format(lrap, lrap_val),
            lr="({:.5f})".format(optimizer.param_groups[0]['lr'])
            )

            print(
                "[{epoch:02d}] L: {loss} Lr: {lrap} F1: {f1} R: {rec} P: {prec} LR:{lr}".format(
                    epoch=epoch,
                    loss="({:.6f}, V:{:.6f})".format(l, l_val),
                    prec="({:.3f}, V:{:.3f})".format(prec, prec_val),
                    rec="({:.3f}, V:{:.3f})".format(rec, rec_val),
                    f1="({:.3f}, V:{:.3f})".format(f1, f1_val),
                    lrap="({:.3f}, V:{:.3f})".format(lrap, lrap_val),
                    lr="({:.5f})".format(optimizer.param_groups[0]['lr'])
                )
            )
            log.write(message() + '\n')
            if CFG.save:
                torch.save({'net': net.state_dict(), 
                          'optimizer': optimizer.state_dict(), 
                          'scheduler': scheduler.state_dict(), 
                          'epoch': epoch,
                          'loss':l,
                          'f1':f1, 
                          'f1_val':f1_val
                        },
                          OUTPUT_DIR+f'/{CFG.model_name}_fold{fold}_e{epoch}_cont_{CFG.exp_name}.pth')
    else:
        net = get_model(CFG.model_name).to(DEVICE)
        optimizer = Adam(net.parameters(), lr=CFG.base_lr, amsgrad=False)

        scheduler = get_scheduler(optimizer)
        criterion = nn.BCEWithLogitsLoss()

        # Stochastic Weighted Averaging
    #     swa_model = AveragedModel(net)
    #     swa_scheduler = SWALR(optimizer, swa_lr=CFG.swa_lr)

        epochs_bar = tqdm(list(range(CFG.epochs)), leave=False)
        # FP-16
        scaler = torch.cuda.amp.GradScaler()
        for epoch  in epochs_bar:
            epochs_bar.set_description(f"[EPOCH {epoch:02d}]")
            net.train()

            (l, l_val), (lrap, lrap_val), (f1, f1_val), (rec, rec_val), (prec, prec_val) = one_epoch(
                epoch,
                net=net,
                criterion=criterion,
                optimizer=optimizer,
                scheduler=scheduler,
                scaler=scaler,
                train_loader=train_loader,
                valid_loader=valid_loader,
              )

            epochs_bar.set_postfix(
            loss="({:.6f}, {:.6f})".format(l, l_val),
            prec="({:.3f}, {:.3f})".format(prec, prec_val),
            rec="({:.3f}, {:.3f})".format(rec, rec_val),
            f1="({:.3f}, {:.3f})".format(f1, f1_val),
            lrap="({:.3f}, {:.3f})".format(lrap, lrap_val),
            )

            print(
                "[{epoch:02d}] L: {loss} Lr: {lrap} F1: {f1} R: {rec} P: {prec}".format(
                    epoch=epoch,
                    loss="({:.6f}, V:{:.6f})".format(l, l_val),
                    prec="({:.3f}, V:{:.3f})".format(prec, prec_val),
                    rec="({:.3f}, V:{:.3f})".format(rec, rec_val),
                    f1="({:.3f}, V:{:.3f})".format(f1, f1_val),
                    lrap="({:.3f}, V:{:.3f})".format(lrap, lrap_val),
                )
            )
            log.write(message() + '\n')
            if CFG.save:
                torch.save({'net': net.state_dict(), 
                          'optimizer': optimizer.state_dict(), 
                          'scheduler': scheduler.state_dict(), 
                          'epoch': epoch,
                          'loss':l,
                          'f1':f1, 
                          'f1_val':f1_val
                        },
                          OUTPUT_DIR+f'/{CFG.model_name}_fold{fold}_e{epoch}_{CFG.exp_name}.pth')

In [None]:
def main():
    gc.collect()
    torch.cuda.empty_cache()
    
    if CFG.train:
        for fold in range(CFG.n_fold):
            if fold in CFG.trn_fold:
                train_loop(folds,fold)

In [None]:
if __name__ == '__main__':
    main()