In [7]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

import torch

from bpnetlite.io import extract_loci
from bpnetlite.io import PeakGenerator
from bpnetlite import BPNet

root = '/users/myin25/projects/human_proseq/data/coPRO_3prime/K562/'

peaks_train = root + 'peaks_fold1_train.bed.gz'
peaks_val = root + 'peaks_fold1_val.bed.gz'
seqs = '/users/myin25/projects/human_proseq/refs/hg38.fasta'
signals = [root + '3prime.pos.bigWig', root + '3prime.neg.bigWig']
controls = None

training_chroms = ['chr{}'.format(i) for i in range(0, 23)]
valid_chroms = ['chr{}'.format(i) for i in range(0, 23)]

In [8]:
training_data = PeakGenerator(peaks_train, seqs, signals, controls)

In [9]:
X_valid, y_valid = extract_loci(peaks_val, seqs, signals, controls, max_jitter=0)

In [10]:
from datetime import datetime, date
timestamp = str(date.today()) + '_' + str(datetime.now().strftime("%H:%M:%S"))
model_root = '/users/myin25/projects/human_proseq/models/'
model_timestamped = model_root + timestamp

fit_parameters = {
    'n_filters': 256,
    'n_layers': 11,
    'profile_output_bias': True,
    'count_output_bias': True,
    'name': None,
    'batch_size': 64,
    'in_window': 2114,
    'out_window': 1000,
    'max_jitter': 128,
    'reverse_complement': True,
    'max_epochs': 50,
    'validation_iter': 100,
    'lr': 0.0001,
    'alpha': 1000,
    'verbose': False,

    'min_counts': 0,
    'max_counts': 99999999,

    'training_chroms': ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 
        'chr9', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 
        'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'],
    'validation_chroms': ['chr8', 'chr10'],
    'sequences': None,
    'loci': None,
    'signals': None,
    'controls': None,
    'random_state': None
}

model = BPNet(n_filters = fit_parameters['n_filters'], n_layers = fit_parameters['n_layers'], n_outputs = 2,
              alpha = fit_parameters['alpha'], n_control_tracks=0, name = model_timestamped, 
              trimming=(2114 - 1000) // 2).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [11]:
import torch
print(torch.cuda.is_available())

True


In [12]:
model.fit(training_data, optimizer, X_valid=X_valid, X_ctl_valid=None, y_valid=y_valid, early_stopping=60)
print('done')

Epoch	Iteration	Training Time	Validation Time	Training MNLL	Training Count MSE	Validation MNLL	Validation Profile Pearson	Validation Count Pearson	Validation Count MSE	Saved?
0	0	0.1138	10.2774	1629.9883	22.1837	53430.1172	-0.019620266	-0.33764055	2036.4468	True
0	100	41.3504	10.2751	2298.3853	1.3894	1265.7345	0.18079373	0.4641144	2.8327	True
0	200	62.3196	22.4875	2245.8066	1.1717	1349.0001	0.23188366	0.46554136	3.1834	False
0	300	94.042	23.3941	1312.6508	0.8592	1106.0615	0.2613095	0.47022337	2.3922	True
0	400	94.9342	23.3798	2168.0549	1.8982	1243.9458	0.27117676	0.47441107	2.6609	False
0	500	94.6913	23.2182	1912.2162	1.5749	1073.6028	0.2769651	0.4803906	3.1688	False
0	600	93.6747	23.2029	1728.7515	0.9318	1055.4895	0.28886288	0.48753798	2.8336	False
0	700	94.1318	23.3108	2995.1902	1.5558	1077.6399	0.26854447	0.49443826	2.0745	True
1	800	68.2454	23.2653	1571.1868	1.3355	1053.1937	0.2884819	0.49674338	2.9654	False
1	900	93.7005	23.2671	1871.2502	0.874	1035.8353	0.29451373	0.504791	3.0801

14	9900	14.622	10.6633	1296.9573	0.882	792.4056	0.43310001	0.63753104	2.5576	False
14	10000	43.3469	10.7064	1413.5969	0.5756	796.5823	0.42453256	0.63524514	2.8828	False
14	10100	43.3693	10.6564	1477.6412	0.8229	795.1877	0.4325305	0.6376464	2.17	False
14	10200	43.3171	10.6384	1083.115	0.7644	784.0013	0.4350545	0.6374918	3.0002	False
14	10300	43.2481	10.7206	1423.8812	0.9671	786.8156	0.4319311	0.6360481	3.1357	False
14	10400	43.3119	10.6876	1199.0629	0.6371	810.2155	0.43642068	0.6390885	2.6107	False
14	10500	57.2185	11.0758	1171.9788	0.7971	787.9626	0.43301913	0.63889503	2.5103	False
15	10600	13.3105	10.6825	962.5181	0.5583	794.8676	0.428269	0.6355314	2.1175	False
15	10700	43.4583	10.6738	1082.5571	0.5429	798.856	0.42627093	0.6384266	2.3544	False
15	10800	43.3607	10.7067	1073.0514	0.8942	792.9805	0.42941102	0.636675	2.8524	False
15	10900	43.3283	10.6631	1315.835	0.6499	825.8093	0.4183562	0.6394624	2.3857	False
15	11000	43.2845	10.7191	1249.8107	0.5845	796.5284	0.4269421	0.64026225	2.2267