In [1]:
import pickle
import random
import numpy as np
from tqdm import trange, tqdm
import torch
from torch.nn import functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from model import PVNet

In [2]:
random.seed(5)
np.random.seed(5)
torch.manual_seed(5)
torch.cuda.manual_seed(5)

In [3]:
with open('data/201212_50_0_step_dataset.pickle', 'rb') as f:
    data = list(pickle.load(f))
# len(data)

In [4]:
dataloader = DataLoader(data, batch_size=32, shuffle=True, drop_last=False)

In [5]:
device = torch.device('cuda')
model = PVNet(18, 5, 256, 15).to(device)
# model.load_state_dict(torch.load('data/201211_50_0_step_model.pickle'))

In [6]:
no_decay = ['bn', 'bias']
model_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(
        nd in n for nd in no_decay)], 'weight_decay': 1e-4},
    {'params': [p for n, p in model.named_parameters() if any(
        nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = torch.optim.SGD(model_parameters, momentum=0.9, lr=6e-4)
# optimizer = torch.optim.AdamW(model_parameters, lr=3e-4, eps=1e-6)

In [7]:
# writer = SummaryWriter(f'results/sgd_b64_lr1e-3_wd1e-4')

In [8]:
epochs = 50
for epoch in range(epochs):
    p_bar = tqdm(dataloader)
    losses = []
    for step, (s, pi, z) in enumerate(p_bar):
        s = s.to(device, dtype=torch.float32)
        pi = pi.to(device, dtype=torch.float32)
        z = z.to(device, dtype=torch.float32)
        optimizer.zero_grad()
        p, v = model(s)
        p_loss = -(pi * p.log()).sum(dim=-1).mean()
        v_loss = (v - z).pow(2).mean()
        loss = v_loss + p_loss
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
#         writer.add_scalar('loss', loss.item())
        p_bar.set_description(f"[{epoch+1:2}/{epochs:2}] V: {v_loss.item():.4f}  P: {p_loss.item():.4f}  "
                              f"Loss: {loss.item():.4f}  Avg.Loss: {np.mean(losses):.4f}")

[ 1/50] V: 0.0345  P: 5.4298  Loss: 5.4642  Avg.Loss: 5.7987: 100%|██████████| 784/784 [01:05<00:00, 11.90it/s]
[ 2/50] V: 0.0106  P: 5.4053  Loss: 5.4158  Avg.Loss: 5.5246:   5%|▌         | 40/784 [00:03<01:02, 11.90it/s]

KeyboardInterrupt: 

In [None]:
{Avg.Loss: 5.7983
    "sgd_10_128_base": (0.8018, 0.8251),
    "sgd_10_128_nobias": (0.8862, 0.8248),
    "sgd_10_128_bias": (0.5416, 0.8248),
    "sgd_10_128_bias_nodecay": (0.5462, 0.8245),
    "sgd_8_128_bias_nodecay": (0.8034, 0.8261),
    "sgd_12_128_bias_nodecay": (0.7032, 0.8236),
    "sgd_14_128_bias_nodecay": (0.9310, 0.8231),
    "sgd_16_128_bias_nodecay": (0.3344, 0.8229),
    "sgd_18_128_bias_nodecay": (0.5897, 0.8220),
    "sgd_20_128_bias_nodecay": (0.9349, 0.8230),
    "sgd_18_64_bias_nodecay": (1.0397, 0.8431),
    "sgd_18_192_bias_nodecay": (0.7780, 0.8157),
    "sgd_18_256_bias_nodecay": (1.0259, 0.8125),
    "sgd_18_320_bias_nodecay": (1.0259, 0.8114),
}

In [None]:
{
    "sgd_b32_lr1e-3_wd1e-4": (0.6376, 0.8289),
    "sgd_b64_lr1e-3_wd1e-4": (0.8023, 0.8252),
    "sgd_b64_lr1e-3_wd1e-4_bn": (0.7677, 0.8243),
    "sgd_b64_lr1e-3_wd1e-4_bn_bias": (0.8003, 0.8250),
    "sgd_b64_lr1e-3_wd5e-4": (0.8078, 0.8259),
    "sgd_b64_lr6e-4_wd1e-4": (0.8086, 0.8264),
    "sgd_b64_lr2e-3_wd1e-4": (0.8057, 0.8271),
    "sgd_b64_lr3e-3_wd1e-4": (0.8153, 0.8352),
    "sgd_b128_lr4e-3_wd1e-4": (0.9025, 0.8295),
    "adam_b64_lr1e-4": (0.7073, 0.8466),
    "adam_b64_lr2e-4": (1.0097, 0.8473),
    "adam_b64_lr3e-4": (1.0294, 0.8417),
    "adamw_b64_lr3e-4_wd1e-2": (0.8102, 0.8326),
    "adamw_b64_lr3e-4_wd1e-2_bn_bias": (0.8184, 0.8375),
    "adamw_b64_lr3e-4_wd1e-2_bn": (0.8162, 0.8318),
    "adamw_b64_lr1e-4_wd1e-2": (0.8176, 0.8371),
    "adamw_b64_lr3e-4_wd1e-1": (1.0266, 0.8453),
    "adamw_b64_lr3e-4_wd1e-1_bn": (0.5019, 0.8322),
}