In [1]:
import os
# import glob
import numpy as np
import pandas as pd
import random
# import math
from multiprocessing import Pool
# import gc
import cv2
from tqdm import tqdm
import time
import matplotlib.pyplot as plt
import re
from time import time
tqdm.pandas()
from sklearn.metrics import matthews_corrcoef, accuracy_score
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
from helper import *
# from functools import lru_cache
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from itertools import cycle


from datetime import datetime
from torchvision.models import resnet50

import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [2]:
CFG = {
    'seed': 42,
    'test_size': 40,
    'lr': 1,
    'use_multi': False,
    'num_workers': 8,
    'batch_size': 64,
    'iterations': 1,
    'val_wait': 1,
    'scheduler_patience': 100,
    'saver_mode': 'all',
    'es_patience': 2,
    'rop_factor': 0.5,
    'rop_patience': 100,
}


In [22]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(CFG['seed'])

In [4]:
%%time
df = pd.read_csv('./final_data2.csv')
df['bbox_endzone'] = df['bbox_endzone'].progress_apply(process_bbox)
df['bbox_sideline'] = df['bbox_sideline'].progress_apply(process_bbox)

100%|██████████| 660553/660553 [00:23<00:00, 27622.28it/s]
100%|██████████| 660553/660553 [00:24<00:00, 26620.50it/s]


CPU times: user 53.5 s, sys: 919 ms, total: 54.4 s
Wall time: 54.5 s


In [36]:
np.random.seed(42)
val_plays = np.random.choice(df['game_play'].unique(), size = 2)
val_set =  df.apply(lambda row: row[df['game_play'].isin(val_plays)])
train_set = df.apply(lambda row: row[~df['game_play'].isin(val_plays)])
val_set.shape, train_set.shape, df.shape

((5422, 44), (655131, 44), (660553, 44))

In [6]:
categorical_data_for_fitting = [
    ['home', 'CB', 'home', 'CB', ] ,
    ['home', 'DE', 'home', 'DE', ],
    ['home', 'FS', 'home', 'FS', ],
    ['home', 'TE', 'home', 'TE', ] ,
    ['home', 'ILB', 'home', 'ILB', ],
    ['home', 'OLB', 'home', 'OLB', ],
    ['home', 'T', 'home', 'T', ],
    ['home', 'G', 'home', 'G', ] ,
    ['home', 'C', 'home', 'C', ] ,
    ['home', 'QB', 'home', 'QB', ],
    ['home', 'WR', 'home', 'WR', ],
    ['home', 'RB', 'home', 'RB', ],
    ['home', 'NT', 'home', 'NT', ],
    ['home', 'DT', 'home', 'DT', ],
    ['home', 'MLB', 'home', 'MLB', ],
    ['home', 'SS', 'home', 'SS', ] ,
    ['home', 'OT', 'home', 'OT', ],
    ['home', 'LB', 'home', 'LB', ],
    ['home', 'OG', 'home', 'OG', ] ,
    ['home', 'SAF', 'home', 'SAF', ],
    ['home', 'DB', 'home', 'DB', ] ,
    ['home', 'LS', 'home', 'LS', ] ,
    ['home', 'K', 'home', 'K', ],
    ['home', 'P', 'home', 'P', ],
    ['home', 'FB', 'home', 'FB', ] ,
    ['home', 'S', 'home', 'S', ],
    ['home', 'DL', 'Ground', 'DL', ],
    ['away', 'HB', 'away', 'HB', ],
    ['away', 'HB', 'away', 'Ground', ],
]
    
one_hot = OneHotEncoder()
one_hot.fit(categorical_data_for_fitting)

OneHotEncoder()

In [7]:
one_hot.transform(categorical_data_for_fitting).toarray()

a = np.array(categorical_data_for_fitting)
a.shape

(29, 4)

In [8]:
df_G1 = df.loc[(df['contact'] == 1) & (df['G_flug'] == True)]
df_G0 = df.loc[(df['contact'] == 0) & (df['G_flug'] == True)]
df_P1 = df.loc[(df['contact'] == 1) & (df['G_flug'] == False)]
df_P0 = df.loc[(df['contact'] == 0) & (df['G_flug'] == False)]


In [9]:
random_state = 42

train_G1, test_G1 = train_test_split(
    df_G1, test_size=CFG['test_size']//4, random_state=random_state)
train_G0, test_G0 = train_test_split(
    df_G0, test_size=CFG['test_size']//4, random_state=random_state)
train_P1, test_P1 = train_test_split(
    df_P1, test_size=CFG['test_size']//4, random_state=random_state)
train_P0, test_P0 = train_test_split(
    df_P0, test_size=CFG['test_size']//4, random_state=random_state)


In [10]:
test_df = pd.concat([test_G1,test_G0,test_P1,test_P0,]).reset_index()
test_df.shape

(40, 45)

In [11]:
train_aug = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.ShiftScaleRotate(p=0.5),
    A.RandomBrightnessContrast(
        brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), p=0.5),
    A.Normalize(mean=[0.], std=[1.]),
    ToTensorV2()
])

valid_aug = A.Compose([
    A.Normalize(mean=[0.], std=[1.]),
    ToTensorV2()
])


In [17]:
def read_image(path, cx, cy, view, aug):
    img_new = np.zeros((256, 256), dtype=np.float32)
    if os.path.isfile(path):
        if view == 'Endzone':
            img = cv2.imread(path, 0)[
                cy-76:cy+180, cx-128:cx+128].copy()
            img_new[:img.shape[0], :img.shape[1]] = img
        else:
            img = cv2.imread(path, 0)[
                cy-128:cy+128, cx-128:cx+128].copy()
            img_new[:img.shape[0], :img.shape[1]] = img
        
    # fin_img = aug(image=img_new.transpose(1,2,0))['image'][0]
    return img_new

In [13]:
class MyDataset(Dataset):
    def __init__(self, df, feature_cols=['rel_pos_x',
                                         'rel_pos_y', 'rel_pos_mag', 'rel_pos_ori', 'rel_speed_x', 'rel_speed_y',
                                         'rel_speed_mag', 'rel_speed_ori', 'rel_acceleration_x',
                                         'rel_acceleration_y', 'rel_acceleration_mag', 'rel_acceleration_ori',
                                         'G_flug', 'orientation_1', 'orientation_2'], aug=train_aug):

        self.df = df
        self.features = feature_cols
        self.aug = aug

    def __len__(self):
        return len(self.df)

    def normalize_features(self, features):
        """
        normalizes the features of the players

       'rel_pos_x',
       'rel_pos_y', 'rel_pos_mag', 'rel_pos_ori', 'rel_speed_x', 'rel_speed_y',
       'rel_speed_mag', 'rel_speed_ori', 'rel_acceleration_x',
       'rel_acceleration_y', 'rel_acceleration_mag', 'rel_acceleration_ori',
       'G_flug', 'orientation_1', 'orientation_2'
        """
        features /= 100
        features[3] /= 3.6
        features[7] /= 3.6
        features[11] /= 3.6
        features[13] /= 3.6
        features[14] /= 3.6
        return features

    def __getitem__(self, idx):
        window = 24
        frames_to_skip = 4
        start = time()
        # print(idx*(CFG['batch_size']//4), (idx+1)*(CFG['batch_size']//4))
        row = self.df.loc[idx*(CFG['batch_size']//4) : (idx+1)*(CFG['batch_size']//4)-1].reset_index()
        # print(f"row shape: {row.shape}")
        # print(row)
        mid_frame = row['frame']
        # print(mid_frame)
        label = np.array(row['contact'])
        args = []
        print(f" Initi: {time()- start}")
        start = time()
        for i in range(CFG['batch_size']//4):
            for view in ['Endzone', 'Sideline']:
                video = row.loc[i]['game_play'] + f'_{view}.mp4'
                cur_mid_frame = mid_frame.loc[i]
                frames = [cur_mid_frame - window +
                        i for i in range(0, 2*window+1, frames_to_skip)]
                # print(f'leng frames: {len(frames)}')

                bbox_col = 'bbox_endzone' if view == 'Endzone' else 'bbox_sideline'
                bboxes = row.loc[i][bbox_col][::frames_to_skip].astype(np.int32)

                if bboxes.sum() <= 0:
                    args += [('dummy', 0,0, view, self.aug)]*len(frames)
                    continue

                for i, frame in enumerate(frames):
                    img_new = np.zeros((256, 256), dtype=np.float32)
                    cx, cy = bboxes[i]
                    path = f"./work/train_frames/{video}_{frame:04d}.jpg"
                    # print(path)
                    args.append((path, cx, cy, view, self.aug))

                # if os.path.isfile(path):
                #     img_new = np.zeros((256, 256), dtype=np.float32)
                #     if view == 'Endzone':
                #         img = cv2.imread(path, 0)[
                #             cy-76:cy+180, cx-128:cx+128].copy()
                #         img_new[:img.shape[0], :img.shape[1]] = img
                #     else:
                #         img = cv2.imread(path, 0)[
                #             cy-128:cy+128, cx-128:cx+128].copy()
                #         img_new[:img.shape[0], :img.shape[1]] = img
                # imgs.append(img_new)
        print(f" args init: {time()- start}")
        start = time()
        with Pool(8) as pool: 
            imgs = list(pool.starmap(read_image, args))
        print(f" imgs load: {time()- start}")
        start = time()
        # img = np.array(imgs).transpose(1, 2, 0)
        # img = self.aug(image=img)["image"]
        print(imgs)
        img = torch.stack(imgs).reshape(CFG['batch_size']//4, 26, 256, 256)
        print(f" img aug: {time()- start}")
        start = time()
        features = np.array(row[self.features], dtype=np.float32)
        # print(f"feature shape: {features.shape}")
        features[np.isnan(features)] = 0
        # print(f"feature shape: {features}")
        """
        rel_pos_x                0
        rel_pos_y                1
        rel_pos_mag              2
        rel_pos_ori              3
        rel_speed_x              4
        rel_speed_y              5
        rel_speed_mag            6
        rel_speed_ori            7
        rel_acceleration_x       8
        rel_acceleration_y       9
        rel_acceleration_mag     10
        rel_acceleration_ori     11 
        """
        for i in range(CFG['batch_size']//4):
            if row.iloc[i]['G_flug']:
                features[i, 6] = row.loc[i]['speed_1']
                features[i, 7] = row.loc[i]['direction_1']
                features[i, 10] = row.loc[i]['acceleration_1']
                features[i, 11] = row.loc[i]['direction_1']

                features[i, 4] = row.loc[i]['speed_1']*np.sin(row.loc[i]['direction_1']*np.pi/180)
                features[i, 5] = row.loc[i]['speed_1']*np.cos(row.loc[i]['direction_1']*np.pi/180)
                features[i, 8] = row.loc[i]['acceleration_1'] * \
                    np.sin(row.loc[i]['direction_1']*np.pi/180)
                features[i, 9] = row.loc[i]['acceleration_1'] * \
                    np.cos(row.loc[i]['direction_1']*np.pi/180)
            
            features[i,: ] = self.normalize_features(features[i])
            # if i==0:
                # print(features)

        team_pos = np.array(
            row[['team_1', 'position_1', 'team_2', 'position_2']].fillna('Ground'))
       
        team_pos = one_hot.transform(
            team_pos
        ).toarray()
        print(f" features compute: {time()- start}")
        
        return img, torch.from_numpy(np.hstack((features, team_pos)).astype(np.float32)), torch.as_tensor(label)


In [20]:
class MyDataset2(Dataset):
    def __init__(self, df1,df2,df3, df4, aug=train_aug, one_hot_transform=one_hot, feature_cols=['rel_pos_x',
                                                                 'rel_pos_y', 'rel_pos_mag', 'rel_pos_ori', 'rel_speed_x', 'rel_speed_y',
                                                                 'rel_speed_mag', 'rel_speed_ori', 'rel_acceleration_x',
                                                                 'rel_acceleration_y', 'rel_acceleration_mag', 'rel_acceleration_ori',
                                                                 'G_flug', 'orientation_1', 'orientation_2']):

        self.df1 = df1
        self.df2 = df2
        self.df3 = df3
        self.df4 = df4
        # self.logger = logger
        self.features = feature_cols
        self.aug = aug
        self.one_hot_transform = one_hot_transform

    def __len__(self):
        return max(len(self.df1),len(self.df2),len(self.df3),len(self.df4))//(CFG['batch_size']//4)

    def get_rows(self, lnum, unum, df_num):
        df_map = {
            1:self.df1,
            2:self.df2,
            3:self.df3,
            4:self.df4,
        }
        # self.logger.debug(f"df{df_num} with lnum:{lnum}, unum:{unum}")
        lnum = lnum%len(df_map[df_num])
        unum = unum%len(df_map[df_num]) 
        # self.logger.debug(f"df{df_num} with lnum:{lnum}, unum:{unum}")

        if lnum  < unum:
            # self.logger.debug(f"return is {df_map[df_num][lnum:unum].shape}")
            return df_map[df_num][lnum:unum]
        else:
            # self.logger.debug(f"return is {pd.concat([df_map[df_num][lnum:], df_map[df_num][:unum]]).shape}")
            return pd.concat([df_map[df_num][lnum:], df_map[df_num][:unum]])

    def normalize_features(self, features):
        """
        normalizes the features of the players

       'rel_pos_x',
       'rel_pos_y', 'rel_pos_mag', 'rel_pos_ori', 'rel_speed_x', 'rel_speed_y',
       'rel_speed_mag', 'rel_speed_ori', 'rel_acceleration_x',
       'rel_acceleration_y', 'rel_acceleration_mag', 'rel_acceleration_ori',
       'G_flug', 'orientation_1', 'orientation_2'
        """
        features /= 100
        features[3] /= 3.6
        features[7] /= 3.6
        features[11] /= 3.6
        features[13] /= 3.6
        features[14] /= 3.6
        return features

    def __getitem__(self, idx):
        window = 24
        frames_to_skip = 4

        row1 = self.get_rows(idx*(CFG['batch_size']//4), (idx+1)*(CFG['batch_size']//4), 1).reset_index()
        row2 = self.get_rows(idx*(CFG['batch_size']//4), (idx+1)*(CFG['batch_size']//4), 2).reset_index()
        row3 = self.get_rows(idx*(CFG['batch_size']//4), (idx+1)*(CFG['batch_size']//4), 3).reset_index()
        row4 = self.get_rows(idx*(CFG['batch_size']//4), (idx+1)*(CFG['batch_size']//4), 4).reset_index()

        row = pd.concat([row1, row2, row3, row4])
        # self.logger.debug(f"row colums: {row.columns}")
        # self.logger.debug(f"Row shape:{row.shape}")
        mid_frame = row['frame']
        # self.logger.debug(f"mid frames shape:{len(mid_frame)}")
        label = np.array(row['contact']).astype(np.float32)
        # self.logger.debug(f"label:{len(label)}")
        args = []
        for i in range(CFG['batch_size']):
            for view in ['Endzone', 'Sideline']:
                video = row.iloc[i]['game_play'] + f'_{view}.mp4'
                cur_mid_frame = mid_frame.iloc[i]
                frames = [cur_mid_frame - window +
                        i for i in range(0, 2*window+1, frames_to_skip)]
                bbox_col = 'bbox_endzone' if view == 'Endzone' else 'bbox_sideline'
                print(f"bbox details:\n{row.iloc[i][bbox_col][::frames_to_skip]}")
                bboxes = row.iloc[i][bbox_col][::frames_to_skip].astype(np.int32)

                if bboxes.sum() <= 0:
                    args += [('dummy', 0,0, view, self.aug)]*len(frames)
                    continue

                for i, frame in enumerate(frames):
                    # img_new = np.zeros((256, 256), dtype=np.float32)
                    cx, cy = bboxes[i]
                    print(video)
                    path = f'./work/train_frames/{video}_{frame:04d}.jpg'
                    args.append((path, cx, cy, view, self.aug))
                    # if os.path.isfile(path):
                    #     if view == 'Endzone':
                    #         img = cv2.imread(path, 0)[
                    #             cy-76:cy+180, cx-128:cx+128].copy()
                    #         img_new[:img.shape[0], :img.shape[1]] = img
                    #     else:
                    #         img = cv2.imread(path, 0)[
                    #             cy-128:cy+128, cx-128:cx+128].copy()
                    #         img_new[:img.shape[0], :img.shape[1]] = img
                    # imgs.append(img_new)

        # self.logger.debug(f"sizeof args:{len(args)}")
        with Pool(8) as pool: 
            imgs = list(pool.starmap(read_image, args))

        # img = torch.stack(imgs).reshape(CFG['batch_size'], 26, 256, 256)
        img = np.array(imgs).transpose(1, 2, 0)
        img = self.aug(image=img)["image"]
        # del imgs
        # self.logger.debug(f"processed imgs:{img.shape}")
        features = np.array(row[self.features], dtype=np.float32)
        features[np.isnan(features)] = 0

        """
        rel_pos_x                0
        rel_pos_y                1
        rel_pos_mag              2
        rel_pos_ori              3
        rel_speed_x              4
        rel_speed_y              5
        rel_speed_mag            6
        rel_speed_ori            7
        rel_acceleration_x       8
        rel_acceleration_y       9
        rel_acceleration_mag     10
        rel_acceleration_ori     11 
        """
        for i in range(CFG['batch_size']):
            if row.iloc[i]['G_flug']:
                features[i, 6] = row.iloc[i]['speed_1']
                features[i, 7] = row.iloc[i]['direction_1']
                features[i, 10] = row.iloc[i]['acceleration_1']
                features[i, 11] = row.iloc[i]['direction_1']

                features[i, 4] = row.iloc[i]['speed_1']*np.sin(row.iloc[i]['direction_1']*np.pi/180)
                features[i, 5] = row.iloc[i]['speed_1']*np.cos(row.iloc[i]['direction_1']*np.pi/180)
                features[i, 8] = row.iloc[i]['acceleration_1'] * \
                    np.sin(row.iloc[i]['direction_1']*np.pi/180)
                features[i, 9] = row.iloc[i]['acceleration_1'] * \
                    np.cos(row.iloc[i]['direction_1']*np.pi/180)
            
            features[i,: ] = self.normalize_features(features[i])
        self.logger.debug(f"processed features:{features.shape}")

        team_pos = np.array(row[['team_1', 'position_1', 'team_2', 'position_2']].fillna('Ground'))
        team_pos = self.one_hot_transform.transform(team_pos).toarray()
        # gc.collect()
        return img, torch.from_numpy(np.hstack((features, team_pos)).astype(np.float32)), torch.as_tensor(label)


In [21]:
train_G1_set = MyDataset2(train_G1[:256].reset_index(), train_G0[:256].reset_index(), train_P1[:256].reset_index(), train_P0[:256].reset_index())
imgs, feats, labels = train_G1_set[2]
# print(imgs.shape, feats, labels)

bbox details:
[[497.5 423. ]
 [497.5 423. ]
 [497.5 423. ]
 [494.  422. ]
 [646.5 305.5]
 [648.5 307. ]
 [644.5 306. ]
 [633.5 302. ]
 [633.  297.5]
 [626.  301. ]
 [597.  276. ]
 [590.5 265.5]
 [579.  259.5]]
58190_000309_Endzone.mp4
58190_000309_Endzone.mp4
58190_000309_Endzone.mp4
58190_000309_Endzone.mp4
58190_000309_Endzone.mp4
58190_000309_Endzone.mp4
58190_000309_Endzone.mp4
58190_000309_Endzone.mp4
58190_000309_Endzone.mp4
58190_000309_Endzone.mp4
58190_000309_Endzone.mp4
58190_000309_Endzone.mp4
58190_000309_Endzone.mp4
bbox details:
[[581.  251. ]
 [583.5 241.5]
 [586.  231. ]
 [588.5 218. ]
 [588.  205.5]
 [587.5 193.5]
 [587.5 184. ]
 [590.5 177. ]
 [592.5 171. ]
 [594.  166. ]
 [594.  160. ]
 [596.  150.5]
 [598.  136.5]]
58401_002419_Sideline.mp4
58401_002419_Sideline.mp4
58401_002419_Sideline.mp4
58401_002419_Sideline.mp4
58401_002419_Sideline.mp4
58401_002419_Sideline.mp4
58401_002419_Sideline.mp4
58401_002419_Sideline.mp4
58401_002419_Sideline.mp4
58401_002419_Sideline

AttributeError: 'MyDataset2' object has no attribute 'logger'

In [15]:
imgs.shape

torch.Size([16, 26, 256, 256])

In [None]:
class EarlyStopping():
    """
    A class which decides to stop early or not based on patience and Mathew Correaltion Metrci

    Use this on validation set only
    """

    def __init__(self, patience):
        self.patience = patience
        self.best_metric = -np.inf
        self.counter = 0

    def get_metric(y_true, y_label):
        return matthews_corrcoef(y_true, y_label)

    def stop(self, metric_val):

        flag = False
        if metric_val > self.best_metric:
            self.best_metric = metric_val
            self.counter = 0
            flag = True
        else:
            self.counter += 1

        if self.counter > self.patience:
            return True, flag
        else:
            return False, flag


class ModelSaver():

    """
    save mode: can be one of {'none', 'best', 'indexed', 'all'}
    """

    def __init__(self, save_mode, name=None):
        self.save_mode = save_mode

        self.path = './model_checkpoints/' + \
            str(datetime.now()).replace(
                " ", '-').replace(':', '-').split('.')[0]
        if name is not None:
            self.path += f"-{name}"
        else:
            self.path += "-model"

        if self.save_mode is not None:
            os.mkdir(self.path)

    def save(self, model, index, best: bool):
        if self.save_mode == 'best':
            if best:
                torch.save(model, self.path + '/best_model.pth')
                return
            else:
                return

        elif self.save_mode == 'indexed':
            torch.save(model, self.path + '/iteration' +
                       str(index).zfill(7) + '.pth')
            return

        elif self.save_mode == 'all':
            torch.save(model, self.path + '/iteration' +
                       str(index).zfill(7) + '.pth')
            if best:
                torch.save(model, self.path + '/best_model.pth')
                return
            else:
                return


class Validator():
    def __init__(self, test_df, use_multi=False, verbose=True):
        self.test_set = MyDataset(test_df)
        self.verbose = verbose
        if not use_multi:
            self.test_loader = DataLoader(
                self.test_set, batch_size=CFG['batch_size'], shuffle=False, pin_memory=True)
        else:
            self.test_loader = DataLoader(
                self.test_set, batch_size=CFG['batch_size'], num_workers=CFG['num_workers'], shuffle=True, pin_memory=True)

    def validate(self, model):
        y_hat = []
        y = []
        loss = 0
        model.eval()
        with torch.no_grad():
            for batch in self.test_loader:
                imgs, features, labels = batch

                imgs = imgs.to(0, non_blocking=True)
                features = features.to(0, non_blocking=True)
                labels = labels.to(0, non_blocking=True)

                preds = model(imgs, features)

                loss += criterion(preds,
                                  labels).cpu().detach().numpy().ravel()[0]
                y.append(labels.cpu().detach().numpy())
                y_hat.append(preds.cpu().detach().numpy())

            y = np.ravel(y)
            y_hat = np.ravel(y_hat)
            loss = loss/CFG['batch_size']
            if self.verbose:
                print(get_stats(loss, y, y_hat))

        return matthews_corrcoef(y, y_hat)


class Callback():
    def __init__(self, args):
        self.valer = Validator(**args['Validator'])
        self.es = EarlyStopping(**args['EarlyStopping'])
        self.ms = ModelSaver(**args['ModelSaver'])

    def callback(self, model, iteration, y_true, y_pred):
        print("Validating Data")
        metric = self.valer.validate(model)
        stop, best = self.es.stop(metric)
        if best:
            print("New Best Model")
        if stop:
            print("Early Stopping Triggered")
        print(f'stop metric = {stop, best}')
        self.ms.save(model, iteration, best)

        return stop


def get_stats(loss, y, y_pred, cur_iter='val', thresh=0.5):
    """
    Gets the stats for a particular batch
    """
    y_hat = (y_pred > thresh)*1.0
    mathew_corr = matthews_corrcoef(y, y_hat)
    acc = accuracy_score(y, y_hat)

    size = len(y_hat)
    classwise_mathew_corr = []
    classwise_acc = []
    for i in range(4):
        y_cl = y[i*(size//4):(i+1)*(size//4)]
        y_hat_cl = y_hat[i*(size//4):(i+1)*(size//4)]
        classwise_mathew_corr.append(matthews_corrcoef(y_cl, y_hat_cl))
        classwise_acc.append(accuracy_score(y_cl, y_hat_cl))

    stats = f'Iteration: {cur_iter} || Loss: {loss:.5f} || mat_corr: {mathew_corr:.5f} || acc: {acc:.5f}'
    stats += f""" 
|| G1_mat_corr: {classwise_mathew_corr[0]:.5f} || G0_mat_corr: {classwise_mathew_corr[1]:.5f} || P1_mat_corr: {classwise_mathew_corr[2]:.5f} || P0_mat_corr: {classwise_mathew_corr[3]:.5f}"""
    stats += f"""
||G1_acc: {classwise_acc[0]:.5f} || G0_acc: {classwise_acc[1]:.5f} || P1_acc: {classwise_acc[2]:.5f} || P0_acc: {classwise_acc[3]:.5f} EOL
    """

    return stats.replace("\n", "")


In [None]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.backbone = timm.create_model('resnet18', pretrained=False, num_classes=250, in_chans=26)
        self.mlp = nn.Sequential(
            nn.Linear(77, 128),
            nn.LayerNorm(128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.LayerNorm(64),
            nn.ReLU(),
            nn.Dropout(0.2),
        )
        self.fc = nn.Linear(64+250, 1)

    def forward(self, img, feature):
        img = self.backbone(img)
        feature = self.mlp(feature)
        y = torch.sigmoid(self.fc(torch.cat([img, feature], dim=1)))
        return y.flatten()

In [None]:
train_G1_set = MyDataset(train_G1[:256].reset_index())
train_G0_set = MyDataset(train_G0[:256].reset_index())
train_P1_set = MyDataset(train_P1[:256].reset_index())
train_P0_set = MyDataset(train_P0[:256].reset_index())

if not CFG['use_multi']:
    train_G1_loader = DataLoader(train_G1_set, batch_size=CFG['batch_size']//4, shuffle=True, pin_memory=True)
    train_G0_loader = DataLoader(train_G0_set, batch_size=CFG['batch_size']//4, shuffle=True, pin_memory=True)
    train_P1_loader = DataLoader(train_P1_set, batch_size=CFG['batch_size']//4, shuffle=True, pin_memory=True)
    train_P0_loader = DataLoader(train_P0_set, batch_size=CFG['batch_size']//4, shuffle=True, pin_memory=True)
else:
    train_G1_loader = DataLoader(train_G1_set, batch_size=CFG['batch_size']//4, num_workers=CFG['num_workers'], shuffle=True, pin_memory=True)
    train_G0_loader = DataLoader(train_G0_set, batch_size=CFG['batch_size']//4, num_workers=CFG['num_workers'], shuffle=True, pin_memory=True)
    train_P1_loader = DataLoader(train_P1_set, batch_size=CFG['batch_size']//4, num_workers=CFG['num_workers'], shuffle=True, pin_memory=True)
    train_P0_loader = DataLoader(train_P0_set, batch_size=CFG['batch_size']//4, num_workers=CFG['num_workers'], shuffle=True, pin_memory=True)

In [None]:
cl_args = {
    "EarlyStopping": {
        'patience': CFG['es_patience']
    },
    "ModelSaver": {
        'save_mode': CFG['saver_mode'],
        'name': "baseline"
    },
    "Validator": {
        "test_df": test_df,
        "use_multi": CFG['use_multi'],
        "verbose": True
    }, 
}

cl = Callback(cl_args)

model = Model()
model.to('cuda')
criterion = nn.BCELoss()
optimizer=  torch.optim.Adam(model.parameters(), lr=CFG['lr'])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 'min', factor=0.5, patience=CFG['scheduler_patience'], verbose =True
)

In [None]:
%%time

num_iters = CFG['iterations']
validate_wait = CFG['val_wait']
model.train()

for cur_iter, (batch_G1, batch_P1, batch_P0, batch_G0) in enumerate(zip(cycle(train_G1_loader),cycle(train_P1_loader),cycle(train_P0_loader),cycle(train_G0_loader))):
    if num_iters <= cur_iter:
        break
    imgs1, features1, labels1 = batch_G1
    imgs2, features2, labels2 = batch_G0
    imgs3, features3, labels3 = batch_P1
    imgs4, features4, labels4 = batch_P0

    imgs = torch.vstack([imgs1, imgs2, imgs3, imgs4]).to(0, non_blocking=True)
    y = torch.hstack([labels1, labels2, labels3, labels4]).to(0, non_blocking=True)
    feats = torch.vstack([features1, features2, features3, features4]).to(0, non_blocking=True)

    optimizer.zero_grad()

    y_hat = model(imgs, feats)
    loss = criterion(y_hat, y)
    
    loss.backward()
    optimizer.step()
    scheduler.step(loss)
    
    y = y.cpu().detach().numpy()
    y_hat= y_hat.cpu().detach().numpy()
    print(get_stats(loss, y, y_hat, cur_iter=cur_iter))
    
    if cur_iter%validate_wait == 0:
        if cl.callback(model, cur_iter, y, y_hat):
            break

In [None]:
!nvidia-smi