In [1]:
import os
import pandas as pd
import numpy as np
from easydict import EasyDict as edict
import pandas as pd
from sklearn.model_selection import StratifiedKFold
import setproctitle
import random
import copy
import logging
import time
import os
from torch.utils.data import DataLoader
from timm.loss.binary_cross_entropy import BinaryCrossEntropy
from tqdm import tqdm
import timm
import random
import torchaudio
import torch.nn as nn
import torch
import sys
import torch
import numpy as np
import warnings
from sklearn.metrics import f1_score, precision_score, recall_score, roc_curve
from sklearn.metrics import confusion_matrix


In [2]:
setproctitle.setproctitle("spike_train")
sys.path.append('.')
warnings.filterwarnings('ignore')

In [3]:
numpy_files_path = "/run/media/kami/SSD/DATASETS/vepiset-dataset/NPY-Files/"
save_csv_file = "./data.csv"

# 1. Configuration

In [4]:
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)  # cpu
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


In [5]:
config = edict()

config.TRAIN = edict()

config.TRAIN.process_num = 1

config.TRAIN.batch_size = 128
config.TRAIN.validatiojn_batch_size = config.TRAIN.batch_size
config.TRAIN.accumulation_batch_size = 128
config.TRAIN.log_interval = 10
config.TRAIN.test_interval = 1
config.TRAIN.epoch = 1

config.TRAIN.init_lr = 0.0005
config.TRAIN.lr_scheduler = 'cos'

if config.TRAIN.lr_scheduler == 'ReduceLROnPlateau':
    config.TRAIN.epoch = 1
    config.TRAIN.lr_scheduler_factor = 0.1

config.TRAIN.weight_decay_factor = 1.e-2
config.TRAIN.vis = False

config.TRAIN.warmup_step = 1500
config.TRAIN.opt = 'Adamw'

config.TRAIN.gradient_clip = 5

config.TRAIN.vis_mixcut = False
if config.TRAIN.vis:
    config.TRAIN.mix_precision = False
else:
    config.TRAIN.mix_precision = False

config.MODEL = edict()

config.MODEL.model_path = './trained_models/'

config.DATA = edict()

config.DATA.data_file = save_csv_file

config.DATA.data_root_path = 'utils'

config.MODEL.early_stop = 30

config.MODEL.pretrained_model = None

config.SEED = 10086

seed_everything(config.SEED)
config.is_base = 1


# 2. Create CSV File From NPY Files

### Variables

In [6]:
def find_npy_files(folder_path):
    npy_files = []
    for root, dirs, files in os.walk(folder_path):
        for file in files:
            if file.endswith('.npy'):
                npy_files.append(os.path.join(root, file))

    return npy_files


def get_data(data_dir):
    samples = find_npy_files(data_dir)
    labels = [1 if int(x.split("__")[1].split(".")[0]) > 1 else int(x.split("__")[1].split(".")[0]) for x in samples]
    train_val = [0] * len(samples)
    return samples, labels, train_val


def get_data_files(numpy_files_path):
    #data_dir = npy_val_data
    fns_list = []
    labels_list = []
    train_val_list = []
    samples, labels, train_val = get_data(numpy_files_path)
    fns_list.extend(samples)
    labels_list.extend(labels)
    train_val_list.extend(train_val)
    # print("train_val_list" , train_val_list)
    return fns_list, labels_list, train_val_list


In [7]:
val_fns_list, val_labels_list, val_vals_list = get_data_files(numpy_files_path)
# print(val_fns_list)
# print( val_labels_list)
# print( val_vals_list)
submission = pd.DataFrame({'file_path': val_fns_list,
                           'target': val_labels_list,
                           'train_val': val_vals_list})
# submission
### split train - val = 8:2
indices = submission[submission['train_val'] == 0].index
val_num = len(indices) // 5
indices_to_change = np.random.choice(indices, val_num, replace=False)
submission.loc[indices_to_change, 'train_val'] = 1

# print("fns len:", len(val_fns_list))
# print("label len:", len(val_labels_list))
# print("val len:", val_num)
# print("train:val {0}:{1}".format(8, 2))
submission.to_csv(save_csv_file, index=False)
# submission
# submission.head()


# 3. Train

### Logger

In [8]:
def get_logger(LEVEL, log_file=None):
    head = '[%(asctime)-15s] [%(levelname)s] %(message)s '
    if LEVEL == 'info':
        logging.basicConfig(level=logging.INFO, format=head)
    elif LEVEL == 'debug':
        logging.basicConfig(level=logging.DEBUG, format=head)
    logger = logging.getLogger()

    if log_file != None:
        fh = logging.FileHandler(log_file)
        logger.addHandler(fh)
    return logger


logger = get_logger('info')


## Models

### AUG

In [9]:
class AUG(nn.Module):
    def __init__(self, ):
        super().__init__()

    def forward(self, x):
        bs = x.size(0)
        for i in range(bs):
            if random.uniform(0, 1) < 0.5:
                x[i, ...] = self.pitch_shift_spectrogram(x[i, ...])
            if random.uniform(0, 1) < 0.0:
                x[i, ...] = self.time_shift_spectrogram(x[i, ...])

        return x

    def do_cut_out(self, x):

        h = 128
        w = 128
        line_width = random.randint(1, 8)

        if random.uniform(0, 1) < 0.5:

            start = random.randint(0, w - line_width)
            x[:, :, start:start + line_width] = 0
        else:
            start = random.randint(0, h - line_width)
            x[:, start:start + line_width, :] = 0

        return x

    def pitch_shift_spectrogram(self, spectrogram):
        """ Shift a spectrogram along the frequency axis in the spectral-domain at
        random
        """
        nb_cols = spectrogram.size(1)
        max_shifts = nb_cols // 50  # around 5% shift
        nb_shifts = random.randint(-max_shifts, max_shifts)

        return torch.roll(spectrogram, nb_shifts, dims=[1])

    def time_shift_spectrogram(self, spectrogram):
        """ Shift a spectrogram along the frequency axis in the spectral-domain at
        random
        """
        nb_cols = spectrogram.size(2)
        max_shifts = nb_cols // 2  # around 100% shift
        nb_shifts = random.randint(-max_shifts, max_shifts)

        return torch.roll(spectrogram, nb_shifts, dims=[2])




### Transform

In [10]:
class Transform(nn.Module):
    def __init__(self, ):
        super().__init__()
        self.wave_transform = torchaudio.transforms.Spectrogram(n_fft=256, hop_length=16, power=1, pad_mode='reflect')

    def forward(self, x):
        image = self.wave_transform(x)
        return image


def weight_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0)

    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)




### MLP

In [11]:
class MLP(nn.Module):

    def __init__(self, feature_size):
        super().__init__()

        self.linear1 = nn.Linear(feature_size, 6)
        self.relu1 = nn.LeakyReLU()
        self.dropout1 = nn.Dropout(0.3)
        self.linear2 = nn.Linear(6, 24)
        self.relu2 = nn.LeakyReLU()
        self.dropout2 = nn.Dropout(0.3)
        self.linear3 = nn.Linear(24, 24)
        self.relu3 = nn.LeakyReLU()
        self.dropout3 = nn.Dropout(0.3)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu1(x)
        x = self.dropout1(x)

        x = self.linear2(x)
        x = self.relu2(x)
        x = self.dropout2(x)

        x = self.linear3(x)
        x = self.relu3(x)
        x = self.dropout3(x)

        return x




### NET

In [12]:
class Net(nn.Module):
    def __init__(self, num_classes=1, add_channel=0):
        super().__init__()

        self.preprocess = Transform()

        self.model = timm.create_model('vgg16',
                                       pretrained=True,
                                       in_chans=19)

        self.avg_pooling = nn.AdaptiveAvgPool2d(1)
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(512, num_classes, bias=True)
        self.add_sleep_feature = MLP(1)
        weight_init(self.fc)

    def forward(self, x):
        # do preprocess
        # print("FORWARD 1")
        bs = x.size(0)
        # print("FORWARD 2")
        x = self.preprocess(x)
        # print("FORWARD 3")
        x = self.model.forward_features(x)
        # print("FORWARD 4")
        fm = self.avg_pooling(x)
        # print("FORWARD 5")
        fm = fm.view(bs, -1)
        # print("FORWARD 6")
        feature = self.dropout(fm)
        # print("FORWARD 7")
        x = self.fc(feature)
        # print("FORWARD 8")

        return x


### AlaskaDataIter

In [13]:
class AlaskaDataIter():
    def __init__(self, df,
                 training_flag=True, shuffle=True):

        self.training_flag = training_flag
        self.shuffle = shuffle
        self.raw_data_set_size = None

        self.df = df
        logger.info(' contains%d samples  %d pos' % (len(self.df), np.sum(self.df['target'] == 1)))
        logger.info(' contains%d samples' % len(self.df))

        logger.info(' After filter contains%d samples  %d pos' % (len(self.df), np.sum(self.df['target'] == 1)))
        logger.info(' After filter contains%d samples' % len(self.df))

        self.leads_nm = ['Fp1', 'Fp2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2', 'F7', 'F8', 'T3', 'T4', 'T5',
                         'T6',
                         'Fz', 'Cz', 'Pz',
                         'PG1', 'PG2', 'A1', 'A2',
                         'ECG1', 'ECG2', 'EMG1',
                         'EMG2', 'EMG3', 'EMG4']

        self.leads_dict = {value: index for index, value in enumerate(self.leads_nm)}

        self.left_brain = ['Fp1', 'F3', 'C3', 'P3', 'O1', 'T5', 'T3', 'F7']
        self.right_brain = ['Fp2', 'F4', 'C4', 'P4', 'O2', 'T6', 'T4', 'F8']

    def filter(self, df):

        df = copy.deepcopy(df)
        pos_indx = df['target'] == 1
        pos_df = df[pos_indx]

        neg_indx = df['target'] == 0
        neg_df = df[neg_indx]

        neg_df = neg_df.sample(frac=1)

        dst_df = neg_df
        for i in range(1):
            dst_df = dst_df._append(pos_df)
        dst_df.reset_index()

        return dst_df

    def __getitem__(self, item):
        fname = self.df.iloc[item]['file_path']
        label = self.df.iloc[item]['target']
        try:
            fname = fname.strip()
            waves = np.load(fname)
            # print("Init waves (Shape): " , waves.shape)
        except Exception as e:
            print("=====fname====exception:", fname, e)
            waves = np.zeros(shape=[29, 2000])
            label = 0

        waves = self.norm(waves)
        # print("After Norm - waves (Shape): " , waves.shape)

        # Normalize :19 channels
        waves = copy.deepcopy(waves)
        meadn = np.mean(waves[:19, :], axis=0)
        avg_lead = waves[:19, :] - meadn

        # print('avg_lead (Shape):', avg_lead.shape)

        if self.training_flag and random.uniform(0, 1) < 1.:
            waves[:19, :] = self.xshuffle(waves[:19, :])
            avg_lead = self.xshuffle(avg_lead)
        # print("After Shuffle - waves (Shape): " , waves.shape, " avg_lead (Shape): " , avg_lead.shape)
        # waves = np.concatenate([waves, avg_lead], axis=0)

        label = np.expand_dims(label, -1)

        C, L = waves.shape
        # print("After Concat (Shape): " , waves.shape, " avg_lead (Shape): " , avg_lead.shape)

        waves = waves[:19, ...]
        if L < 2000:
            waves = np.pad(waves, ((0, 0), (0, 2000 - L)), 'constant', constant_values=0)
        elif L > 2000:
            waves = waves[:, 2000]
        waves = np.ascontiguousarray(waves)
        # print("Final (Shape): " , waves.shape, " avg_lead (Shape): " , avg_lead.shape)

        return waves, label

    def __len__(self):
        return len(self.df)

    def norm(self, wave):

        wave[:23, ...] = wave[:23, ...] / 1e-3
        wave[23:, ...] = wave[23:, ...] / 1e-2
        return wave

        # heart_wave = wave[23, :] - wave[24, :]
        # muscle_wave1 = wave[25, :] - wave[26, :]
        # muscle_wave2 = wave[27, :] - wave[28, :]
        # heart_muscle = np.stack([heart_wave, muscle_wave1, muscle_wave2], axis=0)
        # wave_26 = np.concatenate([wave[:23, ...], heart_muscle], axis=0)
        # return wave_26

    def xshuffle(self, wave):

        n_channels, n_samples = wave.shape
        channel_indices = np.arange(n_channels)
        np.random.shuffle(channel_indices)
        shuffled_wave = wave[channel_indices]
        return shuffled_wave


### AverageMeter

In [14]:
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


class ROCAUCMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):

        self.y_true_11 = None
        self.y_pred_11 = None

    def update(self, y_true, y_pred):
        y_true = y_true.cpu().numpy()

        y_pred = torch.sigmoid(y_pred).data.cpu().numpy()

        if self.y_true_11 is None:
            self.y_true_11 = y_true
            self.y_pred_11 = y_pred
        else:
            self.y_true_11 = np.concatenate((self.y_true_11, y_true), axis=0)
            self.y_pred_11 = np.concatenate((self.y_pred_11, y_pred), axis=0)

        return self.y_true_11, self.y_pred_11

    def fast_auc(self, y_true, y_prob):

        y_true = np.asarray(y_true)
        y_true = y_true[np.argsort(y_prob)]
        cumfalses = np.cumsum(1 - y_true)
        nfalse = cumfalses[-1]
        auc = (y_true * cumfalses).sum()

        auc /= (nfalse * (len(y_true) - nfalse))
        return auc

    @property
    def avg(self):

        self.y_true_11 = self.y_true_11.reshape(-1)
        self.y_pred_11 = self.y_pred_11.reshape(-1)
        score = self.fast_auc(self.y_true_11, self.y_pred_11)

        return score

    def evaluate(y_true, y_pred, digits=4, cutoff='auto'):

        if cutoff == 'auto':
            fpr, tpr, thresholds = roc_curve(y_true, y_pred)
            youden = tpr - fpr
            cutoff = thresholds[np.argmax(youden)]

        return cutoff

    def report(self):

        self.y_true_11 = self.y_true_11.reshape(-1)
        self.y_pred_11 = self.y_pred_11.reshape(-1)

        for score in range(1, 20):
            score = score / 20
            y_pre = self.y_pred_11 > score

            tn, fp, fn, tp = confusion_matrix(self.y_true_11, y_pre).ravel()

            precision = precision_score(self.y_true_11, y_pre)
            recall = recall_score(self.y_true_11, y_pre)
            f1 = f1_score(self.y_true_11, y_pre)

            print('for threshold: %.4f, tn: %d,fp: %d,fn: %d,tp: %d,precision: %.4f, '
                  'recall: %.4f, f1: %.4f' % (score, tn, fp, fn, tp, precision, recall, f1))

        return score





### Train

In [15]:


class Train(object):

    def __init__(self,
                 train_df,
                 val_df,
                 fold):
        self.train_df = train_df

        self.train_generator = AlaskaDataIter(train_df, training_flag=True, shuffle=False)

        self.train_ds = DataLoader(self.train_generator,
                                   config.TRAIN.batch_size,
                                   num_workers=config.TRAIN.process_num, shuffle=True)

        self.val_generator = AlaskaDataIter(val_df, training_flag=False, shuffle=False)

        self.val_ds = DataLoader(self.val_generator,
                                 config.TRAIN.validatiojn_batch_size,
                                 num_workers=config.TRAIN.process_num, shuffle=False)

        self.device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

        self.fold = fold

        self.init_lr = config.TRAIN.init_lr
        self.warmup_step = config.TRAIN.warmup_step
        self.epochs = config.TRAIN.epoch
        self.batch_size = config.TRAIN.batch_size
        self.l2_regularization = config.TRAIN.weight_decay_factor
        self.early_stop = config.MODEL.early_stop
        self.accumulation_step = config.TRAIN.accumulation_batch_size // config.TRAIN.batch_size
        self.gradient_clip = config.TRAIN.gradient_clip
        self.is_base = config.is_base
        self.save_dir = config.MODEL.model_path
        self.fp16 = config.TRAIN.mix_precision

        channel_num = 0
        self.model = Net(add_channel=channel_num).to(self.device)
        self.load_weight()

        if 'Adamw' in config.TRAIN.opt:
            self.optimizer = torch.optim.AdamW(self.model.parameters(),
                                               lr=self.init_lr, eps=1.e-5,
                                               weight_decay=self.l2_regularization)
        else:
            self.optimizer = torch.optim.SGD(self.model.parameters(),
                                             lr=self.init_lr,
                                             momentum=0.9,
                                             weight_decay=self.l2_regularization)

        self.model = torch.nn.DataParallel(self.model)

        self.iter_num = 0

        if config.TRAIN.lr_scheduler == 'cos':
            logger.info('lr_scheduler.CosineAnnealingLR')
            self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer,
                                                                        self.epochs,
                                                                        eta_min=1.e-7)
        else:
            logger.info('lr_scheduler.ReduceLROnPlateau')
            self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                                        mode='max',
                                                                        patience=5,
                                                                        min_lr=1e-7,
                                                                        factor=config.TRAIN.lr_scheduler_factor,
                                                                        verbose=True)

        self.criterion = BinaryCrossEntropy(smoothing=0.1, pos_weight=torch.tensor(2.)).to(self.device)

        self.scaler = torch.cuda.amp.GradScaler()

    def custom_loop(self):

        def distributed_train_epoch(epoch_num):

            summary_loss = AverageMeter()
            rocauc_score = ROCAUCMeter()
            self.model.train()

            for images, label in self.train_ds:

                if epoch_num < 10:
                    # excute warm up in the first epochs
                    if self.warmup_step > 0:
                        if self.iter_num < self.warmup_step:
                            for param_group in self.optimizer.param_groups:
                                param_group['lr'] = self.iter_num / float(self.warmup_step) * self.init_lr
                                lr = param_group['lr']

                            logger.info('warm up with learning rate: [%f]' % (lr))

                start = time.time()

                data = images.to(self.device).float()
                label = label.to(self.device).float()

                batch_size = data.shape[0]

                with torch.cuda.amp.autocast(enabled=self.fp16):
                    predictions = self.model(data)
                    current_loss = self.criterion(predictions, label)

                summary_loss.update(current_loss.detach().item(), batch_size)
                rocauc_score.update(label, predictions)
                self.scaler.scale(current_loss).backward()

                if ((self.iter_num + 1) % self.accumulation_step) == 0:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.gradient_clip, norm_type=2)
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.optimizer.zero_grad()

                self.iter_num += 1
                time_cost_per_batch = time.time() - start

                images_per_sec = config.TRAIN.batch_size / time_cost_per_batch

                if self.iter_num % config.TRAIN.log_interval == 0:
                    log_message = '[fold %d], ' \
                                  'Train Step %d, ' \
                                  'summary_loss: %.6f, ' \
                                  'time: %.6f, ' \
                                  'speed %d images/persec' % (
                                      self.fold,
                                      self.iter_num,
                                      summary_loss.avg,
                                      time.time() - start,
                                      images_per_sec)
                    logger.info(log_message)

            return summary_loss, rocauc_score

        def distributed_test_epoch(epoch_num):

            rocauc_score = ROCAUCMeter()
            summary_loss = AverageMeter()
            self.model.eval()

            with torch.no_grad():
                for (images, labels) in tqdm(self.val_ds):
                    data = images.to(self.device).float()
                    labels = labels.to(self.device).float()

                    batch_size = data.shape[0]

                    predictions = self.model(data)
                    current_loss = self.criterion(predictions, labels)

                    rocauc_score.update(labels, predictions)
                    summary_loss.update(current_loss.detach().item(), batch_size)

            return rocauc_score, summary_loss

        best_distance = 0.
        not_improvement = 0
        for epoch in range(self.epochs):

            for param_group in self.optimizer.param_groups:
                lr = param_group['lr']
            logger.info('learning rate: [%f]' % (lr))
            t = time.time()

            summary_loss, roc_auc_score = distributed_train_epoch(epoch)
            train_epoch_log_message = '[fold %d], ' \
                                      '[RESULT]: TRAIN. Epoch: %d,' \
                                      ' summary_loss: %.5f,' \
                                      ' time:%.5f' % (
                                          self.fold,
                                          epoch,
                                          summary_loss.avg,
                                          (time.time() - t))
            logger.info(train_epoch_log_message)
            roc_auc_score.report()

            if epoch % config.TRAIN.test_interval == 0:
                roc_auc_score, summary_loss = distributed_test_epoch(epoch)

                val_epoch_log_message = '[fold %d], ' \
                                        '[RESULT]: VAL. Epoch: %d,' \
                                        ' val_loss: %.5f,' \
                                        ' val_roc_auc: %.5f,' \
                                        ' time:%.5f' % (
                                            self.fold,
                                            epoch,
                                            summary_loss.avg,
                                            roc_auc_score.avg,
                                            (time.time() - t))
                logger.info(val_epoch_log_message)
                roc_auc_score.report()

            if config.TRAIN.lr_scheduler == 'cos':
                self.scheduler.step()
            else:
                self.scheduler.step(roc_auc_score.avg)

            # save model
            if not os.access(config.MODEL.model_path, os.F_OK):
                os.mkdir(config.MODEL.model_path)

            #### save the model every end of epoch
            current_model_saved_name = self.save_dir + '/fold%d_epoch_%d_val_rocauc_%.6f_loss_%.6f.pth' % (self.fold,
                                                                                                           epoch,
                                                                                                           roc_auc_score.avg,
                                                                                                           summary_loss.avg)

            logger.info('A model saved to %s' % current_model_saved_name)
            torch.save(self.model.module.state_dict(), current_model_saved_name)

            if summary_loss.avg < best_distance:
                best_distance = summary_loss.avg
                logger.info(' best loss value update as %.6f' % (best_distance))
                logger.info(' bestmodel update as %s' % (current_model_saved_name))
                not_improvement = 0

            else:
                not_improvement += 1

            if not_improvement >= self.early_stop:
                logger.info(' best metric score not improvement for %d, break' % (self.early_stop))
                break

            torch.cuda.empty_cache()

    def load_weight(self):
        if config.MODEL.pretrained_model is not None:
            state_dict = torch.load(config.MODEL.pretrained_model, map_location=self.device)
            self.model.load_state_dict(state_dict, strict=False)


### Folding

In [16]:
n_fold = 5


def get_fold(n_fold=n_fold):
    data = pd.read_csv(config.DATA.data_file)

    folds = data.copy()
    Fold = StratifiedKFold(n_splits=n_fold, shuffle=True, random_state=config.SEED)
    for n, (train_index, val_index) in enumerate(Fold.split(folds, folds['target'])):
        folds.loc[val_index, 'fold'] = int(n)
    return folds


data = get_fold(n_fold)
# print(data.columns)
# data.head()




### Training

In [17]:
import gc

# del trainer
gc.collect()
torch.cuda.empty_cache()

for fold in range(1):
    ###build dataset

    train_ind = data[data['train_val'] == 0].index.values
    train_data = data.iloc[train_ind].copy()
    val_ind = data[data['train_val'] == 1].index.values
    val_data = data.iloc[val_ind].copy()
    trainer = Train(train_df=train_data,
                    val_df=val_data,
                    fold=fold)
    print(trainer.train_generator[100][0].shape)
    # print(trainer.train_generator[100][0][0])
    # print(trainer.val_generator[100][0].shape)
    # print(trainer.val_generator[100][0][0])

    # break
    ### train
    trainer.custom_loop()




[2025-05-09 17:34:05,143] [INFO]  contains20360 samples  1979 pos 
[2025-05-09 17:34:05,143] [INFO]  contains20360 samples 
[2025-05-09 17:34:05,144] [INFO]  After filter contains20360 samples  1979 pos 
[2025-05-09 17:34:05,144] [INFO]  After filter contains20360 samples 
[2025-05-09 17:34:05,144] [INFO]  contains5089 samples  537 pos 
[2025-05-09 17:34:05,144] [INFO]  contains5089 samples 
[2025-05-09 17:34:05,145] [INFO]  After filter contains5089 samples  537 pos 
[2025-05-09 17:34:05,145] [INFO]  After filter contains5089 samples 
[2025-05-09 17:34:06,009] [INFO] Loading pretrained weights from Hugging Face hub (timm/vgg16.tv_in1k) 
[2025-05-09 17:34:06,562] [INFO] [timm/vgg16.tv_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors. 
[2025-05-09 17:34:06,563] [INFO] Converted input conv features.0 pretrained weights from 3 to 19 channel(s) 
[2025-05-09 17:34:06,790] [INFO] lr_scheduler.CosineAnnealingLR 
[2025-05-09 1

(19, 2000)


[2025-05-09 17:34:06,864] [INFO] warm up with learning rate: [0.000000] 
[2025-05-09 17:34:07,649] [INFO] warm up with learning rate: [0.000000] 
[2025-05-09 17:34:07,918] [INFO] warm up with learning rate: [0.000001] 
[2025-05-09 17:34:08,198] [INFO] warm up with learning rate: [0.000001] 
[2025-05-09 17:34:08,476] [INFO] warm up with learning rate: [0.000001] 
[2025-05-09 17:34:08,775] [INFO] warm up with learning rate: [0.000002] 
[2025-05-09 17:34:09,063] [INFO] warm up with learning rate: [0.000002] 
[2025-05-09 17:34:09,353] [INFO] warm up with learning rate: [0.000002] 
[2025-05-09 17:34:09,633] [INFO] warm up with learning rate: [0.000003] 
[2025-05-09 17:34:09,918] [INFO] warm up with learning rate: [0.000003] 
[2025-05-09 17:34:10,233] [INFO] [fold 0], Train Step 10, summary_loss: 0.866561, time: 0.314421, speed 407 images/persec 
[2025-05-09 17:34:10,235] [INFO] warm up with learning rate: [0.000003] 
[2025-05-09 17:34:10,559] [INFO] warm up with learning rate: [0.000004] 
[

for threshold: 0.0500, tn: 4765,fp: 13616,fn: 84,tp: 1895,precision: 0.1222, recall: 0.9576, f1: 0.2167
for threshold: 0.1000, tn: 7900,fp: 10481,fn: 213,tp: 1766,precision: 0.1442, recall: 0.8924, f1: 0.2483
for threshold: 0.1500, tn: 10203,fp: 8178,fn: 369,tp: 1610,precision: 0.1645, recall: 0.8135, f1: 0.2736
for threshold: 0.2000, tn: 11989,fp: 6392,fn: 561,tp: 1418,precision: 0.1816, recall: 0.7165, f1: 0.2897
for threshold: 0.2500, tn: 13270,fp: 5111,fn: 707,tp: 1272,precision: 0.1993, recall: 0.6427, f1: 0.3042
for threshold: 0.3000, tn: 14277,fp: 4104,fn: 826,tp: 1153,precision: 0.2193, recall: 0.5826, f1: 0.3187
for threshold: 0.3500, tn: 15053,fp: 3328,fn: 959,tp: 1020,precision: 0.2346, recall: 0.5154, f1: 0.3224
for threshold: 0.4000, tn: 15764,fp: 2617,fn: 1074,tp: 905,precision: 0.2570, recall: 0.4573, f1: 0.3290
for threshold: 0.4500, tn: 16395,fp: 1986,fn: 1155,tp: 824,precision: 0.2932, recall: 0.4164, f1: 0.3441
for threshold: 0.5000, tn: 16948,fp: 1433,fn: 1259,tp: 7

100%|██████████| 40/40 [00:05<00:00,  7.79it/s]
[2025-05-09 17:34:57,894] [INFO] [fold 0], [RESULT]: VAL. Epoch: 0, val_loss: 0.44460, val_roc_auc: 0.90158, time:51.10180 
[2025-05-09 17:34:57,953] [INFO] A model saved to ./trained_models//fold0_epoch_0_val_rocauc_0.901582_loss_0.444598.pth 


for threshold: 0.0500, tn: 3894,fp: 658,fn: 107,tp: 430,precision: 0.3952, recall: 0.8007, f1: 0.5292
for threshold: 0.1000, tn: 4199,fp: 353,fn: 172,tp: 365,precision: 0.5084, recall: 0.6797, f1: 0.5817
for threshold: 0.1500, tn: 4342,fp: 210,fn: 222,tp: 315,precision: 0.6000, recall: 0.5866, f1: 0.5932
for threshold: 0.2000, tn: 4414,fp: 138,fn: 248,tp: 289,precision: 0.6768, recall: 0.5382, f1: 0.5996
for threshold: 0.2500, tn: 4446,fp: 106,fn: 271,tp: 266,precision: 0.7151, recall: 0.4953, f1: 0.5853
for threshold: 0.3000, tn: 4474,fp: 78,fn: 294,tp: 243,precision: 0.7570, recall: 0.4525, f1: 0.5664
for threshold: 0.3500, tn: 4493,fp: 59,fn: 309,tp: 228,precision: 0.7944, recall: 0.4246, f1: 0.5534
for threshold: 0.4000, tn: 4504,fp: 48,fn: 323,tp: 214,precision: 0.8168, recall: 0.3985, f1: 0.5357
for threshold: 0.4500, tn: 4517,fp: 35,fn: 332,tp: 205,precision: 0.8542, recall: 0.3818, f1: 0.5277
for threshold: 0.5000, tn: 4526,fp: 26,fn: 341,tp: 196,precision: 0.8829, recall: 0.36

In [18]:
# exit(0)

# 4. Validation

In [19]:
def get_data_iter(test_path=config.DATA.data_file):
    data = pd.read_csv(test_path)

    val_ind = data[data['train_val'] == 1].index.values
    val_data = data.iloc[val_ind].copy()

    valds = AlaskaDataIter(val_data, training_flag=False, shuffle=False)
    valds = DataLoader(valds,
                       32,
                       num_workers=2,
                       shuffle=False)
    return valds


def get_model(weight, device, is_base=0):
    channel_num = 0
    if is_base == 0:
        channel_num = 128
    model = Net(add_channel=channel_num).to(device)
    state_dict = torch.load(weight, map_location=device)
    model.load_state_dict(state_dict, strict=False)
    model.eval()

    return model


def eval_add_plt(weight_video, weight_base, test_path):
    rocauc_score = ROCAUCMeter()

    base_y_true, base_y_pre = estimated_score(weight_base, test_path, 1)
    video_y_true, video_y_pre = estimated_score(weight_video, test_path, 0)

    print("========= estimated_score base line  ==========", test_path)
    rocauc_score.report_with_recall_precision(base_y_true, base_y_pre)
    #rocauc_score.report_all(base_y_true, base_y_pre)

    print("========= estimated_score add video ==========", test_path)
    rocauc_score.report_with_recall_precision(video_y_true, video_y_pre)
    #rocauc_score.report_all(video_y_true, video_y_pre)

    print("========= precision_recall ==========", test_path)
    img_path_p_r = test_path.split(".")[0] + "_Precision_Recall__Add_Data_Pre" + ".jpg"
    rocauc_score.report_with_recall(video_y_true, video_y_pre, base_y_true, base_y_pre, img_path_p_r)

    print("========= Specificity_Sensitivity ==========", test_path)
    img_path_t_f = test_path.split(".")[0] + "_Specificity_Sensitivity__Add_Data_Pre" + ".jpg"
    rocauc_score.report_tpr_fpr(video_y_true, video_y_pre, base_y_true, base_y_pre, img_path_t_f)


def estimated_score(weight, test_path, is_base):
    # print("========= estimated_score test_path ==========", test_path)

    device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
    rocauc_score = ROCAUCMeter()
    model = get_model(weight, device, is_base)
    val_ds = get_data_iter(test_path)

    labels_list = []
    y_pre_list = []

    y_true_11 = None
    y_pred_11 = None

    with torch.no_grad():
        print("val_ds:", val_ds)
        for (images, labels, video_feature) in tqdm(val_ds):
            data = images.to(device).float()
            labels = labels.to(device).float()
            labels_list.append(labels)
            # base_feature = base_feature.to(device).float()
            video_feature = video_feature.to(device).float()
            batch_size = data.shape[0]
            predictions = model(data, video_feature, is_base)
            y_pre_list.append(predictions)
            y_true_11, y_pred_11 = rocauc_score.update(labels, predictions)
            #print("=====y_true_11=====", y_true_11)
            #print("=====y_pred_11=====", y_pred_11)
    # save labels_list and y_pre_list
    labels_data = torch.cat(labels_list, dim=0)
    y_pre_data = torch.cat(y_pre_list, dim=0)
    print("labels len:", len(labels_data.tolist()))
    print("predictions len:", len(y_pre_data.tolist()))

    return y_true_11, y_pred_11


def eval(weight, test_path):
    device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')
    rocauc_score = ROCAUCMeter()
    model = get_model(weight, device)
    val_ds = get_data_iter(test_path)

    labels_list = []
    y_pre_list = []

    with torch.no_grad():
        print("val_ds:", val_ds)
        for (images, labels) in tqdm(val_ds):
            data = images.to(device).float()
            labels = labels.to(device).float()

            labels_list.append(labels)
            # batch_size = data.shape[0]
            predictions = model(data)
            # intermediate_output = model.intermediate_layer(data)
            y_pre_list.append(predictions)
            rocauc_score.update(labels, predictions)

        labels_data = torch.cat(labels_list, dim=0)
        y_pre_data = torch.cat(y_pre_list, dim=0)
        rocauc_score.report()

    print("labels len:", len(labels_data.tolist()))
    print("predictions len:", len(y_pre_data.tolist()))

    return rocauc_score




In [21]:
weight = os.path.join(config.MODEL.model_path, "fold0_epoch_0_val_rocauc_0.901582_loss_0.444598.pth")
# test_path = "/run/media/kami/SSD/DATASETS/vepiset-dataset/CSV-Files/data.csv"
test_path = config.DATA.data_file
try:
    eval(weight, test_path)
except Exception as e:
    print("=====e=====", e)


[2025-05-09 17:35:17,500] [INFO] Loading pretrained weights from Hugging Face hub (timm/vgg16.tv_in1k) 
[2025-05-09 17:35:17,758] [INFO] [timm/vgg16.tv_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors. 
[2025-05-09 17:35:17,760] [INFO] Converted input conv features.0 pretrained weights from 3 to 19 channel(s) 
[2025-05-09 17:35:18,093] [INFO]  contains5089 samples  537 pos 
[2025-05-09 17:35:18,094] [INFO]  contains5089 samples 
[2025-05-09 17:35:18,094] [INFO]  After filter contains5089 samples  537 pos 
[2025-05-09 17:35:18,094] [INFO]  After filter contains5089 samples 


val_ds: <torch.utils.data.dataloader.DataLoader object at 0x765c536159d0>


100%|██████████| 160/160 [00:02<00:00, 57.19it/s]

for threshold: 0.0500, tn: 3893,fp: 659,fn: 107,tp: 430,precision: 0.3949, recall: 0.8007, f1: 0.5289
for threshold: 0.1000, tn: 4199,fp: 353,fn: 172,tp: 365,precision: 0.5084, recall: 0.6797, f1: 0.5817
for threshold: 0.1500, tn: 4341,fp: 211,fn: 222,tp: 315,precision: 0.5989, recall: 0.5866, f1: 0.5927
for threshold: 0.2000, tn: 4414,fp: 138,fn: 248,tp: 289,precision: 0.6768, recall: 0.5382, f1: 0.5996
for threshold: 0.2500, tn: 4446,fp: 106,fn: 271,tp: 266,precision: 0.7151, recall: 0.4953, f1: 0.5853
for threshold: 0.3000, tn: 4474,fp: 78,fn: 294,tp: 243,precision: 0.7570, recall: 0.4525, f1: 0.5664
for threshold: 0.3500, tn: 4493,fp: 59,fn: 309,tp: 228,precision: 0.7944, recall: 0.4246, f1: 0.5534
for threshold: 0.4000, tn: 4504,fp: 48,fn: 323,tp: 214,precision: 0.8168, recall: 0.3985, f1: 0.5357
for threshold: 0.4500, tn: 4517,fp: 35,fn: 332,tp: 205,precision: 0.8542, recall: 0.3818, f1: 0.5277
for threshold: 0.5000, tn: 4526,fp: 26,fn: 341,tp: 196,precision: 0.8829, recall: 0.36


