In [1]:
# !pip3 install -q -U openmim
# !mim install -q mmcv-full

In [2]:
# import sys 
# sys.path = [ '../input/dk-1st-data-2/kaggle_data_models', 
#             '../input/dk-1st-data-3/configs'] + sys.path

In [3]:
import os, gc, cv2, sys, random, argparse, importlib

import numpy as np
import pandas as pd
from glob import glob
from tqdm import tqdm
from types import SimpleNamespace

import torch
import torch.nn as nn
from shutil import copyfile
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.sampler import Sampler

# https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/
#from torchcontrib.optim import SWA

from transformers import get_cosine_schedule_with_warmup
from transformers import get_linear_schedule_with_warmup

from sklearn.metrics import roc_auc_score, accuracy_score, matthews_corrcoef, f1_score
from resnet3d_csn import ResNet3dCSN

#import warnings
from warnings import filterwarnings
filterwarnings("ignore")

from albumentations import ReplayCompose

  'On January 1, 2023, MMCV will release v2.0.0, in which it will remove '
  warn(f"Failed to load image Python extension: {e}")


In [4]:
# !pip3 install -q torchview
# !pip3 install -q -U graphviz
# import graphviz
# from torchview import draw_graph
# graphviz.set_jupyter_format('png')

In [5]:
parser = argparse.ArgumentParser(description="PyTorch")

# base directory "config" inside which file "config_1.py" is present.
# parser.add_argument("-C", "--config", default='config_2', type=str, help="config filename") # for kaggle.
parser.add_argument("-C", "--config", default='configs.config_2', type=str, help="config filename")
    # config_1.py file for player-to-ground and config_2.py file for player-to-player.


parser.add_argument("-M", "--mode", default='train', type=str, help="mode type")
parser.add_argument("-T", "--tta", default=0, help="is use tta for inference")

parser_args, unknown = parser.parse_known_args()

In [6]:
cfg = importlib.import_module(parser_args.config).cfg
# NModel = importlib.import_module(cfg.model).NModel
# NDataset = importlib.import_module(cfg.dataset).NDataset

In [7]:
## for kaggle notebook =>
# cfg.out_dir = './'
# cfg.train_csv_path = '../input/dk-1st-data-3/slicing_not_g.csv'
# cfg.data = '../input/dk-1st-data/kaggle_data/trk_pos.npy'
# cfg.path = '../input/dk-1st-data-3/'

cfg.out_dir = './models'
cfg.train_csv_path = 'slicing_not_g.csv' 
    # slicing_g.csv for 11_G, slicing_g_next.csv for 15_G, slicing_not_g.csv for 15_all
    
cfg.data = 'kaggle_data/trk_pos.npy'
cfg.path = './'


In [8]:
class SimpleClassSampler(Sampler):
    def __init__(self, df, cfg):
        self.cfg=cfg
        self.df = df.reset_index(drop=True)
        self.index_class1 = self.df[self.df.contact==1].index.to_list()
        self.index_class0 = self.df[self.df.contact==0].index.to_list()
        
        self.length = int(self.cfg.pos_frac*(len(self.index_class1))) + int(self.cfg.frac*(len(self.index_class0)))

    def __iter__(self):
        random_choice1 = np.random.choice(self.index_class1, int(self.cfg.pos_frac*(len(self.index_class1))), replace=False)

        random_choice0 = np.random.choice(self.index_class0, int(self.cfg.frac*(len(self.index_class0))), replace=False)

        print('======',len(random_choice0), len(random_choice1))
        # print('======',random_choice0[:10], random_choice1[:10])
        
        all_indexs = list(random_choice0) + list(random_choice1)

        l = np.array(all_indexs)
        l = l.reshape(-1)
        random.shuffle(l)
        return iter(l)

    def __len__(self):
        return int(self.length)

class NDataset(Dataset):
    def __init__(self, cfg, df, tfms=None, fold_id = 0, is_train = True):
        super().__init__()
        self.df = df.reset_index(drop=True)#.sample(frac = 1.0, random_state=42) 
        self.cfg = cfg
        self.fold_id = fold_id
        self.transform = tfms
        self.is_train = is_train
        self.feat_cols = ['x_position_1', 'y_position_1', 'distance']
        self.trk_step = 0
        if self.is_train:
            self.e_transform = self.cfg.train_e_transform
            self.s_transform = self.cfg.train_s_transform
        else:
            self.e_transform = self.cfg.val_e_transform
            self.s_transform = self.cfg.val_s_transform
            
        self.trk_pos = np.load(cfg.data, allow_pickle=True).item()

        print(f'Fold: {fold_id}, is_train: {is_train}, total frame {len(self.df)}')

    def __getitem__(self, index):
        row = self.df.loc[index]
            # row => 
            # path        slicing_g/58188_001358_46171_G_0591_50
            # fold                                             2
            # contact                                          0
            # distance                                  0.274579
            # step                                            50
            # e_empty                                          0
            # s_empty                                          0
            # Name: 0, dtype: object        

        path = row['path']
        step = row['step']
        idx = path.split('/')[-1]
        vid = '_'.join(idx.split('_')[:2])

        if '_ext' in path:
            idx1 = idx.split('_')[2]
        else:
            idx1 = int(idx.split('_')[2])
            
        if not self.cfg.is_G:
            if '_ext' in path:
                idx2 = idx.split('_')[3]
            else:
                idx2 = int(idx.split('_')[3])
        
        path = cfg.path + path
        e_path = f'{path}_e.npy'
        s_path = f'{path}_s.npy'

        e_images = np.load(e_path)
        s_images = np.load(s_path)

        if self.cfg.skip_frame > 0:
            e_images = e_images[self.cfg.skip_frame:-self.cfg.skip_frame,:,:,:]
            s_images = s_images[self.cfg.skip_frame:-self.cfg.skip_frame,:,:,:]

        num_empty = 0
        for img in e_images:
            h, w, c = img.shape 
                # img.shape, np.sum(img<2), h*w*c => (256, 256, 3), 18, 196608
                # 18 / 196608 => 0.00009
            if np.sum(img<2)/(h*w*c) > 0.9:
                num_empty += 1
                
            # len(e_images), num_empty => 23, 0        
        if len(e_images) - num_empty < 2:
            # print('e empty')
            e_images = s_images.copy()

        num_empty = 0
        for img in s_images:
            h, w, c = img.shape 
            if np.sum(img<2)/(h*w*c) > 0.9:
                num_empty += 1
                
        if len(s_images) - num_empty < 2:
            # print('s empty')
            s_images = e_images.copy()

        if self.trk_step == 0:
            self.trk_step = len(e_images)//2
        
        # self.trk_step => 11        
        if not self.cfg.is_G:
            trk_images = self.render_trk(vid, step, idx1, idx2, self.trk_pos)

        replay = None
        e_images_ = []
        # applying same augmentation to all frames of particular vid.
        for img in e_images:
            if replay is None:
                sample = self.e_transform(image=img)
                replay = sample["replay"] # 'replay' key contains the play that played with image.
            else:
                sample = ReplayCompose.replay(replay, image=img)
            img = sample["image"]
            e_images_.append(img)
        
        e_images = None # saving space
        
        replay = None
        s_images_ = []
        # applying same augmentation to all frames of particular vid.
        for img in s_images:
            if replay is None:
                sample = self.s_transform(image=img)
                replay = sample["replay"]
            else:
                sample = ReplayCompose.replay(replay, image=img)
            img = sample["image"]
            s_images_.append(img)

        s_images = None # saving space
        
        #simple trk image augmentation
        flip_trk_lr = False
        flip_trk_ud = False
        is_swap = False
        if self.is_train:
            if random.random() < 0.5:
                flip_trk_lr = True
            if random.random() < 0.5:
                flip_trk_ud = True
            if random.random() < 0.5:
                is_swap = True

        # s_images_[0].shape => (256, 256, 3)
        images = []
        for i in range(len(s_images_)):
            if not self.cfg.is_G:
                trk_img = trk_images[i]
                if flip_trk_lr:
                    trk_img = np.fliplr(trk_img)
                    # reverse the order of elements along axis 1 (left/right).
                if flip_trk_ud:
                    trk_img = np.flipud(trk_img)
                    # Reverse the order of elements along axis 0 (up/down).
                    
                # s_img = np.vstack([trk_img, s_images_[i]])
                if is_swap:
                    img = np.hstack([s_images_[i], trk_img, e_images_[i]])
                else:
                    img = np.hstack([e_images_[i], trk_img, s_images_[i]])
            else:
                if is_swap:
                    img = np.hstack([s_images_[i], e_images_[i]])
                else:
                    img = np.hstack([e_images_[i], s_images_[i]])

            images.append(img)
        
        
        e_images_ = s_images_ = None # saving space
        
        # images[0].shape => (256, 512, 3)            
        img = np.array(images)
        # img.shape => (23, 256, 512, 3)  
        
        images = None # saving space
            
        if self.cfg.model in ['model_25d']:
            # img = img.transpose(0,3,1,2)
            img = np.concatenate(img, axis=2)
            img = img.transpose(2,0,1)
        else:
            img = img.transpose(3,0,1,2) #C T H W
            
        
        # img.shape => (3,23,256,512)
        img = img/255.

        img = torch.from_numpy(img)

        target = row['contact']

        if self.cfg.use_meta:
            feat = row[self.feat_cols]
        else:
            feat = target

        if self.cfg.use_oof and self.is_train:
            feat = row['pred']

        # print(img.shape, mask.shape)

        return torch.tensor(img, dtype=torch.float), torch.tensor(target, dtype=torch.float), torch.tensor(feat, dtype=torch.float)

    def render_trk(self, vid, step, idx1, idx2, trk_pos):
        '''
        append image - showing players head position for both video at particular step. both video => end and side view.
        'away' team players' head position is appended in the 3rd channel of image.
        'home' team players' head position is appended in the 2nd channel of image.
        both team players' head position is appended in the 1st channel of image.
        '''

        if self.is_train:
            shift_x = random.randint(-10,10)
            # random.randint(a,b) => Return a random integer N such that a <= N <= b.
            # shift_x => 8
            shift_y = random.randint(-20,20)
            # shift_y => -4            
        else:
            shift_x = 0
            shift_y = 0

        # d_x = 0.1*random.randint(30,60)
        d_x = 5
        scale = 60/d_x # scale to fit players in image size (256,128).
        
        idx = f'{vid}_{step}'
        images = []
        x1 = trk_pos[idx][idx1]['x']#idx1 => player_id_1
        y1 = trk_pos[idx][idx1]['y']

        x2 = trk_pos[idx][idx2]['x']#idx2 => player_id_2
        y2 = trk_pos[idx][idx2]['y']
        
        # x1,y1,x2,y2 => 40.33, 25.28, 40.11, 26.73        
        xc = 0.5*x1 + 0.5*x2
        yc = 0.5*y1 + 0.5*y2
        # xc, yc => 40.22 26.005000000000003        
        # step, self.trk_step => 0, 9 (9 when data is not_ground)
        for st in range(step-self.trk_step, step+self.trk_step + 1):
            this_idx = f'{vid}_{st}'
            img = np.ones((3, self.cfg.img_size, 128), dtype=np.uint8)
            # img.shape => (3, 256, 128)
            if this_idx in trk_pos:
                for p_id, meta in trk_pos[this_idx].items():
                    x = meta['x']
                    y = meta['y']
                    t = meta['t']
                    # x,y,t,d_x => 52.83, 26.05, away, 5                    
                    x = x - xc + d_x # d_x; as length of x-axis is 128
                    y = y - yc + (2*d_x) # 2*d_x; as length of y-axis is 256
                    # x,y,scale => 17.61, 10.044999999999998, 12.0
                    
                    # shift_x and shift_y; every time to show different proportional length b/w players.
                    x = round(x*scale) + shift_x
                    y = round(y*scale) + shift_y
                    # x,y => 204, 119

                    if x>0 and y>0 and x<128 and y<self.cfg.img_size:
                        if self.cfg.trk_type == 1:
                            #v1
                            radius = 3
                            val = 125# other players pixel value
                            if p_id in [idx1, idx2]:# p_id => player_id
                                radius = 5
                                val = 255# focus player pixel value

                            cv2.circle(img[0], (x, y), radius, val, thickness=-1)
                            if t == 'home':
                                cv2.circle(img[1], (x, y), radius, val, thickness=-1)
                            else:
                                cv2.circle(img[2], (x, y), radius, val, thickness=-1)

                        elif self.cfg.trk_type == 2:
                            ##v2
                            radius = 4
                            val = 255
                            if p_id in [idx1, idx2]:
                                radius = 6
                                val = 255

                            # if p_id in [idx1, idx2]:
                            #     cv2.circle(img[0], (x, y), radius, val, thickness=-1)
                            if t == 'home':
                                cv2.circle(img[1], (x, y), radius, val, thickness=-1)
                            else:
                                cv2.circle(img[2], (x, y), radius, val, thickness=-1)

            img = img.transpose(1,2,0)
            images.append(img)

        return images

    def __len__(self):
        return len(self.df) 

In [9]:
# for debugging
# c = NDataset(cfg, a, tfms=b, fold_id = 0, is_train=True)
# k = c[0]

In [10]:
# import matplotlib.pyplot as plt

In [11]:
# # plotter Cid = 7
# plt.rcParams["figure.figsize"] = (20,8)
# fx, arr = plt.subplots(1,1)
# arr.imshow(np.fliplr(k[-1]))

In [12]:
class NModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        self.backbone = ResNet3dCSN(
            pretrained2d=False,
            in_channels = 3,
            pretrained=None,
            depth=int(cfg.model_name[1:-2]),
            with_pool2=False,
            bottleneck_mode=cfg.model_name[-2:],
            norm_eval=False,
            zero_init_residual=False)

        #self.final = nn.Linear(2048+1024, out_features=1)
        self.final = nn.Linear(2048+1024, out_features=1)
        if cfg.pool_type == 'avg':
            self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1)) #if pool =="avg" else 
        else:
            self.avg_pool = nn.AdaptiveMaxPool3d((1, 1, 1))
        self.dropout = nn.Dropout(0.0)# 0.5

    def forward(self, x):
            # x,shape => torch.Size([2, 3, 23, 256, 512])
        bs = x.size(0)  
            # bs => batch_size
            
        if x.size(1) == 1:
                # x.shape => torch.Size([2, 1, 23, 256, 512])
            x = x.repeat(1, 3, 1, 1, 1)[:, :, :, :, :]
                # x.shape => torch.Size([2, 3, 23, 256, 512])
            
        x = self.backbone(x)
        # x = x[-1]
        # x = self.avg_pool(x)

        # type(x) => tuple
        # x[-1].shape => torch.Size([2, 2048, 3, 8, 16])
        # x[-2].shape => torch.Size([2, 1024, 6, 16, 32])
        x_fast = self.avg_pool(x[-2])
        # x_fast.shape => torch.Size([2, 1024, 1, 1, 1])
        x_slow = self.avg_pool(x[-1])
        # x_slow.shape => torch.Size([2, 2048, 1, 1, 1])
            
        x = torch.cat((x_slow, x_fast), dim=1)
            # x.shape => torch.Size([2, 3072, 1, 1, 1])
        x = self.dropout(x)
        x = x.flatten(start_dim=1)# start_dim => start dimension
            # x.shape => torch.Size([2, 3072])

        x = self.final(x)# mat1 and mat2 shapes can be multiplied (2x3072 and 3072x1)
            # x.shape => torch.Size([2, 1])

        return {'out1':x, 'emb':x}

In [13]:
#m = NModel(cfg)

In [14]:
#draw_graph(m, input_data = torch.randn(1, 3, 23, 256, 512), expand_nested=True, save_graph=True, device='cpu').visual_graph

In [15]:
os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"]="1"
def seed_everything(random_seed):
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    os.environ["PYTHONHASHSEED"] = str(random_seed)
#     if torch.cuda.is_available():
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    #         torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False        
    torch.backends.cudnn.deterministic = True

In [16]:
os.makedirs("output/", exist_ok=True)

In [17]:
# function to selectively load "state_dict keys" from two state_dict.
def intersect_dicts(da, db, exclude=()):
    # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
    return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}


In [18]:
def logfile(message, lg_path):
    print(message)
    with open(lg_path, 'a+') as logger:
        logger.write(f'{message}\n')

In [19]:
def get_optimizer(cfg, model):
    params = [{
            "params": [param for name, param in model.named_parameters()],
            "lr": cfg.lr,
        }]

    if cfg.optimizer == "Adam":
        optimizer = torch.optim.Adam(params, lr=params[0]["lr"], weight_decay=cfg.weight_decay)
    elif cfg.optimizer == "SGD":
        optimizer = torch.optim.SGD(params, lr=params[0]["lr"], momentum=0.9, nesterov=True,weight_decay=cfg.weight_decay)
    elif cfg.optimizer == "AdamW":
        optimizer = torch.optim.AdamW(params, lr=params[0]["lr"], weight_decay=cfg.weight_decay)

    return optimizer

In [20]:
def get_scheduler(cfg, optimizer, total_steps=0):
    iter_update = False
    if cfg.scheduler == "steplr":
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40000, gamma=0.8)
    elif cfg.scheduler == "cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.epochs, eta_min=1e-6, verbose=False)
    elif cfg.scheduler == "linear":
        iter_update = True
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=0,
            num_training_steps=(total_steps // cfg.batch_size),
        )
        print("num_steps", (total_steps // cfg.batch_size))
    elif cfg.scheduler == "step":
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 25, 30], gamma=0.5, verbose=False)
    elif cfg.scheduler == "cosinewarmup":
        iter_update = True
        scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=cfg.warmup,
            num_training_steps=(total_steps // cfg.batch_size),
        )
    else:
        scheduler = None

    return scheduler, iter_update

In [21]:
def get_dataloader(fold_id = 0):
    train_transform = cfg.train_transform
    val_transform = cfg.val_transform

    df = pd.read_csv(cfg.train_csv_path)
    print(df.shape)
    
    if cfg.use_oof:
        ffs = []
        for ff in [0,1,2,3,4]:
            if ff != fold_id:
                f_df = pd.read_csv(f'{cfg.pl_path}/oof_f{ff}.csv')
                ffs.append(f_df)
        train_df = pd.concat(ffs)
    else:
        # train_df = df[50000:][df.fold!=fold_id]
        train_df = df[df.fold!=fold_id]

    if not cfg.sampler:
        train_df1 = train_df[train_df.contact==1]
        train_df0 = train_df[train_df.contact==0]
        
        # shuffling the both dataframe.
        train_df1 = train_df1.sample(frac=1.0)
        train_df0 = train_df0.sample(frac=1.0)

            # cfg.frac, train_df1.shape[0] => 3, 111
        num_neg = int(cfg.frac*train_df1.shape[0])
            # num_neg => 333 (3 times the positive samples)
            
        train_df0 = train_df0.head(num_neg)
        train_df = pd.concat([train_df0, train_df1])

    # shuffling the dataframe.
    train_df = train_df.sample(frac = 1.0) 
    print(train_df.contact.value_counts())
    
    # return df, train_transform
    train_dataset = NDataset(cfg, train_df, tfms=train_transform, fold_id = fold_id, is_train=True)

    if fold_id < 5:
        val_df = df[df.fold==fold_id]
    else:
        val_df = df[df.fold==0]

    if cfg.mode not in ['val']:
        # val_df = val_df.head(100)
        # val_df = val_df.sample(frac=0.1, random_state=42)
        val_df1 = val_df[val_df.contact==1]
        val_df0 = val_df[val_df.contact==0]
        val_df1 = val_df1.sample(frac=0.1*cfg.val_frac, random_state=42)
        val_df0 = val_df0.sample(frac=0.04*cfg.val_frac, random_state=42)
        val_df = pd.concat([val_df0, val_df1])
        val_df = val_df.sample(frac = 1.0, random_state=42) 
        print(val_df.contact.value_counts())

    val_dataset = NDataset(cfg, val_df, tfms=val_transform,  fold_id = fold_id, is_train = False)
    
    print('cfg.sampler => ', cfg.sampler)
    if not cfg.sampler:
        train_dataloader = DataLoader(train_dataset,
            batch_size=cfg.batch_size,
            num_workers=cfg.num_workers,
            shuffle=True)
    else:
        train_dataloader = DataLoader(train_dataset,
            batch_size=cfg.batch_size,
            sampler=SimpleClassSampler(train_df, cfg),
            num_workers=cfg.num_workers,
            drop_last=True
            )

    val_dataloader = DataLoader(val_dataset,
        batch_size=cfg.batch_size,
        num_workers=cfg.num_workers,
        shuffle=False)

    total_steps = len(train_dataloader)*cfg.batch_size

    # for i, batch in enumerate(train_dataset):
    #     img, lb, f = batch
    #     print(img.shape)
    #     img = img.numpy().transpose(1,2,3,0)*255
    #     # out = img[5,:,:,:3]
    #     # mask0 = np.uint8(255*mask[0].numpy())
    #     # print(mask0.shape)
    #     # mask0 = cv2.cvtColor(mask0,cv2.COLOR_GRAY2RGB)
    #     # out = np.hstack([img, mask0])
    #     for ii in range(img.shape[0]):
    #         im = img[ii]
    #         cv2.imwrite(f'{cfg.out_dir}/s{i}_{ii}.jpg', im)
    #     if i>10:
    #         break
    # exit()

    return train_dataloader, val_dataloader, total_steps, val_df


In [22]:
# a,b = get_dataloader(0)

In [23]:
def valid_func(model, val_loader, tta=int(parser_args.tta)):
    if cfg.loss_fn == 'bce':
        loss_cls_fn = nn.BCEWithLogitsLoss()
    elif cfg.loss_fn == 'focal':
        loss_cls_fn = BCEFocalLoss()
    else:
        loss_cls_fn = nn.CrossEntropyLoss()

    y_preds = []
    y_trues = []
    count = 1
    device = cfg.device
    model.eval()
    with torch.no_grad():
        losses = []
        bar = tqdm(val_loader)
        for batch_idx, batch_data in enumerate(bar):
            if cfg.debug and batch_idx>10:
                break
            images, lb, feat = batch_data
            images = images.float().to(device)
            if cfg.use_meta:
                pred = model(images, feat.to(device))
                logit = pred['out1']
            else:
                pred = model(images)
                logit = pred['out1']
                if tta:
                    pred1 = model(images.flip(-1))
                    logit = 0.5*logit + 0.5*pred1['out1']

            if cfg.loss_fn in ['bce', 'focal']:
                loss = loss_cls_fn(logit, lb.to(device).unsqueeze(-1))
            else:
                loss = loss_cls_fn(logit, lb.to(device).long())
            
            losses.append(loss.item())
            smooth_loss = np.mean(losses[:])

            bar.set_description(f'loss: {loss.item():.5f}, smth: {smooth_loss:.5f}')

            y_trues.append(lb.detach().cpu().numpy())
            if cfg.loss_fn in ['bce', 'focal']:
                out = logit.sigmoid().detach().cpu().numpy()
            else:
                out = logit.softmax(-1).detach().cpu().numpy()[:,1]
                # print(out.shape)
            y_preds.append(out)

    y_preds = np.concatenate(y_preds).astype(np.float64)
    y_trues = np.concatenate(y_trues).astype(np.float64)
    print(y_preds.shape, y_trues.shape)

    y_preds = y_preds.reshape(-1)
    y_trues = y_trues.reshape(-1)

    if cfg.mode == 'test':
        acc = 1
        auc = 1
        macro_score = 1
    else:
        acc = matthews_corrcoef(y_trues>0.5, y_preds > 0.5)
        auc = roc_auc_score(y_trues>0.5, y_preds)
        # micro_score = f1_score(y_trues>0.5, y_preds > 0.5, average='micro')
        macro_score = f1_score(y_trues>0.5, y_preds > 0.5, average='macro')

    val_loss = np.mean(losses)
    return val_loss, auc, acc, macro_score, y_preds

In [24]:
def main():
    random_seed = 42
    _ = seed_everything(random_seed)
    
    f_preds = []
    models = []
    for fold_id in cfg.folds:
        device = cfg.device
        model = NModel(cfg)

        if cfg.num_freeze > 0 and cfg.mode == 'train':
            for cc, (name, params) in enumerate(model.backbone.named_parameters()):
                # if cc < cfg.num_freeze or 'layer1.' in name or 'layer2.' in name or 'layer3.' in name:
                if cc < cfg.num_freeze or 'layer1.' in name:
                # if cc < cfg.num_freeze:
                    print(f'layer {cc} {name} is frozen!')
                    params.requires_grad = False

        model.to(device)        
        
        if cfg.mode == 'train':
            log_path = f'{cfg.out_dir}/log_f{fold_id}.txt'
            logfile(f'====== FOLD {fold_id} =======', log_path)

            if len(cfg.load_weight) > 10:
                if '.pth' in cfg.load_weight:
                    load_weight = cfg.load_weight
                else:
                    load_weight = f'{cfg.load_weight}_last_f{(fold_id)%5}.pth'

                logfile(f'load pretrained weight {load_weight}!!!', log_path)
                state_dict = torch.load(load_weight, map_location=device)  # load checkpoint
                if 'state_dict' in state_dict.keys():
                    state_dict = state_dict['state_dict']
                state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=[])  # intersect
                model.load_state_dict(state_dict, strict=False)  # load
                logfile('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), load_weight), log_path)  # report
                del state_dict
                gc.collect()
            
            train_loader, valid_loader, total_steps, val_df =  get_dataloader(fold_id)
            total_steps = total_steps*cfg.epochs

            optimizer = get_optimizer(cfg, model)

            if cfg.use_swa:
                optimizer = SWA(optimizer, swa_start=500, swa_freq=50, swa_lr=None)

            scheduler, iter_update = get_scheduler(cfg, optimizer, total_steps)

            best_loss = 1e6
            best_auc = best_macro = best_acc = metric = 0
            step = 0
            l_acc = [5,7,9,11]# values to be assigned to n_accumulate.
            for epoch in range(1,cfg.epochs+1):
                if not iter_update:
                    scheduler.step(epoch)
                if cfg.loss_fn == 'bce':
                    loss_cls_fn = nn.BCEWithLogitsLoss()
                elif cfg.loss_fn == 'focal':
                    loss_cls_fn = BCEFocalLoss()
                else:
                    loss_cls_fn = nn.CrossEntropyLoss()

                # enabled (bool, optional):  If ``False``, disables gradient scaling.
                scaler = torch.cuda.amp.GradScaler(enabled=cfg.apex)

                model.train()
                losses = []
                bar = tqdm(train_loader)
                                
                cfg.n_accumulate = l_acc[-1]                
                #  np.roll; roll array elements along a given axis.
                l_acc = np.roll(l_acc, shift=1)
                
                k=0                
                for batch_idx, batch_data in enumerate(bar):                    
                    step +=1
                    if cfg.debug and batch_idx>10:
                        break
                    images, lb, feat = batch_data
                    # lb.shape => torch.Size([2])  
                    # lb is target or ['contact'].
                    
                    if cfg.use_meta:
                        pred = model(images.float().to(device), feat.to(device))
                    else:                        
                        pred = model(images.float().to(device))                                                
                    logit = pred['out1']
                    # print(logit.shape)
                    if cfg.loss_fn in ['bce', 'focal']:
                        loss = loss_cls_fn(logit, lb.to(device).unsqueeze(-1))
                            # lb.to(device).unsqueeze(-1) => torch.Size([2, 1])
                    else: 
                        loss = loss_cls_fn(logit, lb.to(device).long())

                    if cfg.use_oof:
                        loss1 = loss_cls_fn(logit, feat.to(device).unsqueeze(-1))
                        loss = 0.5*loss + 0.5*loss1
                    
                    scaler.scale(loss/cfg.n_accumulate).backward()
                    
                    if (k+1)%cfg.n_accumulate == 0:
                        scaler.step(optimizer)
                        scaler.update()

                        if batch_idx%100 == 0 and batch_idx>300 and cfg.use_swa: optimizer.update_swa()

                        optimizer.zero_grad()
                        k = -1
                    k+=1
                    
                    if iter_update:
                        scheduler.step()

                    losses.append(loss.item())
                    smooth_loss = np.mean(losses[-2000:])

                    bar.set_description(f'loss: {loss.item():.5f}, smth: {smooth_loss:.5f}, LR {scheduler.get_lr()[0]:.6f}')

                train_loss = np.mean(losses)

                loss_valid, auc, acc, macro_score, y_preds = valid_func(model, valid_loader)
                logfile(f'[EPOCH] {epoch}, train_loss: {train_loss:.6f},  val loss: {loss_valid:.6f}, auc: {auc:.5f}, acc {acc:.5f}, macro_score {macro_score:.5f}', log_path)

                if metric <= (auc - loss_valid):
                    logfile(f'[EPOCH] {epoch} ===============> best_metric ({metric:.6f} --> {(auc - loss_valid):.6f}). Saving model .......!!!!\n', log_path)
                    torch.save(model.state_dict(), f'{cfg.out_dir}/{cfg.config}_best_f{fold_id}.pth')
                    metric = auc - loss_valid                    
                
                # acc is main target (competition evaluation metric), so keeping it in both condition.
                if (auc>best_auc and acc>best_acc) or (acc>best_acc and macro_score>best_macro):
                    torch.save(model.state_dict(), f'{cfg.out_dir}/{cfg.config}_last_f{fold_id}.pth')
                    best_auc=auc
                    best_acc=acc                    
                    best_macro=macro_score

            # torch.save(model.state_dict(), f'{cfg.out_dir}/{cfg.config}_last_f{fold_id}.pth')

            model = scheduler = optimizer = None
            _ = torch.cuda.empty_cache()
            _ = gc.collect()   
            
            
        elif cfg.mode == 'val':
            if fold_id > 4: continue
                
            # chpt_path = f'{cfg.out_dir}/best_metric_f{fold_id}.pth'
            chpt_path = f'{cfg.out_dir_val}/{cfg.config}_last_f{fold_id}.pth'

            print(f' load {chpt_path}!')
            checkpoint = torch.load(chpt_path, map_location="cpu")
            model.load_state_dict(checkpoint)

            train_loader, valid_loader, total_steps, val_df =  get_dataloader(fold_id)
            loss_valid, auc, acc, macro_score, y_preds = valid_func(model, valid_loader)
            print(f'val loss: {loss_valid:.6f}, auc: {auc:.5f}, acc {acc:.5f}, macro_score {macro_score:.5f}')
            val_df['pred'] = y_preds
            f_preds.append(val_df)

            if int(parser_args.tta) == 1:
                val_df.to_csv(f'{cfg.out_dir}/oof_{cfg.config}_f{fold_id}_tta.csv', index=False)
            else:
                val_df.to_csv(f'{cfg.out_dir}/oof_{cfg.config}_f{fold_id}.csv', index=False)
            
            model = checkpoint = None
            _ = torch.cuda.empty_cache()
            _ = gc.collect()               
            
    if cfg.mode == 'val':
        oof_df = pd.concat(f_preds)
        if int(parser_args.tta) == 1:
            oof_df.to_csv(f'{cfg.out_dir}/oof_{cfg.config}_tta.csv', index=False)
        else:
            oof_df.to_csv(f'{cfg.out_dir}/oof_{cfg.config}.csv', index=False)

In [25]:
# asdff

In [26]:
## for kaggle notebook =>
# cfg.load_weight = '../input/dk-1st-data-3/pretrained/vmz_ircsn_ig65m_pretrained_r50_32x2x1_58e_kinetics400_rgb_20210617-86d33018.pth'

cfg.load_weight = 'pretrained/vmz_ircsn_ig65m_pretrained_r50_32x2x1_58e_kinetics400_rgb_20210617-86d33018.pth'
# cfg.load_weight = 'kaggle/r50ir_csn_c15_m1_d2_all_last_f0.pth'
cfg.epochs = 15


cfg.config = 'r50ir_csn_c15_m1_d2_all' # r50ir_csn_c11_m1_d2_G_all, r50ir_csn_c15_m1_d2_G_all
cfg.batch_size = 1
cfg.lr = 1876e-9# 1876e-8
cfg.device = "cuda"#"cpu"


In [27]:
cfg.folds = [4]
main()

load pretrained weight pretrained/vmz_ircsn_ig65m_pretrained_r50_32x2x1_58e_kinetics400_rgb_20210617-86d33018.pth!!!
Transferred 318/320 items from pretrained/vmz_ircsn_ig65m_pretrained_r50_32x2x1_58e_kinetics400_rgb_20210617-86d33018.pth
(94093, 7)
0    18120
1    10067
Name: contact, dtype: int64
Fold: 4, is_train: True, total frame 28187
0    399
1    356
Name: contact, dtype: int64
Fold: 4, is_train: False, total frame 755
cfg.sampler =>  0
num_steps 422805


loss: 1.26396, smth: 0.42255, LR 0.000002: 100%|█| 28187/28187 [5:01:37<00:00,  
loss: 0.46143, smth: 0.61302: 100%|███████████| 755/755 [03:13<00:00,  3.91it/s]


(755, 1) (755,)
[EPOCH] 1, train_loss: 0.525991,  val loss: 0.613016, auc: 0.83206, acc 0.37562, macro_score 0.57258



loss: 0.25359, smth: 0.36164, LR 0.000002: 100%|█| 28187/28187 [5:00:31<00:00,  
loss: 0.47257, smth: 0.57465: 100%|███████████| 755/755 [03:13<00:00,  3.91it/s]


(755, 1) (755,)
[EPOCH] 2, train_loss: 0.389769,  val loss: 0.574645, auc: 0.84833, acc 0.50777, macro_score 0.69206



loss: 0.17561, smth: 0.33083, LR 0.000002: 100%|█| 28187/28187 [5:02:20<00:00,  
loss: 0.42371, smth: 0.57699: 100%|███████████| 755/755 [03:12<00:00,  3.92it/s]


(755, 1) (755,)
[EPOCH] 3, train_loss: 0.342126,  val loss: 0.576994, auc: 0.83402, acc 0.56277, macro_score 0.73729


loss: 0.08297, smth: 0.31721, LR 0.000001: 100%|█| 28187/28187 [5:02:42<00:00,  
loss: 2.22864, smth: 1.12844: 100%|███████████| 755/755 [03:13<00:00,  3.91it/s]


(755, 1) (755,)
[EPOCH] 4, train_loss: 0.314164,  val loss: 1.128444, auc: 0.54838, acc 0.00000, macro_score 0.32043


loss: 0.14170, smth: 0.28535, LR 0.000001: 100%|█| 28187/28187 [5:01:10<00:00,  
loss: 0.59868, smth: 0.58644: 100%|███████████| 755/755 [03:12<00:00,  3.93it/s]


(755, 1) (755,)
[EPOCH] 5, train_loss: 0.291012,  val loss: 0.586439, auc: 0.81742, acc 0.57562, macro_score 0.75378


loss: 0.74286, smth: 0.26663, LR 0.000001:   3%| | 768/28187 [08:15<4:54:45,  1.


KeyboardInterrupt: 

In [None]:
### 11 g =>
# f2 =>
# [EPOCH] 6, train_loss: 0.131609,  val loss: 0.440705, auc: 0.80264, acc 0.57309, macro_score 0.78258

# f1 =>
# [EPOCH] 5, train_loss: 0.127586,  val loss: 0.392070, auc: 0.90474, acc 0.69466, macro_score 0.84713


# f0 =>
#[EPOCH] 8, train_loss: 0.149220,  val loss: 0.317695, auc: 0.96655, acc 0.81134, macro_score 0.90563

In [None]:
### 15 g =>
# f2 =>
# [EPOCH] 9, train_loss: 0.147380,  val loss: 0.379531, auc: 0.82559, acc 0.60795, macro_score 0.79601

# f1 =>
# [EPOCH] 7, train_loss: 0.172915,  val loss: 0.313531, auc: 0.89974, acc 0.74610, macro_score 0.86937

# f0 =>
# [EPOCH] 3, train_loss: 0.227716,  val loss: 0.300652, auc: 0.91696, acc 0.56272, macro_score 0.75591

In [None]:
## 15 =>
# f2 =>
# [EPOCH] 4, train_loss: 0.254828,  val loss: 0.526849, auc: 0.89088, acc 0.63712, macro_score 0.81127

# f0 =>
#[EPOCH] 4, train_loss: 0.281130,  val loss: 0.492572, auc: 0.91551, acc 0.69613, macro_score 0.84370

# f3 =>
# [EPOCH] 3, train_loss: 0.375842,  val loss: 0.561336, auc: 0.84152, acc 0.53353, macro_score 0.72987

# f4 =>
# [EPOCH] 5, train_loss: 0.291012,  val loss: 0.586439, auc: 0.81742, acc 0.57562, macro_score 0.75378

### Validation

In [27]:
torch.multiprocessing.set_sharing_strategy('file_system')

In [28]:
cfg.device = "cuda"#"cpu"
cfg.config = 'r50ir_csn_c15_m1_d2_all' # r50ir_csn_c11_m1_d2_G_all

cfg.batch_size = 6
cfg.out_dir_val = 'kaggle'
cfg.mode = 'val'
cfg.folds = [3,4]
main()

 load kaggle/r50ir_csn_c15_m1_d2_all_last_f3.pth!
(91801, 7)
0    31840
1    17689
Name: contact, dtype: int64
Fold: 3, is_train: True, total frame 49529
Fold: 3, is_train: False, total frame 21646
cfg.sampler =>  0


loss: 0.28222, smth: 0.50754: 100%|███████| 3608/3608 [1:02:42<00:00,  1.04s/it]


(21646, 1) (21646,)
val loss: 0.507535, auc: 0.80863, acc 0.51533, macro_score 0.74140
 load kaggle/r50ir_csn_c15_m1_d2_all_last_f4.pth!
(91801, 7)
0    32932
1    18296
Name: contact, dtype: int64
Fold: 4, is_train: True, total frame 51228
Fold: 4, is_train: False, total frame 17033
cfg.sampler =>  0


loss: 0.48984, smth: 0.55941: 100%|█████████| 2839/2839 [49:10<00:00,  1.04s/it]


(17033, 1) (17033,)
val loss: 0.559410, auc: 0.78100, acc 0.51730, macro_score 0.74129
