In [57]:
import numpy as np
import pickle
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils import *
from model import *

seed = 0 
seed_everything(seed)

In [58]:
data_num = 10000

In [59]:
loader_train,loader_validate = make_loader_train(data_num)

In [60]:
model = MLP()
model.cuda()
model

MLP(
  (layers): Sequential(
    (0): Linear(in_features=3, out_features=1000, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1000, out_features=1000, bias=True)
    (3): ReLU()
    (4): Linear(in_features=1000, out_features=1, bias=True)
  )
)

In [61]:
optimizer = optim.Adam(model.parameters(),lr=1e-3)
criterion = nn.MSELoss()
scheduler = ReduceLROnPlateau(optimizer,min_lr=1e-5,patience=100)
early_stopping = EarlyStopping(patience=300, verbose=False)

In [62]:
def loop_train(loader_train,loader_validate) :
    
    losses_train = []
    for x,y in loader_train :
        x = x.cuda(); y = y.cuda()
        
        optimizer.zero_grad()
        y_pred = model(x)
    
        loss = criterion(y_pred, y)
        losses_train.append( loss.item() )
        
        loss.backward()
        optimizer.step()
        
    losses_validate = []
    for x,y in loader_train :
        x = x.cuda(); y = y.cuda()
        
        with torch.no_grad() :
            y_pred = model(x)
    
        loss = criterion(y_pred, y)
        losses_validate.append( loss.item() )
        
    lt = np.mean(losses_train)
    lv = np.mean(losses_validate)
        
    scheduler.step(lv)
    
    return (lt,lv)

def loop_test(loader) :
    
    losses = []
    for x,y in loader :
        x = x.cuda(); y = y.cuda()
        
        with torch.no_grad() :
            y_pred = model(x)
    
        loss = criterion(y_pred, y)
        losses.append( loss.item() )
            
    return np.mean(losses)

In [None]:
total_epoch = 10000
inner_unit = 10
outer_unit = total_epoch//inner_unit

for outer_loop in range(outer_unit):  
    
    losses_train = []; losses_validate = []
    for inner_loop in range(inner_unit) :
        epoch = inner_unit*outer_loop + inner_loop
        loss_train,loss_validate = loop_train(loader_train,loader_validate)
        losses_train.append(loss_train)
        losses_validate.append(loss_validate)
        
        early_stopping(loss_validate.item(), model)
        if early_stopping.early_stop :
            break 
    
    if early_stopping.early_stop :
        print("Early stopping")
        break 
        
    lr = optimizer.param_groups[0]['lr']
    lt = np.mean(losses_train)
    lv = np.mean(losses_validate)
    print(f'{epoch:5d} {lr:5.2e} {lt:8.5e} {lv:8.5e}')

    9 1.00e-03 1.58603e-03 4.99718e-05
   19 1.00e-03 3.04552e-05 4.30907e-05
   29 1.00e-03 2.23293e-05 2.40009e-05
   39 1.00e-03 1.25599e-05 9.83483e-06
   49 1.00e-03 1.76706e-05 6.02618e-05
   59 1.00e-03 4.68971e-06 3.61459e-06
   69 1.00e-03 5.23874e-06 9.04848e-06
   79 1.00e-03 5.95828e-06 8.49881e-06
   89 1.00e-03 4.96361e-06 2.73775e-06
   99 1.00e-03 5.17309e-06 4.21823e-06
  109 1.00e-03 3.69746e-06 3.41508e-06
  119 1.00e-03 4.26606e-06 4.69298e-06
  129 1.00e-03 3.44633e-06 4.48763e-06
  139 1.00e-03 4.86943e-06 5.48748e-06
  149 1.00e-03 3.45432e-06 3.21812e-06
  159 1.00e-03 3.63867e-06 3.18929e-06
  169 1.00e-03 4.19866e-06 3.00793e-06
  179 1.00e-03 3.64614e-06 3.13871e-06
  189 1.00e-03 3.38415e-06 2.32119e-06
  199 1.00e-03 3.19480e-06 2.45983e-06
  209 1.00e-03 2.83904e-06 2.80170e-06
  219 1.00e-03 3.74525e-06 3.33800e-06
  229 1.00e-03 3.17491e-06 2.77624e-06
  239 1.00e-03 2.99713e-06 3.30857e-06
  249 1.00e-03 3.03473e-06 3.47329e-06
  259 1.00e-03 2.87024e-0

 2119 1.00e-05 1.04553e-06 1.02764e-06
 2129 1.00e-05 1.04534e-06 1.02741e-06
 2139 1.00e-05 1.04503e-06 1.02739e-06
 2149 1.00e-05 1.04496e-06 1.02714e-06
 2159 1.00e-05 1.04475e-06 1.02699e-06
 2169 1.00e-05 1.04452e-06 1.02685e-06
 2179 1.00e-05 1.04433e-06 1.02661e-06
 2189 1.00e-05 1.04407e-06 1.02664e-06
 2199 1.00e-05 1.04387e-06 1.02653e-06
 2209 1.00e-05 1.04365e-06 1.02615e-06
 2219 1.00e-05 1.04348e-06 1.02596e-06
 2229 1.00e-05 1.04321e-06 1.02587e-06
 2239 1.00e-05 1.04294e-06 1.02549e-06
 2249 1.00e-05 1.04268e-06 1.02531e-06
 2259 1.00e-05 1.04244e-06 1.02510e-06
 2269 1.00e-05 1.04236e-06 1.02515e-06
 2279 1.00e-05 1.04217e-06 1.02487e-06
 2289 1.00e-05 1.04200e-06 1.02465e-06
 2299 1.00e-05 1.04192e-06 1.02441e-06
 2309 1.00e-05 1.04164e-06 1.02413e-06
 2319 1.00e-05 1.04147e-06 1.02389e-06
 2329 1.00e-05 1.04130e-06 1.02370e-06
 2339 1.00e-05 1.04117e-06 1.02382e-06
 2349 1.00e-05 1.04101e-06 1.02365e-06
 2359 1.00e-05 1.04087e-06 1.02341e-06
 2369 1.00e-05 1.04080e-0

 4229 1.00e-05 1.01005e-06 9.94077e-07
 4239 1.00e-05 1.00990e-06 9.94049e-07
 4249 1.00e-05 1.00980e-06 9.94017e-07
 4259 1.00e-05 1.00960e-06 9.93712e-07
 4269 1.00e-05 1.00956e-06 9.93558e-07
 4279 1.00e-05 1.00950e-06 9.93495e-07
 4289 1.00e-05 1.00930e-06 9.93228e-07
 4299 1.00e-05 1.00918e-06 9.93153e-07
 4309 1.00e-05 1.00894e-06 9.93222e-07
 4319 1.00e-05 1.00874e-06 9.92990e-07
 4329 1.00e-05 1.00853e-06 9.92885e-07
 4339 1.00e-05 1.00840e-06 9.92710e-07
 4349 1.00e-05 1.00835e-06 9.92646e-07
 4359 1.00e-05 1.00828e-06 9.92302e-07
 4369 1.00e-05 1.00823e-06 9.91915e-07
 4379 1.00e-05 1.00805e-06 9.91932e-07
 4389 1.00e-05 1.00799e-06 9.91796e-07
 4399 1.00e-05 1.00773e-06 9.91660e-07
 4409 1.00e-05 1.00739e-06 9.91772e-07
 4419 1.00e-05 1.00730e-06 9.91686e-07
 4429 1.00e-05 1.00717e-06 9.91456e-07
 4439 1.00e-05 1.00712e-06 9.91447e-07
 4449 1.00e-05 1.00711e-06 9.91271e-07
 4459 1.00e-05 1.00694e-06 9.90998e-07
 4469 1.00e-05 1.00659e-06 9.90764e-07
 4479 1.00e-05 1.00643e-0

 6339 1.00e-05 9.81822e-07 9.66669e-07
 6349 1.00e-05 9.81759e-07 9.66521e-07
 6359 1.00e-05 9.81629e-07 9.66374e-07
 6369 1.00e-05 9.81480e-07 9.66118e-07
 6379 1.00e-05 9.81544e-07 9.65953e-07
 6389 1.00e-05 9.81289e-07 9.66044e-07
 6399 1.00e-05 9.81202e-07 9.66028e-07
 6409 1.00e-05 9.81140e-07 9.65938e-07
 6419 1.00e-05 9.81044e-07 9.65681e-07
 6429 1.00e-05 9.80883e-07 9.65402e-07
 6439 1.00e-05 9.80799e-07 9.65318e-07
 6449 1.00e-05 9.80700e-07 9.65145e-07
 6459 1.00e-05 9.80634e-07 9.64976e-07
 6469 1.00e-05 9.80396e-07 9.65071e-07
 6479 1.00e-05 9.80268e-07 9.65034e-07
 6489 1.00e-05 9.80068e-07 9.64967e-07
 6499 1.00e-05 9.80006e-07 9.64726e-07
 6509 1.00e-05 9.79820e-07 9.64718e-07
 6519 1.00e-05 9.79798e-07 9.64575e-07
 6529 1.00e-05 9.79765e-07 9.64583e-07
 6539 1.00e-05 9.79631e-07 9.64509e-07
 6549 1.00e-05 9.79542e-07 9.64164e-07
 6559 1.00e-05 9.79378e-07 9.63923e-07
 6569 1.00e-05 9.79254e-07 9.63762e-07
 6579 1.00e-05 9.79213e-07 9.63810e-07
 6589 1.00e-05 9.79249e-0

 8449 1.00e-05 9.57109e-07 9.42458e-07
 8459 1.00e-05 9.56840e-07 9.42395e-07
 8469 1.00e-05 9.56795e-07 9.42349e-07
 8479 1.00e-05 9.56700e-07 9.42235e-07
 8489 1.00e-05 9.56529e-07 9.42099e-07
 8499 1.00e-05 9.56325e-07 9.42030e-07
 8509 1.00e-05 9.56129e-07 9.41918e-07
 8519 1.00e-05 9.56029e-07 9.41823e-07
 8529 1.00e-05 9.56145e-07 9.41680e-07
 8539 1.00e-05 9.56071e-07 9.41582e-07
 8549 1.00e-05 9.55936e-07 9.41524e-07
 8559 1.00e-05 9.55680e-07 9.41370e-07
 8569 1.00e-05 9.55736e-07 9.41284e-07
 8579 1.00e-05 9.55765e-07 9.41066e-07
 8589 1.00e-05 9.55587e-07 9.41065e-07
 8599 1.00e-05 9.55439e-07 9.40789e-07
 8609 1.00e-05 9.55221e-07 9.40611e-07
 8619 1.00e-05 9.55038e-07 9.40576e-07
 8629 1.00e-05 9.54792e-07 9.40525e-07
 8639 1.00e-05 9.54734e-07 9.40547e-07
 8649 1.00e-05 9.54595e-07 9.40357e-07
 8659 1.00e-05 9.54613e-07 9.40086e-07
 8669 1.00e-05 9.54644e-07 9.40092e-07
 8679 1.00e-05 9.54604e-07 9.39989e-07
 8689 1.00e-05 9.54291e-07 9.39672e-07
 8699 1.00e-05 9.54106e-0

In [None]:
model.load_state_dict( torch.load('checkpoint.pt') )

In [None]:
net_dir = 'net'
if not os.path.exists(net_dir) :
    os.mkdir(net_dir)
torch.save(model.state_dict(),f'net/bs_net_{data_num}.pt')