# Step 2: Train and infer

## Import libs

In [None]:
import numpy as np
import pandas as pd
import time
import h5py
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from glob import glob
import gc
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from timm.scheduler import CosineLRScheduler

import os
device = torch.device('cuda')

## Config

In [None]:
class CFG:
    seed = 42
    batch_size = 4
    num_workers = 20
    lr = 1e-4
    epochs = 3
    weight_decay = 1e-8
    epochs_warmup = 1
    start_eval_epoch = 0
    img_width = 120

def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(CFG.seed)

# train set directory
DIR = './input/g2net-detecting-continuous-gravitational-waves'

## Dataset

In [None]:
# training set - generated dataset
df_train = pd.DataFrame(columns=["path", "target"])
signal_paths = glob("./input/generated-data/signal*.npz")
for signal_path in signal_paths:
    df_train = pd.concat([df_train, pd.DataFrame([[signal_path]], columns=["path"])], ignore_index=True)

# validation set - given train set
df_val = pd.read_csv(DIR + '/train_labels.csv')
df_val = df_val[df_val.target >= 0]  # Remove 3 unknowns (target = -1)

In [None]:
class TrainDataset(Dataset):
    def __init__(self, df):
        self.df = df
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, i):
        
        # load data
        path = self.df.iloc[i]["path"]
        img = np.full([2, 360, 5760], 0.0, dtype=np.float32)
        img_noise = np.full([2, 360, 5760], 0.0, dtype=np.float32)
        data = np.load(path)

        # signal
        H1 = data["H1"]
        L1 = data["L1"]
        # noise
        H1_noise = data["H1_noise"]
        L1_noise = data["L1_noise"]

        # additional random noise injection (chi-square distribution)
        H1_noise_add = np.random.chisquare(1, size=H1.shape) * np.random.uniform(0, 0.5, size=(1, 5760))
        L1_noise_add = np.random.chisquare(1, size=L1.shape) * np.random.uniform(0, 0.5, size=(1, 5760))
        H1 = H1 + H1_noise_add
        L1 = L1 + L1_noise_add
        H1_noise = H1_noise + H1_noise_add
        L1_noise = L1_noise + L1_noise_add
        
        # vertical scaling (scaling by timestamps)
        H1 /= H1.mean(axis=0)
        L1 /= L1.mean(axis=0)
        H1_noise /= H1_noise.mean(axis=0)
        L1_noise /= L1_noise.mean(axis=0)

        # data normalization
        H1 = (H1 - H1.mean()) / H1.std()
        L1 = (L1 - L1.mean()) / L1.std()
        H1_noise = (H1_noise - H1_noise.mean()) / H1_noise.std()
        L1_noise = (L1_noise - L1_noise.mean()) / L1_noise.std()

        # random mask (simulate missing timestamps in test set)
        H1_cols = np.random.choice(5760, np.random.randint(4300, 4700, 1)[0], replace=False)
        L1_cols = np.random.choice(5760, np.random.randint(4300, 4700, 1)[0], replace=False)
        H1_noise_cols = np.random.choice(5760, np.random.randint(4300, 4700, 1)[0], replace=False)
        L1_noise_cols = np.random.choice(5760, np.random.randint(4300, 4700, 1)[0], replace=False)
        img[0][:, H1_cols] = H1[:, H1_cols]
        img[1][:, L1_cols] = L1[:, L1_cols]
        img_noise[0][:, H1_noise_cols] = H1_noise[:, H1_noise_cols]
        img_noise[1][:, L1_noise_cols] = L1_noise[:, L1_noise_cols]

        # avg pooling (reducing the length of time axis)
        img = np.mean(img.reshape(2, 360, CFG.img_width, -1), axis=3)
        img_noise = np.mean(img_noise.reshape(2, 360, CFG.img_width, -1), axis=3)

        # flip augmentation
        if np.random.rand() <= 0.5:
            img = np.flip(img, axis=1).copy()
        if np.random.rand() <= 0.5:
            img = np.flip(img, axis=2).copy()

        if np.random.rand() <= 0.5:
            img_noise = np.flip(img_noise, axis=1).copy()
        if np.random.rand() <= 0.5:
            img_noise = np.flip(img_noise, axis=2).copy()

        # numpy to tensor
        img = torch.from_numpy(img)
        img_noise = torch.from_numpy(img_noise)

        # label
        y = torch.tensor(1.0, dtype=torch.float32)
        y_noise = torch.tensor(0.0, dtype=torch.float32)

        return img, y, img_noise, y_noise

In [None]:
# some abnormal data (these anomalies are detected based on signal variance analysis)
H1_anomoly = [
    'd809d13d5', '24b4ea622', '907a93301', 'f2eeb89ff', '0fd7c7cee', 'ef9c0beaf', 'd86eea9ac', 'b383cdb39', '1b5941ca4', '3cc6680fb',
    '0153a00c9', '3ab95152d', 'd6828b59a', 'bc1ca0c9f', '9d340855d', '2faf23a2a', '41fe70359', '5ffc75014', 'c0e19f82e', '6735074ac',
    '728fa2106', '804df5390', '3fca6a63f', '4ccda6107', '19a3195b9', 'd2a943911', 'ef1a6ac39', '612200ec4', 'e5f7c1840', '96264d064',
    'e67f75b59', '71b40b311', 'b0990161d', 'c447eb379', '09d7ea37a', '55473f041', 'c103523ac', 'acc728828', '239afe0e3', '5bd485ddd',
    'ddd8c6e90', '9321b08d8', '0542c5ed2', '8e1b55c92', '81d69558d', 'f0900d441', 'ebce3dbfd', '64a563381', '4ef36d800', 'e2256cda9',
    '7c7d5b29c', '711f82733', '400b94859', 'becdfa440', 'd65d8383c', 'db965ed0a', '52186fb05', '24172ff03', 'b69313a43', '126790a29',
    'b3e00a24b', '1d33fd108', 'fe3005e83', '68cfdceb2', 'b86b67f5c', '4ca95032a', 'e08f4e117', '7341591b0', '53f6fb48c', '80b695868',
    'a8b362d98', '040b35321', '5ffa7a8f4', 'f5860cd63', 'e98acc4de', '5e2305d9d', 'b0dcbccc5', 'df7f170d3', 'c5cd03dc9', '81fd533ac',
    'b1ce28bbb', 'a13486580', '895b27680', 'b3bb15de7', 'eb749c00e', '698567d90', 'df65f4148', 'ebeb1ca65', '41154b5f0', 'bbf289347',
    '8da264074', '081ee0aea', '5b3eb27e1', 'b95698658', '87c1d57ad', '6ef79f3e4', 'd9ef85811', 'f2ffd991a', '0fc3c449f', '524b0c283', 
    '55dd4d584', '0bf7da48d', 'a733a9c7e', 'c86841ef0', 'ceafe2326', '7ee0a00f8', 'f3f739a5a', 'e4d6595d9', '9c060f0f4', '4768dd659',
    'c1b7030ac', 'd64a4a759', 'bf09dc4c2', 'fd8617a36', 'd93435874', '71051c674', 'a9e280d75', '34b6b7a85', 'b3f06a2d7', 'dc5116c0e',
    '9f37c3fde', 'a666b93a9', '1e36242e8', '91d35ca0e', '38a84a185', '0a495f928', '2553cfdf2', '56b090eaf', '67e294a77', '2c6117e69',
    'd1da4fc07', 'd809d13d5'
]

L1_anomoly = [
    'bc1ca0c9f', 'd75b29ee3', '8b180f74f', '025517630', '63b51b240', 'f00d2044a', '22f3c3f98', '308417080', '542bf15f1', 'd07399f8e', 
    'c6e6c32b9', 'a1f9b8e82', 'dc2aaaee9', '575f47724', 
]

In [None]:
# the timestamps alignment method is from: https://www.kaggle.com/code/laeyoung/g2net-large-kernel-inference
class InferDataset(torch.utils.data.Dataset):
    def __init__(self, data_type, df, tta=0):
        self.data_type = data_type
        self.df = df
        self.tta = tta

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, i):
        r = self.df.iloc[i]
        y = np.float32(r.target)

        file_id = r.id
        img = np.full([2, 360, 5760], 0.0, dtype=np.float32)
        filename = '%s/%s/%s.hdf5' % (DIR, self.data_type, file_id)
        
        with h5py.File(filename, 'r') as f:

            fid, _ = os.path.splitext(os.path.split(filename)[1])
            HT = (np.asarray(f[fid]["H1"]["timestamps_GPS"]) / 1800).round().astype(np.int64)
            LT = (np.asarray(f[fid]["L1"]["timestamps_GPS"]) / 1800).round().astype(np.int64)

            MIN = min(HT.min(), LT.min())
            HT -= MIN
            LT -= MIN

            H1 = f[fid]["H1"]["SFTs"][:] * 1e22
            H1 = np.sqrt(H1.real**2 + H1.imag**2)
            L1 = f[fid]["L1"]["SFTs"][:] * 1e22
            L1 = np.sqrt(L1.real**2 + L1.imag**2)

            if file_id in H1_anomoly:
                H1 = L1.copy()
                HT = LT.copy()
            elif file_id in L1_anomoly:
                L1 = H1.copy()
                LT = HT.copy()

            H1 /= H1.mean(axis=0)
            L1 /= L1.mean(axis=0)

            H1 = (H1 - H1.mean()) / H1.std()
            L1 = (L1 - L1.mean()) / L1.std()

            valid = LT < 5760
            img[1][:, LT[valid]] = L1[:, valid]
            valid = HT < 5760
            img[0][:, HT[valid]] = H1[:, valid]

        img = np.mean(img.reshape(2, 360, CFG.img_width, -1), axis=3)

        if np.isnan(img).any():
            print('nan in img')
        
        if self.tta == 1:
            img = np.flip(img, axis=1).copy()
        elif self.tta == 2:
            img = np.flip(img, axis=2).copy()
        elif self.tta == 3:
            img = np.flip(img, axis=[1, 2]).copy()

        img = torch.from_numpy(img)

        return img, y

## Model

In [None]:
# only use common cnn model; fancy model modification tricks does not work for me in this competition
class Model(nn.Module):
    def __init__(self, model_name, pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=pretrained, in_chans=2, num_classes=1) # only use the first 2 channels

    def forward(self, x):
        x = self.model(x)
        return x

## Loss function

In [None]:
# focal loss
class BCEFocalLoss(nn.Module):
    def __init__(self, alpha=0.5, gamma=3, reduction='mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, input, target):
        ce_loss = F.binary_cross_entropy_with_logits(input, target, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# criterion = BCEFocalLoss()
criterion = nn.BCEWithLogitsLoss()

## Inference function

In [None]:
# from: https://www.kaggle.com/code/leolu1998/g2net-basic-audio-data-augmentation-inference
@ torch.no_grad()
def inference(model, loader_val, compute_score=True, pbar=None):
    """
        Validation or inference
    """
    tb = time.time()
    model.eval()
    loss_sum = 0.0
    n_sum = 0
    y_all = []
    y_pred_all = []

    if pbar is not None:
        pbar = tqdm(desc='Predict', nrows=78, total=pbar)
    
    for img, y in loader_val:
        n = y.size(0)
        img = img.to(device)
        y = y.to(device)
        y_pred = model(img)
        loss = criterion(y_pred.view(-1), y)
        n_sum += n
        loss_sum += n * loss.item()
        y_all.append(y.cpu().detach().numpy())
        y_pred_all.append(y_pred.sigmoid().squeeze().cpu().detach().numpy())
        if pbar is not None:
            pbar.update(len(img))

        gc.collect()

    loss_val = loss_sum / n_sum
    y = np.concatenate(y_all)
    y_pred = np.concatenate(y_pred_all)

    score = roc_auc_score(y, y_pred) if compute_score else None

    ret = {'loss': loss_val,
           'score': score,
           'y': y,
           'y_pred': y_pred,
           'time': time.time() - tb}
    
    return ret

## Train

In [None]:
model_name = 'convnext_base'

In [None]:
# dataset
dataset_train = TrainDataset(df_train) # the generated dataset is used for training
dataset_val = InferDataset('train', df_val) # the given "train data" is used for validation

loader_train = DataLoader(dataset_train, batch_size=CFG.batch_size,
                    num_workers=CFG.num_workers, pin_memory=True, shuffle=True, drop_last=True)

loader_val = DataLoader(dataset_val, batch_size=CFG.batch_size,
                    num_workers=CFG.num_workers, pin_memory=True)


model = Model(model_name, pretrained=True)
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay)

# learning rate schedule
n_batch = len(loader_train)
warmup = CFG.epochs_warmup * n_batch
n_steps = CFG.epochs * n_batch
scheduler = CosineLRScheduler(optimizer,
                warmup_t=warmup, warmup_lr_init=0.0, warmup_prefix=True,
                t_initial=(n_steps - warmup), lr_min=1e-6)

time_val = 0.0
lrs = []

tb = time.time()
best_score = 0.0
for iepoch in range(CFG.epochs):

    # use BCEWithLogitsLoss for the first epoch, then use BCEFocalLoss
    if iepoch == 0: 
        criterion = nn.BCEWithLogitsLoss()
    else:
        criterion = BCEFocalLoss()
    
    loss_sum = 0.0
    n_sum = 0
    
    # train the model using generated training instances
    ibatch = 0
    model.train()
    for img, y, img_noise, y_noise in tqdm(loader_train):
        n = y.size(0)
        img = img.to(device)
        y = y.to(device)
        img_noise = img_noise.to(device)
        y_noise = y_noise.to(device)
        img = torch.cat([img, img_noise], dim=0) # distinguish signal and noise in the same batch
        y = torch.cat([y, y_noise], dim=0)
        optimizer.zero_grad()
        y_pred = model(img)

        loss = criterion(y_pred.view(-1), y)
        loss_train = loss.item()
        loss_sum += n * loss_train
        n_sum += n
        loss.backward()
        optimizer.step()
        scheduler.step(iepoch * n_batch + ibatch + 1)
        lrs.append(optimizer.param_groups[0]['lr'])    
        ibatch += 1

    # evaluate the model using the given "train data"
    val = inference(model, loader_val)
    time_val += val['time']
    loss_train = loss_sum / n_sum
    lr_now = optimizer.param_groups[0]['lr']
    dt = (time.time() - tb) / 60
    print('Epoch %d, lr %.6f, train loss %.4f, val loss %.4f, val score %.4f, time %.2f min' %
            (iepoch, lr_now, loss_train, val['loss'], val['score'], dt))
            
    if iepoch >= CFG.start_eval_epoch and val['score'] > best_score:
        best_score = val['score']
        torch.save(model.state_dict(), 'model_best.pth')
    
    # save the model every 5 epochs
    if (iepoch + 1) % 5 == 0:
        torch.save(model.state_dict(), 'model_epoch%d.pth' % iepoch)

    gc.collect()
    
# save the model of the last epoch
torch.save(model.state_dict(), 'model_last.pth')

dt = time.time() - tb
print('time %.2f min, best score %.4f' % (dt / 60, best_score))
print('\n')

## Inference with test time augmentation (TTA)

In [None]:
test = pd.read_csv(DIR + '/sample_submission.csv')
test['target'] = 0

model = Model(model_name, pretrained=False)
# filename = f'model_best.pth'
filename = f'model_last.pth'
model.to(device)
model.load_state_dict(torch.load(filename, map_location=device))
model.eval()

# employ horizontal and vertical flips in TTA
for i in range(4):
    dataset_test = InferDataset('test', test, tta=i)
    loader_test = DataLoader(dataset_test, batch_size=CFG.batch_size,
                                            num_workers=CFG.num_workers, pin_memory=True)
    test_ = inference(model, loader_test, compute_score=False, pbar=len(test))
    test['target'] += test_['y_pred'] / 4

test.to_csv('submission.csv', index=False)
print('target range [%.2f, %.2f]' % (test['target'].min(), test['target'].max()))