In [1]:
import os
import gc
import pandas as pd
import torch
import importlib
from importlib import reload

import dataloader
from dataloader import RNA_Dataset, LenMatchBatchSampler, DeviceDataLoader

import model
from model import RNA_Model

import metrics
from metrics import MAE, loss

from fastai.vision.data import  DataLoaders
from fastai.vision.all import Learner, GradientClip, SaveModelCallback

import matplotlib.pyplot as plt
import numpy as np
import pickle

In [None]:
reload(dataloader)
reload(metrics)
reload(model)

In [2]:
fname = 'example0'
PATH = '/scratch/lemercier/WIP_data/'
OUT = '/scratch/lemercier/'
bs = 128
num_workers = 2
nfolds = 4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
df = pd.read_parquet(os.path.join(PATH,'train_data.parquet'))

In [26]:
filter_ = True

for fold in [0]: 
    
    ds_train = RNA_Dataset(df, mode='train', fold=fold, nfolds=nfolds, filter_SNR=filter_)
    ds_train_len = RNA_Dataset(df, mode='train', fold=fold, 
                nfolds=nfolds, filter_SNR=filter_, mask_only=True)
    sampler_train = torch.utils.data.RandomSampler(ds_train_len)
    len_sampler_train = LenMatchBatchSampler(sampler_train, batch_size=bs,
                drop_last=True)
    dl_train = DeviceDataLoader(torch.utils.data.DataLoader(ds_train, 
                batch_sampler=len_sampler_train, num_workers=num_workers,
                persistent_workers=True), device)

    ds_val = RNA_Dataset(df, mode='eval', fold=fold, nfolds=nfolds, filter_SNR=filter_)
    ds_val_len = RNA_Dataset(df, mode='eval', fold=fold, nfolds=nfolds, 
               mask_only=True, filter_SNR=filter_)
    sampler_val = torch.utils.data.SequentialSampler(ds_val_len)
    len_sampler_val = LenMatchBatchSampler(sampler_val, batch_size=bs, 
               drop_last=False)
    dl_val= DeviceDataLoader(torch.utils.data.DataLoader(ds_val, 
               batch_sampler=len_sampler_val, num_workers=num_workers), device)
    gc.collect()

    data = DataLoaders(dl_train,dl_val) 

    model = RNA_Model()

    model = model.to(device)
    

    
    learn = Learner(data, model, loss_func=loss,cbs=[GradientClip(3.0)],
                metrics=[MAE()]).to_fp16() 

    learn.fit_one_cycle(100, lr_max=5e-4, wd=0.05, pct_start=0.02, cbs=SaveModelCallback(every_epoch=True, fname='modelAF'))

    torch.save(learn.model.state_dict(),os.path.join(OUT,f'{fname}_{fold}_100_epoch_AF2.pth'))
    gc.collect()

epoch,train_loss,valid_loss,mae,time
0,0.246618,0.233732,0.234042,03:10
1,0.214238,0.214183,0.214708,03:19
2,0.190238,0.197791,0.198255,03:25
3,0.179449,0.1858,0.186258,03:23
4,0.169264,0.16915,0.169542,03:24
5,0.161633,0.159653,0.160073,03:24
6,0.157517,0.154601,0.15496,03:29
7,0.153611,0.153381,0.153764,03:25
8,0.151319,0.149788,0.150191,03:26
9,0.14948,0.147531,0.147948,03:27


In [11]:
# print the number of parameters
def count_params(model):
    c = 0
    for p in list(model.parameters()):
        c += reduce(operator.mul, list(p.size()))
    return c