In [1]:
import torch
from torch import optim,nn
from torch.autograd import Variable
import numpy as np
import time
from models.data_loader import DataLoader
from models.retain_bidirectional import RETAIN

In [2]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=0


In [3]:
# hyperparameters
epochs = 30
batch_size = 50
max_seq_length = 300
min_seq_length = 5
num_classes = 268
emb_size = 128
hid_size = 128
lr = 0.001
cuda_flag = True

# data loader
D = DataLoader(batch_size=batch_size,
   data_dir='data/batches/',
    mode='train', max_seq_length=max_seq_length, min_seq_length=min_seq_length)

# import model and optimization settings
# model = RETAIN(128,128,268,False)

model = RETAIN(emb_size,hid_size,num_classes,cuda_flag)
model.release = True

criterion = nn.CrossEntropyLoss()
cnt = 0
if cuda_flag:
    model.cuda()
    criterion.cuda()

In [4]:
# train model
lr_list = [0.001, 0.0003, 0.0001, 0.00003, 0.00001, 0.000003, 0.000001]
lr_counter = 0
lr = lr_list[lr_counter]
opt = optim.Adam(model.parameters(), lr=lr)
loss_list = []
loss_mean = 0.0
file_cnt = 0
cnt = 0
loss_dict = dict()

In [8]:
len_train = len(D.train_list)
while file_cnt<(epochs*len(D.train_list)):
    idx = file_cnt%(len_train)
    file = D.train_list[idx]
    if file not in loss_dict:
        loss_dict[file] = []
    print("Epoch %d [%d, %d/%d] - opening file %s" %(((file_cnt+1)/len_train), file_cnt, idx, len_train, D.train_list[idx]))
    file_num = int(file.split('_')[1].split('.')[0])
    D.batch_size = int(20000/file_num)
    D.load_batch_file(file)
    loss_list = []
    for i in range(D.batch_count):
        cnt+=1
        input_list, targets = D.get_batch()
        start = time.time()
        inputs = model.list_to_tensor(input_list)
        outputs = model(inputs)
        targets = Variable(torch.LongTensor(targets)[:,-1]) # to only use last of each sequence
#             targets = Variable(torch.LongTensor(targets)).view(len(inputs),-1)[:,-1] # to only use last of each sequence
        if cuda_flag:
            targets = targets.cuda()
        loss = criterion(outputs.view(-1,num_classes),targets)
        loss_list.append(loss.data[0])
        if cnt%10==0:
            print('[%d] %1.3f' %(cnt,loss.data[0]))
        if cnt%500==0:
            print("Saving model at %dth step" %cnt)
            torch.save(model.state_dict(),'data/saved_weights/retain_bi_%d.pth'%(cnt))
            # create CPU version
            model2 = RETAIN(emb_size,hid_size,num_classes,False)
            if cuda_flag:
                model.cpu()
            model2.load_state_dict(model.state_dict())
            torch.save(model2.state_dict(),'data/saved_weights/retain_bi_%d_cpu.pth'%(cnt))
            if cuda_flag:
                model.cuda()
            print("Saving at %dth step"%cnt)
        # manual loss changes
        if cnt==100:
            lr_counter+=1
            lr = lr_list[lr_counter]
            opt = optim.Adam(model.parameters(),lr=lr)
        if cnt==500:
            lr_counter+=1
            lr = lr_list[lr_counter]
            opt = optim.Adam(model.parameters(),lr=lr)        
        if loss.data[0]>10:
            import sys
            sys.exit()
#             print(loss.data[0])
        loss.backward()
        opt.step()
    print("Loss: %1.3f" %np.mean(loss_list))
    loss_dict[file].append(loss.data[0])
    file_cnt+=1

Epoch 11 [4916, 406/410] - opening file 2014_263.pckl


  outputs1 = self.RNN1(embedded) # [b x seq x 128*2]
  outputs2 = self.RNN2(embedded) # [b x seq x 128]


Loss: 1.196
Epoch 11 [4917, 407/410] - opening file 2014_50.pckl
Loss: 2.579
Epoch 11 [4918, 408/410] - opening file 2014_103.pckl
[14510] 1.709
Loss: 2.161
Epoch 12 [4919, 409/410] - opening file 2015_146.pckl
Loss: 2.043
Epoch 12 [4920, 0/410] - opening file 2015_118.pckl
Loss: 2.309
Epoch 12 [4921, 1/410] - opening file 2015_105.pckl
Loss: 2.469
Epoch 12 [4922, 2/410] - opening file 2015_171.pckl
Loss: 1.424
Epoch 12 [4923, 3/410] - opening file 2014_94.pckl
Loss: 2.649
Epoch 12 [4924, 4/410] - opening file 2014_216.pckl
Loss: 1.213
Epoch 12 [4925, 5/410] - opening file 2014_213.pckl
Loss: 1.184
Epoch 12 [4926, 6/410] - opening file 2014_254.pckl
Loss: 1.634
Epoch 12 [4927, 7/410] - opening file 2015_152.pckl
[14520] 1.473
Loss: 1.473
Epoch 12 [4928, 8/410] - opening file 2014_192.pckl
Loss: 1.700
Epoch 12 [4929, 9/410] - opening file 2014_287.pckl
Loss: 0.016
Epoch 12 [4930, 10/410] - opening file 2015_297.pckl
Loss: 0.459
Epoch 12 [4931, 11/410] - opening file 2014_34.pckl
[14530]

[14960] 2.699
Loss: 2.654
Epoch 12 [5034, 114/410] - opening file 2015_134.pckl
[14970] 2.195
Loss: 2.195
Epoch 12 [5035, 115/410] - opening file 2015_222.pckl
Loss: 0.334
Epoch 12 [5036, 116/410] - opening file 2015_37.pckl
[14980] 2.612
Loss: 2.678
Epoch 12 [5037, 117/410] - opening file 2014_193.pckl
Loss: 2.224
Epoch 12 [5038, 118/410] - opening file 2015_290.pckl
Loss: 0.406
Epoch 12 [5039, 119/410] - opening file 2014_166.pckl
Loss: 1.384
Epoch 12 [5040, 120/410] - opening file 2015_97.pckl
Loss: 2.242
Epoch 12 [5041, 121/410] - opening file 2014_112.pckl
Loss: 2.493
Epoch 12 [5042, 122/410] - opening file 2014_204.pckl
Loss: 0.924
Epoch 12 [5043, 123/410] - opening file 2014_228.pckl
[14990] 1.112
Loss: 1.112
Epoch 12 [5044, 124/410] - opening file 2015_128.pckl
Loss: 2.452
Epoch 12 [5045, 125/410] - opening file 2015_164.pckl
Loss: 0.920
Epoch 12 [5046, 126/410] - opening file 2014_224.pckl
Loss: 2.038
Epoch 12 [5047, 127/410] - opening file 2015_30.pckl
[15000] 2.563
Saving mo

[15300] 2.646
[15310] 2.700
Loss: 2.662
Epoch 12 [5151, 231/410] - opening file 2014_299.pckl
Loss: 0.002
Epoch 12 [5152, 232/410] - opening file 2014_118.pckl
Loss: 2.532
Epoch 12 [5153, 233/410] - opening file 2014_162.pckl
[15320] 1.269
Loss: 1.269
Epoch 12 [5154, 234/410] - opening file 2014_278.pckl
Loss: 0.186
Epoch 12 [5155, 235/410] - opening file 2014_106.pckl
Loss: 2.603
Epoch 12 [5156, 236/410] - opening file 2014_65.pckl
Loss: 2.590
Epoch 12 [5157, 237/410] - opening file 2014_133.pckl
Loss: 2.534
Epoch 12 [5158, 238/410] - opening file 2015_130.pckl
Loss: 2.644
Epoch 12 [5159, 239/410] - opening file 2014_7.pckl
[15330] 3.053
[15340] 2.994
Loss: 2.992
Epoch 12 [5160, 240/410] - opening file 2015_184.pckl
[15350] 1.853
Loss: 1.853
Epoch 12 [5161, 241/410] - opening file 2014_199.pckl
Loss: 0.976
Epoch 12 [5162, 242/410] - opening file 2014_59.pckl
Loss: 2.616
Epoch 12 [5163, 243/410] - opening file 2015_174.pckl
Loss: 1.181
Epoch 12 [5164, 244/410] - opening file 2015_107.p

Loss: 0.456
Epoch 12 [5267, 347/410] - opening file 2014_21.pckl
[15730] 2.641
[15740] 2.640
Loss: 2.641
Epoch 12 [5268, 348/410] - opening file 2015_165.pckl
Loss: 1.050
Epoch 12 [5269, 349/410] - opening file 2014_246.pckl
Loss: 0.832
Epoch 12 [5270, 350/410] - opening file 2014_120.pckl
Loss: 2.511
Epoch 12 [5271, 351/410] - opening file 2014_73.pckl
Loss: 2.522
Epoch 12 [5272, 352/410] - opening file 2014_49.pckl
[15750] 2.732
Loss: 2.683
Epoch 12 [5273, 353/410] - opening file 2014_175.pckl
Loss: 1.298
Epoch 12 [5274, 354/410] - opening file 2014_223.pckl
Loss: 1.484
Epoch 12 [5275, 355/410] - opening file 2014_67.pckl
[15760] 2.710
Loss: 2.696
Epoch 12 [5276, 356/410] - opening file 2014_143.pckl
Loss: 2.514
Epoch 12 [5277, 357/410] - opening file 2015_84.pckl
Loss: 2.797
Epoch 12 [5278, 358/410] - opening file 2015_68.pckl
Loss: 2.196
Epoch 12 [5279, 359/410] - opening file 2015_41.pckl
[15770] 2.654
Loss: 2.612
Epoch 12 [5280, 360/410] - opening file 2015_148.pckl
Loss: 2.023
E

Loss: 0.593
Epoch 13 [5384, 54/410] - opening file 2014_236.pckl
Loss: 2.094
Epoch 13 [5385, 55/410] - opening file 2014_102.pckl
[16130] 2.550
Loss: 2.550
Epoch 13 [5386, 56/410] - opening file 2015_8.pckl
[16140] 2.934
[16150] 2.941
Loss: 2.931
Epoch 13 [5387, 57/410] - opening file 2014_168.pckl
Loss: 1.573
Epoch 13 [5388, 58/410] - opening file 2014_127.pckl
Loss: 2.872
Epoch 13 [5389, 59/410] - opening file 2015_89.pckl
Loss: 2.678
Epoch 13 [5390, 60/410] - opening file 2015_95.pckl
Loss: 2.290
Epoch 13 [5391, 61/410] - opening file 2014_262.pckl
Loss: 1.438
Epoch 13 [5392, 62/410] - opening file 2014_202.pckl
[16160] 0.720
Loss: 0.720
Epoch 13 [5393, 63/410] - opening file 2015_79.pckl
Loss: 2.268
Epoch 13 [5394, 64/410] - opening file 2014_53.pckl
Loss: 2.664
Epoch 13 [5395, 65/410] - opening file 2015_124.pckl
[16170] 2.731
Loss: 2.731
Epoch 13 [5396, 66/410] - opening file 2015_116.pckl
Loss: 2.753
Epoch 13 [5397, 67/410] - opening file 2015_106.pckl
Loss: 2.647
Epoch 13 [5398

Loss: 0.478
Epoch 13 [5500, 170/410] - opening file 2014_47.pckl
[16570] 2.691
Loss: 2.674
Epoch 13 [5501, 171/410] - opening file 2014_72.pckl
Loss: 2.746
Epoch 13 [5502, 172/410] - opening file 2014_51.pckl
[16580] 2.719
Loss: 2.645
Epoch 13 [5503, 173/410] - opening file 2015_178.pckl
Loss: 1.957
Epoch 13 [5504, 174/410] - opening file 2015_197.pckl
Loss: 1.162
Epoch 13 [5505, 175/410] - opening file 2015_298.pckl
Loss: 0.342
Epoch 13 [5506, 176/410] - opening file 2015_81.pckl
Loss: 2.646
Epoch 13 [5507, 177/410] - opening file 2014_231.pckl
[16590] 1.787
Loss: 1.787
Epoch 13 [5508, 178/410] - opening file 2015_291.pckl
Loss: 0.306
Epoch 13 [5509, 179/410] - opening file 2014_156.pckl
Loss: 1.172
Epoch 13 [5510, 180/410] - opening file 2015_277.pckl
Loss: 0.644
Epoch 13 [5511, 181/410] - opening file 2014_161.pckl
Loss: 0.949
Epoch 13 [5512, 182/410] - opening file 2015_101.pckl
Loss: 2.494
Epoch 13 [5513, 183/410] - opening file 2015_195.pckl
Loss: 1.650
Epoch 13 [5514, 184/410] -

KeyboardInterrupt: 

In [7]:
# load new model
model = torch.load('/home/mjc/github/EHRVis/data/saved_weights/retain_bi_14500.pth')
file_cnt = 4916
cnt = 14500

In [6]:
# lower learning rate
lr_counter+=1
lr = lr_list[lr_counter]
opt = optim.Adam(model.parameters(), lr=lr)