In [None]:
import os
import pandas as pd
import os, gc
import numpy as np
from sklearn.model_selection import KFold

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import neptune
from neptune.utils import stringify_unsupported

In [None]:
# Fix fastai bug to enable fp16 training with dictionaries

import torch
from fastai.vision.all import *
def flatten(o):
    "Concatenate all collections and items as a generator"
    for item in o:
        if isinstance(o, dict): yield o[item]; continue
        elif isinstance(item, str): yield item; continue
        try: yield from flatten(item)
        except TypeError: yield item

from torch.cuda.amp import GradScaler, autocast
@delegates(GradScaler)
class MixedPrecision(Callback):
    "Mixed precision training using Pytorch's `autocast` and `GradScaler`"
    order = 10
    def __init__(self, **kwargs): self.kwargs = kwargs
    def before_fit(self): 
        self.autocast,self.learn.scaler,self.scales = autocast(),GradScaler(**self.kwargs),L()
    def before_batch(self): self.autocast.__enter__()
    def after_pred(self):
        if next(flatten(self.pred)).dtype==torch.float16: self.learn.pred = to_float(self.pred)
    def after_loss(self): self.autocast.__exit__(None, None, None)
    def before_backward(self): self.learn.loss_grad = self.scaler.scale(self.loss_grad)
    def before_step(self):
        "Use `self` as a fake optimizer. `self.skipped` will be set to True `after_step` if gradients overflow. "
        self.skipped=True
        self.scaler.step(self)
        if self.skipped: raise CancelStepException()
        self.scales.append(self.scaler.get_scale())
    def after_step(self): self.learn.scaler.update()

    @property 
    def param_groups(self): 
        "Pretend to be an optimizer for `GradScaler`"
        return self.opt.param_groups
    def step(self, *args, **kwargs): 
        "Fake optimizer step to detect whether this batch was skipped from `GradScaler`"
        self.skipped=False
    def after_fit(self): self.autocast,self.learn.scaler,self.scales = None,None,None
        
import fastai
fastai.callback.fp16.MixedPrecision = MixedPrecision

In [None]:
import sys
import argparse
from copy import copy
import importlib

In [None]:
BASEDIR= './'#'../input/asl-fingerspelling-config'
for DIRNAME in 'configs data models postprocess metrics utils repos'.split():
    sys.path.append(f'{BASEDIR}/{DIRNAME}/')

In [None]:
parser = argparse.ArgumentParser(description="")

In [None]:
#importlib.import_module(parser_args.config)

In [None]:
parser.add_argument("-C", "--config", help="config filename", default="cfg_0")
parser.add_argument("-G", "--gpu_id", default="", help="GPU ID")
parser_args, other_args = parser.parse_known_args(sys.argv)
#cfg = copy(importlib.import_module(parser_args.config).cfg)

In [None]:
cfg = copy(importlib.import_module('cfg_fastai').cfg)

In [None]:
df = pd.read_parquet(cfg.train_df)
BPPs_RNA_Dataset = importlib.import_module(cfg.dataset).BPPs_RNA_Dataset
LenMatchBatchSampler = importlib.import_module(cfg.dataset).LenMatchBatchSampler
DeviceDataLoader = importlib.import_module(cfg.dataset).DeviceDataLoader
Squeezeformer_RNA = importlib.import_module(cfg.model).Squeezeformer_RNA
loss = importlib.import_module(cfg.loss).loss
MAE=importlib.import_module(cfg.metrics).MAE
OUT=cfg.OUT
SEED=cfg.SEED
nfolds=cfg.nfolds
set_seed=importlib.import_module(cfg.utils).set_seed

In [None]:
#pip install neptune-fastai
#from neptune.integrations.fastai import NeptuneCallback
#neptune_callback = NeptuneCallback(run=neptune_run)

In [None]:
set_seed(SEED)
os.makedirs(OUT, exist_ok=True)

    
for fold in [0]: #[0,1,2,3]
    ds_train = BPPs_RNA_Dataset(df, mode='train', fold=fold, nfolds = nfolds)
    ds_train_len = BPPs_RNA_Dataset(df, mode='train', fold=fold, 
                nfolds=nfolds, mask_only=True)
    sampler_train = torch.utils.data.RandomSampler(ds_train_len)
    len_sampler_train = LenMatchBatchSampler(sampler_train, batch_size=cfg.bs,
                drop_last=True)
    dl_train = DeviceDataLoader(torch.utils.data.DataLoader(ds_train, 
                batch_sampler=len_sampler_train, num_workers=cfg.num_workers,
                persistent_workers=True), cfg.device)
    ds_val = BPPs_RNA_Dataset(df, mode='eval', fold=fold, nfolds=nfolds)
    ds_val_len = BPPs_RNA_Dataset(df, mode='eval', fold=fold, nfolds=nfolds, 
               mask_only=True)
    sampler_val = torch.utils.data.SequentialSampler(ds_val_len)
    len_sampler_val = LenMatchBatchSampler(sampler_val, batch_size=cfg.bs, 
               drop_last=False)
    dl_val= DeviceDataLoader(torch.utils.data.DataLoader(ds_val, 
               batch_sampler=len_sampler_val, num_workers=cfg.num_workers), cfg.device)
    gc.collect()
    data = DataLoaders(dl_train,dl_val)
    model = Squeezeformer_RNA(cfg)

    #model.load_state_dict(torch.load(m,map_location=torch.device('cpu')))
    model = model.to(cfg.device)
    learn = Learner(data, model, loss_func=loss,cbs=[GradientClip(3.0)], #neptune_callback
                metrics=[MAE()]).to_fp16()
    learn.fit_one_cycle(cfg.epochs, lr_max=cfg.lr, wd=cfg.weight_decay, pct_start=cfg.pct_start,)#pct_start=0.02

    if not os.path.exists(f"{cfg.output_dir}/fold{fold}/"): 
        os.makedirs(f"{cfg.output_dir}/fold{fold}/")
    torch.save(learn.model.state_dict(),f"{cfg.output_dir}/fold{fold}/checkpoint_last_SEED{cfg.SEED}.pth")
    gc.collect()